<a href="https://colab.research.google.com/github/DavidMachajewski/ResidualMaskingNetworkFER/blob/main/ResidualMaskingNetwork.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Residual Masking Network

This notebook uses the Residual Masking Network [1] for Facial Expression Recognition (FER). Their source code can be found on GitHub [2].

[1] "Facial Expression Recognition Using Residual Masking Network", Pham et al.

[2] ResMaskingNet implementation, GitHub: https://github.com/phamquiluan/ResidualMaskingNetwork


---

**Install missing packages**

Run the following cell to install the typing package, which is needed for using type hinting and the rmn packages, which is the implementation of the resmaskingnet, see resource [2].

In [None]:
!pip install rmn
!pip install typing

from google.colab import drive
drive.mount('/content/drive')

Collecting rmn
  Downloading rmn-3.0.2-py3-none-any.whl (109 kB)
[?25l[K     |███                             | 10 kB 18.9 MB/s eta 0:00:01[K     |██████                          | 20 kB 25.3 MB/s eta 0:00:01[K     |█████████                       | 30 kB 16.8 MB/s eta 0:00:01[K     |████████████                    | 40 kB 16.3 MB/s eta 0:00:01[K     |███████████████                 | 51 kB 9.7 MB/s eta 0:00:01[K     |██████████████████              | 61 kB 9.0 MB/s eta 0:00:01[K     |████████████████████▉           | 71 kB 9.5 MB/s eta 0:00:01[K     |███████████████████████▉        | 81 kB 10.5 MB/s eta 0:00:01[K     |██████████████████████████▉     | 92 kB 10.9 MB/s eta 0:00:01[K     |█████████████████████████████▉  | 102 kB 9.3 MB/s eta 0:00:01[K     |████████████████████████████████| 109 kB 9.3 MB/s 
Collecting pytorchcv
  Downloading pytorchcv-0.0.67-py2.py3-none-any.whl (532 kB)
[K     |████████████████████████████████| 532 kB 62.0 MB/s 
Installing collect

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


**!! IF YOU DON'T NEED JAFFE SKIP THIS CELL AND GO FOR FER2013 DATASET !!**


**Uploading dataset (JAFFE)**


Import the dataset to the colab runtime using the wget command or the upload option. 
In case of the JAFFE dataset we load it as a zipped "dataset" folder. It takes just between 2 and 5 minutes as its a small dataset.

*NOTE*: 
*The "jaffe.zip" should contain a "dataset" folder, with all JAFFE images inside, as the collectors provide.*

In [None]:
from google.colab import files
dataset_raw = files.upload()
!unzip "/content/jaffe.zip" -d "/content/jaffe/"
# resulting folder is /content/jaffe/datast/

**Uploading dataset (FER2013)**

Import the FER2013 dataset to the colab runtime.

In [None]:
from google.colab import files
dataset_raw = files.upload()
#
# train_ids_0.csv and test_ids_0.csv and valid_ids_0.csv
#

Saving train_ids_0.csv to train_ids_0.csv


KeyboardInterrupt: ignored

**Import mandatory libraries**

In [None]:
from pathlib import Path
from PIL import Image
from typing import Dict, NamedTuple, List, Union
from skimage import transform
from torchvision import transforms
from rmn import RMN, models
import numpy as np
import torch
import datetime
import PIL
import os
import json
import pandas as pd
from tqdm import tqdm
from torch import nn
import rmn
from copy import deepcopy
from torch.utils.tensorboard import SummaryWriter

pretrained_ckpt does not exists!


Downloading pretrained_ckpt..: 100%|██████████| 552M/552M [00:58<00:00, 9.47MiB/s]


deploy.prototxt.txt does not exists!


Downloading deploy.prototxt.txt..: 100%|██████████| 28.1k/28.1k [00:00<00:00, 6.49MiB/s]


res10_300x300_ssd_iter_140000.caffemodel does not exists!


Downloading res10_300x300_ssd_iter_140000.caffemodel..: 100%|██████████| 10.7M/10.7M [00:00<00:00, 40.9MiB/s]


In [None]:
example_path = Path(Path.cwd() / "drive" / "MyDrive" / "fer2013" / "test_ids_0.csv")
print(Path.cwd())
print(example_path)
print(example_path.is_file())

/content
/content/drive/MyDrive/fer2013/test_ids_0.csv
True


Set the location of the jaffe dataset. Also needed for FER2013

In [None]:
# this "dataset_path" object will be needed as input for the JAFFE class
dataset_path = Path(Path.cwd() / "jaffe" / "dataset")

# dataset paths for train and test of FER2013 dataset if you upload it 
## dataset_path_fer_train = Path(Path.cwd() / "train_ids_0.csv")
## dataset_path_fer_test = Path(Path.cwd() / "test_ids_0.csv")
## dataset_path_fer_valid = Path(Path.cwd() / "valid_ids_0.csv")

# dataset paths for train and test of FER2013 dataset from google drive
dataset_path_fer_train = Path(Path.cwd() / "drive" / "MyDrive" / "fer2013" / "train_ids_0.csv")
dataset_path_fer_test = Path(Path.cwd() / "drive" / "MyDrive" / "fer2013" / "test_ids_0.csv")
dataset_path_fer_valid = Path(Path.cwd() / "drive" / "MyDrive" / "fer2013" / "valid_ids_0.csv")



class Sample(dict):
    """Accessing dict keys by dot notation"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__


class Scale(object):
  def __init__(self, new_size: List[Union[int, int]]):
    self.new_size = new_size
  
  def __call__(self, sample: Sample) -> Sample:
    h, w = self.new_size
    sample.image = transforms.Resize(size=(h,w))(sample.image)
    return sample


class ToTensor(object):
  def __call__(self, sample: Sample):
    # sample.image = np.asarray(sample.image)[:, :, np.newaxis]
    # from H x W x C to C x H x W
    sample.image = torch.from_numpy(np.asarray(sample.image))
    # sample.image = torch.unsqueeze(sample.image, 2)
    sample.image = sample.image.permute((2, 0, 1))
    # sample.image = torch.unsqueeze(sample.image, 2)
    sample.image = sample.image.float()
    sample.label = int(sample.label) # torch.Tensor(sample.label)
    return sample


def get_transforms(new_size: List[Union[int, int]]):
  transformer = transforms.Compose([
       Scale(new_size),
       ToTensor()])
  return transformer

**Implementation of the FER2013 dataset and dataloader**

In [None]:
class FER2013(torch.utils.data.Dataset):
  def __init__(self, mode = None, transform: transforms=None):
    self._mode = mode  # "train" or "test"
    self._data_pd, self._data = self._load_data()
    self._transform = transform
  
  def _load_data(self):
    if self._mode == "train":
      print("Load training data")
      _data = pd.read_csv(Path.cwd() / "drive" / "MyDrive" / "fer2013" / "train_ids_0.csv")
    elif self._mode == "test":
      print("Load testing data")
      _data = pd.read_csv(Path.cwd() / "drive" / "MyDrive" / "fer2013" / "test_ids_0.csv")
    elif self._mode == "valid":
      print("Load valid data")
      _data = pd.read_csv(Path.cwd() / "drive" / "MyDrive" / "fer2013" / "valid_ids_0.csv")
    else:  # raise an error
      _data = None

    _imgs = [[int(y) for y in x.split()] for x in _data['pixels']]
    _labels = _data['emotion'].tolist()

    return _data, list(zip(_imgs, _labels))
  
  def __len__(self):
    return len(self._data_pd)
  
  def __getitem__(self, idx: int) -> Sample:
    image = Image.fromarray(np.uint8(np.reshape(self._data[idx][0], (48, 48))))
    image = transforms.Grayscale(num_output_channels=3)(image)
    sample = Sample({
        "image": image, 
        "label": self._data[idx][1]})

    if self._transform:
      sample = self._transform(sample)

    return sample


def FER2013_dataloader(transforms_new_size):
  """Instantiate the FER2013 class and return a PyTorch Dataloader."""
  transformer = get_transforms(transforms_new_size)

  trainset = FER2013("train", transform=transformer)
  testset = FER2013("test", transform=transformer)
  valset = FER2013("valid", transform=transformer)

  print(f"created trainset of size: {len(trainset)}")
  print(f"created testset of size: {len(testset)}")
  print(f"created valset of size: {len(valset)}")

  traindl = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2, drop_last=True)
  testdl = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=True, num_workers=2, drop_last=True)
  valdl = torch.utils.data.DataLoader(valset, batch_size=32, shuffle=True, num_workers=2, drop_last=True)
  return traindl, testdl, valdl


# f = FER2013(mode = "train")
#for idx, sample in enumerate(f):
#  print(f'image nr. {idx}')

#for idx in range(9700, 10000):
#  print(f'image nr. {idx} and row: \n {f[idx]}')

**Implementation of the JAFFE dataset and dataloader**

In [None]:
class JAFFE(torch.utils.data.Dataset):
  """Implementation of the JAFFE dataset by subclassing PyTorch's dataset class."""
  def __init__(self, path_to_dataset: Path, mode = None, split_dataset: float = None, transform: transforms=None):
    """
    path_to_dataset:
    transform:
    """
    self._mode = mode
    self._split_dataset = split_dataset
    self._path_dataset: Path = path_to_dataset
    self._dataset = [f for f in self._path_dataset.iterdir() if f.is_file()]
    # get train, test and validation split of image paths
    self._path_images_train, self._path_images_test, self._path_images_val = self._split(self._dataset, 0.5, self._split_dataset)
    print(f"inside init function: train {len(self._path_images_train)}, test {len(self._path_images_test)}")
    self._path_images = None
    self._transform = transform
    self._classes = {'AN': 0, 'DI': 1, 'FE': 2, 'HA': 3, 'SA': 4, 'SU': 5, 'NE': 6}

  def get_split(self, mode: str):
    """ Return the training, testing or validation JAFFE dataset
    :mode: "train", "test", "val"
    return: JAFFE dataset
    """
    print(f"Returning {mode}-set.")
    self._mode = mode
    if mode == "train":
      self._path_images = self._path_images_train
      new_instance = deepcopy(self)
    elif mode == "test":
      self._path_images = self._path_images_test
      new_instance = deepcopy(self)
    elif mode == "val":
      self._path_images = self._path_images_val
      new_instance = deepcopy(self)
    return new_instance

  def _split(self, dataset: List, testval = 0.5, split_dataset = 0.8):
    """Splits the image paths into train, test image paths
    split_dataset: Has to be a float in range [0.0, 1.0]
    test_val: splits the test set into test/val of same size"""
    lendata = len(dataset)
    np.random.shuffle(dataset)

    train_amount = split_dataset

    bound = int(train_amount * lendata)
    train, test = dataset[:bound], dataset[bound:]
    len_testset = len(test)
    newtest, val = test[:int(testval * len_testset)], test[int(testval * len_testset):]

    return train, newtest, val

  def show_sample(self, idx: int):
    """Given a index idx this function shows the corresponding image sample"""
    with Image.open(self._path_images[idx]) as img:
      display(img)
  
  def _load_image(self, path: Path):
    """Load image using path and returns this image as np.Array"""
    image = PIL.Image.open(path)
    image = transforms.Grayscale(num_output_channels=3)(image)
    return image
  
  def _get_label(self, filepath: Path) -> str:
    """Extracts label from the filename of an JAFFE image sample"""
    return filepath.stem.split(".")[1][:2]

  def __len__(self):
    return len(self._path_images)
  
  def __getitem__(self, idx: int) -> Sample:
    tmp_path = self._path_images[idx]

    sample = Sample({
        "image": self._load_image(tmp_path), 
        "label": self._classes[self._get_label(tmp_path)]})

    if self._transform:
      sample = self._transform(sample)

    return sample

    def _download():
      """..."""
      pass



def JAFFE_dataloader(path_to_dataset: Path, split_dataset, transforms_new_size):
  """Instantiate the jaffe class and return a PyTorch Dataloader."""
  transformer = get_transforms(transforms_new_size)
  dataset = JAFFE(path_to_dataset=path_to_dataset,
                  mode=None,
                  split_dataset=0.8,
                  transform=transformer)

  trainset = dataset.get_split("train")
  testset = dataset.get_split("test")
  valset = dataset.get_split("val")

  print(f"created trainset of size: {len(trainset)}")
  print(f"created testset of size: {len(testset)}")
  print(f"created valset of size: {len(valset)}")

  traindl = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=8, drop_last=True)
  testdl = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=True, num_workers=8, drop_last=True)
  valdl = torch.utils.data.DataLoader(valset, batch_size=32, shuffle=True, num_workers=8, drop_last=True)
  return traindl, testdl, valdl

**Instantiate the JAFFE or FER2013 dataloader**

In [None]:
#
# make "transform_new_size" dependend on the .json config file
#

# traindl, testdl, valdl = JAFFE_dataloader(path_to_dataset=dataset_path, split_dataset = 0.8, transforms_new_size=[224, 224])

traindl, testdl, valdl = FER2013_dataloader(transforms_new_size=[224, 224])

#for idx, batch in enumerate(traindl):
#  if idx==0:
#    print(batch['image'].shape)
#    print(batch["label"])
#    break

Load training data
Load testing data
Load valid data
created trainset of size: 28709
created testset of size: 3589
created valset of size: 3589


**Coding the JAFFE/FER2013 trainer**

In [None]:
def accuracy(output, target):
    with torch.no_grad():
        batch_size = target.size(0)
        pred = torch.argmax(output, dim=1)
        correct = pred.eq(target).float().sum(0)
        acc = correct * 100 / batch_size
    return [acc]


class JAFFETrainer():
  def __init__(self, model, train_set, val_set, test_set, configs):
    """
    :train_set: dataloader of the train set
    :val_set: dataloader of the validation set
    :test_set: dataloader of the test set
    """
    # print start and configs
    #
    # load configurations like the author defines
    self._configs = configs
    self._configs = configs
    self._lr = self._configs["lr"]
    self._batch_size = self._configs["batch_size"]
    self._momentum = self._configs["momentum"]
    self._weight_decay = self._configs["weight_decay"]
    self._distributed = self._configs["distributed"]
    self._num_workers = self._configs["num_workers"]
    self._device = torch.device(self._configs["device"])
    self._max_epoch_num = self._configs["max_epoch_num"]
    self._max_plateau_count = self._configs["max_plateau_count"]
    # model
    self._model = model(in_channels=configs["in_channels"], num_classes=configs["num_classes"])
    
    self._model.to(self._device)
    # datasets
    self._train_loader = train_set
    self._test_loader = test_set
    self._val_loader = val_set
    # Loss and optimizer
    self._criterion = nn.CrossEntropyLoss().to(self._device)
    self._optimizer = torch.optim.Adam(params=self._model.parameters(),
                                       lr=self._lr,
                                       weight_decay=self._weight_decay)
    self._scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=self._optimizer, 
                                                                 patience=self._configs["plateau_patience"],
                                                                 min_lr=1e-6, 
                                                                 verbose=True)
    
    # training info
    self._start_time = datetime.datetime.now()
    self._start_time = self._start_time.replace(microsecond=0)

    #log_dir = os.path.join(
    #    self._configs["cwd"],
    #    self._configs["log_dir"],
    #    "{}_{}".format(
    #        self._configs["model_name"], self._start_time.strftime("%Y%b%d_%H.%M")
    #    ),
    #)

    log_dir = os.path.join(
        self._configs["log_dir"],
        "{}_{}".format(
            self._configs["model_name"], self._start_time.strftime("%Y%b%d_%H.%M")
        ),
    )

    self._writer = SummaryWriter(log_dir)
    self._train_loss = []
    self._train_acc = []
    self._val_loss = []
    self._val_acc = []
    self._best_loss = 1e9
    self._best_acc = 0
    self._test_acc = 0.0
    self._plateau_count = 0
    self._current_epoch_num = 0

    # for checkpoints
    #
    #
    self._checkpoint_dir = self._configs["checkpoint_dir"]
    # self._checkpoint_dir = os.path.join(self._configs["cwd"], "saved/checkpoints")
    #
    #
    if not os.path.exists(self._checkpoint_dir):
        os.makedirs(self._checkpoint_dir, exist_ok=True)

    self._checkpoint_path = os.path.join(
        self._checkpoint_dir,
        "{}_{}".format(
            self._configs["model_name"], self._start_time.strftime("%Y%b%d_%H.%M")
        ),
    )
  
  def _train(self):
    # print("training step")
    self._model.train()
    train_loss, train_acc = 0.0, 0.0

    for i, batch in tqdm(enumerate(self._train_loader), total=len(self._train_loader), leave=False):
      # print(f"size of train loader: {len(self._train_loader)}")
      images = batch["image"].to(self._device)  # .to(self._device)
      # print(f"shape of image tensor: {images.shape}")
      targets = batch["label"].to(self._device)  # tensor? .to ...

      outputs = self._model(images)

      loss = self._criterion(outputs, targets)
      acc = accuracy(outputs, targets)[0]

      train_loss += loss.item()
      train_acc += acc.item()

      self._optimizer.zero_grad()
      loss.backward()
      self._optimizer.step()
    
    i += 1
    self._train_loss.append(train_loss / i)
    self._train_acc.append(train_acc / i)
  
  def _val(self):
    print("validation")
    self._model.eval()
    val_loss, val_acc = 0.0, 0.0
    
    with torch.no_grad():
      for i, batch in tqdm(enumerate(self._val_loader), total=len(self._val_loader), leave=False):
        images = batch["image"].to(self._device)  # .cuda(non_blocking=True)
        targets = batch["label"].to(self._device)  #.cuda(non_blocking=True)

        # compute output, measure accuracy and record loss
        outputs = self._model(images)

        loss = self._criterion(outputs, targets)
        acc = accuracy(outputs, targets)[0]

        val_loss += loss.item()
        val_acc += acc.item()

      i += 1
      self._val_loss.append(val_loss / i)
      self._val_acc.append(val_acc / i)
  
  def _increase_epoch_num(self):
    self._current_epoch_num += 1
  
  def _is_stop(self):
    return (
      self._plateau_count > self._max_plateau_count
      or self._current_epoch_num > self._max_epoch_num
    )

  def _update_training_state(self):
    if self._val_acc[-1] > self._best_acc:
      self._save_weights()
      self._plateau_count = 0
      self._best_acc = self._val_acc[-1]
      self._best_loss = self._val_loss[-1]
    else:
      self._plateau_count += 1

    self._scheduler.step(100 - self._val_acc[-1])  

  def _save_weights(self, test_acc=0.0):
    if self._distributed == 0:
      state_dict = self._model.state_dict()
    else:
      state_dict = self._model.module.state_dict()

    state = {
        **self._configs,
        "net": state_dict,
        "best_loss": self._best_loss,
        "best_acc": self._best_acc,
        "train_losses": self._train_loss,
        "val_loss": self._val_loss,
        "train_acc": self._train_acc,
        "val_acc": self._val_acc,
        "test_acc": self._test_acc,
      }
    torch.save(state, self._checkpoint_path)

  def _logging(self):
    consume_time = str(datetime.datetime.now() - self._start_time)

    message = "\nE{:03d}  {:.3f}/{:.3f}/{:.3f} {:.3f}/{:.3f}/{:.3f} | p{:02d}  Time {}\n".format(
      self._current_epoch_num,
      self._train_loss[-1],
      self._val_loss[-1],
      self._best_loss,
      self._train_acc[-1],
      self._val_acc[-1],
      self._best_acc,
      self._plateau_count,
      consume_time[:-7],
    )

    self._writer.add_scalar(
        "Accuracy/Train", self._train_acc[-1], self._current_epoch_num
    )
    self._writer.add_scalar(
        "Accuracy/Val", self._val_acc[-1], self._current_epoch_num
    )
    self._writer.add_scalar(
        "Loss/Train", self._train_loss[-1], self._current_epoch_num
    )
    self._writer.add_scalar("Loss/Val", self._val_loss[-1], self._current_epoch_num)

    print(message)
  
  def _calc_acc_on_private_test(self):
      self._model.eval()
      test_acc = 0.0
      print("Calc acc on private test..")

      with torch.no_grad():
          for i, (images, targets) in tqdm(
              enumerate(self._test_loader), total=len(self._test_loader), leave=False
          ):

              # TODO: implement augment when predict
              images = images.cuda(non_blocking=True)
              targets = targets.cuda(non_blocking=True)

              outputs = self._model(images)
              acc = accuracy(outputs, targets)[0]
              test_acc += acc.item()

          test_acc = test_acc / (i + 1)
      print("Accuracy on private test: {:.3f}".format(test_acc))
      return test_acc

  def _calc_acc_on_private_test_with_tta(self):
      self._model.eval()
      test_acc = 0.0
      print("Calc acc on private test with tta..")

      # for idx in len(self._test_set):
      #     image, label = self._test_set[idx]

      with torch.no_grad():
          for i, batch in tqdm(enumerate(self._test_loader), total=len(self._test_loader), leave=False):

              # TODO: implement augment when predict
              # images = images.to(self._device)
              
              # targets = targets.to(self._device) # .to(non_blocking=True)

              images = batch["image"].to(self._device)  # .cuda(non_blocking=True)
              targets = batch["label"].to(self._device)

              outputs = self._model(images)
              acc = accuracy(outputs, targets)[0]
              test_acc += acc.item()

          test_acc = test_acc / (i + 1)
      print("Accuracy on private test: {:.3f}".format(test_acc))
      return test_acc

  def train(self):
    print("start training")
    # print(self._model)
    while not self._is_stop():
      self._increase_epoch_num()
      self._train()
      self._val()

      self._update_training_state()
      self._logging()
    
    # training stop and then load the checkpoint
    # and produce masks
    try:
      state = torch.load(self._checkpoint_path)
      if self._distributed:
        self._model.module.load_state_dict(state["net"])
      else:
        self._model.load_state_dict(state["net"])
      test_acc = self._calc_acc_on_private_test_with_tta()
      self._save_weights()
    except Exception as e:
      print("Testing error when training stop")
      print(e)
    
    self._writer.add_text(
        "Summary", "Converged after {} epochs".format(self._current_epoch_num)
    )
    self._writer.add_text(
        "Summary",
        "Best validation accuracy: {:.3f}".format(self._current_epoch_num),
    )
    self._writer.add_text(
        "Summary", "Private test accuracy: {:.3f}".format(self._test_acc)
    )
    self._writer.close()

**Training the Residual Masking Network on JAFFE or FER2013**

In [None]:
# config file for JAFFE
train_config = {
	"data_path": "content/jaffe/dataset/",
	"image_size": 224,
	"in_channels": 3,
	"num_classes": 7,
	"arch": "resmasking_dropout1", # alexnet
	"lr":  0.0001,
	"weighted_loss": 0,
	"momentum": 0.9,
	"weight_decay": 0.001,
	"distributed": 0,
	"batch_size": 16, 
  "num_workers": 2,
  "device": "cuda:0",
  "max_epoch_num": 300,
  "max_plateau_count": 25,
  "plateau_patience": 5,
  "steplr": 50,
  "log_dir": "log",
  "checkpoint_dir": "content/drive/resmaskfer2013ckp/checkpoint/",
  "model_name": "test",
  "cwd": "content/"
}

train_fer_config = {
	"data_path": "content/fer2013/dataset/",
	"image_size": 224,
	"in_channels": 3,
	"num_classes": 7,
	"arch": "resmasking_dropout1", # alexnet
	"lr":  0.0001,
	"weighted_loss": 0,
	"momentum": 0.9,
	"weight_decay": 0.001,
	"distributed": 0,
	"batch_size": 32, 
  "num_workers": 2,
  "device": "cpu", # "cuda:0", # "cpu" 
  "max_epoch_num": 50,
  "max_plateau_count": 6,
  "plateau_patience": 2,
  "steplr": 50,
  "log_dir": "/drive/MyDrive/resmaskfer2013ckp/log/",
  "checkpoint_dir": "/drive/MyDrive/resmaskfer2013ckp/saved/checkpoints/",
  "model_name": "test",
  "cwd": "content/"
}

train_config2 = {
	"data_path": "content/fer2013/dataset/",
	"image_size": 224,
	"in_channels": 3,
	"num_classes": 8,
	"arch": "resmasking_dropout1",
	"lr":  0.0001,
	"momentum": 0.9,
	"weight_decay": 1e-3,
	"distributed": 0,
	"batch_size": 10,
	"num_workers": 5,
	"device": "cpu", #"cuda:0",
	"max_epoch_num": 100000,
	"max_plateau_count": 20,
	"plateau_patience": 4,
	"steplr": 50,
	"log_dir": "saved/logs",
	"checkpoint_dir": "saved/checkpoints",
	"model_name": "aug",
  "cwd": "content/"
}

def train(train_config):

  model = models.__dict__[train_config["arch"]]

  # load train, test, val data traindl, testdl, valdl
  print(f"SIZE, amount of batches of traindl: {len(traindl)}")
  print(f"SIZE, amount of batches of testdl: {len(testdl)}")
  print(f"SIZE, amount of batches of valdl: {len(valdl)}")
  trainer = JAFFETrainer(model=model, train_set=traindl, val_set=valdl, test_set=testdl, configs=train_config)
  trainer.train()


train(train_fer_config)

SIZE, amount of batches of traindl: 897
SIZE, amount of batches of testdl: 112
SIZE, amount of batches of valdl: 112
start training


  0%|          | 1/897 [01:16<18:55:30, 76.04s/it]

In the next cell we will show the masks by extracting and saving the activations as images

In [None]:
#
# source: https://github.com/phamquiluan/ResidualMaskingNetwork/blob/b1b3bb0c8a19e4230357dbbbdc00689f2198c65a/_ar/masking_provement.py
#
import cv2
import glob
from models import resmasking_dropout1
from natsort import natsorted

transform = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor()])


def activations_mask(tensor): # (1, 64, 56, 56)
    # print("Input to activation_mask: ", tensor)
    tensor = torch.squeeze(tensor, 0) # e.g. (64, 56, 56)
    # print("Input to activation_mask after squeeze: ", tensor)
    tensor = torch.mean(tensor, 0)
    tensor = tensor.detach().cpu().numpy()
    tensor = np.maximum(tensor, 0)
    tensor = cv2.resize(tensor, (224, 224)) # (224, 224)
    # print("Input to activation_mask after resizing: ", tensor)
    tensor = tensor - np.min(tensor)
    tensor = tensor / np.max(tensor)

    heatmap = cv2.applyColorMap(np.uint8(255 * tensor), cv2.COLORMAP_JET) # (224, 224, 3)
    # print("Input to activation_mask after colormap: ", heatmap)
    return heatmap


model = resmasking_dropout1(3, 7)

# state = torch.load("./content/saved/checkpoints/test_2021Dec19_23.13")
# state = torch.load("./drive/MyDrive/residualmaskingnetwork_checkpoints/content/saved/checkpoints/test_2021Dec19_23.13", map_location=torch.device('cpu'))
# state for FER2013 training
state = torch.load("./drive/MyDrive/fer2013/training3101/content/saved/checkpoints/test_2022Jan31_09.48", map_location=torch.device('cpu'))

# images jaffe 
# images_path = "./drive/MyDrive/residualmaskingnetwork_checkpoints/images/*.tiff"
# images fer2013
images_path = "./drive/MyDrive/fer2013/training3101/content/saved/checkpoints/imgs/*.tiff"

model.load_state_dict(state["net"])
# model.cuda()
model.eval()

in_images, names, feats, masks, concats = [], [], [], [], []

def add_to_arrays(name, inp_img, feat, mask, concat):
  names.append(name)
  in_images.append(inp_img)
  feats.append(activations_mask(feat))
  masks.append(activations_mask(mask))
  concats.append(activations_mask(concat))

# infere images
print("infere images")
for image_path in natsorted(glob.glob(images_path, recursive=True)):
  image_name = os.path.basename(image_path)
  
  image = cv2.imread(image_path)
  image = cv2.resize(image, (224, 224))
  tensor = transform(image)
  tensor = torch.unsqueeze(tensor, 0)
  # tensor = tensor.cuda()

  # forward
  x = model.conv1(tensor)  # 112
  x = model.bn1(x)
  x = model.relu(x)
  x = model.maxpool(x)  # 56

  x = model.layer1(x)  # 56
  m = model.mask1(x)
  x_copy = x
  # feat_1 = activations_mask(x)
  # heat_1 = activations_mask(m)
  x = x * (1 + m)
  #print(f"Shape of \n layer1: \n{np.shape(x_copy)} \n mask1:\n {np.shape(m)} \n concat1:\n {np.shape(x)}")
  add_to_arrays(image_name, image, x_copy, m, x)
  # conc_1 = activations_mask(x)

  x = model.layer2(x)  # 28
  m = model.mask2(x)
  x_copy = x
  # feat_1 = activations_mask(x)
  # heat_1 = activations_mask(m)
  x = x * (1 + m)
  #print(f"Shape of \n layer2: \n{np.shape(x_copy)} \n mask2:\n {np.shape(m)} \n concat2:\n {np.shape(x)}")
  add_to_arrays(image_name, image, x_copy, m, x)
  # conc_1 = activations_mask(x)

  x = model.layer3(x)  # 14
  m = model.mask3(x)
  x_copy = x
  # feat_1 = activations_mask(x)
  # heat_1 = activations_mask(m)
  x = x * (1 + m)
  #print(f"Shape of \n layer3: \n{np.shape(x_copy)} \n mask3:\n {np.shape(m)} \n concat3:\n {np.shape(x)}")
  add_to_arrays(image_name, image, x_copy, m, x)
  # conc_1 = activations_mask(x)

  x = model.layer4(x)  # 7
  m = model.mask4(x)
  x_copy = x
  x = x * (1 + m)
  #print(f"Shape of \n layer4: \n{np.shape(x_copy)} \n mask4:\n {np.shape(m)} \n concat4:\n {np.shape(x)}")
  add_to_arrays(image_name, image, x_copy, m, x)
  

  x = model.avgpool(x)
  x = torch.flatten(x, 1)

  output = model.fc(x)

# ######################################################
#
# create a "masking" and a "images" folder 
# if not existent automatically
#
#
# ######################################################
if not os.path.exists('./masking'):
  os.makedirs('./masking')

print("create images")
for idx, (name, image, feat, heat, conc) in enumerate(zip(names, in_images, feats, masks, concats)):
  #
  # superimposing heatmap and image
  #
  #print(f"name: \n{name}\n image: \n{np.shape(image)}\n feat: \n{np.shape(feat)}\n heat:\n {np.shape(heat)}\n conc: \n{np.shape(conc)}\n")

  gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
  heatgrayimg = cv2.cvtColor(heat, cv2.COLOR_BGR2GRAY)
  heatmap_img = cv2.applyColorMap(heatgrayimg, cv2.COLORMAP_JET)
  # print(f"shape of heatmap_img {np.shape(heatmap_img)}")
  # supimp_image = cv2.addWeighted(image, 0.4, heatmap_img, 0.6, 0)
  
  r1 = Image.fromarray(image)
  r2 = Image.fromarray(heat)
  supimp_image = np.array(Image.blend(r1, r2, 0.5))
  
  # print(f"shape of input image {np.shape(image)}")
  # print(f"shape of heat {np.shape(heatmap_img)}")
  # print(f"shape of heat {np.shape(supimp_image)}")
  
  cv2.imwrite(
    f"./masking/{'supimp_'+str(idx % 4)+'_'+name}",
    supimp_image
  )

  cv2.imwrite(
      f"./masking/{'feat_'+str(idx % 4)+'_'+name}",
      np.concatenate((image, feat), axis=1),
  )
  # heat are masks
  cv2.imwrite(
      f"./masking/{'heat_'+str(idx % 4)+'_'+name}",
      np.concatenate((image, heat), axis=1),
  )
  cv2.imwrite(
      f"./masking/{'conc_'+str(idx % 4)+'_'+name}",
      np.concatenate((image, conc), axis=1),
  )

infere images
create images


In [None]:
!zip -r /content/resmaking_output.zip /content/masking

  adding: content/masking/ (stored 0%)
  adding: content/masking/conc_0_14_3.tiff (deflated 1%)
  adding: content/masking/heat_2_141_3.tiff (deflated 0%)
  adding: content/masking/supimp_2_112_6.tiff (deflated 1%)
  adding: content/masking/conc_1_56_0.tiff (deflated 1%)
  adding: content/masking/conc_2_130_4.tiff (deflated 1%)
  adding: content/masking/feat_0_102_6.tiff (deflated 1%)
  adding: content/masking/conc_2_136_6.tiff (deflated 1%)
  adding: content/masking/supimp_2_135_6.tiff (deflated 1%)
  adding: content/masking/conc_1_145_3.tiff (deflated 1%)
  adding: content/masking/supimp_3_14_3.tiff (deflated 1%)
  adding: content/masking/heat_0_0_0.tiff (deflated 0%)
  adding: content/masking/conc_1_55_5.tiff (deflated 2%)
  adding: content/masking/supimp_3_97_3.tiff (deflated 1%)
  adding: content/masking/supimp_0_85_2.tiff (deflated 1%)
  adding: content/masking/supimp_1_34_3.tiff (deflated 1%)
  adding: content/masking/conc_1_17_2.tiff (deflated 1%)
  adding: content/masking/feat_

In [None]:
from google.colab import files
files.download("/content/resmaking_output.zip")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>