In [None]:
##Pet Images Dataset

In [None]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from skimage.transform import resize
from torch.utils.data import Dataset, DataLoader
import matplotlib.image as img
import cv2
from sklearn.model_selection import train_test_split

import torchvision.models as models

In [None]:
# Fix random seed
sd = 0
np.random.seed(sd)
torch.backends.cudnn.deterministic = True
torch.manual_seed(sd)
random.seed(sd)
if torch.cuda.is_available():
  torch.cuda.manual_seed_all(sd)

In [None]:
#Split dataset
def equalize_all(dct):
    lens = []
    for k in dct.keys():
        lens.append(len(dct[k]))
    
    l_min = min(lens)
    dct_new = {}
    for k in dct.keys():
        list_temp = dct[k]
        dct_new[k] = list_temp[0:l_min]
    return dct_new

def test_lengths(dct):
    lens = []
    for k in dct.keys():
        lens.append(len(dct[k]))
    
    return lens


dict_animal_to_number = {'cats' : 0,'dogs': 1}
dict_color_to_number = {'dark' : 0,'light': 1}

data_path = r'./data'
animal_list = ['cats','dogs']

col_list = ['dark','light']

sets = ['train','test']

train_data_dict = {}
val_data_dict = {}
test_data_dict = {}

for col in col_list:
    for animal in animal_list:
        f_path = data_path+'/'+animal+'/'+col+'/'
        all_f = os.listdir(f_path)
        all_f = [f_path + s for s in all_f]
        random.shuffle(all_f)
        l_temp = len(all_f)
        
        train_data_dict[col+animal] = all_f[0:int(0.75*l_temp)]
        val_data_dict[col+animal] = all_f[int(0.75*l_temp):int(0.875*l_temp)]
        test_data_dict[col+animal] = all_f[int(0.875*l_temp):]
train_data_dict = equalize_all(train_data_dict)

        
val_data_dict = equalize_all(val_data_dict)
test_data_dict = equalize_all(test_data_dict)
l1 = test_lengths(train_data_dict)
print(l1)
l1 = test_lengths(val_data_dict)
print(l1)
l1 = test_lengths(test_data_dict)
print(l1)

In [None]:
#DataLoader
class PetDataset(Dataset):
    def __init__(self, data, path , transform = None):
        super().__init__()
        self.data = data.values
        self.path = path
        self.transform = transform
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self,index):
#         img_name,label = self.data[index]
        img_name = self.data[index][0]
        label = dict_animal_to_number[self.data[index][2]]   # index 2 for shape, need as tensor -> convert to number from str first
        label = torch.tensor(label)
        img_path = img_name#os.path.join(self.path, img_name)
        image0 = cv2.imread(img_path)
        [h, w, ch] = image0.shape
        mini = min(h,w)
        if mini%2==1:
          mini = mini-1
        image1 = image0[int((h-mini)/2):int((h+mini)/2),int((w-mini)/2):int((w+mini)/2),:]
        image = cv2.resize(image1, (256,256))# resize(image0, (28, 28))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.transform is not None:
            image = self.transform(image)
        return image, label

In [None]:
#Transforms go here
train_transform = transforms.Compose([transforms.ToTensor()])
test_transform = transforms.Compose([transforms.ToTensor()])
valid_transform = transforms.Compose([transforms.ToTensor()])

In [None]:
#Functions to resample and equalize datasets
def resample_dataset_color(data,frac,protected_var,task_var):
    #frac wrt first group class
    frac_list = [frac,1-frac]
    protected_trait = []
    task_trait = []
    file_name = []
    ix = 0
    for pr in protected_var:
        for ta in task_var:
            
            list_temp = data[pr+ta]
            list_temp = list_temp[0:int(frac_list[ix]*len(list_temp))]
            
            for l in list_temp:
                file_name.append(l)
                protected_trait.append(pr)
                task_trait.append(ta)
           
        ix = ix+1
    
    df = {'file': file_name, 'color': protected_trait, 'animal': task_trait}

    dat_split = df = pd.DataFrame(data=df)
    
    dat_split = dat_split.sample(frac=1)
    
    
    return dat_split

