In [1]:
"""
Author        : Aditya Jain
Date Started  : May 2, 2022
About         : Division of dataset into train, validation and test sets for non-moth classifier
"""
import os
import glob
import random
import pandas as pd

moth_data_dir    = '/home/mila/a/aditya.jain/scratch/GBIF_Data/moths/'               # root directory of moth data
nonmoth_data_dir = '/home/mila/a/aditya.jain/scratch/GBIF_Data/nonmoths/'            # root directory of nonmoth data
write_dir        = '/home/mila/a/aditya.jain/mothAI/classification_nonmoths/data/'   # split files to be written
TRAIN_SPLIT      = 0.75                                                              # train set ratio
VAL_SPLIT        = 0.10                                                              # validation set ration
TEST_SPLIT       = 0.15                                                              # test set ratio

In [2]:
def prepare_split_list(global_pd, new_list, fields, class_name):
    """
    prepares a global csv list for every type of data split
    
    Args:
        global_pd: a global list into which new entries will be appended
        new_list : list of new entries to be appended to global list   
        fields   : contains the column names
        class_name: moth, nonmoth
    """
    new_data = []
    
    for path in new_list:
        path_split = path.split('/')        
        filename   = path_split[-1]
        
        if class_name=='moth':
            species    = path_split[-2]
            genus      = path_split[-3]
            family     = path_split[-4]
        else:
            species    = 'NA'
            genus      = 'NA'
            family     = path_split[-2]
        
        new_data.append([filename, family, genus, species, class_name])
        
    new_data  = pd.DataFrame(new_data, columns=fields)    
    global_pd = global_pd.append(new_data, ignore_index=True)
    
    return global_pd        

defining the data split files

In [3]:
fields     = ['filename', 'family', 'genus', 'species', 'class']
train_data = pd.DataFrame(columns = fields)
val_data   = pd.DataFrame(columns = fields)
test_data  = pd.DataFrame(columns = fields)

bifurcating the moth classes


In [4]:
for family in os.listdir(moth_data_dir):
    if not family.endswith('.csv') and not family.endswith('.ipynb_checkpoints'):
        for genus in os.listdir(moth_data_dir + family):        
            for species in os.listdir(moth_data_dir + family + '/' + genus):
            
                file_data  = glob.glob(moth_data_dir + family + '/' + genus + '/' + species + '/*.jpg')
                random.shuffle(file_data)
            
                total      = len(file_data)
                train_amt  = round(total*TRAIN_SPLIT)
                val_amt    = round(total*VAL_SPLIT)            
             
                train_list = file_data[:train_amt]
                val_list   = file_data[train_amt:train_amt+val_amt]
                test_list  = file_data[train_amt+val_amt:]
            
                train_data = prepare_split_list(train_data, train_list, fields, 'moth')
                val_data   = prepare_split_list(val_data, val_list, fields, 'moth')
                test_data  = prepare_split_list(test_data, test_list, fields, 'moth')
            

moth_train_pts = len(train_data)
print('No. of moth training points: ', moth_train_pts)

No. of moth training points:  207420


bifurcating the nonmoth classes


In [5]:
for order in os.listdir(nonmoth_data_dir):
    if not order.endswith('.csv') and not order.endswith('.ipynb_checkpoints'):
        
        file_data  = glob.glob(nonmoth_data_dir + order + '/*.jpg')
        random.shuffle(file_data)
        
        total      = len(file_data)
        train_amt  = round(total*TRAIN_SPLIT)
        val_amt    = round(total*VAL_SPLIT)            
             
        train_list = file_data[:train_amt]
        val_list   = file_data[train_amt:train_amt+val_amt]
        test_list  = file_data[train_amt+val_amt:]
        
        train_data = prepare_split_list(train_data, train_list, fields, 'nonmoth')
        val_data   = prepare_split_list(val_data, val_list, fields, 'nonmoth')
        test_data  = prepare_split_list(test_data, test_list, fields, 'nonmoth')
        
        
nonmoth_train_pts = len(train_data)-moth_train_pts
print('No. of non-moth training points: ', nonmoth_train_pts)

No. of non-moth training points:  158123


In [6]:
# shuffling and saving the lists to disk

train_data = train_data.sample(frac=1).reset_index(drop=True)
val_data   = val_data.sample(frac=1).reset_index(drop=True)
test_data  = test_data.sample(frac=1).reset_index(drop=True)

train_data.to_csv(write_dir + '01-train_split.csv', index=False)
val_data.to_csv(write_dir + '01-val_split.csv', index=False)
test_data.to_csv(write_dir + '01-test_split.csv', index=False)

print('No. of total training points: ', len(train_data))
print('No. of total validation points: ', len(val_data))
print('No. of total testing points: ', len(test_data))

No. of total training points:  365543
No. of total validation points:  48736
No. of total testing points:  73119


In [7]:
# unique entry test
filelist = list(train_data['filename'])
print(len(filelist))
filelist = set(filelist)
print(len(filelist))

365543
365360
