In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## data loader

In [2]:
!unzip /content/drive/MyDrive/Colab/FERVT/data/FER2013Train.zip

[1;30;43m스트리밍 출력 내용이 길어서 마지막 5000줄이 삭제되었습니다.[0m
  inflating: FER2013Train/fer0023636.png  
  inflating: FER2013Train/fer0023637.png  
  inflating: FER2013Train/fer0023638.png  
  inflating: FER2013Train/fer0023639.png  
  inflating: FER2013Train/fer0023640.png  
  inflating: FER2013Train/fer0023641.png  
  inflating: FER2013Train/fer0023642.png  
  inflating: FER2013Train/fer0023643.png  
  inflating: FER2013Train/fer0023644.png  
  inflating: FER2013Train/fer0023645.png  
  inflating: FER2013Train/fer0023646.png  
  inflating: FER2013Train/fer0023647.png  
  inflating: FER2013Train/fer0023648.png  
  inflating: FER2013Train/fer0023649.png  
  inflating: FER2013Train/fer0023650.png  
  inflating: FER2013Train/fer0023651.png  
  inflating: FER2013Train/fer0023652.png  
  inflating: FER2013Train/fer0023653.png  
  inflating: FER2013Train/fer0023654.png  
  inflating: FER2013Train/fer0023655.png  
  inflating: FER2013Train/fer0023656.png  
  inflating: FER2013Train/fer0023657.png  
  in

In [3]:
!unzip /content/drive/MyDrive/Colab/FERVT/data/FER2013Test.zip

Archive:  /content/drive/MyDrive/Colab/FERVT/data/FER2013Test.zip
   creating: FER2013Test/
  inflating: FER2013Test/fer0032220.png  
  inflating: FER2013Test/fer0032222.png  
  inflating: FER2013Test/fer0032223.png  
  inflating: FER2013Test/fer0032224.png  
  inflating: FER2013Test/fer0032225.png  
  inflating: FER2013Test/fer0032226.png  
  inflating: FER2013Test/fer0032227.png  
  inflating: FER2013Test/fer0032228.png  
  inflating: FER2013Test/fer0032229.png  
  inflating: FER2013Test/fer0032230.png  
  inflating: FER2013Test/fer0032231.png  
  inflating: FER2013Test/fer0032232.png  
  inflating: FER2013Test/fer0032233.png  
  inflating: FER2013Test/fer0032234.png  
  inflating: FER2013Test/fer0032235.png  
  inflating: FER2013Test/fer0032236.png  
  inflating: FER2013Test/fer0032237.png  
  inflating: FER2013Test/fer0032238.png  
  inflating: FER2013Test/fer0032239.png  
  inflating: FER2013Test/fer0032240.png  
  inflating: FER2013Test/fer0032241.png  
  inflating: FER2013Test/f

In [4]:
!unzip /content/drive/MyDrive/Colab/FERVT/data/FER2013Valid.zip

Archive:  /content/drive/MyDrive/Colab/FERVT/data/FER2013Valid.zip
   creating: FER2013Valid/
  inflating: FER2013Valid/fer0028638.png  
  inflating: FER2013Valid/fer0028639.png  
  inflating: FER2013Valid/fer0028640.png  
  inflating: FER2013Valid/fer0028641.png  
  inflating: FER2013Valid/fer0028642.png  
  inflating: FER2013Valid/fer0028643.png  
  inflating: FER2013Valid/fer0028644.png  
  inflating: FER2013Valid/fer0028645.png  
  inflating: FER2013Valid/fer0028646.png  
  inflating: FER2013Valid/fer0028647.png  
  inflating: FER2013Valid/fer0028648.png  
  inflating: FER2013Valid/fer0028649.png  
  inflating: FER2013Valid/fer0028650.png  
  inflating: FER2013Valid/fer0028651.png  
  inflating: FER2013Valid/fer0028652.png  
  inflating: FER2013Valid/fer0028653.png  
  inflating: FER2013Valid/fer0028654.png  
  inflating: FER2013Valid/fer0028655.png  
  inflating: FER2013Valid/fer0028656.png  
  inflating: FER2013Valid/fer0028657.png  
  inflating: FER2013Valid/fer0028658.png  
  i

## Model

In [5]:
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F


def split_last(x, shape):
    "split the last dimension to given shape"
    shape = list(shape)
    assert shape.count(-1) <= 1
    if -1 in shape:
        shape[shape.index(-1)] = int(x.size(-1) / -np.prod(shape))
    return x.view(*x.size()[:-1], *shape)


def merge_last(x, n_dims):
    "merge the last n_dims to a dimension"
    s = x.size()
    assert n_dims > 1 and n_dims < len(s)
    return x.view(*s[:-n_dims], -1)


class MultiHeadedSelfAttention(nn.Module):
    """Multi-Headed Dot Product Attention"""

    def __init__(self, dim, num_heads, dropout):
        super().__init__()
        self.proj_q = nn.Linear(dim, dim)
        self.proj_k = nn.Linear(dim, dim)
        self.proj_v = nn.Linear(dim, dim)
        self.drop = nn.Dropout(dropout)
        self.n_heads = num_heads
        self.scores = None  # for visualization

    def forward(self, x, mask):
        """
        x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim))
        mask : (B(batch_size) x S(seq_len))
        * split D(dim) into (H(n_heads), W(width of head)) ; D = H * W
        """
        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x)
        q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2) for x in [q, k, v])
        # (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S)
        scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1))
        if mask is not None:
            mask = mask[:, None, None, :].float()
            scores -= 10000.0 * (1.0 - mask)
        scores = self.drop(F.softmax(scores, dim=-1))
        # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W)
        h = (scores @ v).transpose(1, 2).contiguous()
        # -merge-> (B, S, D)
        h = merge_last(h, 2)
        self.scores = scores
        return h


