In [6]:
import os
import json
import shutil

import numpy as np
import pandas as pd

from pathlib import Path
from typing import Dict, List


In [7]:
def rm_n_mkdir(dir_path: str):
    """Remove and make directory."""
    if os.path.isdir(dir_path):
        shutil.rmtree(dir_path)
    os.makedirs(dir_path)

def save_json(d: Dict, file: str):
  file = Path(file)
  file.parents[0].mkdir(exist_ok=True, parents=True)
  with open(file, 'w') as f:
    json.dump(d, f)

In [8]:
from sklearn.model_selection import StratifiedShuffleSplit

def create_splits(
  patch_info_file: str, 
  seed: int = 5, 
  n_splits: int = 10, 
  train_size: float = 0.8, 
  test_size: float = 0.2
  ) -> List[Dict]:
  """Creates splits of the dataset, given the `patch_info` file.
  This uses the same defaults that were used for HoVerNet baseline splits.

  Args:
      patch_info_file (str): path of the 'patch_info.json' file.
      seed (int, optional): seed for reproducibility. Defaults to 5.
      n_splits (int, optional): number of splits/folds to create. Defaults to 10.
      train_size (float, optional): train percentage. Defaults to 0.8.
      test_size (float, optional): test percentage. Defaults to 0.2.

  Returns:
      List[Dict]: list contaning different folds for 'train' and 'valid'
  """  
  info = pd.read_csv(patch_info_file)
  file_names = np.squeeze(info.to_numpy()).tolist()

  img_sources = [v.split('-')[0] for v in file_names]
  img_sources = np.unique(img_sources)

  cohort_sources = [v.split('_')[0] for v in img_sources]
  _, cohort_sources = np.unique(cohort_sources, return_inverse=True)

  splitter = StratifiedShuffleSplit(
      n_splits=n_splits,
      train_size=train_size,
      test_size=test_size,
      random_state=seed
  )

  splits = []
  split_generator = splitter.split(img_sources, cohort_sources)
  for train_indices, valid_indices in split_generator:
      train_cohorts = img_sources[train_indices]
      valid_cohorts = img_sources[valid_indices]
      assert np.intersect1d(train_cohorts, valid_cohorts).size == 0
      train_names = [
          file_name
          for file_name in file_names
          for source in train_cohorts
          if source == file_name.split('-')[0]
      ]
      valid_names = [
          file_name
          for file_name in file_names
          for source in valid_cohorts
          if source == file_name.split('-')[0]
      ]
      train_names = np.unique(train_names)
      valid_names = np.unique(valid_names)
      print(f'Train: {len(train_names):04d} - Valid: {len(valid_names):04d}')
      assert np.intersect1d(train_names, valid_names).size == 0
      # train_indices = [file_names.index(v) for v in train_names]
      # valid_indices = [file_names.index(v) for v in valid_names]
      splits.append({
          'train': train_names,
          'valid': valid_names
      })
  return splits

In [13]:
data_root = "../dataset" 

OUT_DIR = 'output'

images_json = f'{OUT_DIR}/images.json'
instances_json = f'{OUT_DIR}/instances.json'
panoptic_json = f'{OUT_DIR}/panoptic.json'

In [14]:
info_path = f"{data_root}/patch_info.csv"

splits = create_splits(info_path)

Train: 3963 - Valid: 1018
Train: 4053 - Valid: 0928
Train: 3952 - Valid: 1029
Train: 3988 - Valid: 0993
Train: 3997 - Valid: 0984
Train: 4002 - Valid: 0979
Train: 3894 - Valid: 1087
Train: 4012 - Valid: 0969
Train: 3988 - Valid: 0993
Train: 3964 - Valid: 1017


In [15]:
def save_folds(splits: List[Dict], json_path: str, folds_dir: str, data_type: str):
  """Saves the folds from `splits` into respective json files in `folds_dir`/`data_type`/*

  Args:
      splits (List[Dict]): splits generated using the `create_split` function.
      json_path (str): path to the input json file
      folds_dir (str): directory to store the folds json files
      data_type (str): could be either 'instances' or 'panoptic'
  """  
  # assert data_type in ['instances', 'panoptic']

  with open(json_path, 'r') as json_file:
    info = json.load(json_file)

  images_pd = pd.DataFrame(info['images'])
  ann_pd = pd.DataFrame(info['annotations'])

  for split in range(len(splits)):
    for mode in ['train', 'valid']:
      images_pd_mode = images_pd.loc[images_pd['id'].isin(splits[split][mode])]
      ann_pd_mode = ann_pd.loc[ann_pd['image_id'].isin(splits[split][mode])]

      info_mode = info
      info_mode['images'] = images_pd_mode.to_dict(orient='records')
      info_mode['annotations'] = ann_pd_mode.to_dict(orient='records')

      save_json(info_mode, f'{folds_dir}/{data_type}/{mode}/fold_{split}.json')
      print(f'Saved {folds_dir}/{data_type}/{mode}/fold_{split}.json')

  print(f"JSON files for {len(splits)} folds created successfully.")

In [16]:
rm_n_mkdir(f'{OUT_DIR}/folds')

In [17]:
save_folds(splits, instances_json, f'{OUT_DIR}/folds', 'instances')

Saved output/folds/instances/train/fold_0.json
Saved output/folds/instances/valid/fold_0.json
Saved output/folds/instances/train/fold_1.json
Saved output/folds/instances/valid/fold_1.json
Saved output/folds/instances/train/fold_2.json
Saved output/folds/instances/valid/fold_2.json
Saved output/folds/instances/train/fold_3.json
Saved output/folds/instances/valid/fold_3.json
Saved output/folds/instances/train/fold_4.json
Saved output/folds/instances/valid/fold_4.json
Saved output/folds/instances/train/fold_5.json
Saved output/folds/instances/valid/fold_5.json
Saved output/folds/instances/train/fold_6.json
Saved output/folds/instances/valid/fold_6.json
Saved output/folds/instances/train/fold_7.json
Saved output/folds/instances/valid/fold_7.json
Saved output/folds/instances/train/fold_8.json
Saved output/folds/instances/valid/fold_8.json
Saved output/folds/instances/train/fold_9.json
Saved output/folds/instances/valid/fold_9.json
JSON files for 10 folds created successfully.


In [18]:
save_folds(splits, panoptic_json, f'{OUT_DIR}/folds', 'panoptic')

Saved output/folds/panoptic/train/fold_0.json
Saved output/folds/panoptic/valid/fold_0.json
Saved output/folds/panoptic/train/fold_1.json
Saved output/folds/panoptic/valid/fold_1.json
Saved output/folds/panoptic/train/fold_2.json
Saved output/folds/panoptic/valid/fold_2.json
Saved output/folds/panoptic/train/fold_3.json
Saved output/folds/panoptic/valid/fold_3.json
Saved output/folds/panoptic/train/fold_4.json
Saved output/folds/panoptic/valid/fold_4.json
Saved output/folds/panoptic/train/fold_5.json
Saved output/folds/panoptic/valid/fold_5.json
Saved output/folds/panoptic/train/fold_6.json
Saved output/folds/panoptic/valid/fold_6.json
Saved output/folds/panoptic/train/fold_7.json
Saved output/folds/panoptic/valid/fold_7.json
Saved output/folds/panoptic/train/fold_8.json
Saved output/folds/panoptic/valid/fold_8.json
Saved output/folds/panoptic/train/fold_9.json
Saved output/folds/panoptic/valid/fold_9.json
JSON files for 10 folds created successfully.
