In [1]:
import torch
from torch.utils.data import Dataset
import glob
import os
import numpy as np
import pickle
from PIL import Image
from tqdm import tqdm
import cv2

In [2]:
def save_dict(di_, filename_):
    with open(filename_, 'wb') as f:
        pickle.dump(di_, f)

def load_dict(filename_):
    with open(filename_, 'rb') as f:
        ret_di = pickle.load(f)
    return ret_di

In [3]:
def get_files(filepath,expression='*.json'):
    '''
    Walks over a directory and its children to get all children json files pathes
    Arguments:
    file_path: string that specifies the path to the data parent directory 
    Returns:
    all_files: List of all the filepaths of the matching expression files included in the directory
    '''
    all_files = []
    for root, dirs, files in os.walk(filepath):
        files = glob.glob(os.path.join(root,expression))
        for f in files :
            all_files.append(os.path.abspath(f))
    return all_files

In [4]:
cd ..\..

E:\CVprojects\Butterflies


In [5]:
data_path = 'Data'

In [6]:
files_paths = get_files(data_path,'*sub.pt')
len(files_paths)

25279

In [7]:
files_paths[0][:-6]+'.jpg'

'E:\\CVprojects\\Butterflies\\Data\\images_small\\001.Atrophaneura_horishanus\\001.jpg'

In [8]:

images_classes = [os.path.basename(path.split('.')[0]) for path in files_paths]
classes_samples = {}
classes_names = {}
for path in files_paths:
    class_num = int(os.path.basename(path.split('.')[0]))-1
    class_name = path.split('.')[1].split('/')[0]
    if class_num not in classes_samples:
        classes_samples[class_num] = []
    classes_samples[class_num].append(path)
    classes_names[class_num] = class_name

## Downsampling
Each class will be limited to 50 samples, rest will be in the slow testing dataset

In [9]:
val_rat = 0.1
test_rat = 0.1
# use a big number to include all the examples
max_samples_per_class = 1000000
# classes to be used for the renset finetuning
classes_to_consider = [60,55,45,121]

In [10]:
train_paths = []
val_paths = []
test_paths = []
slow_test_paths = []
np.random.seed(0)

for key in classes_to_consider:
    cur_paths = np.array(classes_samples[key])
    l = len(cur_paths)
    l2 = min(l,max_samples_per_class)
    num_val = int(l2*val_rat)
    num_test = int(l2*val_rat)

    all_indeces = np.arange(l)
    np.random.shuffle(all_indeces)

    if len(all_indeces)>max_samples_per_class:
        indeces_slow_test = all_indeces[max_samples_per_class:]
        all_indeces_down = all_indeces[:max_samples_per_class]
    else:
        indeces_slow_test = []
        all_indeces_down = all_indeces

    val_indeces = all_indeces_down[:num_val]
    test_indeces = all_indeces_down[num_val:num_val+num_test]
    train_indeces = all_indeces_down[num_val+num_test:]

    train_paths.extend(cur_paths[train_indeces])
    val_paths.extend(cur_paths[val_indeces])
    test_paths.extend(cur_paths[test_indeces])


In [11]:
classes_samples[60][0]

'E:\\CVprojects\\Butterflies\\Data\\images_small\\061.Gonepteryx_rhamni\\001sub.pt'

In [12]:
len(train_paths)+len(val_paths)+len(test_paths)

2152

In [13]:
train_paths = [os.path.relpath(p,'.') for p in train_paths ]
val_paths = [os.path.relpath(p,'.') for p in val_paths ]
test_paths = [os.path.relpath(p,'.') for p in test_paths ]

In [16]:
split_dict = {
    'train':train_paths,
    'val':val_paths,
    'test':test_paths
}
delim = os.sep
split_dict_name = 'configs'+delim+'splits'+delim+"split_dict_res.pkl"
save_dict(split_dict,split_dict_name)