# Startified KFold for the OM dataset

This code is provided to split a dataset in COCO (+ YOLO later) format given labels distributions across the dataset. This algorithms is indended for preserving the samples percentages for each class in order to provide a smoother generalization and adress class imbalance. 

## Setup

In [1]:
import json
from sklearn.model_selection import StratifiedKFold
import numpy as np
import os

## Load the dataset

In [2]:
dataset_path = './roboflow_datasets/xmm_om_artefacts_512-7-COCO/'
json_file_path = dataset_path+'train/_annotations.coco.json'
dest_train_path = dataset_path+'train/'
dest_valid_path = dataset_path+'valid/'

with open(json_file_path) as f:
    data_in = json.load(f)
data_in['categories']

[{'id': 0, 'name': 'artefacts', 'supercategory': 'none'},
 {'id': 1, 'name': 'central-ring', 'supercategory': 'artefacts'},
 {'id': 2, 'name': 'smoke-ring', 'supercategory': 'artefacts'},
 {'id': 3, 'name': 'star-loop', 'supercategory': 'artefacts'}]

In [3]:
# running the script multiple times may change the redistribution, 
# thus we need to make sure that we won't keep the old files

import glob

files = glob.glob(f'{dest_valid_path}/*')

for file in files:
    os.remove(file)

## Stratified KFold

In [4]:
images, labels = [], []

for k in range(len(data_in['images'])):
    img_id = data_in['images'][k]['id']
    annotations = [data_in['annotations'][j] for j in range(len(data_in['annotations'])) if data_in['annotations'][j]['image_id'] == img_id]
    categories = set(str(annot['category_id']) for annot in annotations)
    images.append([img_id])
    labels.append(''.join(list(categories)))
        
images, labels = np.array(images), np.array(labels)

In [5]:
labels

