<a href="https://colab.research.google.com/github/SridharSola/FER/blob/main/FERPlusClass.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import torch
import torchvision.transforms as transforms
import torch.utils.data as data
import pandas as pd
import cv2
import random

class FplusDataSet(data.Dataset):
    def __init__(self, fplus_path, partition = 'train',  transform = None, num_classes = 8):
        
        self.transform = transform
        self.fplus_path = fplus_path

        self.num_classes = num_classes
        self.partition = partition
        self.transform = transforms.Compose([transforms.ToPILImage(),
                                        transforms.Resize((224, 224)),
                                        transforms.ToTensor(),
                                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                            std=[0.229, 0.224, 0.225])])

        if partition == 'train':
          self.transform = transforms.Compose([
                                        transforms.ToPILImage(),
                                        transforms.RandomHorizontalFlip(p=0.5),
                                        transforms.RandomApply([transforms.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.25),
                                        transforms.RandomAffine(degrees=0, translate=(.1, .1), scale=(1.0, 1.25),resample=Image.BILINEAR)],p=0.5), 
                                        transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])



        NAME_COLUMN = 0
        LABEL_COLUMN = 1
        df_train = pd.read_csv('/content/drive/MyDrive/FERPLUS/ferplus_trainvalid_list.txt', sep=' ', header=None)
        df_test = pd.read_csv('/content/drive/MyDrive/FERPLUS/ferplus_test.txt', sep=' ', header=None)
        
        if partition == 'train':           
            self.label = df_train.iloc[:, LABEL_COLUMN].values           
            file_names = df_train.iloc[:, NAME_COLUMN].values    
        else: 
            self.label = df_test.iloc[:, LABEL_COLUMN].values             
            file_names = df_test.iloc[:, NAME_COLUMN].values
            
        self.new_label = [] 

        for label in self.label:
            self.new_label.append(self.change_emotion_label_same_as_affectnet(label))
            
        self.label = self.new_label
        
        self.file_paths = []
        # use raf aligned images for training/testing
        for f in file_names:
          f = f + '.png'
          if partition == 'train':
            working_dir = os.path.join(self.fplus_path, 'FER2013TrainValid')
          else:
            working_dir = os.path.join(self.fplus_path, 'FER2013Test')

          path = os.path.join(working_dir,f)
          self.file_paths.append(path)
                
        


 
    def change_emotion_label_same_as_affectnet(self, emo_to_return):
        """
        Parse labels to make them compatible with AffectNet.  
        #https://github.com/siqueira-hc/Efficient-Facial-Feature-Learning-with-Wide-Ensemble-based-Convolutional-Neural-Networks/blob/master/model/utils/udata.py
        """    
        if emo_to_return == 2:
            emo_to_return = 3
        elif emo_to_return == 3:
            emo_to_return = 2
        elif emo_to_return == 4:
            emo_to_return = 6
        elif emo_to_return == 6:
            emo_to_return = 4

        return emo_to_return 
           
    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
            
        label = self.label[idx]
        path = self.file_paths[idx]
             
        image = cv2.imread(path)
             
        if self.transform is not None:
            image =  self.transform(image)
                
        #label = torch.tensor(label, dtype = torch.int64) 
                       
        return image, label   

    def show_img(self, index):
      image, label = self.__getitem__(index)
      image = image.permute(1, 2, 0)
      plt.axis("off")
      plt.imshow(image)
      print('Label:', label)    