# Split into Train and Val

In [None]:
import os
import json

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import cv2
from PIL import Image
import imageio, skimage

from collections import Counter

# train-val split
from sklearn.model_selection import train_test_split, GroupShuffleSplit

In [None]:
os.listdir('../datasets/ARCH')

In [None]:
os.listdir('../datasets/ARCH/annotations/')

In [None]:
RANDOM_STATE = 42
TEST_SIZE = 0.2

## Unified Set

In [None]:
with open('../datasets/ARCH/annotations/captions_all.json', 'r') as f:
    arch_captions = json.load(f)

### Check the unified dataset

In [None]:
arch_captions_df = pd.DataFrame(arch_captions).T

# check that the 'uuid'-s are unique and fine 
assert len(arch_captions_df.uuid) == arch_captions_df.uuid.nunique()

In [None]:
arch_captions_df

## Split the unified set into `books` and `pubmed` sets

In [None]:
# `inplace=False`: keep the index column to be able to connect with the file with all annotations

books_captions_df = arch_captions_df[arch_captions_df.source == 'books']
books_captions_df = books_captions_df.reset_index(inplace=False)

pubmed_captions_df = arch_captions_df[arch_captions_df.source == 'pubmed']
pubmed_captions_df = pubmed_captions_df.reset_index(inplace=False)

In [None]:
books_captions_df

In [None]:
pubmed_captions_df

In [None]:
arch_captions_df.nunique()

## Make the Train/Val splits

In [None]:
my_split = GroupShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=RANDOM_STATE)
pubmed_train_idxs, pubmed_val_idxs  = next(iter(
    my_split.split(pubmed_captions_df, groups=pubmed_captions_df.caption)
))

pubmed_train = pubmed_captions_df.iloc[pubmed_train_idxs]
pubmed_val = pubmed_captions_df.iloc[pubmed_val_idxs]

# no captions should be in both sets
assert len(set(pubmed_train.caption).intersection(set(pubmed_val.caption))) == 0

In [None]:
pubmed_train

In [None]:
pubmed_val

In [None]:
len(pubmed_val)/ (len(pubmed_train) + len(pubmed_val))

In [None]:
my_split = GroupShuffleSplit(n_splits=1, test_size=TEST_SIZE, random_state=RANDOM_STATE)
books_train_idxs, books_val_idxs  = next(iter(my_split.split(books_captions_df, groups=books_captions_df.caption)))

books_train = books_captions_df.iloc[books_train_idxs]
books_val = books_captions_df.iloc[books_val_idxs]

# no captions should be in both sets
assert len(set(books_train.caption).intersection(set(books_val.caption))) == 0

In [None]:
books_train

In [None]:
books_val

In [None]:
len(books_val)/ (len(books_train) + len(books_val))

## Record the splits in `.json` files

In [None]:
arch_captions_train = {}

for _, row in books_train.iterrows():
    idx, figure_id, letter, caption, uuid, source = row
    assert str(idx) not in arch_captions_train.keys()
    arch_captions_train[str(idx)] = {}
    
    arch_captions_train[str(idx)]['figure_id'] = figure_id
    arch_captions_train[str(idx)]['letter'] = letter
    arch_captions_train[str(idx)]['caption'] = caption
    arch_captions_train[str(idx)]['uuid'] = uuid
    arch_captions_train[str(idx)]['source'] = source
    
for _, row in pubmed_train.iterrows():
    idx, figure_id, letter, caption, uuid, source = row
    assert str(idx) not in arch_captions_train.keys()
    arch_captions_train[str(idx)] = {}
    
    arch_captions_train[str(idx)]['figure_id'] = figure_id
    arch_captions_train[str(idx)]['letter'] = letter
    arch_captions_train[str(idx)]['caption'] = caption
    arch_captions_train[str(idx)]['uuid'] = uuid
    arch_captions_train[str(idx)]['source'] = source

    
arch_captions_train

In [None]:
arch_captions_val = {}

for _, row in books_val.iterrows():
    idx, figure_id, letter, caption, uuid, source = row
    assert str(idx) not in arch_captions_val.keys()
    arch_captions_val[str(idx)] = {}
    
    arch_captions_val[str(idx)]['figure_id'] = figure_id
    arch_captions_val[str(idx)]['letter'] = letter
    arch_captions_val[str(idx)]['caption'] = caption
    arch_captions_val[str(idx)]['uuid'] = uuid
    arch_captions_val[str(idx)]['source'] = source
    
for _, row in pubmed_val.iterrows():
    idx, figure_id, letter, caption, uuid, source = row
    assert str(idx) not in arch_captions_val.keys()
    arch_captions_val[str(idx)] = {}
    
    arch_captions_val[str(idx)]['figure_id'] = figure_id
    arch_captions_val[str(idx)]['letter'] = letter
    arch_captions_val[str(idx)]['caption'] = caption
    arch_captions_val[str(idx)]['uuid'] = uuid
    arch_captions_val[str(idx)]['source'] = source

    
arch_captions_val

In [None]:
# check result

# non-intersecting indexes
assert len(set(arch_captions_train.keys()).intersection(set(arch_captions_val.keys()))) == 0
# union of indexes gives the indexes of the whole dataset
assert set(arch_captions_train.keys()).union(set(arch_captions_val.keys())) == set(arch_captions.keys())
# union of the 2 dictionaries gives the overall dictionary
assert {**arch_captions_train, **arch_captions_val} == arch_captions

## Save the Train and Val json files

In [None]:
with open('../datasets/ARCH/annotations/captions_train.json', 'w') as f:
    json.dump(arch_captions_train, f)
    
with open('../datasets/ARCH/annotations/captions_val.json', 'w') as f:
    json.dump(arch_captions_val, f)

## Check the saved datasets

In [None]:
with open('../datasets/ARCH/annotations/captions_all.json', 'r') as f:
    arch_captions_all = json.load(f)

with open('../datasets/ARCH/annotations/captions_train.json', 'r') as f:
    arch_captions_train = json.load(f)
    
with open('../datasets/ARCH/annotations/captions_val.json', 'r') as f:
    arch_captions_val = json.load(f)
    
# check the split makes the whole dictinary
assert {**arch_captions_train, **arch_captions_val} == arch_captions_all
# non-intersecting indexes
assert len(set(arch_captions_train.keys()).intersection(set(arch_captions_val.keys()))) == 0
# union of indexes gives the indexes of the whole dataset
assert set(arch_captions_train.keys()).union(set(arch_captions_val.keys())) == set(arch_captions_all.keys())

In [None]:
len(arch_captions_val) / (len(arch_captions_val) + len(arch_captions_train))