def split_two(data,protected_var,task_var):
    
    ix = 0
    for pr in protected_var:
        protected_trait = []
        task_trait = []
        file_name = []
        for ta in task_var:
            
            list_temp = data[pr+ta]
            
            
            for l in list_temp:
                file_name.append(l)
                protected_trait.append(pr)
                task_trait.append(ta)
           
        
    
        df = {'file': file_name, 'color': protected_trait, 'animal': task_trait}
        
        if ix==0:
            dat_split0 = df = pd.DataFrame(data=df)
            dat_split0 = dat_split0.sample(frac=1)
        else:
            dat_split1 = df = pd.DataFrame(data=df)
            dat_split1 = dat_split1.sample(frac=1)

        ix = ix+1
    return dat_split0,dat_split1

In [None]:
newD2 = resample_dataset_color(train_data_dict,0.5,col_list,animal_list)
print(newD2['color'].value_counts())
print(newD2)
batch_size = 30

val_split_D,val_split_L = split_two(val_data_dict,col_list,animal_list)
val_split = val_split_L

# print(val_split)
print(val_split['color'].value_counts())
print(val_split['animal'].value_counts())

# ##split into class wise test samples
test_split_D,test_split_L = split_two(test_data_dict,col_list,animal_list)
print(test_split_D['color'].value_counts())
print(test_split_D['animal'].value_counts())
print(test_split_L['color'].value_counts())
print(test_split_L['animal'].value_counts())

#dataloaders
valid_data = PetDataset(val_split, data_path, valid_transform )
test_data_D = PetDataset(test_split_D, data_path, test_transform )
test_data_L = PetDataset(test_split_L, data_path, test_transform )
valid_loader = DataLoader(dataset = valid_data, batch_size = batch_size, shuffle=False, num_workers=0)
test_loader_D = DataLoader(dataset = test_data_D, batch_size = batch_size, shuffle=False, num_workers=0)
test_loader_L = DataLoader(dataset = test_data_L, batch_size = batch_size, shuffle=False, num_workers=0)

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
def test_performance(model,dataL,criterion):

    model.eval()
    model.to(device)

    test_loss = 0
    test_acc = 0
    temp_test_acc = []

    for data, target in dataL:

        data = data.to(device)
        target = target.to(device)

        output = model(data)

        loss = criterion(output, target)
        # update-average-validation-loss 
        test_loss += loss.item() * data.size(0)

        op_temp = output.detach().cpu().numpy()
        op_temp = np.argmax(op_temp,axis=1)

        test_acc += np.mean(op_temp==target.detach().cpu().numpy())*data.size(0)

    ttacc  = test_acc/len(dataL.sampler)
    test_loss_M = test_loss/len(dataL.sampler)
    
    test_print = 'Test Loss: {:.3f} \tTest Acc: {:.3f}'.format(
        test_loss_M, ttacc)

    print(test_print)
    return test_print, ttacc

def write_file(fname,string,act):
    with open(fname, act) as text_file:
        text_file.write(string+'\n')