class PositionWiseFeedForward(nn.Module):
    """FeedForward Neural Networks for each position"""

    def __init__(self, dim, ff_dim):
        super().__init__()
        self.fc1 = nn.Linear(dim, ff_dim)
        self.fc2 = nn.Linear(ff_dim, dim)

    def forward(self, x):
        # (B, S, D) -> (B, S, D_ff) -> (B, S, D)
        return self.fc2(self.gelu(self.fc1(x)))

    # GELU only support pytorch >=1.7.0,our model use 1.0.0,so define it.
    def gelu(self, x):
        return 0.5 * x * (1 + torch.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * torch.pow(x, 3))))


class Block(nn.Module):
    """Transformer Block"""

    def __init__(self, dim, num_heads, ff_dim, dropout):
        super().__init__()
        self.attn = MultiHeadedSelfAttention(dim, num_heads, dropout)
        self.proj = nn.Linear(dim, dim)
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.pwff = PositionWiseFeedForward(dim, ff_dim)
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        self.drop = nn.Dropout(dropout)

    def forward(self, x, mask):
        h = self.drop(self.proj(self.attn(self.norm1(x), mask)))
        x = x + h
        h = self.drop(self.pwff(self.norm2(x)))
        x = x + h
        return x


class Transformer(nn.Module):
    """Transformer with Self-Attentive Blocks"""

    def __init__(self, num_layers, dim, num_heads, ff_dim, dropout):
        super().__init__()
        self.blocks = nn.ModuleList([
            Block(dim, num_heads, ff_dim, dropout) for _ in range(num_layers)])

    def forward(self, x, mask=None):
        for block in self.blocks:
            x = block(x, mask)
        return x

In [9]:
import torch.nn.functional as F
import torch.nn as nn
import torch
import torchvision
import torchvision.models as models


