<a href="https://colab.research.google.com/github/SridharSola/FER/blob/main/RAFDBClass.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#Imports
import torch
import torchvision
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import torch.utils.data
import torch.optim
import os #for creating and removing directories
import torch.utils.data as data
import matplotlib.pyplot as plt #to view the image as images are defaulted to objects of Pillow lib in Python
%matplotlib inline 
import torchvision.transforms as transforms #to convert images to tensors
import argparse
from PIL import Image
import pandas as pd
import cv2
from torch.utils.data import random_split
from IPython import display

#Dataset Class
rafdb_root = '/content/drive/MyDrive/RAFDB'

class RafDataset(data.Dataset):
  def __init__(self, root, mask_file, partition = 'train', mask = 0, transform = None, num_classes = 7):
    """
      root --> path to ocation of images in drive
      mask_file --> path for masked RAF-DB images
      num_classes --> number of annotations (7 in RAF-DB)
      labels are same for bothe masked and unmasked

      Note: We read the labels here but leave reading of images to __getitem__
    """
    self.root = root
    self.mask_file = mask_file
    self.mask = mask
    self.num_classes = num_classes
    self.transform = transforms.Compose([transforms.ToPILImage(),
                                        transforms.Resize((224, 224)),
                                        transforms.ToTensor(),
                                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                            std=[0.229, 0.224, 0.225])])

    if partition == 'train':
      self.transform = transforms.Compose([
                                     transforms.ToPILImage(),
                                     transforms.RandomHorizontalFlip(p=0.5),
                                     transforms.RandomApply([transforms.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.25),
                                     transforms.RandomAffine(degrees=0, translate=(.1, .1), scale=(1.0, 1.25),resample=Image.BILINEAR)],p=0.5), 
                                     transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])


    NAME_COLUMN = 0
    LABEL_COLUMN = 1
    #Reading labels
    train_txtfile = pd.read_csv(os.path.join(self.root, 'train_label.txt'), sep=' ', header=None)
    test_txtfile = pd.read_csv(os.path.join(self.root, 'test_label.txt'), sep=' ', header=None)
    if partition == 'train':
      self.label = train_txtfile.iloc[:, LABEL_COLUMN].values - 1 # 0:Surprise, 1:Fear, 2:Disgust, 3:Happiness, 4:Sadness, 5:Anger, 6:Neutral
      file_names = train_txtfile.iloc[:, NAME_COLUMN].values
    else:
      self.label = test_txtfile.iloc[:, LABEL_COLUMN].values - 1 #same as above
      file_names = test_txtfile.iloc[:, NAME_COLUMN].values

    #Changing file names to match actual image names
    self.file_paths = []
    for f in file_names:
        f = f.split(".")[0]
        f = f +"_aligned.jpg"
        if mask == 0 or mask == 2:
          working_directory = self.root + '/aligned'
        #Putting non-masked image paths into file_paths
          path = os.path.join(working_directory, f)
          self.file_paths.append(path)
        #Putting masked image paths into file_paths
        if mask == 1 or mask == 2:
          working_directory = self.mask_file + '/aligned_mask'

          path = os.path.join(working_directory, f)
          self.file_paths.append(path)

  def __getitem__(self, index):
    """
    Here we read the actual image
    We randomly apply few transforms to the image for image augmentation
    Return: image and its label
    """
    
    imgPath = self.file_paths[index]
    img = cv2.imread(imgPath)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    """
    Note on cv2.imread(): OpenCV uses the BGR format
    We need to change to RGB format(ideally before returning the image)
    """
    
    img = self.transform(img)
    if self.mask == 2:
      if index % 2 != 0:
        pos = index//2  
      else:
        pos = index // 2
    else:
      pos = index

    label = self.label[pos]
    return img, label

  def __len__(self):
    return len(self.file_paths)
  
  def show_img(self, index):
    image, label = self.__getitem__(index)
    image = image.permute(1, 2, 0)
    plt.axis("off")
    plt.imshow(image)
    print('Label:', label)

#End of Class

def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')
    
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')
    
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)