<a href="https://colab.research.google.com/github/SridharSola/Knowledge-Distillation-FER/blob/main/FERPlusClass.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
class KDFplusDataSet(data.Dataset):
    def __init__(self, f1, f2,  partition = 'train',  transform = None, num_classes = 8):
        
        self.transform = transform
        self.root = f1
        self.mask_file = f2

        self.num_classes = num_classes
        self.partition = partition
        if partition == 'train':
          self.rand = 0.5
        else:
          self.rand = 0
        self.transform = transforms.Compose([transforms.ToPILImage(),
                                        transforms.Resize((224, 224)),
                                        transforms.ToTensor(),
                                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                            std=[0.229, 0.224, 0.225])])

        if partition == 'train':
            self.transform1 = A.Compose([
                                      A.HorizontalFlip(), A.HueSaturationValue(), A.RandomContrast(), 
                                      A.ShiftScaleRotate(shift_limit = 0.0625,scale_limit = 0.1 ,rotate_limit = 3, p = 0.5),
                                      A.IAAAffine(scale = (1.0, 1.25), rotate = 0.0, p = 0.5)
                                      ],
                                        additional_targets={'image1':'image'})
            self.transform2 = transforms.Compose([
                                                  transforms.ToPILImage(),
                                                  transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
                                                  ])


        NAME_COLUMN = 0
        LABEL_COLUMN = 1
        df_train = pd.read_csv('/content/drive/MyDrive/FERPLUS/ferplus_trainvalid_list.txt', sep=' ', header=None)
        df_test = pd.read_csv('/content/drive/MyDrive/FERPLUS/ferplus_test.txt', sep=' ', header=None)
        
        if partition == 'train':           
            self.label = df_train.iloc[:, LABEL_COLUMN].values           
            file_names = df_train.iloc[:, NAME_COLUMN].values    
        else: 
            self.label = df_test.iloc[:, LABEL_COLUMN].values             
            file_names = df_test.iloc[:, NAME_COLUMN].values
            
        self.new_label = [] 

        for label in self.label:
            self.new_label.append(self.change_emotion_label_same_as_affectnet(label))
            
        self.label = self.new_label
        
        #self.file_paths = []
        self.unmasked_file_paths = []
        self.masked_file_paths = []
        self.dataset = []
        # use raf aligned images for training/testing
        for f in file_names:
          f = f + '.png'
          if partition == 'train':
            working_dir = os.path.join(self.root, 'FER2013TrainValid')
          else:
            working_dir = os.path.join(self.root, 'FER2013Test')
          
          #Putting non-masked image paths into unmasked_file_paths
          upath = os.path.join(working_dir, f)
          self.unmasked_file_paths.append(upath)
          #Putting masked image paths into masked_file_paths
          if partition == 'train':
            working_dir = os.path.join(self.mask_file, 'withmask_FER2013TrainValid')
          else:
            working_dir = os.path.join(self.mask_file, 'withmask_FER2013Test')
          
          mpath = os.path.join(working_dir, f)
          self.masked_file_paths.append(mpath)
          item = (upath, mpath)
          self.dataset.append(item)
                
        


 
    def change_emotion_label_same_as_affectnet(self, emo_to_return):
        """
        Parse labels to make them compatible with AffectNet.  
        #https://github.com/siqueira-hc/Efficient-Facial-Feature-Learning-with-Wide-Ensemble-based-Convolutional-Neural-Networks/blob/master/model/utils/udata.py
        """    
        if emo_to_return == 2:
            emo_to_return = 3
        elif emo_to_return == 3:
            emo_to_return = 2
        elif emo_to_return == 4:
            emo_to_return = 6
        elif emo_to_return == 6:
            emo_to_return = 4

        return emo_to_return 
           
    def __len__(self):
        return len(self.dataset)

    
    def __getitem__(self, index):
      """
      Here we read the actual images
      We randomly apply few transforms to the image for image augmentation
      Return: image1, image 2 and their label
      """
      """
      u_imgPath = self.unmasked_file_paths[index]
      uimg = cv2.imread(u_imgPath)
      uimg = cv2.cvtColor(uimg, cv2.COLOR_BGR2RGB)
  
      m_imgPath = self.masked_file_paths[index]
      mimg = cv2.imread(m_imgPath)
      mimg = cv2.cvtColor(mimg, cv2.COLOR_BGR2RGB)
      """
      uimg_path, mimg_path = self.dataset[index]
      uimg = cv2.imread(uimg_path)
      uimg = cv2.cvtColor(uimg, cv2.COLOR_BGR2RGB)
      mimg = cv2.imread(mimg_path)
      mimg = cv2.cvtColor(mimg, cv2.COLOR_BGR2RGB)


      """
      Note on cv2.imread(): OpenCV uses the BGR format
      We need to change to RGB format(ideally before returning the image)
      """
      if self.partition == 'train':
        transformed = self.transform1(image = uimg, image1 = mimg)
        uimg  = transformed['image']
        mimg = transformed['image1']
        uimg = self.transform2(uimg)
        mimg = self.transform2(mimg)
      else:
        uimg = self.transform(uimg)
        mimg = self.transform(mimg)
      label = self.label[index]
      r = random.uniform(0, 1)
      if r <self.rand:
        return uimg, uimg, label#, uimg_path, mimg_path
      else:
        return uimg, mimg, label#, uimg_path, mimg_path


    def change_rand(self, new_rand):
      self.rand = new_rand
    def __len__(self):
      return len(self.dataset)
    
    def show_img(self, index):
      uimg, mimg, label= self.__getitem__(index)
      uimg = uimg.permute(1, 2, 0)
      mimg = mimg.permute(1, 2, 0)

      f = plt.figure()
      f.add_subplot(1,2, 1)
      plt.imshow(np.rot90(uimg,0))
      f.add_subplot(1,2, 2)
      plt.imshow(np.rot90(mimg,0))
      plt.show(block=True)
      #print(uimg_path)
      print("Label: ", label) 