# backbone + token_embedding + position_embedding
class Backbone(nn.Module):
    @staticmethod
    def weight_init(m):
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_uniform_(m.weight)
        elif isinstance(m, nn.Linear):
            nn.init.xavier_normal_(m.weight)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)

    def __init__(self):
        super(Backbone, self).__init__()

        resnet = models.resnet34(weights=models.ResNet34_Weights.DEFAULT)
        self.conv1 = resnet.conv1
        self.bn1 = resnet.bn1
        self.relu = resnet.relu
        self.maxpool = resnet.maxpool
        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4

        #  feature resize networks
        # shape trans 128
        self.convtran1 = nn.Conv2d(128, 3, 21, 1)
        self.bntran1 = nn.BatchNorm2d(3)
        self.convtran2 = nn.Conv2d(256, 3, 7, 1)
        self.bntran2 = nn.BatchNorm2d(3)
        self.convtran3 = nn.Conv2d(512, 3, 2, 1,1)
        self.bntran3 = nn.BatchNorm2d(3)
        # Visual Token Embedding.
        self.layernorm = nn.LayerNorm(192)
        self.dropout = nn.Dropout(0.2)
        self.line = nn.Linear(192, 192)
        # class token init
        self.class_token = nn.Parameter(torch.zeros(1, 192))
        # position embedding
        self.pos_embedding = nn.Parameter(torch.zeros(4, 192))

        self.apply(self.weight_init)

    def forward(self, x):
        batchsize = x.shape[0]
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)
        x = self.layer1(x)

        x = self.layer2(x)

        # L1  feature transformation from the pyramid features
        l1 = F.leaky_relu(self.bntran1(self.convtran1(x)))
        # L1 reshape to (1 x c h w)
        l1 = l1.view(batchsize,1,-1)
        # L1 token_embedding to T1    L1(1xCHW)--->T1(1xD)
        # in this model D=128
        l1 = self.line(self.dropout(F.relu(self.layernorm(l1))))

        x = self.layer3(x)
        l2 = F.leaky_relu(self.bntran2(self.convtran2(x)))
        l2 = l2.view(batchsize,1,-1)
        l2 = self.line(self.dropout(F.relu(self.layernorm(l2))))

        x = self.layer4(x)
        l3 = F.leaky_relu(self.bntran3(self.convtran3(x)))
        l3 = l3.view(batchsize,1,-1)
        l3 = self.line(self.dropout(F.relu(self.layernorm(l3))))

        x = torch.cat((l1, l2), dim=1)
        x = torch.cat((x, l3), dim=1)
        x = torch.cat((self.class_token.expand(batchsize, 1, 192), x), dim=1)
        x = x + self.pos_embedding.expand(batchsize, 4, 192)

        return x


#  refer to SubSection 3.3
# input: img(batchsize,c,h,w)--->output: img_feature_map(batchsize,c,h,w)
# in FER+ (b,3,48,48)
class GWA(nn.Module):
    @staticmethod
    def weight_init(m):
        if isinstance(m,nn.Conv2d):
            nn.init.kaiming_uniform_(m.weight)
        elif isinstance(m,nn.Linear):
            nn.init.xavier_normal_(m.weight)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)

    def __init__(self):
        super(GWA, self).__init__()
        # low level feature extraction
        self.conv1 = nn.Conv2d(3, 64, 1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 3, 1)
        self.bn2 = nn.BatchNorm2d(3)
        # 图像分割，每块16x16，用一个卷积层实现
        self.patch_embeddings = nn.Conv2d(in_channels=3,
                                          out_channels=9408,
                                          kernel_size=(56, 56),
                                          stride=(56, 56))
        # 使用自适应pool压缩一维
        self.aap = nn.AdaptiveAvgPool2d((1, 1))

        self.apply(self.weight_init)

    def forward(self, x):
        img = x
        batchsize = x.shape[0]
        x = self.patch_embeddings(x)
        x = x.flatten(2).transpose(-1, -2).view(batchsize, 16, 3, 56, 56)  # （batchsize,9,768）（batchsize,9,3,256
        temp = []
        for i in range(x.shape[1]):
            temp.append(F.leaky_relu(self.bn2(self.conv2(
                F.leaky_relu(self.bn1(self.conv1(x[:, i, :, :, :])))))).unsqueeze(0).transpose(0, 1))

        # x = x.view(batchsize, 9, 3, 16, 16)
        # x = F.softmax(torch.matmul(x, torch.transpose(x, 3, 4)) / 3)
        x = torch.cat(tuple(temp), dim=1)
        query = x
        key = torch.transpose(query, 3, 4)
        attn = F.softmax(torch.matmul(query, key) / 56,dim=1)
        # nattn = torch.zeros(batchsize, 9, 3, 1, 1)
        temp = []
        for i in range(attn.shape[1]):
            temp.append(self.aap(attn[:, i, :, :, :]).unsqueeze(0).transpose(0, 1))
        pattn = torch.ones(56, 56).cuda() * torch.cat(tuple(temp), dim=1)
        pattn = pattn.permute(0,2,3,1,4).contiguous()
        pattn = pattn.view(batchsize, 3, 224, 224).cuda()
        map = pattn * img  # (b,3,48,48)
        return img, map


