In [1]:
"""
Author        : Aditya Jain
Date Started  : 7th May, 2021
About         : Division of dataset into train, validation and test sets
"""
import os
import glob
import random
import pandas as pd

data_dir    = '/home/mila/a/aditya.jain/scratch/GBIF_Data/moths_uk/'            # root directory of data
write_dir   = '/home/mila/a/aditya.jain/mothAI/classification_moths/data/'      # split files to be written
train_spt   = 0.75                                                              # train set ratio
val_spt     = 0.10                                                              # validation set ration
test_spt    = 0.15                                                              # test set ratio

# family_list = ['Apatelodidae', 'Bombycidae', 'Cossidae',
#                'Drepanidae', 'Erebidae', 'Geometridae', 'Hepialidae',
#                'Lasiocampidae', 'Limacodidae', 'Notodontidae', 'Noctuidae', 'Nolidae',
#                'Saturniidae', 'Sesiidae', 'Sphingidae', 'Uraniidae', 'Zygaenidae']

In [2]:
def prepare_split_list(global_pd, new_list, fields):
    """
    prepares a global csv list for every type of data split
    
    Args:
        global_pd: a global list into which new entries will be appended
        new_list : list of new entries to be appended to global list   
        fields   : contains the column names
    """
    new_data = []
    
    for path in new_list:
        path_split = path.split('/')
        
        filename   = path_split[-1]
        species    = path_split[-2]
        genus      = path_split[-3]
        family     = path_split[-4]
        
        new_data.append([filename, family, genus, species])
        
    new_data  = pd.DataFrame(new_data, columns=fields)    
    global_pd = global_pd.append(new_data, ignore_index=True)
    
    return global_pd        

In [3]:
fields     = ['filename', 'family', 'genus', 'species']
train_data = pd.DataFrame(columns = fields)
val_data   = pd.DataFrame(columns = fields)
test_data  = pd.DataFrame(columns = fields)

# for family in family_list:            # if you want to sample particular families in the dataset
for family in os.listdir(data_dir):
    if os.path.isdir(data_dir + '/' + family):
        
        for genus in os.listdir(data_dir + family):
            if os.path.isdir(data_dir + '/' + family + '/' + genus):
                
                for species in os.listdir(data_dir + family + '/' + genus):
                    if os.path.isdir(data_dir + '/' + family + '/' + genus + '/' + species):
            
                        file_data  = glob.glob(data_dir + family + '/' + genus + '/' + species + '/*.jpg')
                        random.shuffle(file_data)
            
                        total      = len(file_data)
                        train_amt  = round(total*train_spt)
                        val_amt    = round(total*val_spt)            
             
                        train_list = file_data[:train_amt]
                        val_list   = file_data[train_amt:train_amt+val_amt]
                        test_list  = file_data[train_amt+val_amt:]
        #             print(species, ' ', len(train_list), len(val_list), len(test_list))
            
                        train_data = prepare_split_list(train_data, train_list, fields)
                        val_data   = prepare_split_list(val_data, val_list, fields)
                        test_data  = prepare_split_list(test_data, test_list, fields)
            

In [4]:
# saving the lists to disk
train_data.to_csv(write_dir + '01-uk-train-split.csv', index=False)
val_data.to_csv(write_dir + '01-uk-val-split.csv', index=False)
test_data.to_csv(write_dir + '01-uk-test-split.csv', index=False)

print('Training data size: ', len(train_data))
print('Validation data size: ', len(val_data))
print('Testing data size: ', len(test_data))

print('Total images: ', len(train_data)+len(val_data)+len(test_data))

Training data size:  240114
Validation data size:  32007
Testing data size:  48009
Total images:  320130


In [5]:
# unique entry test
filelist = list(train_data['filename'])
print(len(filelist))
filelist = set(filelist)
print(len(filelist))

240114
240114
