## Multiclass에서 Stratified 하게 샘플을 분리하는 방법


- 데이터 셋: VOC2012 데이터 셋에서 해당 이미지에 사물들이 있는지 여부

In [1]:
import os 
import pandas as pd
import pickle

from iterstrat.ml_stratifiers import MultilabelStratifiedKFold

pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)

In [2]:
def get_labedict():
    label_dict_loc = "./label_dict.pkl"
    with open(label_dict_loc, "rb") as f:
        labeldict = pickle.load(f)
    return labeldict


In [3]:
train_df = pd.read_csv('./train_df.csv', index_col=0)


In [4]:
labeldict = get_labedict()
label2num = labeldict['label2num']
num2label = labeldict['num2label']
label_names = label2num.keys()
label_cnt_cols = [i for i in train_df.columns if i.endswith('_cnt')]
label_cnt_cols = label_cnt_cols[:-2]


In [5]:
train_df.head()

Unnamed: 0,id,aeroplane_cnt,bicycle_cnt,bird_cnt,boat_cnt,bottle_cnt,bus_cnt,car_cnt,cat_cnt,chair_cnt,cow_cnt,diningtable_cnt,dog_cnt,horse_cnt,motorbike_cnt,person_cnt,pottedplant_cnt,sheep_cnt,sofa_cnt,train_cnt,tvmonitor_cnt
0,2007_000032,True,False,False,False,False,False,False,False,False,False,False,False,False,False,True,False,False,False,False,False
1,2007_000039,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True
2,2007_000063,False,False,False,False,False,False,False,False,True,False,False,True,False,False,False,False,False,False,False,False
3,2007_000068,False,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False
4,2007_000121,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True


In [6]:


NUM_SPLIT = 5

for rs in range(1,6):
    mskf = MultilabelStratifiedKFold(n_splits=NUM_SPLIT, random_state=rs, shuffle=True)
    label_onehot_values = train_df[label_cnt_cols].astype(int).values

    for fold, (train_idx, val_idx) in enumerate(mskf.split(train_df.index, label_onehot_values)):        
        train_df.loc[train_df.index.isin(train_idx), f'mskf_rstate{rs}_fold{fold}'] = 'train'
        train_df.loc[train_df.index.isin(val_idx), f'mskf_rstate{rs}_fold{fold}'] = 'val'
        
mskf_cols = [i for i in train_df.columns if 'mskf' in i]



In [7]:
label_cnt_dict = {}
for rs in range(1,6):
    for fold in range(0,5):
        cur_col = f'mskf_rstate{rs}_fold{fold}'
        train_cnt = train_df[train_df[cur_col] == 'train'][label_cnt_cols].sum()
        val_cnt = train_df[train_df[cur_col] == 'val'][label_cnt_cols].sum()
        label_cnt_dict[f'{cur_col}_train'] = train_cnt
        label_cnt_dict[f'{cur_col}_val'] = val_cnt

In [8]:
mskf_label_cnt_df = pd.DataFrame(label_cnt_dict).T

In [9]:
mskf_label_cnt_df

Unnamed: 0,aeroplane_cnt,bicycle_cnt,bird_cnt,boat_cnt,bottle_cnt,bus_cnt,car_cnt,cat_cnt,chair_cnt,cow_cnt,diningtable_cnt,dog_cnt,horse_cnt,motorbike_cnt,person_cnt,pottedplant_cnt,sheep_cnt,sofa_cnt
mskf_rstate1_fold0_train,70,52,84,62,70,62,102,105,118,52,66,97,55,65,354,66,50,75
mskf_rstate1_fold0_val,18,13,21,16,17,16,26,26,30,12,16,24,13,16,88,16,13,18
mskf_rstate1_fold1_train,71,52,84,62,70,62,103,105,118,51,65,97,54,64,354,65,50,74
mskf_rstate1_fold1_val,17,13,21,16,17,16,25,26,30,13,17,24,14,17,88,17,13,19
mskf_rstate1_fold2_train,70,52,84,62,69,62,102,104,119,51,65,96,54,65,353,65,51,74
mskf_rstate1_fold2_val,18,13,21,16,18,16,26,27,29,13,17,25,14,16,89,17,12,19
mskf_rstate1_fold3_train,70,52,84,63,70,63,102,105,118,51,66,97,55,65,354,66,50,74
mskf_rstate1_fold3_val,18,13,21,15,17,15,26,26,30,13,16,24,13,16,88,16,13,19
mskf_rstate1_fold4_train,71,52,84,63,69,63,103,105,119,51,66,97,54,65,353,66,51,75
mskf_rstate1_fold4_val,17,13,21,15,18,15,25,26,29,13,16,24,14,16,89,16,12,18


In [10]:
train_df[['mskf_rstate1_fold0', 'mskf_rstate1_fold1', 'mskf_rstate2_fold0']].head(10)

Unnamed: 0,mskf_rstate1_fold0,mskf_rstate1_fold1,mskf_rstate2_fold0
0,train,train,train
1,val,train,val
2,train,val,val
3,train,train,train
4,train,train,val
5,train,val,train
6,val,train,train
7,train,train,train
8,train,val,train
9,val,train,train
