<a href="https://colab.research.google.com/github/SridharSola/Knowledge-Distillation-FER/blob/main/RAFDBClass.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#Imports
import torch
import torchvision
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import torch.utils.data
import torch.optim
import os #for creating and removing directories
import torch.utils.data as data
import matplotlib.pyplot as plt #to view the image as images are defaulted to objects of Pillow lib in Python
%matplotlib inline 
import torchvision.transforms as transforms
#import torchvision.transforms.functional as TF #to convert images to tensors
import argparse
from PIL import Image
import pandas as pd
import cv2
from torch.utils.data import random_split
from IPython import display
import random
import albumentations as A #For applying same transforms
from torchvision.datasets.folder import default_loader
import random
rafdb_root = '/content/drive/MyDrive/RAFDB'
rafdb_train_length = 12271
rafdb_test_length = 3068

class RandomChoice(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.t = random.choice(self.transforms)

    def __call__(self, img):
        return self.t(img)

class KDRafDataset(data.Dataset):
  def __init__(self, root, mask_file, partition = 'train', transform = None, num_classes = 7, loader = default_loader, UU = False, MM = False):
    """
      root --> path to ocation of images in drive
      mask_file --> path for masked RAF-DB images
      num_classes --> number of annotations (7 in RAF-DB)
      labels are same for bothe masked and unmasked
      Need mask label as well for mask detection task

      Note: We read the labels here but leave reading of images to __getitem__
    """
    self.root = root
    self.mask_file = mask_file
    self.num_classes = num_classes
    self.partition = partition
    self.loader = loader
    if partition == 'train':
      self.rand = 0.5
    else:
      self.rand = 0
    self.transform = transforms.Compose([transforms.ToPILImage(),
                                        transforms.Resize((224, 224)),
                                        transforms.ToTensor(),
                                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                            std=[0.229, 0.224, 0.225])])

    if partition == 'train':
      
      
      self.transform1 = A.Compose([
                                 A.HorizontalFlip(), A.HueSaturationValue(), A.RandomContrast(), 
                                 A.ShiftScaleRotate(shift_limit = 0.0625,scale_limit = 0.1 ,rotate_limit = 3, p = 0.5),
                                 A.IAAAffine(scale = (1.0, 1.25), rotate = 0.0, p = 0.5)
                                 ],
                                  additional_targets={'image1':'image'})
      self.transform2 = transforms.Compose([
                                            transforms.ToPILImage(),
                                            transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
                                            ])
     


    NAME_COLUMN = 0
    LABEL_COLUMN = 1
    #Reading labels
    train_txtfile = pd.read_csv(os.path.join(self.root, 'train_label.txt'), sep=' ', header=None)
    test_txtfile = pd.read_csv(os.path.join(self.root, 'test_label.txt'), sep=' ', header=None)
    if partition == 'train':
      self.label = train_txtfile.iloc[:, LABEL_COLUMN].values - 1 # 0:Surprise, 1:Fear, 2:Disgust, 3:Happiness, 4:Sadness, 5:Anger, 6:Neutral
      file_names = train_txtfile.iloc[:, NAME_COLUMN].values
    else:
      self.label = test_txtfile.iloc[:, LABEL_COLUMN].values - 1 #same as above
      file_names = test_txtfile.iloc[:, NAME_COLUMN].values

    #Changing file names to match actual image names
    self.unmasked_file_paths = []
    self.masked_file_paths = []
    self.dataset = []
    if UU == False and MM == False:
      aligned1 = '/aligned'
      aligned2 = '/aligned_mask'
      file1 = self.root
      file2 = self.mask_file
    elif UU == True:
      aligned1 = '/aligned'
      aligned2 = '/aligned'
      file1 = self.root
      file2 = self.root
    elif MM == True:
      aligned1 = '/aligned_mask'
      aligned2 = '/aligned_mask'
      file1 = self.mask_file
      file2 = self.mask_file
    for f in file_names:
        f = f.split(".")[0]
        f = f +"_aligned.jpg"
        working_directory = file1 + aligned1
        #Putting non-masked image paths into unmasked_file_paths
        upath = os.path.join(working_directory, f)
        self.unmasked_file_paths.append(upath)
        #Putting masked image paths into masked_file_paths
        working_directory = file2 + aligned2
        mpath = os.path.join(working_directory, f)
        self.masked_file_paths.append(mpath)
        item = (upath, mpath)
        self.dataset.append(item)

  def __getitem__(self, index):
    """
    Here we read the actual images
    We randomly apply few transforms to the image for image augmentation
    Return: image1, image 2 and their label
    """
    """
    u_imgPath = self.unmasked_file_paths[index]
    uimg = cv2.imread(u_imgPath)
    uimg = cv2.cvtColor(uimg, cv2.COLOR_BGR2RGB)
 
    m_imgPath = self.masked_file_paths[index]
    mimg = cv2.imread(m_imgPath)
    mimg = cv2.cvtColor(mimg, cv2.COLOR_BGR2RGB)
    """
    uimg_path, mimg_path = self.dataset[index]
    uimg = cv2.imread(uimg_path)
    uimg = cv2.cvtColor(uimg, cv2.COLOR_BGR2RGB)
    mimg = cv2.imread(mimg_path)
    mimg = cv2.cvtColor(mimg, cv2.COLOR_BGR2RGB)


    """
    Note on cv2.imread(): OpenCV uses the BGR format
    We need to change to RGB format(ideally before returning the image)
    """
    if self.partition == 'train':
      transformed = self.transform1(image = uimg, image1 = mimg)
      uimg  = transformed['image']
      mimg = transformed['image1']
      uimg = self.transform2(uimg)
      mimg = self.transform2(mimg)
    else:
      uimg = self.transform(uimg)
      mimg = self.transform(mimg)
    label = self.label[index]
    r = random.uniform(0, 1)
    if r <self.rand:
      return uimg, uimg, label, uimg_path, mimg_path
    else:
      return uimg, mimg, label, uimg_path, mimg_path

    
  def change_rand(self, new_rand):
    self.rand = new_rand
  def __len__(self):
    return len(self.dataset)
  
  def show_img(self, index):
    uimg, mimg, label, uimg_path, _= self.__getitem__(index)
    uimg = uimg.permute(1, 2, 0)
    mimg = mimg.permute(1, 2, 0)

    f = plt.figure()
    f.add_subplot(1,2, 1)
    plt.imshow(np.rot90(uimg,0))
    f.add_subplot(1,2, 2)
    plt.imshow(np.rot90(mimg,0))
    plt.show(block=True)
    print(uimg_path)
    print("Label: ", label) 

#End of Class