class GWA_Fusion(nn.Module):
    @staticmethod
    def weight_init(m):
        if isinstance(m, nn.Conv2d):
            nn.init.xavier_uniform_(m.weight, 0.1)
        elif isinstance(m, nn.Linear):
            nn.init.xavier_normal_(m.weight)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)

    def __init__(self):
        super(GWA_Fusion, self).__init__()
        # 原图特征转换网络
        self.convt1 = nn.Conv2d(3, 3, (3, 3), 1, 1)
        self.bnt1 = nn.BatchNorm2d(3)
        # map特征转换网络
        self.convt2 = nn.Conv2d(3, 3, (3, 3), 1, 1)
        self.bnt2 = nn.BatchNorm2d(3)
        # RFN参与特征融合网络
        self.convrfn1 = nn.Conv2d(3,3,(3,3),1,1)
        self.bnrfn1 = nn.BatchNorm2d(3)
        self.prelu1 = nn.PReLU(3)
        self.convrfn2 = nn.Conv2d(3, 3, (3, 3), 1, 1)
        self.bnrfn2 = nn.BatchNorm2d(3)
        self.prelu2 = nn.PReLU(3)
        self.convrfn3 = nn.Conv2d(3, 3, (3, 3), 1, 1)
        self.sigmod = nn.Sigmoid()

        self.apply(self.weight_init)

    def forward(self, img, map):
        img_trans = F.relu(self.bnt1(self.convt1(img)))
        map_trans = F.relu(self.bnt2(self.convt1(map)))
        result = self.prelu1(self.bnrfn1(self.convrfn1(img_trans + map_trans)))
        result = self.prelu2(self.bnrfn2(self.convrfn2(result)))
        result = self.sigmod(self.convrfn3(result+img_trans + map_trans))

        return result


class VTA(nn.Module):
    def __init__(self):
        super(VTA, self).__init__()

        self.transformer = Transformer(num_layers=12, dim=192, num_heads=8,
                                       ff_dim=768, dropout=0.1)
        self.layernorm = nn.LayerNorm(192)
        self.fc = nn.Linear(192, 8)


    def forward(self, x):
        x = self.transformer(x)
        # x = x.transpose(1, 2)
        x = self.layernorm(x)[:, 0, :]
        x = self.fc(x)
        return x


class FERVT(nn.Module):
    def __init__(self, device):
        super(FERVT, self).__init__()
        self.gwa = GWA()
        self.gwa.to(device)
        self.gwa_f = GWA_Fusion()
        self.gwa_f.to(device)
        self.backbone = Backbone()
        self.backbone.to(device)
        self.vta = VTA()
        self.vta.to(device)

        self.to(device)
        # Evaluation mode on

    def forward(self, x):
        img,map = self.gwa(x)
        emotions = self.vta(self.backbone(self.gwa_f(img,map)))
        return emotions


