In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
dataset_dir = '/content/drive/MyDrive/datasets/Cat Emotions.v1-test.folder/Master folder'

In [3]:
import os
from PIL import Image
import torch
import torchvision
import numpy as np
import albumentations
from tqdm import tqdm

In [4]:
configs = {
    "train_batch_size": 64,
    "valid_batch_size": 1,
    "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    "lr": 1e-4,
    "epochs": 30,
    "train_augmentations": albumentations.Compose([
        albumentations.HorizontalFlip(),
        albumentations.RandomCrop(480, 480, p=0.05),
        albumentations.Resize(224, 224),
        albumentations.Normalize()]),
    "valid_augmentations": albumentations.Compose([
        albumentations.Resize(224, 224),
        albumentations.Normalize()])
}

In [5]:
class EmotionDataset:
    def __init__(self, parent_folder: str, mode: str = 'train', shuffle=True, augmentations=None):
        self.parent_folder = parent_folder
        self.data = []
        self.mode = mode
        self.labels={'angry':0, 'disgusted':1, 'scared':2, 'happy':3, 'normal':4, 'sad':5, 'surprised':6}
        self.augmentations = augmentations

        subset = self.mode
        subset_path = os.path.join(parent_folder, subset)
        for emotion in os.listdir(subset_path):
            emotion_path = os.path.join(subset_path, emotion)
            if os.path.isdir(emotion_path):
                for file_name in os.listdir(emotion_path):
                    file_path = os.path.join(emotion_path, file_name)
                    if os.path.isfile(file_path):
                        image = Image.open(file_path)
                        self.data.append((image, self.labels[emotion.lower()]))

        if shuffle:
            np.random.shuffle(self.data)

    def __getitem__(self, index: int):
        image, label = self.data[index]
        image = np.array(image)
        if self.augmentations:
            image = self.augmentations(image=image)['image']
        image = torch.tensor(image.transpose(2, 0, 1))
        label_one_hot = np.zeros(len(self.labels))
        label_one_hot[label] = 1
        label_one_hot = torch.tensor(label_one_hot)


        return  image.float(), label_one_hot.float()

    def __len__(self) -> int:
        return len(self.data)

In [6]:
train_dataset = EmotionDataset(dataset_dir, augmentations=configs["train_augmentations"])
valid_dataset = EmotionDataset(dataset_dir, 'valid', augmentations=configs["valid_augmentations"])

In [7]:
train_dataloder = torch.utils.data.DataLoader(train_dataset, batch_size=configs["train_batch_size"], shuffle=True)
valid_dataloder = torch.utils.data.DataLoader(valid_dataset, batch_size=configs["valid_batch_size"], shuffle=True)

In [8]:
model = torchvision.models.resnet18(pretrained=True)
model.fc = torch.nn.Sequential(torch.nn.Linear(512, 7), torch.nn.Softmax(dim=1))
model.to(configs["device"])

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 132MB/s]


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [9]:
loss = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.Adam(model.parameters(), lr=configs["lr"])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=4, factor=0.5)

In [10]:
best_acc = 0.5
for epoch in range(configs["epochs"]):
    model.train()
    for image, label in tqdm(train_dataloder):
        image = image.to(configs["device"])
        label = label.to(configs["device"])
        optimizer.zero_grad()
        output = model(image)
        loss_value = loss(output, label)
        loss_value.backward()
        optimizer.step()
        scheduler.step(loss_value)
    model.eval()
    with torch.no_grad():

        total_loss = 0
        predicts = []
        real_labels = []
        for image, label in tqdm(valid_dataloder):
            image = image.to(configs["device"])
            label = label.to(configs["device"])
            output = model(image)
            loss_value = loss(output, label)
            total_loss += loss_value.item()
            predicts.append(output.argmax(1).cpu().numpy())
            real_labels.append(label.argmax(1).cpu().numpy())
        print(f"Epoch: {epoch}, loss: {total_loss/len(valid_dataloder)}, acc: {np.sum(np.array(predicts) == np.array(real_labels))/len(valid_dataloder)}")

        if np.sum(np.array(predicts) == np.array(real_labels))/len(valid_dataloder) > best_acc:
            best_acc = np.sum(np.array(predicts) == np.array(real_labels))/len(valid_dataloder)
            torch.save(model.state_dict(), 'best_model.pt')

100%|██████████| 8/8 [02:41<00:00, 20.16s/it]
100%|██████████| 169/169 [00:21<00:00,  7.95it/s]


Epoch: 0, loss: 1.8477304560192944, acc: 0.38461538461538464


100%|██████████| 8/8 [02:25<00:00, 18.18s/it]
100%|██████████| 169/169 [00:19<00:00,  8.63it/s]


Epoch: 1, loss: 1.7490766986587343, acc: 0.4556213017751479


100%|██████████| 8/8 [02:30<00:00, 18.75s/it]
100%|██████████| 169/169 [00:19<00:00,  8.79it/s]


Epoch: 2, loss: 1.6767400510212374, acc: 0.5739644970414202


100%|██████████| 8/8 [02:29<00:00, 18.68s/it]
100%|██████████| 169/169 [00:19<00:00,  8.61it/s]


Epoch: 3, loss: 1.6412371164242896, acc: 0.6035502958579881


100%|██████████| 8/8 [02:26<00:00, 18.27s/it]
100%|██████████| 169/169 [00:18<00:00,  9.17it/s]


Epoch: 4, loss: 1.6224241334305713, acc: 0.6094674556213018


100%|██████████| 8/8 [02:23<00:00, 17.96s/it]
100%|██████████| 169/169 [00:19<00:00,  8.70it/s]


Epoch: 5, loss: 1.609674721779908, acc: 0.6153846153846154


