In [1]:
import glob
import pandas as pd
import numpy as np
import pickle

In [2]:
!pip install iterative-stratification

Collecting iterative-stratification
  Downloading iterative_stratification-0.1.7-py3-none-any.whl (8.5 kB)
Installing collected packages: iterative-stratification
Successfully installed iterative-stratification-0.1.7


In [2]:
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold

In [6]:
from google.colab import drive

drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [3]:
train = pd.read_csv("./project/Train.csv")
test = pd.read_csv("./project/Test.csv")

In [4]:
train

Unnamed: 0,Image_ID,class,xmin,ymin,width,height
0,ID_007FAIEI,fruit_woodiness,87.0,87.5,228.0,311.0
1,ID_00G8K1V3,fruit_brownspot,97.5,17.5,245.0,354.5
2,ID_00WROUT9,fruit_brownspot,156.5,209.5,248.0,302.5
3,ID_00ZJEEK3,fruit_healthy,125.0,193.0,254.5,217.0
4,ID_018UIENR,fruit_brownspot,79.5,232.5,233.5,182.0
...,...,...,...,...,...,...
3901,ID_ZZAB1FH1,fruit_healthy,96.0,175.5,289.5,222.0
3902,ID_ZZAB1FH1,fruit_healthy,330.0,241.0,182.0,180.5
3903,ID_ZZJZ2CV6,fruit_healthy,358.0,234.0,134.5,107.0
3904,ID_ZZJZ2CV6,fruit_healthy,98.5,135.0,275.5,213.5


In [5]:
class_dict = { 
    'fruit_healthy': 0, 
    'fruit_woodiness': 1, 
    'fruit_brownspot': 2  
}

In [6]:
X = set(train["Image_ID"])
classes = [[] for i in range(len(X))]
bbox_data = [{} for i in range(len(X))]

prev = "nothing"
idx = 0

for i in range(len(train)):
  bbox_info = train.iloc[i]
  
  if i == 0 or bbox_info["Image_ID"] == prev:
    classes[idx].append(class_dict[bbox_info["class"]])

    if i == 0:
      bbox_data[idx]["ann"] = dict()
      bbox_data[idx]["ann"]["bboxes"] = []
      bbox_data[idx]["ann"]["labels"] = []
      bbox_data[idx]["filename"] = bbox_info["Image_ID"] + '.jpg'
      bbox_data[idx]["height"] = 512
      bbox_data[idx]["width"] = 512
    
    bbox_data[idx]["ann"]["bboxes"].append([bbox_info["xmin"], bbox_info["ymin"], bbox_info["xmin"] + bbox_info["width"], bbox_info["ymin"] + bbox_info["height"]])
    bbox_data[idx]["ann"]["labels"].append(class_dict[bbox_info["class"]])

  else:
    idx += 1
    
    bbox_data[idx]["ann"] = dict()
    bbox_data[idx]["ann"]["bboxes"] = []
    bbox_data[idx]["ann"]["labels"] = []
    bbox_data[idx]["filename"] = bbox_info["Image_ID"] + '.jpg'
    bbox_data[idx]["height"] = 512
    bbox_data[idx]["width"] = 512

    if idx != len(classes):
      classes[idx].append(class_dict[bbox_info["class"]])
      bbox_data[idx]["ann"]["bboxes"].append([bbox_info["xmin"], bbox_info["ymin"], bbox_info["xmin"] + bbox_info["width"], bbox_info["ymin"] + bbox_info["height"]])
      bbox_data[idx]["ann"]["labels"].append(class_dict[bbox_info["class"]])

  prev = bbox_info["Image_ID"]

In [7]:
for i in range(len(X)):
  bbox_data[i]["ann"]["bboxes"] = np.array(bbox_data[i]["ann"]["bboxes"])
  bbox_data[i]["ann"]["labels"] = np.array(bbox_data[i]["ann"]["labels"])

In [8]:
class_counts = [[] for i in range(len(X))]

for i, class_list in enumerate(classes):
  for cls in range(3):
    class_counts[i].append((np.array(class_list) == cls).sum())

In [9]:
#stratify by the class counts

mskf = MultilabelStratifiedKFold(n_splits=5, shuffle=True, random_state=0)
X = np.array(list(X))
y = np.array(class_counts)

fold = 0

for train_index, test_index in mskf.split(X, y):
   with open(f"dcm_folds/train_fold{fold}.pickle", 'wb') as f:
     pickle.dump(list(np.array(bbox_data)[train_index]), f, protocol=pickle.HIGHEST_PROTOCOL)
   
   with open(f"dcm_folds/val_fold{fold}.pickle", 'wb') as f:
     pickle.dump(list(np.array(bbox_data)[test_index]), f, protocol=pickle.HIGHEST_PROTOCOL)
   
   fold += 1

In [157]:
test

Unnamed: 0,Image_ID
0,ID_IUJJG62B
1,ID_ZPNDRD4T
2,ID_AHFYB64P
3,ID_L8JZLNTF
4,ID_IFMUXGPL
...,...
926,ID_77MDJGMZ
927,ID_PG3E6NQT
928,ID_038S0ONN
929,ID_D5YBIDDT


In [10]:
test_data = [{"filename": test.iloc[i]["Image_ID"] + '.jpg', "width": 512, "height": 512} for i in range(len(test))]

In [160]:
len(test_data)

931

In [11]:
with open(f"dcm_folds/test.pickle", 'wb') as f:
     pickle.dump(test_data, f, protocol=pickle.HIGHEST_PROTOCOL)