# CrossEntropyLoss with Label Smoothing is added in pytorch 1.7.0+,change it will be ok if your version >1.7
class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes, smoothing=0.0, dim=-1):
        super(LabelSmoothingLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.dim = dim

    def forward(self, pred, target):
        pred = pred.log_softmax(dim=self.dim)
        with torch.no_grad():
            # true_dist = pred.data.clone()
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            if len(true_dist.shape) == 1:
                true_dist.scatter_(1, target.data.unsqueeze(0), self.confidence)
            else:
                true_dist.scatter_(1, target.data, self.confidence)
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))

## Train

In [21]:
import math
import os
import random
import cv2
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
from sklearn.metrics import confusion_matrix
import torch.nn as nn
import pandas as pd

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Dataset class class FERPlusDataset
def one_hot(x, class_count):
    return torch.eye(class_count)[x, :]

class FERPlusDataset(torch.utils.data.Dataset):
    def __init__(self, data_csv, phase, transform=None):
        self.phase = phase
        self.transform = transform

        # Read the dataset CSV file
        self.data = pd.read_csv(data_csv)
        self.data.iloc[:, 2:12] = self.data.iloc[:, 2:12].replace(1, 0)

        # Get file paths and labels
        self.file_paths = self.data.iloc[:, 0].values
        self.counts = self.data.iloc[:, 2:12].values  # Emotion scores

        # Apply constraints to filter valid samples
        self._apply_constraints()

        # Use argmax to determine the emotion class
        self.labels = np.argmax(self.counts, axis=1)

        # Debugging: Check label range
        print("Unique labels in dataset after filtering:", np.unique(self.labels))

    def _apply_constraints(self):
        # Constraint : 'unknown-face' or 'not-face' label removal
        max_counts = self.counts.max(axis=1)
        counts_eq_max = (self.counts == max_counts[:, None])
        constraint1_violation = counts_eq_max[:, [8, 9]].any(axis=1)

        # Constraint : Remove samples where max vote count <= 50% of total votes
        total_votes = self.counts.sum(axis=1)
        constraint3_violation = max_counts <= (total_votes / 2)

        # Constraint : Remove samples with more than 3 max-vote labels
        num_max_labels = counts_eq_max.sum(axis=1)
        constraint2_violation = num_max_labels > 3

        # Combine constraints
        valid_samples = ~(
            constraint1_violation | constraint2_violation | constraint3_violation
        )

        # Apply valid samples filter
        self.file_paths = self.file_paths[valid_samples]
        self.counts = self.counts[valid_samples]

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        if self.phase == 'train':
            path = '/content/FER2013Train/' + self.file_paths[idx]
        elif self.phase == 'val':
            path = '/content/FER2013Valid/' + self.file_paths[idx]
        elif self.phase == 'test':
            path = '/content/FER2013Test/' + self.file_paths[idx]
        image = Image.open(path).convert('RGB')
        label = self.labels[idx]

        if self.transform is not None:
            image = self.transform(image)

        return image, label

