<a href="https://colab.research.google.com/github/DavidMachajewski/ResidualMaskingNetworkFER/blob/main/ResidualMaskingNetwork.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Residual Masking Network

This notebook uses the Residual Masking Network [1] for Facial Expression Recognition (FER). Their source code can be found on GitHub [2].

[1] "Facial Expression Recognition Using Residual Masking Network", Pham et al.

[2] ResMaskingNet implementation, GitHub: https://github.com/phamquiluan/ResidualMaskingNetwork


---

**Install missing packages**

Run the following cell to install the typing package, which is needed for using type hinting and the rmn packages, which is the implementation of the resmaskingnet, see resource [2].

In [None]:
!pip install rmn
!pip install typing

Collecting rmn
  Downloading rmn-3.0.2-py3-none-any.whl (109 kB)
[K     |████████████████████████████████| 109 kB 5.1 MB/s 
Collecting pytorchcv
  Downloading pytorchcv-0.0.67-py2.py3-none-any.whl (532 kB)
[K     |████████████████████████████████| 532 kB 39.6 MB/s 
Installing collected packages: pytorchcv, rmn
Successfully installed pytorchcv-0.0.67 rmn-3.0.2
Collecting typing
  Downloading typing-3.7.4.3.tar.gz (78 kB)
[K     |████████████████████████████████| 78 kB 3.5 MB/s 
[?25hBuilding wheels for collected packages: typing
  Building wheel for typing (setup.py) ... [?25l[?25hdone
  Created wheel for typing: filename=typing-3.7.4.3-py3-none-any.whl size=26324 sha256=f21743392fcf27f718397b2acaa72ad49aa3d5f5ea212e3a9987a235dee740b9
  Stored in directory: /root/.cache/pip/wheels/35/f3/15/01aa6571f0a72ee6ae7b827c1491c37a1f72d686fd22b43b0e
Successfully built typing
Installing collected packages: typing
Successfully installed typing-3.7.4.3


**Uploading dataset**

Import the dataset to the colab runtime using the wget command or the upload option. 
In case of the JAFFE dataset we load it as a zipped "dataset" folder. It takes just between 2 and 5 minutes as its a small dataset.

*NOTE*: 
*The "jaffe.zip" should contain a "dataset" folder, with all JAFFE images inside, as the collectors provide.*

In [None]:
from google.colab import files
dataset_raw = files.upload()
!unzip "/content/jaffe.zip" -d "/content/jaffe/"
# resulting folder is /content/jaffe/datast/

Saving jaffe.zip to jaffe.zip
Archive:  /content/jaffe.zip
  inflating: /content/jaffe/dataset/KA.AN1.39.tiff  
  inflating: /content/jaffe/dataset/KA.AN2.40.tiff  
  inflating: /content/jaffe/dataset/KA.AN3.41.tiff  
  inflating: /content/jaffe/dataset/KA.DI1.42.tiff  
  inflating: /content/jaffe/dataset/KA.DI2.43.tiff  
  inflating: /content/jaffe/dataset/KA.DI3.44.tiff  
  inflating: /content/jaffe/dataset/KA.FE1.45.tiff  
  inflating: /content/jaffe/dataset/KA.FE2.46.tiff  
  inflating: /content/jaffe/dataset/KA.FE3.47.tiff  
  inflating: /content/jaffe/dataset/KA.FE4.48.tiff  
  inflating: /content/jaffe/dataset/KA.HA1.29.tiff  
  inflating: /content/jaffe/dataset/KA.HA2.30.tiff  
  inflating: /content/jaffe/dataset/KA.HA3.31.tiff  
  inflating: /content/jaffe/dataset/KA.HA4.32.tiff  
  inflating: /content/jaffe/dataset/KA.NE1.26.tiff  
  inflating: /content/jaffe/dataset/KA.NE2.27.tiff  
  inflating: /content/jaffe/dataset/KA.NE3.28.tiff  
  inflating: /content/jaffe/dataset/KA.S

**Import mandatory libraries**

In [None]:
from pathlib import Path
from PIL import Image
from typing import Dict, NamedTuple, List, Union
from skimage import transform
from torchvision import transforms
from rmn import RMN, models
import numpy as np
import torch
import datetime
import PIL
import os
import json
from tqdm import tqdm
from torch import nn
import rmn
from copy import deepcopy
from torch.utils.tensorboard import SummaryWriter

pretrained_ckpt does not exists!


Downloading pretrained_ckpt..: 100%|██████████| 552M/552M [00:12<00:00, 44.9MiB/s]


deploy.prototxt.txt does not exists!


Downloading deploy.prototxt.txt..: 100%|██████████| 28.1k/28.1k [00:00<00:00, 7.38MiB/s]


res10_300x300_ssd_iter_140000.caffemodel does not exists!


Downloading res10_300x300_ssd_iter_140000.caffemodel..: 100%|██████████| 212/212 [00:00<00:00, 77.4kiB/s]


Set the location of the jaffe dataset

In [None]:
# this "dataset_path" object will be needed as input for the JAFFE class
dataset_path = Path(Path.cwd() / "jaffe" / "dataset")

In [None]:
# optional download it from the server using your credentials

# ###############
#
# add code here
#
# ###############

**Implementation of the JAFFE dataset and dataloader**

In [None]:
class Sample(dict):
    """Accessing dict keys by dot notation"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__


class JAFFE(torch.utils.data.Dataset):
  """Implementation of the JAFFE dataset by subclassing PyTorch's dataset class."""
  def __init__(self, path_to_dataset: Path, mode = None, split_dataset: float = None, transform: transforms=None):
    """
    path_to_dataset:
    transform:
    """
    self._mode = mode
    self._split_dataset = split_dataset
    self._path_dataset: Path = path_to_dataset
    self._dataset = [f for f in self._path_dataset.iterdir() if f.is_file()]
    # get train, test and validation split of image paths
    self._path_images_train, self._path_images_test, self._path_images_val = self._split(self._dataset, 0.5, self._split_dataset)
    print(f"inside init function: train {len(self._path_images_train)}, test {len(self._path_images_test)}")
    self._path_images = None
    self._transform = transform
    self._classes = {'AN': 0, 'DI': 1, 'FE': 2, 'HA': 3, 'SA': 4, 'SU': 5, 'NE': 6}

  def get_split(self, mode: str):
    """ Return the training, testing or validation JAFFE dataset
    :mode: "train", "test", "val"
    return: JAFFE dataset
    """
    print(f"Returning {mode}-set.")
    self._mode = mode
    if mode == "train":
      self._path_images = self._path_images_train
      new_instance = deepcopy(self)
    elif mode == "test":
      self._path_images = self._path_images_test
      new_instance = deepcopy(self)
    elif mode == "val":
      self._path_images = self._path_images_val
      new_instance = deepcopy(self)
    return new_instance

  def _split(self, dataset: List, testval = 0.5, split_dataset = 0.8):
    """Splits the image paths into train, test image paths
    split_dataset: Has to be a float in range [0.0, 1.0]
    test_val: splits the test set into test/val of same size"""
    lendata = len(dataset)
    np.random.shuffle(dataset)

    train_amount = split_dataset

    bound = int(train_amount * lendata)
    train, test = dataset[:bound], dataset[bound:]
    len_testset = len(test)
    newtest, val = test[:int(testval * len_testset)], test[int(testval * len_testset):]

    return train, newtest, val

  def show_sample(self, idx: int):
    """Given a index idx this function shows the corresponding image sample"""
    with Image.open(self._path_images[idx]) as img:
      display(img)
  
  def _load_image(self, path: Path):
    """Load image using path and returns this image as np.Array"""
    image = PIL.Image.open(path)
    image = transforms.Grayscale(num_output_channels=3)(image)
    return image
  
  def _get_label(self, filepath: Path) -> str:
    """Extracts label from the filename of an JAFFE image sample"""
    return filepath.stem.split(".")[1][:2]

  def __len__(self):
    return len(self._path_images)
  
  def __getitem__(self, idx: int) -> Sample:
    tmp_path = self._path_images[idx]

    sample = Sample({
        "image": self._load_image(tmp_path), 
        "label": self._classes[self._get_label(tmp_path)]})

    if self._transform:
      sample = self._transform(sample)

    return sample

    def _download():
      """..."""
      pass


class Scale(object):
  def __init__(self, new_size: List[Union[int, int]]):
    self.new_size = new_size
  
  def __call__(self, sample: Sample) -> Sample:
    h, w = self.new_size
    # sample.image = transform.resize(sample.image, (h, w))
    sample.image = transforms.Resize(size=(h,w))(sample.image)
    return sample


class ToTensor(object):
  def __call__(self, sample: Sample) -> Sample:
    # sample.image = np.asarray(sample.image)[:, :, np.newaxis]
    # from H x W x C to C x H x W
    sample.image = torch.from_numpy(np.asarray(sample.image))
    # sample.image = torch.unsqueeze(sample.image, 2)
    sample.image = sample.image.permute((2, 0, 1))
    # sample.image = torch.unsqueeze(sample.image, 2)
    sample.image = sample.image.float()
    sample.label = sample.label # torch.Tensor(sample.label)
    return sample


def get_transforms(new_size: List[Union[int, int]]):
  transformer = transforms.Compose([
       Scale(new_size),
       ToTensor()])
  return transformer


def JAFFE_dataloader(path_to_dataset: Path, split_dataset, transforms_new_size):
  """Instantiate the jaffe class and return a PyTorch Dataloader."""
  transformer = get_transforms(transforms_new_size)
  dataset = JAFFE(path_to_dataset=path_to_dataset,
                  mode=None,
                  split_dataset=0.8,
                  transform=transformer)

  trainset = dataset.get_split("train")
  testset = dataset.get_split("test")
  valset = dataset.get_split("val")

  print(f"created trainset of size: {len(trainset)}")
  print(f"created testset of size: {len(testset)}")
  print(f"created valset of size: {len(valset)}")

  traindl = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=1, drop_last=True)
  testdl = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=True, num_workers=1, drop_last=True)
  valdl = torch.utils.data.DataLoader(valset, batch_size=4, shuffle=True, num_workers=1, drop_last=True)
  return traindl, testdl, valdl

**Instantiate the JAFFE dataloader**

In [None]:
#
# make "transform_new_size" dependend on the .json config file
#

traindl, testdl, valdl = JAFFE_dataloader(path_to_dataset=dataset_path, split_dataset = 0.8, transforms_new_size=[224, 224])

#for idx, batch in enumerate(traindl):
#  if idx==0:
#    print(batch['image'].shape)
#    break

inside init function: train 170, test 21
Returning train-set.
Returning test-set.
Returning val-set.
created trainset of size: 170
created testset of size: 21
created valset of size: 22


**Coding the JAFFE trainer**

In [None]:
def accuracy(output, target):
    with torch.no_grad():
        batch_size = target.size(0)
        pred = torch.argmax(output, dim=1)
        correct = pred.eq(target).float().sum(0)
        acc = correct * 100 / batch_size
    return [acc]


class JAFFETrainer():
  def __init__(self, model, train_set, val_set, test_set, configs):
    """
    :train_set: dataloader of the train set
    :val_set: dataloader of the validation set
    :test_set: dataloader of the test set
    """
    # print start and configs
    #
    # load configurations like the author defines
    self._configs = configs
    self._configs = configs
    self._lr = self._configs["lr"]
    self._batch_size = self._configs["batch_size"]
    self._momentum = self._configs["momentum"]
    self._weight_decay = self._configs["weight_decay"]
    self._distributed = self._configs["distributed"]
    self._num_workers = self._configs["num_workers"]
    self._device = torch.device(self._configs["device"])
    self._max_epoch_num = self._configs["max_epoch_num"]
    self._max_plateau_count = self._configs["max_plateau_count"]
    # model
    self._model = model(in_channels=configs["in_channels"], num_classes=configs["num_classes"])
    
    self._model.to(self._device)
    # datasets
    self._train_loader = train_set
    self._test_loader = test_set
    self._val_loader = val_set
    # Loss and optimizer
    self._criterion = nn.CrossEntropyLoss().to(self._device)
    self._optimizer = torch.optim.Adam(params=self._model.parameters(),
                                       lr=self._lr,
                                       weight_decay=self._weight_decay)
    self._scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=self._optimizer, 
                                                                 patience=self._configs["plateau_patience"],
                                                                 min_lr=1e-6, 
                                                                 verbose=True)
    
    # training info
    self._start_time = datetime.datetime.now()
    self._start_time = self._start_time.replace(microsecond=0)

    log_dir = os.path.join(
        self._configs["cwd"],
        self._configs["log_dir"],
        "{}_{}".format(
            self._configs["model_name"], self._start_time.strftime("%Y%b%d_%H.%M")
        ),
    )

    self._writer = SummaryWriter(log_dir)
    self._train_loss = []
    self._train_acc = []
    self._val_loss = []
    self._val_acc = []
    self._best_loss = 1e9
    self._best_acc = 0
    self._test_acc = 0.0
    self._plateau_count = 0
    self._current_epoch_num = 0

    # for checkpoints
    self._checkpoint_dir = os.path.join(self._configs["cwd"], "saved/checkpoints")
    if not os.path.exists(self._checkpoint_dir):
        os.makedirs(self._checkpoint_dir, exist_ok=True)

    self._checkpoint_path = os.path.join(
        self._checkpoint_dir,
        "{}_{}".format(
            self._configs["model_name"], self._start_time.strftime("%Y%b%d_%H.%M")
        ),
    )
  
  def _train(self):
    # print("training step")
    self._model.train()
    train_loss, train_acc = 0.0, 0.0

    for i, batch in tqdm(enumerate(self._train_loader), total=len(self._train_loader), leave=False):
      # print(f"size of train loader: {len(self._train_loader)}")
      images = batch["image"]  # .to(self._device)
      # print(f"shape of image tensor: {images.shape}")
      targets = batch["label"]  # tensor? .to ...

      outputs = self._model(images)

      loss = self._criterion(outputs, targets)
      acc = accuracy(outputs, targets)[0]

      train_loss += loss.item()
      train_acc += acc.item()

      self._optimizer.zero_grad()
      loss.backward()
      self._optimizer.step()
    
    i += 1
    self._train_loss.append(train_loss / i)
    self._train_acc.append(train_acc / i)
  
  def _val(self):
    print("validation")
    self._model.eval()
    val_loss, val_acc = 0.0, 0.0
    
    with torch.no_grad():
      for i, batch in tqdm(enumerate(self._val_loader), total=len(self._val_loader), leave=False):
        images = batch["image"]  # .cuda(non_blocking=True)
        targets = batch["label"]  #.cuda(non_blocking=True)

        # compute output, measure accuracy and record loss
        outputs = self._model(images)

        loss = self._criterion(outputs, targets)
        acc = accuracy(outputs, targets)[0]

        val_loss += loss.item()
        val_acc += acc.item()

      i += 1
      self._val_loss.append(val_loss / i)
      self._val_acc.append(val_acc / i)
  
  def _increase_epoch_num(self):
    self._current_epoch_num += 1
  
  def _is_stop(self):
    return (
      self._plateau_count > self._max_plateau_count
      or self._current_epoch_num > self._max_epoch_num
    )

  def _update_training_state(self):
    if self._val_acc[-1] > self._best_acc:
      self._save_weights()
      self._plateau_count = 0
      self._best_acc = self._val_acc[-1]
      self._best_loss = self._val_loss[-1]
    else:
      self._plateau_count += 1

    self._scheduler.step(100 - self._val_acc[-1])  

  def _save_weights(self, test_acc=0.0):
    if self._distributed == 0:
      state_dict = self._model.state_dict()
    else:
      state_dict = self._model.module.state_dict()

    state = {
        **self._configs,
        "net": state_dict,
        "best_loss": self._best_loss,
        "best_acc": self._best_acc,
        "train_losses": self._train_loss,
        "val_loss": self._val_loss,
        "train_acc": self._train_acc,
        "val_acc": self._val_acc,
        "test_acc": self._test_acc,
      }
    torch.save(state, self._checkpoint_path)

  def _logging(self):
    consume_time = str(datetime.datetime.now() - self._start_time)

    message = "\nE{:03d}  {:.3f}/{:.3f}/{:.3f} {:.3f}/{:.3f}/{:.3f} | p{:02d}  Time {}\n".format(
      self._current_epoch_num,
      self._train_loss[-1],
      self._val_loss[-1],
      self._best_loss,
      self._train_acc[-1],
      self._val_acc[-1],
      self._best_acc,
      self._plateau_count,
      consume_time[:-7],
    )

    self._writer.add_scalar(
        "Accuracy/Train", self._train_acc[-1], self._current_epoch_num
    )
    self._writer.add_scalar(
        "Accuracy/Val", self._val_acc[-1], self._current_epoch_num
    )
    self._writer.add_scalar(
        "Loss/Train", self._train_loss[-1], self._current_epoch_num
    )
    self._writer.add_scalar("Loss/Val", self._val_loss[-1], self._current_epoch_num)

    print(message)
  
  def _load_ckp(self):
    """Test"""
    pass

  def train(self):
    print("start training")
    # print(self._model)
    while not self._is_stop():
      self._increase_epoch_num()
      self._train()
      self._val()

      self._update_training_state()
      self._logging()
  


**Training the Residual Masking Network on JAFFE**

In [None]:
# config file for JAFFE
train_config = {
	"data_path": "content/jaffe/dataset/",
	"image_size": 224,
	"in_channels": 3,
	"num_classes": 7,
	"arch": "resmasking_dropout1", # alexnet
	"lr":  0.0001,
	"weighted_loss": 0,
	"momentum": 0.9,
	"weight_decay": 0.001,
	"distributed": 0,
	"batch_size": 16, 
  "num_workers": 8,
  "device": "cpu",
  "max_epoch_num": 50,
  "max_plateau_count": 8,
  "plateau_patience": 2,
  "steplr": 50,
  "log_dir": "log",
  "checkpoint_dir": "checkpoint/",
  "model_name": "test",
  "cwd": "content/"
}

train_config2 = {
	"image_size": 224,
	"in_channels": 3,
	"num_classes": 8,
	"arch": "resmasking_dropout1",
	"lr":  0.0001,
	"momentum": 0.9,
	"weight_decay": 1e-3,
	"distributed": 0,
	"batch_size": 10,
	"num_workers": 5,
	"device": "cpu", #"cuda:0",
	"max_epoch_num": 100000,
	"max_plateau_count": 20,
	"plateau_patience": 4,
	"steplr": 50,
	"log_dir": "saved/logs",
	"checkpoint_dir": "saved/checkpoints",
	"model_name": "aug",
  "cwd": "content/"
}

def train(train_config):

  model = models.__dict__[train_config["arch"]]

  # load train, test, val data traindl, testdl, valdl
  print(f"SIZE, amount of batches of traindl: {len(traindl)}")
  print(f"SIZE, amount of batches of testdl: {len(testdl)}")
  print(f"SIZE, amount of batches of valdl: {len(valdl)}")
  trainer = JAFFETrainer(model=model, train_set=traindl, val_set=valdl, test_set=testdl, configs=train_config)
  trainer.train()


train(train_config)

SIZE, amount of batches of traindl: 42
SIZE, amount of batches of testdl: 5
SIZE, amount of batches of valdl: 5
start training


                                               

validation





E001  1.992/1.856/1.856 18.452/30.000/30.000 | p00  Time 0:08:50



                                               

validation


                                             


E002  1.678/1.796/1.856 31.548/25.000/30.000 | p01  Time 0:17:48



                                               

validation


                                             


E003  1.344/1.956/1.856 47.024/25.000/30.000 | p02  Time 0:26:49



                                               

validation


                                             

Epoch     4: reducing learning rate of group 0 to 1.0000e-05.

E004  1.118/4.686/1.856 62.500/30.000/30.000 | p03  Time 0:35:42



                                               

validation





E005  0.969/0.800/0.800 66.071/70.000/70.000 | p00  Time 0:44:37



                                               

validation





E006  0.753/0.616/0.616 74.405/80.000/80.000 | p00  Time 0:53:25



                                               

validation


                                             


E007  0.518/0.693/0.616 85.714/75.000/80.000 | p01  Time 1:02:18



                                               

validation


                                             


E008  0.404/0.591/0.616 91.071/70.000/80.000 | p02  Time 1:11:26



                                               

validation


                                             

Epoch     9: reducing learning rate of group 0 to 1.0000e-06.

E009  0.339/0.613/0.616 92.262/70.000/80.000 | p03  Time 1:20:35



                                               

validation


                                             


E010  0.240/0.638/0.616 96.429/70.000/80.000 | p04  Time 1:29:34





validation


                                             


E011  0.396/0.519/0.616 89.286/80.000/80.000 | p05  Time 1:38:36



                                               

validation





E012  0.264/0.576/0.576 94.048/85.000/85.000 | p00  Time 1:47:32



                                               

validation


                                             


E013  0.294/0.636/0.576 94.048/75.000/85.000 | p01  Time 1:56:19



                                               

validation


                                             


E014  0.269/0.461/0.576 92.857/85.000/85.000 | p02  Time 2:05:15



                                               

validation


                                             


E015  0.279/0.590/0.576 94.048/80.000/85.000 | p03  Time 2:14:16



                                               

validation


                                             


E016  0.256/0.576/0.576 92.857/75.000/85.000 | p04  Time 2:23:11



                                               

validation


                                             


E017  0.282/0.656/0.576 93.452/70.000/85.000 | p05  Time 2:32:03



                                               

validation


                                             


E018  0.325/0.633/0.576 91.667/75.000/85.000 | p06  Time 2:40:50



                                               

validation


                                             


E019  0.353/0.509/0.576 90.476/80.000/85.000 | p07  Time 2:49:43



                                               

validation


                                             


E020  0.228/0.648/0.576 97.619/70.000/85.000 | p08  Time 2:58:31



                                               

validation





E021  0.253/0.604/0.576 94.048/75.000/85.000 | p09  Time 3:07:24