100%|██████████| 8/8 [02:29<00:00, 18.73s/it]
100%|██████████| 169/169 [00:18<00:00,  9.28it/s]


Epoch: 6, loss: 1.5938741215587369, acc: 0.6804733727810651


100%|██████████| 8/8 [02:28<00:00, 18.52s/it]
100%|██████████| 169/169 [00:18<00:00,  9.10it/s]


Epoch: 7, loss: 1.5875200562223175, acc: 0.6745562130177515


100%|██████████| 8/8 [02:23<00:00, 17.96s/it]
100%|██████████| 169/169 [00:20<00:00,  8.37it/s]


Epoch: 8, loss: 1.5861142281244491, acc: 0.6686390532544378


100%|██████████| 8/8 [02:24<00:00, 18.08s/it]
100%|██████████| 169/169 [00:18<00:00,  9.05it/s]


Epoch: 9, loss: 1.5841539816038144, acc: 0.6745562130177515


100%|██████████| 8/8 [02:31<00:00, 18.98s/it]
100%|██████████| 169/169 [00:20<00:00,  8.24it/s]


Epoch: 10, loss: 1.5840965050917406, acc: 0.6745562130177515


100%|██████████| 8/8 [02:31<00:00, 18.90s/it]
100%|██████████| 169/169 [00:19<00:00,  8.85it/s]


Epoch: 11, loss: 1.5838689627732045, acc: 0.6745562130177515


100%|██████████| 8/8 [02:28<00:00, 18.52s/it]
100%|██████████| 169/169 [00:19<00:00,  8.62it/s]


Epoch: 12, loss: 1.5826270185278717, acc: 0.6745562130177515


100%|██████████| 8/8 [02:28<00:00, 18.50s/it]
100%|██████████| 169/169 [00:20<00:00,  8.33it/s]


Epoch: 13, loss: 1.5835171507660455, acc: 0.6745562130177515


100%|██████████| 8/8 [02:36<00:00, 19.55s/it]
100%|██████████| 169/169 [00:18<00:00,  8.97it/s]


Epoch: 14, loss: 1.5833765091980703, acc: 0.6745562130177515


100%|██████████| 8/8 [02:33<00:00, 19.14s/it]
100%|██████████| 169/169 [00:18<00:00,  9.09it/s]


Epoch: 15, loss: 1.5827055499398497, acc: 0.6686390532544378


100%|██████████| 8/8 [02:31<00:00, 18.88s/it]
100%|██████████| 169/169 [00:18<00:00,  9.10it/s]


Epoch: 16, loss: 1.5828395581104346, acc: 0.6745562130177515


100%|██████████| 8/8 [02:32<00:00, 19.08s/it]
100%|██████████| 169/169 [00:18<00:00,  9.07it/s]


Epoch: 17, loss: 1.5824221019914164, acc: 0.6745562130177515


100%|██████████| 8/8 [02:31<00:00, 18.88s/it]
100%|██████████| 169/169 [00:18<00:00,  9.01it/s]


Epoch: 18, loss: 1.5831586616279105, acc: 0.6745562130177515


100%|██████████| 8/8 [02:26<00:00, 18.35s/it]
100%|██████████| 169/169 [00:18<00:00,  9.18it/s]


Epoch: 19, loss: 1.5811009752679859, acc: 0.6686390532544378


100%|██████████| 8/8 [02:24<00:00, 18.10s/it]
100%|██████████| 169/169 [00:18<00:00,  9.14it/s]


Epoch: 20, loss: 1.5809369729115412, acc: 0.6745562130177515


100%|██████████| 8/8 [02:20<00:00, 17.51s/it]
100%|██████████| 169/169 [00:18<00:00,  9.26it/s]


Epoch: 21, loss: 1.5826016608074571, acc: 0.6745562130177515


100%|██████████| 8/8 [02:26<00:00, 18.29s/it]
100%|██████████| 169/169 [00:18<00:00,  9.27it/s]


Epoch: 22, loss: 1.5823662824179294, acc: 0.6686390532544378


100%|██████████| 8/8 [02:22<00:00, 17.82s/it]
100%|██████████| 169/169 [00:20<00:00,  8.37it/s]


Epoch: 23, loss: 1.5843477347898764, acc: 0.6745562130177515


100%|██████████| 8/8 [02:28<00:00, 18.56s/it]
100%|██████████| 169/169 [00:19<00:00,  8.56it/s]


Epoch: 24, loss: 1.583178366429707, acc: 0.6745562130177515


100%|██████████| 8/8 [02:23<00:00, 17.91s/it]
100%|██████████| 169/169 [00:18<00:00,  9.21it/s]


Epoch: 25, loss: 1.582446090568452, acc: 0.6745562130177515


100%|██████████| 8/8 [02:27<00:00, 18.46s/it]
100%|██████████| 169/169 [00:19<00:00,  8.57it/s]


Epoch: 26, loss: 1.58232483920261, acc: 0.6745562130177515


100%|██████████| 8/8 [02:27<00:00, 18.40s/it]
100%|██████████| 169/169 [00:18<00:00,  9.26it/s]


Epoch: 27, loss: 1.5851143365780982, acc: 0.6745562130177515


100%|██████████| 8/8 [02:23<00:00, 17.98s/it]
100%|██████████| 169/169 [00:19<00:00,  8.68it/s]


Epoch: 28, loss: 1.5845604380206948, acc: 0.6745562130177515


100%|██████████| 8/8 [02:27<00:00, 18.39s/it]
100%|██████████| 169/169 [00:19<00:00,  8.59it/s]


Epoch: 29, loss: 1.5831904439531135, acc: 0.6745562130177515