# Training function
def train(dataloader, model, loss_fn, optimizer):
    model.train()
    running_loss = 0.0
    correct_sum = 0
    iter_cnt = 0
    for imgs, targets in tqdm(dataloader, desc='Training', leave=False):
        imgs, targets = imgs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = loss_fn(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicts = torch.max(outputs, 1)
        correct_sum += torch.eq(predicts, targets).sum().item()
        iter_cnt += 1

    acc = correct_sum / len(dataloader.dataset)
    avg_loss = running_loss / iter_cnt
    return avg_loss, acc

# Validation function
def validate(dataloader, model, loss_fn):
    model.eval()
    running_loss = 0.0
    correct_sum = 0
    iter_cnt = 0
    with torch.no_grad():
        for imgs, targets in tqdm(dataloader, desc='Validation', leave=False):
            imgs, targets = imgs.to(device), targets.to(device)
            outputs = model(imgs)
            loss = loss_fn(outputs, targets)

            running_loss += loss.item()
            _, predicts = torch.max(outputs, 1)
            correct_sum += torch.eq(predicts, targets).sum().item()
            iter_cnt += 1

    acc = correct_sum / len(dataloader.dataset)
    avg_loss = running_loss / iter_cnt
    return avg_loss, acc

# Paths to data and labels
train_csv = '/content/drive/MyDrive/Colab/FERVT/data/train_label.csv'
test_csv = '/content/drive/MyDrive/Colab/FERVT/data/test_label.csv'
val_csv = '/content/drive/MyDrive/Colab/FERVT/data/valid_label.csv'

# Hyperparameters
batch_size = 32
lr = 0.001
epochs = 20

# Data transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Datasets and dataloaders
train_dataset = FERPlusDataset(train_csv, phase='train', transform=transform)
val_dataset = FERPlusDataset(val_csv, phase='val', transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# Model, loss, optimizer, and scheduler
model = FERVT(device).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda x: ((1 + math.cos(x * math.pi / epochs)) / 2) * (1 - 0.2) + 0.2)

# Training and validation
best_acc = 0
for epoch in range(1, epochs + 1):
    print(f"Epoch {epoch}/{epochs}")
    train_loss, train_acc = train(train_loader, model, criterion, optimizer)
    val_loss, val_acc = validate(val_loader, model, criterion)
    scheduler.step()

    print(f"[Epoch {epoch}] Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"[Epoch {epoch}] Validation Loss: {val_loss:.4f}, Validation Acc: {val_acc:.4f}")

    if val_acc > best_acc:
        best_acc = val_acc
        print(f"Best model saved with accuracy: {best_acc:.4f}")

print("Training Complete.")

Unique labels in dataset after filtering: [0 1 2 3 4 5 6 7]
Unique labels in dataset after filtering: [0 1 2 3 4 5 6 7]
Epoch 1/20




[Epoch 1] Train Loss: 1.5622, Train Acc: 0.3601
[Epoch 1] Validation Loss: 1.4819, Validation Acc: 0.3663
Best model saved with accuracy: 0.3663
Epoch 2/20




[Epoch 2] Train Loss: 1.2467, Train Acc: 0.5605
[Epoch 2] Validation Loss: 1.1326, Validation Acc: 0.6227
Best model saved with accuracy: 0.6227
Epoch 3/20




[Epoch 3] Train Loss: 1.0163, Train Acc: 0.6518
[Epoch 3] Validation Loss: 1.0287, Validation Acc: 0.6443
Best model saved with accuracy: 0.6443
Epoch 4/20




[Epoch 4] Train Loss: 0.8892, Train Acc: 0.6938
[Epoch 4] Validation Loss: 0.8514, Validation Acc: 0.7133
Best model saved with accuracy: 0.7133
Epoch 5/20




[Epoch 5] Train Loss: 0.8229, Train Acc: 0.7161
[Epoch 5] Validation Loss: 0.8703, Validation Acc: 0.6951
Epoch 6/20




[Epoch 6] Train Loss: 0.7687, Train Acc: 0.7300
[Epoch 6] Validation Loss: 0.8210, Validation Acc: 0.7186
Best model saved with accuracy: 0.7186
Epoch 7/20




[Epoch 7] Train Loss: 0.7373, Train Acc: 0.7457
[Epoch 7] Validation Loss: 0.8094, Validation Acc: 0.7274
Best model saved with accuracy: 0.7274
Epoch 8/20




[Epoch 8] Train Loss: 0.6772, Train Acc: 0.7642
[Epoch 8] Validation Loss: 0.7015, Validation Acc: 0.7587
Best model saved with accuracy: 0.7587
Epoch 9/20




[Epoch 9] Train Loss: 0.6235, Train Acc: 0.7847
[Epoch 9] Validation Loss: 0.7082, Validation Acc: 0.7606
Best model saved with accuracy: 0.7606
Epoch 10/20




[Epoch 10] Train Loss: 0.5753, Train Acc: 0.8052
[Epoch 10] Validation Loss: 0.7024, Validation Acc: 0.7662
Best model saved with accuracy: 0.7662
Epoch 11/20




[Epoch 11] Train Loss: 0.5265, Train Acc: 0.8191
[Epoch 11] Validation Loss: 0.6497, Validation Acc: 0.7869
Best model saved with accuracy: 0.7869
Epoch 12/20




[Epoch 12] Train Loss: 0.5009, Train Acc: 0.8287
[Epoch 12] Validation Loss: 0.6905, Validation Acc: 0.7734
Epoch 13/20




[Epoch 13] Train Loss: 0.4666, Train Acc: 0.8378
[Epoch 13] Validation Loss: 0.6633, Validation Acc: 0.7816
Epoch 14/20




[Epoch 14] Train Loss: 0.4177, Train Acc: 0.8567
[Epoch 14] Validation Loss: 0.6162, Validation Acc: 0.8032
Best model saved with accuracy: 0.8032
Epoch 15/20




[Epoch 15] Train Loss: 0.3562, Train Acc: 0.8811
[Epoch 15] Validation Loss: 0.6195, Validation Acc: 0.8029
Epoch 16/20




[Epoch 16] Train Loss: 0.3198, Train Acc: 0.8926
[Epoch 16] Validation Loss: 0.6676, Validation Acc: 0.7919
Epoch 17/20




[Epoch 17] Train Loss: 0.2754, Train Acc: 0.9078
[Epoch 17] Validation Loss: 0.6767, Validation Acc: 0.7985
Epoch 18/20




[Epoch 18] Train Loss: 0.2416, Train Acc: 0.9201
[Epoch 18] Validation Loss: 0.6873, Validation Acc: 0.8023
Epoch 19/20




[Epoch 19] Train Loss: 0.2139, Train Acc: 0.9306
[Epoch 19] Validation Loss: 0.7090, Validation Acc: 0.8019
Epoch 20/20


                                                             

[Epoch 20] Train Loss: 0.1840, Train Acc: 0.9417
[Epoch 20] Validation Loss: 0.8058, Validation Acc: 0.7863
Training Complete.




## Test

In [22]:
def evaluate_model(model, test_loader, device):
    model.eval()  # Set model to evaluation mode
    correct = 0
    total = 0
    y_true = []
    y_pred = []

    print("Starting evaluation...")
    with torch.no_grad():  # Disable gradient computation
        for i, (imgs, targets) in enumerate(test_loader):

            imgs = imgs.to(device)
            targets = targets.to(device)

            # Forward pass
            outputs= model(imgs)  # Provide dummy latent input
            _, predictions = torch.max(outputs, 1)  # Get predicted class

            # Collect results
            correct += (predictions == targets).sum().item()
            total += targets.size(0)

            y_true.extend(targets.cpu().numpy())
            y_pred.extend(predictions.cpu().numpy())

    # Calculate accuracy
    accuracy = correct / total if total > 0 else 0  # Prevent division by zero
    print(f"Test Accuracy: {accuracy}")
    print(f"Test Accuracy: {accuracy}")
    print(f"Test Accuracy: {accuracy}")

    return accuracy
data_transforms_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
])

test_dataset = FERPlusDataset(test_csv, phase='test', transform=transform)
print('Test set size:', test_dataset.__len__())

test_loader = torch.utils.data.DataLoader(test_dataset,
                                            batch_size=batch_size,
                                            num_workers=4,
                                            shuffle=False,
                                            pin_memory=True)

# Evaluate model
acc = evaluate_model(model, test_loader, device)
print(f"Final Test Accuracy: {acc:.4f}")


Unique labels in dataset after filtering: [0 1 2 3 4 5 6 7]
Test set size: 3136
Starting evaluation...
Test Accuracy: 0.7799744897959183
Test Accuracy: 0.7799744897959183
Test Accuracy: 0.7799744897959183
Final Test Accuracy: 0.7800