In [None]:
#Combined iterations across minority training ratios
def train_model(dark_frac,train_data_dict,valid_loader,test_loader_D,test_loader_L):
  num_epochs = 60
  num_classes = 2  # for animals
  batch_size = 30
  learning_rate = 0.0006

  check_point_dir = col_list[0]+str(dark_frac)
  
  if not os.path.isdir(f"checkpoints/"+check_point_dir):
      os.makedirs(f"checkpoints/"+check_point_dir)
      print("Output directory is created")
      
  #make logger text file
  text_path = f"checkpoints/"+check_point_dir+"/"+"log.txt"
  try:
      os.remove(text_path)
  except OSError:
      pass
  
  write_file(text_path,'*********'+col_list[0]+'  fraction: {} *********'.format(dark_frac),'a')
  
  train_split = resample_dataset_color(train_data_dict,dark_frac,col_list,animal_list)
  print(train_split['color'].value_counts())
  print(train_split['animal'].value_counts())
  write_file(text_path,str(train_split['color'].value_counts()),'a')
  
  write_file(text_path,str(train_split['animal'].value_counts()),'a')
  
  #Dataloaders
  train_data = PetDataset(train_split, data_path, train_transform )

  train_loader = DataLoader(dataset = train_data, batch_size = batch_size, shuffle=True, num_workers=0)
  
  model = models.resnet34(pretrained=False)
  model.fc = nn.Linear(512, num_classes)
  model.load_state_dict(torch.load(f"resnet34_imp_2class.pt"))
  model.to(device)
  criterion = nn.CrossEntropyLoss()
  
  optimizer = torch.optim.AdamW(
      model.parameters(), 
      lr=learning_rate, 
      betas=(0.5, 0.999), 
      weight_decay=0.05
      )

  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
      optimizer, T_max=30, 
      eta_min=0.01 * learning_rate, verbose=True
      )
  
  
  
  # Actual training of model

  train_losses = []
  valid_losses = []

  train_accuracies = []
  val_accuracies = []

  print("Training model...")
  valid_accuracy = []
  test_accuracy_light = []
  test_accuracy_dark = []

  best_val_acc = 0

  for epoch in range(1, num_epochs+1):
      # keep track of train/val loss
      train_loss = 0.0
      valid_loss = 0.0

      # training the model
      model.train()
      temp_train_acc = 0.0
      for data, target in train_loader:
          data = data.to(device)
          target = target.to(device)

          optimizer.zero_grad()                   # init gradients to zeros
          output = model(data)                    # forward pass
      #         print(output)
      #         print(target)
          loss = criterion(output, target)        # compute loss
          loss.backward()                         # loss backwards
          optimizer.step()                        # update model params

          train_loss += loss.item() * data.size(0)

          op_temp = output.detach().cpu().numpy()
          op_temp = np.argmax(op_temp,axis=1)

          temp_train_acc += np.mean(op_temp==target.detach().cpu().numpy())*data.size(0)
          
      
      # validate-the-model
      model.eval()
      temp_val_acc = 0.0
      for data, target in valid_loader:

          data = data.to(device)
          target = target.to(device)

          output = model(data)

          loss = criterion(output, target)

          # update-average-validation-loss 
          valid_loss += loss.item() * data.size(0)

          op_temp = output.detach().cpu().numpy()
          op_temp = np.argmax(op_temp,axis=1)

          temp_val_acc += np.mean(op_temp==target.detach().cpu().numpy())*data.size(0)

      tvacc  = np.mean(np.array(temp_val_acc))

      if tvacc>best_val_acc:
          best_val_acc = tvacc
          torch.save(model.state_dict(), f"checkpoints/"+check_point_dir+"/model_best.pt")
          print('Model saved')
          write_file(text_path,'Model saved','a')

      # calculate-average-losses
      train_loss = train_loss/len(train_loader.sampler)
      valid_loss = valid_loss/len(valid_loader.sampler)
      
      ttacc  = temp_train_acc/len(train_loader.sampler)
      tvacc  = temp_val_acc/len(valid_loader.sampler)
      
      train_losses.append(train_loss)
      valid_losses.append(valid_loss)

      train_accuracies.append(ttacc)
      val_accuracies.append(tvacc)

      scheduler.step()

      # print-training/validation-statistics 
      train_print = 'Epoch: {} \tTr Loss: {:.3f} \tTr Acc: {:.3f} \tVal Loss: {:.3f} \tVal Acc: {:.3f}'.format(
          epoch, train_loss, ttacc, valid_loss, tvacc)
      print(train_print)

      test_print_D, ttacc_D = test_performance(model,test_loader_D,criterion)
      test_print_L, ttacc_L = test_performance(model,test_loader_L,criterion)
      valid_accuracy.append(tvacc)
      test_accuracy_dark.append(ttacc_D)
      test_accuracy_light.append(ttacc_L)
      
      



      write_file(text_path,train_print,'a')

      write_file(text_path,test_print_D,'a')
      
      write_file(text_path,test_print_L,'a')
  path_val = f"checkpoints/"+check_point_dir+"/"+"validation_accuracy"
  path_dark = f"checkpoints/"+check_point_dir+"/"+"test_accuracy_D"
  path_light = f"checkpoints/"+check_point_dir+"/"+"test_accuracy_L"
  valid_accuracy = np.array(valid_accuracy)
  test_accuracy_dark = np.array(test_accuracy_dark)
  test_accuracy_light = np.array(test_accuracy_light)
  np.save(path_val, valid_accuracy)
  np.save(path_dark, test_accuracy_dark)
  np.save(path_light, test_accuracy_light)


dark_fracs = np.linspace(0.0,1.0,11)
for dark_frac in dark_fracs:
    print('********* dark fraction: {} *********'.format(dark_frac))
    train_model(dark_frac,train_data_dict,valid_loader,test_loader_D,test_loader_L)