array(['32', '321', '321', '3', '32', '31', '1', '32', '321', '32', '32',
       '321', '21', '3', '321', '31', '3', '31', '21', '1', '1', '31',
       '32', '321', '3', '321', '3', '321', '321', '321', '32', '32',
       '321', '3', '3', '32', '321', '21', '32', '32', '321', '31', '31',
       '32', '3', '321', '32', '321', '3', '321', '321', '1', '1', '31',
       '31', '3', '3', '321', '2', '1', '32', '31', '321', '321', '31',
       '32', '3', '321', '321', '31', '31', '1', '321', '3', '31', '3',
       '', '321', '321', '321', '321', '3', '321', '321', '321', '31',
       '31', '321', '31', '32', '1', '2', '321', '3', '21', '321', '2',
       '1', '32', '321', '31', '1', '31', '31', '3', '21', '321', '', '3',
       '32', '31', '31', '32', '32', '32', '32', '3', '32', '32', '31',
       '3', '3', '3', '31', '3', '32', '31', '2', '31', '1', '32', '31',
       '2', '321', '321', '32', '31', '3', '321', '3', '31', '3', '21',
       '32', '1', '32', '32', '321', '321', '32', '321', '3

The resulting arrays' size equals to the number of annotations because image ids are repeated for each label associated with them.

In [6]:
images.shape, labels.shape

((687, 1), (687,))

Run the Stratified KFold split and generate train and valid datasets given the number of splits. 

The split percentage is calculated depending on the `n_splits` parameter:

> train_percentage = 100 * 1/n_splits
>
> valid_percentage = 100 - train_percentage

In [34]:
n_splits = 3
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
skf_image_ids, skf_labels = {}, {}

for i, (train_index, valid_index) in enumerate(skf.split(images, labels)):
    # print(f"Fold {i}:")
    # print(f"  Train: Image index={images[train_index]}")
    # print(f"  valid:  Image index={images[valid_index]}")
    skf_image_ids[i] = {'train': images[train_index], 'valid': images[valid_index]}
    skf_labels[i] = {'train': labels[train_index], 'valid': labels[valid_index]}
    print(i, 'train', [idx for idx in train_index if idx==87])
    print(i, 'valid', [idx for idx in valid_index if idx==87])
    

0 train [87]
0 valid []
1 train [87]
1 valid []
2 train []
2 valid [87]




In [8]:
train_index[0]

3

In [9]:
len(train_index), len(valid_index)

(458, 229)

In [10]:
len(skf_image_ids), 'splits'

(3, 'splits')

**Ensure that there are no image ids present in both splits.**

In [11]:
for i in range(n_splits):
    print("intersection", len(np.intersect1d(skf_image_ids[i]['train'], skf_image_ids[i]['valid'])))

intersection 0
intersection 0
intersection 0


**Ensure that the labels distribution is roughly the same between splits.**

In [12]:
labels_percentages = {}

for i in range(n_splits):
        
    train_labels_counts = {'0':0, '1':0, '2':0, '3':0}
    valid_labels_counts = {'0':0, '1':0, '2':0, '3':0}
    
    for j in range(len(skf_image_ids[0]['train'])):
        for cat in list(skf_labels[0]['train'][j]):
            train_labels_counts[cat] += 1
    
    for j in range(len(skf_image_ids[0]['valid'])):
        for cat in list(skf_labels[0]['valid'][j]):
            valid_labels_counts[cat] += 1
            
    train_labels_counts = {cat:counts * 1.0/len(train_index) for cat, counts in train_labels_counts.items()}
    valid_labels_counts = {cat:counts * 1.0/len(valid_index) for cat, counts in valid_labels_counts.items()}
            
    labels_percentages[i] = {'train':train_labels_counts, 'valid':  valid_labels_counts}

In [13]:
labels_percentages

{0: {'train': {'0': 0.0,
   '1': 0.6703056768558951,
   '2': 0.5611353711790393,
   '3': 0.7336244541484717},
  'valid': {'0': 0.0,
   '1': 0.6724890829694323,
   '2': 0.5633187772925764,
   '3': 0.7379912663755459}},
 1: {'train': {'0': 0.0,
   '1': 0.6703056768558951,
   '2': 0.5611353711790393,
   '3': 0.7336244541484717},
  'valid': {'0': 0.0,
   '1': 0.6724890829694323,
   '2': 0.5633187772925764,
   '3': 0.7379912663755459}},
 2: {'train': {'0': 0.0,
   '1': 0.6703056768558951,
   '2': 0.5611353711790393,
   '3': 0.7336244541484717},
  'valid': {'0': 0.0,
   '1': 0.6724890829694323,
   '2': 0.5633187772925764,
   '3': 0.7379912663755459}}}

## Update the dataset and save new annotations files

In [27]:
data_in_train = data_in.copy()
data_in_valid = data_in.copy()

data_in_train['images'] = [data_in['images'][train_index[i]] for i in range(len(train_index))]
data_in_valid['images'] = [data_in['images'][valid_index[i]] for i in range(len(valid_index))]
train_annot_ids, valid_annot_ids = [], []

for img_i in data_in_train['images']:
    annotation_ids = [annot['id'] for annot in data_in_train['annotations'] if annot['image_id'] == img_i['id']]
    train_annot_ids +=annotation_ids
    
for img_i in data_in_valid['images']:
    annotation_ids = [annot['id'] for annot in data_in_valid['annotations'] if annot['image_id'] == img_i['id']]
    valid_annot_ids +=annotation_ids
    
# data_in_train['annotations'] = [data_in_train['annotations'][i] for i in range(len(data_in_train['annotations'])) if data_in_train['annotations'][i] in ]

len(data_in_train['images']), len(data_in_valid['images'])

(458, 229)

In [33]:
for img_i in data_in['images']:
    if 'S0784390401_L_png.rf.dd6fe66e3f159b820bf22a264f2bfdf3.jpg' == img_i['file_name']:
        print(img_i['id'])

87


**extract annotations given skf indices**

In [28]:
data_in_train['annotations'] = [data_in_train['annotations'][id] for id in train_annot_ids]
data_in_valid['annotations'] = [data_in_valid['annotations'][id] for id in valid_annot_ids]

In [29]:
data_in_train['annotations']

[{'id': 20,
  'image_id': 3,
  'category_id': 3,
  'bbox': [189, 264, 49.5, 72.5],
  'area': 3588.75,
  'segmentation': [[194.5,
    303,
    190.5,
    315,
    191.5,
    335,
    200,
    336.5,
    221.34,
    315.356,
    233.5,
    301,
    238.059,
    285.897,
    238.5,
    272,
    229,
    265.5,
    222.134,
    268.119,
    216,
    272.5,
    204.781,
    287.031]],
  'iscrowd': 0},
 {'id': 21,
  'image_id': 4,
  'category_id': 3,
  'bbox': [226, 204, 62.9, 70.9],
  'area': 4459.61,
  'segmentation': [[281.627,
    274.816,
    280.627,
    274.816,
    279.627,
    274.816,
    278.627,
    274.816,
    277.627,
    274.816,
    276.627,
    274.816,
    275.627,
    274.816,
    274.627,
    274.816,
    273.627,
    274.816,
    272.627,
    274.816,
    272.427,
    274.616,
    271.627,
    273.816,
    270.627,
    273.816,
    269.627,
    273.816,
    268.627,
    273.816,
    268.427,
    273.616,
    267.627,
    272.816,
    266.627,
    272.816,
    265.627,
 

In [30]:
len(train_annot_ids), len(valid_annot_ids), len(data_in_train['images']), len(data_in_valid['images'])

(1875, 968, 458, 229)

**save the new json data**

In [17]:
if not os.path.exists(dest_train_path):
    os.mkdir(dest_train_path)
if not os.path.exists(dest_valid_path):
    os.mkdir(dest_valid_path)

In [18]:
new_train_json_path = dest_train_path+'skf_train_annotations.coco.json'
new_valid_json_path = dest_valid_path+'skf_valid_annotations.coco.json'

with open(new_train_json_path, 'w') as f1, open(new_valid_json_path, 'w') as f2:
    json.dump(data_in_train, f1, indent=4)
    json.dump(data_in_valid, f2, indent=4)

**Move the filenames to the corresponding split directories**

In [19]:
import shutil
import os

filenames = [image['file_name'] for image in data_in_valid['images']]
filenames = list(set(filenames))
print(len(filenames), 'files')

# Iterate over the filenames and copy each one
for filename in filenames:
    source_path = os.path.join(dataset_path+'train/', filename)
    dest_path = os.path.join(dest_valid_path, filename)
    
    # Copy the file from source to destination
    shutil.copy(source_path, dest_path)

print("Files moved successfully.")

229 files
Files moved successfully.
