## Model

In [17]:
from torch import nn
from torch.nn import functional as F
import torch
import torch.nn.init as init
from torchvision import models
import time

class OURS(nn.Module):
    def __init__(self, num_class=7,num_head=4, pretrained=True):
        super(OURS, self).__init__()
        resnet = models.resnet18(pretrained)
        
        if pretrained:
            # resnet 18 og model
            checkpoint = torch.load('./latent-ofer/models/resnet18_msceleb.pth')
            resnet.load_state_dict(checkpoint['state_dict'],strict=True)

            

        self.features = nn.Sequential(*list(resnet.children())[:-2])
        self.num_head = num_head
        for i in range(num_head):
            setattr(self,"cat_head%d" %i, CrossAttentionHead())
        self.sig = nn.Sigmoid()
        self.hh_layer = nn.Linear(39936, 10000)
        self.hh_layer2 = nn.Linear(10000, 5000)
        self.hh_batch1 = nn.BatchNorm1d(5000)
        self.hh_layer3 = nn.Linear(5000, 1024)
        self.hh_layer4 = nn.Linear(1024, 256)
        self.batch_norm1 = nn.BatchNorm1d(256)


        self.fc = nn.Linear(512+256, 256)
        self.fc2 = nn.Linear(256, 128)
        self.batch_norm = nn.BatchNorm1d(128)
        self.fc4 = nn.Linear(128, num_class)
        self.bn = nn.BatchNorm1d(num_class)


    def forward(self, x, latent):

        x = self.features(x)
        heads = []
        for i in range(self.num_head):
            heads.append(getattr(self,"cat_head%d" %i)(x))
        
        heads = torch.stack(heads).permute([1,0,2])
        if heads.size(1)>1:
            heads = F.log_softmax(heads,dim=1)
            heads = heads.sum(dim=1)

        latent = self.hh_layer(latent)
        latent = self.hh_layer2(latent)
        latent = self.hh_batch1(latent)
        latent = self.hh_layer3(latent)
        latent = self.hh_layer4(latent)
        latent = self.batch_norm1(latent)
        
        out = torch.cat([heads, latent], dim=1)
        out = self.fc(out)
        out = self.fc2(out)
        out = self.batch_norm(out)
        out = self.fc4(out)
        out = self.bn(out)
   
        return out, x, heads

class CrossAttentionHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.sa = SpatialAttention()
        self.ca = ChannelAttention()
        self.init_weights()


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)         
    def forward(self, x):
        ca = self.ca(x)
        sa = self.sa(ca)

        return sa


class SpatialAttention(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1x1 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=1),
            nn.BatchNorm2d(256),
        )
        self.conv_3x3 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3,padding=1),
            nn.BatchNorm2d(512),
        )
        self.conv_1x3 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=(1,3),padding=(0,1)),
            nn.BatchNorm2d(512),
        )
        self.conv_3x1 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=(3,1),padding=(1,0)),
            nn.BatchNorm2d(512),
        )
        self.relu = nn.ReLU()
        self.gap = nn.AdaptiveAvgPool2d(1)


    def forward(self, x):
        y = self.conv1x1(x)
        a = self.conv_3x3(y)
        b = self.conv_1x3(y)
        c = self.conv_3x1(y)
        
        y = self.relu(a + b + c)
        y = y.sum(dim=1,keepdim=True) 

        out = x*y
        out = self.gap(out)
        out = out.view(out.size(0), -1)
        return out 

class ChannelAttention(nn.Module):

    def __init__(self):
        super().__init__()
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.attention = nn.Sequential(
            nn.Linear(512, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 512),
            nn.Sigmoid()    
        )


    def forward(self, sa):
        sa2 = self.gap(sa)
        sa2 = sa2.view(sa2.size(0),-1)
        y = self.attention(sa2)
        y = y.unsqueeze(dim = -1)
        y = y.unsqueeze(dim = -1)

        out = sa * y
        return out

## Latent-OFER

In [2]:
import os
import warnings
from tqdm import tqdm

from PIL import Image
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.utils.data as data
from torchvision import transforms

from sklearn.metrics import balanced_accuracy_score

def warn(*args, **kwargs):
    pass
warnings.warn = warn

class FERPlusDataset(data.Dataset):
    def __init__(self, data_csv, phase, transform=None):
        self.phase = phase
        self.transform = transform

        # Read the dataset CSV file
        self.data = pd.read_csv(data_csv)
        self.data.iloc[:, 2:12] = self.data.iloc[:, 2:12].replace(1, 0)
        # Get file paths and labels
        self.file_paths = self.data.iloc[:, 0].values
        self.counts = self.data.iloc[:, 2:12].values  # 감정 점수들

        # Apply constraints to filter valid samples
        self._apply_constraints()

        # Use argmax to determine the emotion class
        self.labels = np.argmax(self.counts, axis=1)

        # Debugging: Check label range
        print("Unique labels in dataset after filtering:", np.unique(self.labels))

    def _apply_constraints(self):
        # Constraint : 'unknown-face' 또는 'not-face' 레이블 제거
        max_counts = self.counts.max(axis=1)
        counts_eq_max = (self.counts == max_counts[:, None])
        constraint1_violation = counts_eq_max[:, [8, 9]].any(axis=1)

        # Constraint : 1인 라벨 0으로 만들기


        # Constraint : 최대 투표 수를 가진 레이블이 3개 초과 제거
        num_max_labels = counts_eq_max.sum(axis=1)
        constraint2_violation = num_max_labels > 3

        # Constraint : 최대 투표 수가 전체 투표 수의 절반 이하인 경우 제거
        total_votes = self.counts.sum(axis=1)
        constraint3_violation = max_counts <= (total_votes / 2)

        # Combine constraints
        valid_samples = ~(
            constraint1_violation | constraint2_violation | constraint3_violation
        )

        # Apply valid samples filter
        self.file_paths = self.file_paths[valid_samples]
        self.counts = self.counts[valid_samples]

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        if self.phase == 'train':
            path = '/data/FER2013/FER2013Train/' + self.file_paths[idx]
        elif self.phase == 'val':
            path = '/data/FER2013/FER2013Valid/' + self.file_paths[idx]
        elif self.phase == 'test':
            path = '/data/FER2013/FER2013Test/' + self.file_paths[idx]
        image = Image.open(path).convert('RGB')
        label = self.labels[idx]

        if self.transform is not None:
            image = self.transform(image)

        return image, label



# Define variables
train_csv = '/data/FER2013/train_label.csv'
val_csv = '/data/FER2013/valid_label.csv'
test_csv = '/data/FER2013/test_label.csv'
batch_size = 128
lr = 0.1
workers = 4
epochs = 60

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = True

model = OURS(num_class=8 ,num_head=4, pretrained=True)  # YOUR MODEL CLASS SHOULD BE DEFINED ELSEWHERE
model.to(device)

data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomApply([
            transforms.RandomRotation(20),
            transforms.RandomCrop(224, padding=32)
        ], p=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225]),
    transforms.RandomErasing(scale=(0.02, 0.25)),
])

train_dataset = FERPlusDataset(train_csv, phase='train', transform=data_transforms)
print('Whole train set size:', train_dataset.__len__())

train_loader = torch.utils.data.DataLoader(train_dataset,
                                            batch_size=batch_size,
                                            num_workers=workers,
                                            shuffle=True,
                                            pin_memory=True)

data_transforms_val = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
])

val_dataset = FERPlusDataset(val_csv, phase='val', transform=data_transforms_val)
print('Validation set size:', val_dataset.__len__())

val_loader = torch.utils.data.DataLoader(val_dataset,
                                            batch_size=batch_size,
                                            num_workers=workers,
                                            shuffle=False,
                                            pin_memory=True)

criterion_cls = torch.nn.CrossEntropyLoss()

params = list(model.parameters())
optimizer = torch.optim.SGD(params, lr=lr, weight_decay=1e-4, momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

best_acc = 0
for epoch in tqdm(range(1, epochs + 1)):
    running_loss = 0.0
    correct_sum = 0
    iter_cnt = 0
    model.train()

    for imgs, targets in train_loader:
        iter_cnt += 1
        optimizer.zero_grad()

        imgs = imgs.to(device)
        targets = targets.to(device)
        out, _, _ = model(imgs, torch.zeros(imgs.size(0), 39936).to(device)) # Pass None for latent

        loss = criterion_cls(out, targets)

        loss.backward()
        optimizer.step()

        running_loss += loss
        _, predicts = torch.max(out, 1)
        correct_num = torch.eq(predicts, targets).sum()
        correct_sum += correct_num

    acc = correct_sum.float() / float(train_dataset.__len__())
    running_loss = running_loss / iter_cnt
    tqdm.write('[Epoch %d] Training accuracy: %.4f. Loss: %.3f. LR %.6f' % (epoch, acc, running_loss, optimizer.param_groups[0]['lr']))

    with torch.no_grad():
        running_loss = 0.0
        iter_cnt = 0
        bingo_cnt = 0
        sample_cnt = 0
        baccs = []

        y_true = []
        y_pred = []

        model.eval()
        for imgs, targets in val_loader:
            imgs = imgs.to(device)
            targets = targets.to(device)

            out, _, _ = model(imgs, torch.zeros(imgs.size(0), 39936).to(device))   # Pass None for latent
            loss = criterion_cls(out, targets)
            running_loss += loss
            iter_cnt += 1
            _, predicts = torch.max(out, 1)
            correct_num = torch.eq(predicts, targets)
            bingo_cnt += correct_num.sum().cpu()
            sample_cnt += out.size(0)
            y_true.append(targets.cpu().numpy())
            y_pred.append(predicts.cpu().numpy())

            baccs.append(balanced_accuracy_score(targets.cpu().numpy(), predicts.cpu().numpy()))
        running_loss = running_loss / iter_cnt
        scheduler.step()

        acc = bingo_cnt.float() / float(sample_cnt)
        acc = np.around(acc.numpy(), 4)
        best_acc = max(acc, best_acc)

        y_true = np.concatenate(y_true)
        y_pred = np.concatenate(y_pred)

        bacc = np.around(np.mean(baccs), 4)
        tqdm.write("[Epoch %d] Validation accuracy:%.4f. bacc:%.4f. Loss:%.3f" % (epoch, acc, bacc, running_loss))
        tqdm.write("best_acc:" + str(best_acc))

        #if acc > 0.86 and acc == best_acc:
        #    torch.save({'iter': epoch,
        #                'model_state_dict': model.state_dict(),
        #                 'optimizer_state_dict': optimizer.state_dict()},
        #                os.path.join('checkpoints', "ferplus_epoch" + str(epoch) + "_acc" + str(acc) + "_bacc" + str(bacc) + ".pth"))
        #    tqdm.write('Model saved.')

Unique labels in dataset after filtering: [0 1 2 3 4 5 6 7]
Whole train set size: 25045
Unique labels in dataset after filtering: [0 1 2 3 4 5 6 7]
Validation set size: 3191


  0%|          | 0/60 [00:22<?, ?it/s]

[Epoch 1] Training accuracy: 0.7271. Loss: 0.781. LR 0.100000


  2%|▏         | 1/60 [00:24<24:13, 24.63s/it]

[Epoch 1] Validation accuracy:0.7922. bacc:0.6051. Loss:0.588
best_acc:0.7922


  2%|▏         | 1/60 [00:45<24:13, 24.63s/it]

[Epoch 2] Training accuracy: 0.8178. Loss: 0.520. LR 0.100000


  3%|▎         | 2/60 [00:47<22:43, 23.51s/it]

[Epoch 2] Validation accuracy:0.8358. bacc:0.7079. Loss:0.471
best_acc:0.8358


  3%|▎         | 2/60 [01:07<22:43, 23.51s/it]

[Epoch 3] Training accuracy: 0.8412. Loss: 0.451. LR 0.100000


  5%|▌         | 3/60 [01:09<21:50, 22.98s/it]

[Epoch 3] Validation accuracy:0.8377. bacc:0.6602. Loss:0.462
best_acc:0.8377


  5%|▌         | 3/60 [01:29<21:50, 22.98s/it]

[Epoch 4] Training accuracy: 0.8567. Loss: 0.410. LR 0.100000


  7%|▋         | 4/60 [01:31<21:09, 22.66s/it]

[Epoch 4] Validation accuracy:0.8565. bacc:0.6831. Loss:0.419
best_acc:0.8565


  7%|▋         | 4/60 [01:52<21:09, 22.66s/it]

[Epoch 5] Training accuracy: 0.8691. Loss: 0.373. LR 0.100000


  8%|▊         | 5/60 [01:54<20:40, 22.56s/it]

[Epoch 5] Validation accuracy:0.8276. bacc:0.6476. Loss:0.455
best_acc:0.8565


  8%|▊         | 5/60 [02:14<20:40, 22.56s/it]

[Epoch 6] Training accuracy: 0.8747. Loss: 0.360. LR 0.100000


 10%|█         | 6/60 [02:16<20:12, 22.46s/it]

[Epoch 6] Validation accuracy:0.8233. bacc:0.6521. Loss:0.498
best_acc:0.8565


 10%|█         | 6/60 [02:36<20:12, 22.46s/it]

[Epoch 7] Training accuracy: 0.8840. Loss: 0.331. LR 0.100000


 12%|█▏        | 7/60 [02:38<19:46, 22.38s/it]

[Epoch 7] Validation accuracy:0.8496. bacc:0.7302. Loss:0.425
best_acc:0.8565


 12%|█▏        | 7/60 [02:59<19:46, 22.38s/it]

[Epoch 8] Training accuracy: 0.8844. Loss: 0.324. LR 0.100000


 13%|█▎        | 8/60 [03:01<19:22, 22.36s/it]

[Epoch 8] Validation accuracy:0.8242. bacc:0.7073. Loss:0.511
best_acc:0.8565


 13%|█▎        | 8/60 [03:21<19:22, 22.36s/it]

[Epoch 9] Training accuracy: 0.8918. Loss: 0.304. LR 0.100000


 15%|█▌        | 9/60 [03:23<18:57, 22.31s/it]

[Epoch 9] Validation accuracy:0.8618. bacc:0.7193. Loss:0.405
best_acc:0.8618


 15%|█▌        | 9/60 [03:43<18:57, 22.31s/it]

[Epoch 10] Training accuracy: 0.8957. Loss: 0.295. LR 0.100000


 17%|█▋        | 10/60 [03:45<18:33, 22.27s/it]

[Epoch 10] Validation accuracy:0.8546. bacc:0.7338. Loss:0.437
best_acc:0.8618


 17%|█▋        | 10/60 [04:05<18:33, 22.27s/it]

[Epoch 11] Training accuracy: 0.9251. Loss: 0.219. LR 0.010000


 18%|█▊        | 11/60 [04:07<18:11, 22.27s/it]

[Epoch 11] Validation accuracy:0.8772. bacc:0.7396. Loss:0.353
best_acc:0.8772


 18%|█▊        | 11/60 [04:28<18:11, 22.27s/it]

[Epoch 12] Training accuracy: 0.9395. Loss: 0.181. LR 0.010000


 20%|██        | 12/60 [04:29<17:48, 22.26s/it]

[Epoch 12] Validation accuracy:0.8765. bacc:0.7555. Loss:0.359
best_acc:0.8772


 20%|██        | 12/60 [04:50<17:48, 22.26s/it]

[Epoch 13] Training accuracy: 0.9438. Loss: 0.167. LR 0.010000


 22%|██▏       | 13/60 [04:52<17:25, 22.24s/it]

[Epoch 13] Validation accuracy:0.8825. bacc:0.7603. Loss:0.360
best_acc:0.8825


 22%|██▏       | 13/60 [05:12<17:25, 22.24s/it]

[Epoch 14] Training accuracy: 0.9454. Loss: 0.162. LR 0.010000


 23%|██▎       | 14/60 [05:14<17:02, 22.23s/it]

[Epoch 14] Validation accuracy:0.8781. bacc:0.7637. Loss:0.370
best_acc:0.8825


 23%|██▎       | 14/60 [05:34<17:02, 22.23s/it]

[Epoch 15] Training accuracy: 0.9504. Loss: 0.147. LR 0.010000


 25%|██▌       | 15/60 [05:36<16:41, 22.26s/it]

[Epoch 15] Validation accuracy:0.8790. bacc:0.7640. Loss:0.375
best_acc:0.8825


 25%|██▌       | 15/60 [05:56<16:41, 22.26s/it]

[Epoch 16] Training accuracy: 0.9522. Loss: 0.141. LR 0.010000


 27%|██▋       | 16/60 [05:58<16:19, 22.26s/it]

[Epoch 16] Validation accuracy:0.8784. bacc:0.7614. Loss:0.375
best_acc:0.8825


 27%|██▋       | 16/60 [06:19<16:19, 22.26s/it]

[Epoch 17] Training accuracy: 0.9550. Loss: 0.138. LR 0.010000


 28%|██▊       | 17/60 [06:21<15:57, 22.26s/it]

[Epoch 17] Validation accuracy:0.8743. bacc:0.7703. Loss:0.378
best_acc:0.8825


 28%|██▊       | 17/60 [06:41<15:57, 22.26s/it]

[Epoch 18] Training accuracy: 0.9564. Loss: 0.130. LR 0.010000


 30%|███       | 18/60 [06:43<15:35, 22.28s/it]

[Epoch 18] Validation accuracy:0.8778. bacc:0.7722. Loss:0.383
best_acc:0.8825


 30%|███       | 18/60 [07:03<15:35, 22.28s/it]

[Epoch 19] Training accuracy: 0.9577. Loss: 0.124. LR 0.010000


 32%|███▏      | 19/60 [07:05<15:13, 22.29s/it]

[Epoch 19] Validation accuracy:0.8778. bacc:0.7737. Loss:0.377
best_acc:0.8825


 32%|███▏      | 19/60 [07:26<15:13, 22.29s/it]

[Epoch 20] Training accuracy: 0.9616. Loss: 0.120. LR 0.010000


 33%|███▎      | 20/60 [07:28<14:50, 22.26s/it]

[Epoch 20] Validation accuracy:0.8762. bacc:0.7662. Loss:0.386
best_acc:0.8825


 33%|███▎      | 20/60 [07:48<14:50, 22.26s/it]

[Epoch 21] Training accuracy: 0.9619. Loss: 0.114. LR 0.001000


 35%|███▌      | 21/60 [07:50<14:27, 22.25s/it]

[Epoch 21] Validation accuracy:0.8800. bacc:0.7669. Loss:0.379
best_acc:0.8825


 35%|███▌      | 21/60 [08:10<14:27, 22.25s/it]

[Epoch 22] Training accuracy: 0.9635. Loss: 0.111. LR 0.001000


 37%|███▋      | 22/60 [08:12<14:05, 22.24s/it]

[Epoch 22] Validation accuracy:0.8809. bacc:0.7684. Loss:0.379
best_acc:0.8825


 37%|███▋      | 22/60 [08:32<14:05, 22.24s/it]

[Epoch 23] Training accuracy: 0.9616. Loss: 0.116. LR 0.001000


 38%|███▊      | 23/60 [08:34<13:41, 22.21s/it]

[Epoch 23] Validation accuracy:0.8787. bacc:0.7660. Loss:0.378
best_acc:0.8825


 38%|███▊      | 23/60 [08:54<13:41, 22.21s/it]

[Epoch 24] Training accuracy: 0.9596. Loss: 0.117. LR 0.001000


 40%|████      | 24/60 [08:56<13:19, 22.20s/it]

[Epoch 24] Validation accuracy:0.8809. bacc:0.7669. Loss:0.381
best_acc:0.8825


 40%|████      | 24/60 [09:17<13:19, 22.20s/it]

[Epoch 25] Training accuracy: 0.9636. Loss: 0.110. LR 0.001000


 42%|████▏     | 25/60 [09:19<12:57, 22.21s/it]

[Epoch 25] Validation accuracy:0.8790. bacc:0.7688. Loss:0.386
best_acc:0.8825


 42%|████▏     | 25/60 [09:39<12:57, 22.21s/it]

[Epoch 26] Training accuracy: 0.9631. Loss: 0.112. LR 0.001000


 43%|████▎     | 26/60 [09:41<12:34, 22.20s/it]

[Epoch 26] Validation accuracy:0.8768. bacc:0.7612. Loss:0.389
best_acc:0.8825


 43%|████▎     | 26/60 [10:01<12:34, 22.20s/it]

[Epoch 27] Training accuracy: 0.9618. Loss: 0.112. LR 0.001000


 45%|████▌     | 27/60 [10:03<12:13, 22.21s/it]

[Epoch 27] Validation accuracy:0.8793. bacc:0.7641. Loss:0.381
best_acc:0.8825


 45%|████▌     | 27/60 [10:23<12:13, 22.21s/it]

[Epoch 28] Training accuracy: 0.9612. Loss: 0.113. LR 0.001000


 47%|████▋     | 28/60 [10:25<11:50, 22.21s/it]

[Epoch 28] Validation accuracy:0.8775. bacc:0.7601. Loss:0.387
best_acc:0.8825


 47%|████▋     | 28/60 [10:45<11:50, 22.21s/it]

[Epoch 29] Training accuracy: 0.9655. Loss: 0.107. LR 0.001000


 48%|████▊     | 29/60 [10:47<11:28, 22.21s/it]

[Epoch 29] Validation accuracy:0.8781. bacc:0.7576. Loss:0.380
best_acc:0.8825


 48%|████▊     | 29/60 [11:08<11:28, 22.21s/it]

[Epoch 30] Training accuracy: 0.9629. Loss: 0.112. LR 0.001000


 50%|█████     | 30/60 [11:10<11:05, 22.20s/it]

[Epoch 30] Validation accuracy:0.8812. bacc:0.7681. Loss:0.382
best_acc:0.8825


 50%|█████     | 30/60 [11:30<11:05, 22.20s/it]

[Epoch 31] Training accuracy: 0.9632. Loss: 0.112. LR 0.000100


 52%|█████▏    | 31/60 [11:32<10:44, 22.22s/it]

[Epoch 31] Validation accuracy:0.8781. bacc:0.7587. Loss:0.384
best_acc:0.8825


 52%|█████▏    | 31/60 [11:52<10:44, 22.22s/it]

[Epoch 32] Training accuracy: 0.9642. Loss: 0.107. LR 0.000100


 53%|█████▎    | 32/60 [11:54<10:21, 22.21s/it]

[Epoch 32] Validation accuracy:0.8784. bacc:0.7587. Loss:0.378
best_acc:0.8825


 53%|█████▎    | 32/60 [12:14<10:21, 22.21s/it]

[Epoch 33] Training accuracy: 0.9633. Loss: 0.108. LR 0.000100


 55%|█████▌    | 33/60 [12:16<10:00, 22.23s/it]

[Epoch 33] Validation accuracy:0.8768. bacc:0.7670. Loss:0.385
best_acc:0.8825


 55%|█████▌    | 33/60 [12:37<10:00, 22.23s/it]

[Epoch 34] Training accuracy: 0.9646. Loss: 0.109. LR 0.000100


 57%|█████▋    | 34/60 [12:38<09:37, 22.23s/it]

[Epoch 34] Validation accuracy:0.8778. bacc:0.7619. Loss:0.389
best_acc:0.8825


 57%|█████▋    | 34/60 [12:59<09:37, 22.23s/it]

[Epoch 35] Training accuracy: 0.9633. Loss: 0.111. LR 0.000100


 58%|█████▊    | 35/60 [13:01<09:15, 22.22s/it]

[Epoch 35] Validation accuracy:0.8784. bacc:0.7688. Loss:0.385
best_acc:0.8825


 58%|█████▊    | 35/60 [13:21<09:15, 22.22s/it]

[Epoch 36] Training accuracy: 0.9628. Loss: 0.111. LR 0.000100


 60%|██████    | 36/60 [13:23<08:53, 22.23s/it]

[Epoch 36] Validation accuracy:0.8812. bacc:0.7656. Loss:0.379
best_acc:0.8825


 60%|██████    | 36/60 [13:43<08:53, 22.23s/it]

[Epoch 37] Training accuracy: 0.9655. Loss: 0.107. LR 0.000100


 62%|██████▏   | 37/60 [13:45<08:31, 22.26s/it]

[Epoch 37] Validation accuracy:0.8787. bacc:0.7602. Loss:0.378
best_acc:0.8825


 62%|██████▏   | 37/60 [14:06<08:31, 22.26s/it]

[Epoch 38] Training accuracy: 0.9622. Loss: 0.109. LR 0.000100


 63%|██████▎   | 38/60 [14:08<08:09, 22.27s/it]

[Epoch 38] Validation accuracy:0.8803. bacc:0.7646. Loss:0.381
best_acc:0.8825


 63%|██████▎   | 38/60 [14:28<08:09, 22.27s/it]

[Epoch 39] Training accuracy: 0.9634. Loss: 0.110. LR 0.000100


 65%|██████▌   | 39/60 [14:30<07:47, 22.27s/it]

[Epoch 39] Validation accuracy:0.8790. bacc:0.7669. Loss:0.385
best_acc:0.8825


 65%|██████▌   | 39/60 [14:50<07:47, 22.27s/it]

[Epoch 40] Training accuracy: 0.9640. Loss: 0.106. LR 0.000100


 67%|██████▋   | 40/60 [14:52<07:25, 22.26s/it]

[Epoch 40] Validation accuracy:0.8809. bacc:0.7544. Loss:0.377
best_acc:0.8825


 67%|██████▋   | 40/60 [15:12<07:25, 22.26s/it]

[Epoch 41] Training accuracy: 0.9631. Loss: 0.109. LR 0.000010


 68%|██████▊   | 41/60 [15:14<07:02, 22.26s/it]

[Epoch 41] Validation accuracy:0.8800. bacc:0.7608. Loss:0.380
best_acc:0.8825


 68%|██████▊   | 41/60 [15:35<07:02, 22.26s/it]

[Epoch 42] Training accuracy: 0.9641. Loss: 0.106. LR 0.000010


 70%|███████   | 42/60 [15:36<06:39, 22.22s/it]

[Epoch 42] Validation accuracy:0.8797. bacc:0.7664. Loss:0.383
best_acc:0.8825


 70%|███████   | 42/60 [15:57<06:39, 22.22s/it]

[Epoch 43] Training accuracy: 0.9621. Loss: 0.110. LR 0.000010


 72%|███████▏  | 43/60 [15:59<06:17, 22.21s/it]

[Epoch 43] Validation accuracy:0.8800. bacc:0.7619. Loss:0.381
best_acc:0.8825


 72%|███████▏  | 43/60 [16:19<06:17, 22.21s/it]

[Epoch 44] Training accuracy: 0.9629. Loss: 0.111. LR 0.000010


 73%|███████▎  | 44/60 [16:21<05:55, 22.20s/it]

[Epoch 44] Validation accuracy:0.8784. bacc:0.7624. Loss:0.383
best_acc:0.8825


 73%|███████▎  | 44/60 [16:41<05:55, 22.20s/it]

[Epoch 45] Training accuracy: 0.9640. Loss: 0.108. LR 0.000010


 75%|███████▌  | 45/60 [16:43<05:33, 22.21s/it]

[Epoch 45] Validation accuracy:0.8759. bacc:0.7594. Loss:0.384
best_acc:0.8825


 75%|███████▌  | 45/60 [17:03<05:33, 22.21s/it]

[Epoch 46] Training accuracy: 0.9625. Loss: 0.107. LR 0.000010


 77%|███████▋  | 46/60 [17:05<05:11, 22.22s/it]

[Epoch 46] Validation accuracy:0.8790. bacc:0.7623. Loss:0.381
best_acc:0.8825


 77%|███████▋  | 46/60 [17:26<05:11, 22.22s/it]

[Epoch 47] Training accuracy: 0.9653. Loss: 0.102. LR 0.000010


 78%|███████▊  | 47/60 [17:28<04:49, 22.23s/it]

[Epoch 47] Validation accuracy:0.8775. bacc:0.7595. Loss:0.382
best_acc:0.8825


 78%|███████▊  | 47/60 [17:48<04:49, 22.23s/it]

[Epoch 48] Training accuracy: 0.9645. Loss: 0.107. LR 0.000010


 80%|████████  | 48/60 [17:50<04:26, 22.22s/it]

[Epoch 48] Validation accuracy:0.8825. bacc:0.7674. Loss:0.378
best_acc:0.8825


 80%|████████  | 48/60 [18:10<04:26, 22.22s/it]

[Epoch 49] Training accuracy: 0.9642. Loss: 0.110. LR 0.000010


 82%|████████▏ | 49/60 [18:12<04:04, 22.21s/it]

[Epoch 49] Validation accuracy:0.8790. bacc:0.7589. Loss:0.380
best_acc:0.8825


 82%|████████▏ | 49/60 [18:32<04:04, 22.21s/it]

[Epoch 50] Training accuracy: 0.9629. Loss: 0.108. LR 0.000010


 83%|████████▎ | 50/60 [18:34<03:42, 22.20s/it]

[Epoch 50] Validation accuracy:0.8809. bacc:0.7604. Loss:0.379
best_acc:0.8825


 83%|████████▎ | 50/60 [18:54<03:42, 22.20s/it]

[Epoch 51] Training accuracy: 0.9644. Loss: 0.108. LR 0.000001


 85%|████████▌ | 51/60 [18:56<03:19, 22.20s/it]

[Epoch 51] Validation accuracy:0.8787. bacc:0.7591. Loss:0.382
best_acc:0.8825


 85%|████████▌ | 51/60 [19:17<03:19, 22.20s/it]

[Epoch 52] Training accuracy: 0.9648. Loss: 0.106. LR 0.000001


 87%|████████▋ | 52/60 [19:19<02:57, 22.20s/it]

[Epoch 52] Validation accuracy:0.8759. bacc:0.7617. Loss:0.389
best_acc:0.8825


 87%|████████▋ | 52/60 [19:39<02:57, 22.20s/it]

[Epoch 53] Training accuracy: 0.9649. Loss: 0.106. LR 0.000001


 88%|████████▊ | 53/60 [19:41<02:35, 22.20s/it]

[Epoch 53] Validation accuracy:0.8784. bacc:0.7617. Loss:0.384
best_acc:0.8825


 88%|████████▊ | 53/60 [20:01<02:35, 22.20s/it]

[Epoch 54] Training accuracy: 0.9645. Loss: 0.106. LR 0.000001


 90%|█████████ | 54/60 [20:03<02:13, 22.21s/it]

[Epoch 54] Validation accuracy:0.8787. bacc:0.7624. Loss:0.385
best_acc:0.8825


 90%|█████████ | 54/60 [20:23<02:13, 22.21s/it]

[Epoch 55] Training accuracy: 0.9643. Loss: 0.107. LR 0.000001


 92%|█████████▏| 55/60 [20:25<01:51, 22.20s/it]

[Epoch 55] Validation accuracy:0.8784. bacc:0.7599. Loss:0.385
best_acc:0.8825


 92%|█████████▏| 55/60 [20:45<01:51, 22.20s/it]

[Epoch 56] Training accuracy: 0.9632. Loss: 0.107. LR 0.000001


 93%|█████████▎| 56/60 [20:47<01:28, 22.21s/it]

[Epoch 56] Validation accuracy:0.8790. bacc:0.7631. Loss:0.383
best_acc:0.8825


 93%|█████████▎| 56/60 [21:08<01:28, 22.21s/it]

[Epoch 57] Training accuracy: 0.9641. Loss: 0.108. LR 0.000001


 95%|█████████▌| 57/60 [21:10<01:06, 22.19s/it]

[Epoch 57] Validation accuracy:0.8778. bacc:0.7609. Loss:0.386
best_acc:0.8825


 95%|█████████▌| 57/60 [21:30<01:06, 22.19s/it]

[Epoch 58] Training accuracy: 0.9650. Loss: 0.106. LR 0.000001


 97%|█████████▋| 58/60 [21:32<00:44, 22.20s/it]

[Epoch 58] Validation accuracy:0.8793. bacc:0.7602. Loss:0.382
best_acc:0.8825


 97%|█████████▋| 58/60 [21:52<00:44, 22.20s/it]

[Epoch 59] Training accuracy: 0.9648. Loss: 0.105. LR 0.000001


 98%|█████████▊| 59/60 [21:54<00:22, 22.21s/it]

[Epoch 59] Validation accuracy:0.8793. bacc:0.7575. Loss:0.379
best_acc:0.8825


 98%|█████████▊| 59/60 [22:14<00:22, 22.21s/it]

[Epoch 60] Training accuracy: 0.9653. Loss: 0.104. LR 0.000001


100%|██████████| 60/60 [22:16<00:00, 22.28s/it]

[Epoch 60] Validation accuracy:0.8762. bacc:0.7587. Loss:0.386
best_acc:0.8825





In [15]:
def evaluate_model(model, test_loader, device):
    model.eval()  # Set model to evaluation mode
    correct = 0
    total = 0
    y_true = []
    y_pred = []

    print("Starting evaluation...")
    with torch.no_grad():  # Disable gradient computation
        for i, (imgs, targets) in enumerate(test_loader):

            imgs = imgs.to(device)
            targets = targets.to(device)

            # Forward pass
            outputs, _, _ = model(imgs, torch.zeros(imgs.size(0), 39936).to(device))  # Provide dummy latent input
            _, predictions = torch.max(outputs, 1)  # Get predicted class

            # Collect results
            correct += (predictions == targets).sum().item()
            total += targets.size(0)

            y_true.extend(targets.cpu().numpy())
            y_pred.extend(predictions.cpu().numpy())

    # Calculate accuracy
    accuracy = correct / total if total > 0 else 0  # Prevent division by zero
    print(f"Test Accuracy: {accuracy}")  
    print(f"Test Accuracy: {accuracy}")
    print(f"Test Accuracy: {accuracy}")

    return accuracy
data_transforms_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
])

test_dataset = FERPlusDataset(test_csv, phase='test', transform=data_transforms_test)
print('Test set size:', test_dataset.__len__())

test_loader = torch.utils.data.DataLoader(test_dataset,
                                            batch_size=batch_size,
                                            num_workers=workers,
                                            shuffle=False,
                                            pin_memory=True)

# Evaluate model
acc = evaluate_model(model, test_loader, device)
print(f"Final Test Accuracy: {acc:.4f}")


Unique labels in dataset after filtering: [0 1 2 3 4 5 6 7]
Test set size: 3136
Starting evaluation...


Test Accuracy: 0.8788265306122449
Test Accuracy: 0.8788265306122449
Test Accuracy: 0.8788265306122449
Final Test Accuracy: 0.8788


## Latent-OFER + dlib

In [24]:
import dlib
import pandas as pd
import numpy as np
from PIL import Image
import cv2
from torch.utils.data import Dataset

class FERPlusDataset(Dataset):
    def __init__(self, data_csv, phase, transform=None):
        self.phase = phase
        self.transform = transform

        # Read the dataset CSV file
        self.data = pd.read_csv(data_csv)
        self.data.iloc[:, 2:12] = self.data.iloc[:, 2:12].replace(1, 0)

        # Get file paths and labels
        self.file_paths = self.data.iloc[:, 0].values
        self.counts = self.data.iloc[:, 2:12].values  # 감정 점수들

        # Apply constraints to filter valid samples
        self._apply_constraints()

        # Use argmax to determine the emotion class
        self.labels = np.argmax(self.counts, axis=1)

        # Debugging: Check label range
        print("Unique labels in dataset after filtering:", np.unique(self.labels))

        # Dlib face detector and predictor initialization
        self.detector = dlib.get_frontal_face_detector()
        self.predictor = dlib.shape_predictor('/root/FER2013/shape_predictor_68_face_landmarks.dat')  # Dlib 모델 파일 경로 필요

    def _apply_constraints(self):
        # Constraint : 'unknown-face' 또는 'not-face' 레이블 제거
        max_counts = self.counts.max(axis=1)
        counts_eq_max = (self.counts == max_counts[:, None])
        constraint1_violation = counts_eq_max[:, [8, 9]].any(axis=1)

        # Constraint : 최대 투표 수를 가진 레이블이 3개 초과 제거
        num_max_labels = counts_eq_max.sum(axis=1)
        constraint2_violation = num_max_labels > 3

        # Constraint : 최대 투표 수가 전체 투표 수의 절반 이하인 경우 제거
        total_votes = self.counts.sum(axis=1)
        constraint3_violation = max_counts <= (total_votes / 2)

        # Combine constraints
        valid_samples = ~(
            constraint1_violation | constraint2_violation | constraint3_violation
        )

        # Apply valid samples filter
        self.file_paths = self.file_paths[valid_samples]
        self.counts = self.counts[valid_samples]

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        if self.phase == 'train':
            path = '/data/FER2013/FER2013Train/' + self.file_paths[idx]
        elif self.phase == 'val':
            path = '/data/FER2013/FER2013Valid/' + self.file_paths[idx]
        elif self.phase == 'test':
            path = '/data/FER2013/FER2013Test/' + self.file_paths[idx]
        
        # Open image
        image = cv2.imread(path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Align face using Dlib
        image = self._align_face(image)

        # Convert to PIL image for further processing
        image = Image.fromarray(image)

        label = self.labels[idx]

        if self.transform is not None:
            image = self.transform(image)

        return image, label

    def _align_face(self, image):
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        faces = self.detector(gray)

        if len(faces) == 0:
            return image  # 얼굴이 감지되지 않으면 원본 이미지를 반환

        for face in faces:
            landmarks = self.predictor(gray, face)

            # 좌우 눈의 중심 좌표 추출
            left_eye = (landmarks.part(36).x, landmarks.part(36).y)
            right_eye = (landmarks.part(45).x, landmarks.part(45).y)

            # 두 눈의 중심 계산
            eye_center = ((left_eye[0] + right_eye[0]) // 2, (left_eye[1] + right_eye[1]) // 2)

            # 눈 사이의 기울기 계산
            delta_x = right_eye[0] - left_eye[0]
            delta_y = right_eye[1] - left_eye[1]
            angle = np.degrees(np.arctan2(delta_y, delta_x))

            # 회전 행렬 계산
            rot_matrix = cv2.getRotationMatrix2D(eye_center, angle, 1.0)

            # 이미지 회전 및 정렬
            aligned_face = cv2.warpAffine(image, rot_matrix, (image.shape[1], image.shape[0]))

            return aligned_face  # 첫 번째 얼굴만 처리

        return image


In [25]:


# Define variables
train_csv = '/data/FER2013/train_label.csv'
val_csv = '/data/FER2013/valid_label.csv'
test_csv = '/data/FER2013/test_label.csv'
batch_size = 128
lr = 0.1
workers = 4
epochs = 60

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = True

model = OURS(num_class=8 ,num_head=4, pretrained=True)  # YOUR MODEL CLASS SHOULD BE DEFINED ELSEWHERE
model.to(device)

data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomApply([
            transforms.RandomRotation(20),
            transforms.RandomCrop(224, padding=32)
        ], p=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225]),
    transforms.RandomErasing(scale=(0.02, 0.25)),
])

train_dataset = FERPlusDataset(train_csv, phase='train', transform=data_transforms)
print('Whole train set size:', train_dataset.__len__())

train_loader = torch.utils.data.DataLoader(train_dataset,
                                            batch_size=batch_size,
                                            num_workers=workers,
                                            shuffle=True,
                                            pin_memory=True)

data_transforms_val = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
])

val_dataset = FERPlusDataset(val_csv, phase='val', transform=data_transforms_val)
print('Validation set size:', val_dataset.__len__())

val_loader = torch.utils.data.DataLoader(val_dataset,
                                            batch_size=batch_size,
                                            num_workers=workers,
                                            shuffle=False,
                                            pin_memory=True)

criterion_cls = torch.nn.CrossEntropyLoss()

params = list(model.parameters())
optimizer = torch.optim.SGD(params, lr=lr, weight_decay=1e-4, momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

best_acc = 0
for epoch in tqdm(range(1, epochs + 1)):
    running_loss = 0.0
    correct_sum = 0
    iter_cnt = 0
    model.train()

    for imgs, targets in train_loader:
        iter_cnt += 1
        optimizer.zero_grad()

        imgs = imgs.to(device)
        targets = targets.to(device)
        out, _, _ = model(imgs, torch.zeros(imgs.size(0), 39936).to(device)) # Pass None for latent

        loss = criterion_cls(out, targets)

        loss.backward()
        optimizer.step()

        running_loss += loss
        _, predicts = torch.max(out, 1)
        correct_num = torch.eq(predicts, targets).sum()
        correct_sum += correct_num

    acc = correct_sum.float() / float(train_dataset.__len__())
    running_loss = running_loss / iter_cnt
    tqdm.write('[Epoch %d] Training accuracy: %.4f. Loss: %.3f. LR %.6f' % (epoch, acc, running_loss, optimizer.param_groups[0]['lr']))

    with torch.no_grad():
        running_loss = 0.0
        iter_cnt = 0
        bingo_cnt = 0
        sample_cnt = 0
        baccs = []

        y_true = []
        y_pred = []

        model.eval()
        for imgs, targets in val_loader:
            imgs = imgs.to(device)
            targets = targets.to(device)

            out, _, _ = model(imgs, torch.zeros(imgs.size(0), 39936).to(device))   # Pass None for latent
            loss = criterion_cls(out, targets)
            running_loss += loss
            iter_cnt += 1
            _, predicts = torch.max(out, 1)
            correct_num = torch.eq(predicts, targets)
            bingo_cnt += correct_num.sum().cpu()
            sample_cnt += out.size(0)
            y_true.append(targets.cpu().numpy())
            y_pred.append(predicts.cpu().numpy())

            baccs.append(balanced_accuracy_score(targets.cpu().numpy(), predicts.cpu().numpy()))
        running_loss = running_loss / iter_cnt
        scheduler.step()

        acc = bingo_cnt.float() / float(sample_cnt)
        acc = np.around(acc.numpy(), 4)
        best_acc = max(acc, best_acc)

        y_true = np.concatenate(y_true)
        y_pred = np.concatenate(y_pred)

        bacc = np.around(np.mean(baccs), 4)
        tqdm.write("[Epoch %d] Validation accuracy:%.4f. bacc:%.4f. Loss:%.3f" % (epoch, acc, bacc, running_loss))
        tqdm.write("best_acc:" + str(best_acc))

Unique labels in dataset after filtering: [0 1 2 3 4 5 6 7]
Whole train set size: 25045
Unique labels in dataset after filtering: [0 1 2 3 4 5 6 7]
Validation set size: 3191


  0%|          | 0/60 [00:20<?, ?it/s]

[Epoch 1] Training accuracy: 0.7320. Loss: 0.769. LR 0.100000


  2%|▏         | 1/60 [00:22<22:15, 22.63s/it]

[Epoch 1] Validation accuracy:0.8270. bacc:0.6169. Loss:0.494
best_acc:0.827


  2%|▏         | 1/60 [00:43<22:15, 22.63s/it]

[Epoch 2] Training accuracy: 0.8208. Loss: 0.508. LR 0.100000


  3%|▎         | 2/60 [00:45<22:04, 22.84s/it]

[Epoch 2] Validation accuracy:0.8468. bacc:0.6736. Loss:0.442
best_acc:0.8468


  3%|▎         | 2/60 [01:06<22:04, 22.84s/it]

[Epoch 3] Training accuracy: 0.8430. Loss: 0.445. LR 0.100000


  5%|▌         | 3/60 [01:08<21:33, 22.69s/it]

[Epoch 3] Validation accuracy:0.8524. bacc:0.6888. Loss:0.418
best_acc:0.8524


  5%|▌         | 3/60 [01:28<21:33, 22.69s/it]

[Epoch 4] Training accuracy: 0.8568. Loss: 0.406. LR 0.100000


  7%|▋         | 4/60 [01:30<21:02, 22.55s/it]

[Epoch 4] Validation accuracy:0.8471. bacc:0.6719. Loss:0.437
best_acc:0.8524


  7%|▋         | 4/60 [01:50<21:02, 22.55s/it]

[Epoch 5] Training accuracy: 0.8683. Loss: 0.376. LR 0.100000


  8%|▊         | 5/60 [01:52<20:36, 22.48s/it]

[Epoch 5] Validation accuracy:0.8320. bacc:0.6906. Loss:0.445
best_acc:0.8524


  8%|▊         | 5/60 [02:13<20:36, 22.48s/it]

[Epoch 6] Training accuracy: 0.8786. Loss: 0.344. LR 0.100000


 10%|█         | 6/60 [02:15<20:12, 22.45s/it]

[Epoch 6] Validation accuracy:0.8327. bacc:0.6444. Loss:0.487
best_acc:0.8524


 10%|█         | 6/60 [02:35<20:12, 22.45s/it]

[Epoch 7] Training accuracy: 0.8852. Loss: 0.326. LR 0.100000


 12%|█▏        | 7/60 [02:37<19:48, 22.42s/it]

[Epoch 7] Validation accuracy:0.8511. bacc:0.6742. Loss:0.437
best_acc:0.8524


 12%|█▏        | 7/60 [02:57<19:48, 22.42s/it]

[Epoch 8] Training accuracy: 0.8868. Loss: 0.316. LR 0.100000


 13%|█▎        | 8/60 [02:59<19:25, 22.42s/it]

[Epoch 8] Validation accuracy:0.8602. bacc:0.7171. Loss:0.414
best_acc:0.8602


 13%|█▎        | 8/60 [03:20<19:25, 22.42s/it]

[Epoch 9] Training accuracy: 0.8955. Loss: 0.293. LR 0.100000


 15%|█▌        | 9/60 [03:22<19:01, 22.39s/it]

[Epoch 9] Validation accuracy:0.8518. bacc:0.7378. Loss:0.432
best_acc:0.8602


 15%|█▌        | 9/60 [03:42<19:01, 22.39s/it]

[Epoch 10] Training accuracy: 0.8970. Loss: 0.288. LR 0.100000


 17%|█▋        | 10/60 [03:44<18:38, 22.36s/it]

[Epoch 10] Validation accuracy:0.8499. bacc:0.7622. Loss:0.415
best_acc:0.8602


 17%|█▋        | 10/60 [04:04<18:38, 22.36s/it]

[Epoch 11] Training accuracy: 0.9274. Loss: 0.208. LR 0.010000


 18%|█▊        | 11/60 [04:06<18:14, 22.33s/it]

[Epoch 11] Validation accuracy:0.8765. bacc:0.7801. Loss:0.351
best_acc:0.8765


 18%|█▊        | 11/60 [04:27<18:14, 22.33s/it]

[Epoch 12] Training accuracy: 0.9389. Loss: 0.175. LR 0.010000


 20%|██        | 12/60 [04:29<17:50, 22.30s/it]

[Epoch 12] Validation accuracy:0.8781. bacc:0.7764. Loss:0.359
best_acc:0.8781


 20%|██        | 12/60 [04:49<17:50, 22.30s/it]

[Epoch 13] Training accuracy: 0.9463. Loss: 0.157. LR 0.010000


 22%|██▏       | 13/60 [04:51<17:26, 22.27s/it]

[Epoch 13] Validation accuracy:0.8837. bacc:0.7757. Loss:0.344
best_acc:0.8837


 22%|██▏       | 13/60 [05:11<17:26, 22.27s/it]

[Epoch 14] Training accuracy: 0.9483. Loss: 0.153. LR 0.010000


 23%|██▎       | 14/60 [05:13<17:05, 22.29s/it]

[Epoch 14] Validation accuracy:0.8775. bacc:0.7664. Loss:0.352
best_acc:0.8837


 23%|██▎       | 14/60 [05:33<17:05, 22.29s/it]

[Epoch 15] Training accuracy: 0.9525. Loss: 0.143. LR 0.010000


 25%|██▌       | 15/60 [05:35<16:43, 22.31s/it]

[Epoch 15] Validation accuracy:0.8828. bacc:0.7721. Loss:0.356
best_acc:0.8837


 25%|██▌       | 15/60 [05:56<16:43, 22.31s/it]

[Epoch 16] Training accuracy: 0.9548. Loss: 0.136. LR 0.010000


 27%|██▋       | 16/60 [05:58<16:19, 22.26s/it]

[Epoch 16] Validation accuracy:0.8806. bacc:0.7665. Loss:0.366
best_acc:0.8837


 27%|██▋       | 16/60 [06:18<16:19, 22.26s/it]

[Epoch 17] Training accuracy: 0.9573. Loss: 0.129. LR 0.010000


 28%|██▊       | 17/60 [06:20<15:56, 22.25s/it]

[Epoch 17] Validation accuracy:0.8800. bacc:0.7649. Loss:0.370
best_acc:0.8837


 28%|██▊       | 17/60 [06:40<15:56, 22.25s/it]

[Epoch 18] Training accuracy: 0.9605. Loss: 0.120. LR 0.010000


 30%|███       | 18/60 [06:42<15:35, 22.26s/it]

[Epoch 18] Validation accuracy:0.8822. bacc:0.7614. Loss:0.379
best_acc:0.8837


 30%|███       | 18/60 [07:02<15:35, 22.26s/it]

[Epoch 19] Training accuracy: 0.9576. Loss: 0.122. LR 0.010000


 32%|███▏      | 19/60 [07:04<15:12, 22.25s/it]

[Epoch 19] Validation accuracy:0.8803. bacc:0.7762. Loss:0.370
best_acc:0.8837


 32%|███▏      | 19/60 [07:25<15:12, 22.25s/it]

[Epoch 20] Training accuracy: 0.9605. Loss: 0.116. LR 0.010000


 33%|███▎      | 20/60 [07:27<14:50, 22.26s/it]

[Epoch 20] Validation accuracy:0.8787. bacc:0.7632. Loss:0.379
best_acc:0.8837


 33%|███▎      | 20/60 [07:47<14:50, 22.26s/it]

[Epoch 21] Training accuracy: 0.9615. Loss: 0.116. LR 0.001000


 35%|███▌      | 21/60 [07:49<14:28, 22.26s/it]

[Epoch 21] Validation accuracy:0.8815. bacc:0.7625. Loss:0.374
best_acc:0.8837


 35%|███▌      | 21/60 [08:09<14:28, 22.26s/it]

[Epoch 22] Training accuracy: 0.9648. Loss: 0.107. LR 0.001000


 37%|███▋      | 22/60 [08:11<14:05, 22.26s/it]

[Epoch 22] Validation accuracy:0.8800. bacc:0.7545. Loss:0.377
best_acc:0.8837


 37%|███▋      | 22/60 [08:31<14:05, 22.26s/it]

[Epoch 23] Training accuracy: 0.9601. Loss: 0.112. LR 0.001000


 38%|███▊      | 23/60 [08:33<13:43, 22.26s/it]

[Epoch 23] Validation accuracy:0.8812. bacc:0.7709. Loss:0.379
best_acc:0.8837


 38%|███▊      | 23/60 [08:54<13:43, 22.26s/it]

[Epoch 24] Training accuracy: 0.9638. Loss: 0.110. LR 0.001000


 40%|████      | 24/60 [08:56<13:21, 22.27s/it]

[Epoch 24] Validation accuracy:0.8815. bacc:0.7673. Loss:0.380
best_acc:0.8837


 40%|████      | 24/60 [09:16<13:21, 22.27s/it]

[Epoch 25] Training accuracy: 0.9629. Loss: 0.112. LR 0.001000


 42%|████▏     | 25/60 [09:18<12:59, 22.27s/it]

[Epoch 25] Validation accuracy:0.8834. bacc:0.7728. Loss:0.375
best_acc:0.8837


 42%|████▏     | 25/60 [09:38<12:59, 22.27s/it]

[Epoch 26] Training accuracy: 0.9644. Loss: 0.107. LR 0.001000


 43%|████▎     | 26/60 [09:40<12:36, 22.26s/it]

[Epoch 26] Validation accuracy:0.8840. bacc:0.7792. Loss:0.377
best_acc:0.884


 43%|████▎     | 26/60 [10:00<12:36, 22.26s/it]

[Epoch 27] Training accuracy: 0.9648. Loss: 0.108. LR 0.001000


 45%|████▌     | 27/60 [10:03<12:14, 22.27s/it]

[Epoch 27] Validation accuracy:0.8840. bacc:0.7697. Loss:0.376
best_acc:0.884


 45%|████▌     | 27/60 [10:23<12:14, 22.27s/it]

[Epoch 28] Training accuracy: 0.9652. Loss: 0.105. LR 0.001000


 47%|████▋     | 28/60 [10:25<11:52, 22.27s/it]

[Epoch 28] Validation accuracy:0.8831. bacc:0.7773. Loss:0.378
best_acc:0.884


 47%|████▋     | 28/60 [10:45<11:52, 22.27s/it]

[Epoch 29] Training accuracy: 0.9650. Loss: 0.107. LR 0.001000


 48%|████▊     | 29/60 [10:47<11:31, 22.29s/it]

[Epoch 29] Validation accuracy:0.8815. bacc:0.7726. Loss:0.380
best_acc:0.884


 48%|████▊     | 29/60 [11:07<11:31, 22.29s/it]

[Epoch 30] Training accuracy: 0.9661. Loss: 0.104. LR 0.001000


 50%|█████     | 30/60 [11:09<11:08, 22.28s/it]

[Epoch 30] Validation accuracy:0.8803. bacc:0.7769. Loss:0.382
best_acc:0.884


 50%|█████     | 30/60 [11:30<11:08, 22.28s/it]

[Epoch 31] Training accuracy: 0.9645. Loss: 0.108. LR 0.000100


 52%|█████▏    | 31/60 [11:32<10:45, 22.27s/it]

[Epoch 31] Validation accuracy:0.8840. bacc:0.7764. Loss:0.376
best_acc:0.884


 52%|█████▏    | 31/60 [11:52<10:45, 22.27s/it]

[Epoch 32] Training accuracy: 0.9651. Loss: 0.107. LR 0.000100


 53%|█████▎    | 32/60 [11:54<10:23, 22.28s/it]

[Epoch 32] Validation accuracy:0.8825. bacc:0.7764. Loss:0.374
best_acc:0.884


 53%|█████▎    | 32/60 [12:14<10:23, 22.28s/it]

[Epoch 33] Training accuracy: 0.9644. Loss: 0.108. LR 0.000100


 55%|█████▌    | 33/60 [12:16<10:00, 22.25s/it]

[Epoch 33] Validation accuracy:0.8831. bacc:0.7754. Loss:0.375
best_acc:0.884


 55%|█████▌    | 33/60 [12:36<10:00, 22.25s/it]

[Epoch 34] Training accuracy: 0.9641. Loss: 0.107. LR 0.000100


 57%|█████▋    | 34/60 [12:38<09:38, 22.24s/it]

[Epoch 34] Validation accuracy:0.8812. bacc:0.7718. Loss:0.377
best_acc:0.884


 57%|█████▋    | 34/60 [12:59<09:38, 22.24s/it]

[Epoch 35] Training accuracy: 0.9639. Loss: 0.106. LR 0.000100


 58%|█████▊    | 35/60 [13:01<09:15, 22.23s/it]

[Epoch 35] Validation accuracy:0.8831. bacc:0.7693. Loss:0.377
best_acc:0.884


 58%|█████▊    | 35/60 [13:21<09:15, 22.23s/it]

[Epoch 36] Training accuracy: 0.9638. Loss: 0.107. LR 0.000100


 60%|██████    | 36/60 [13:23<08:54, 22.26s/it]

[Epoch 36] Validation accuracy:0.8828. bacc:0.7804. Loss:0.378
best_acc:0.884


 60%|██████    | 36/60 [13:43<08:54, 22.26s/it]

[Epoch 37] Training accuracy: 0.9648. Loss: 0.104. LR 0.000100


 62%|██████▏   | 37/60 [13:45<08:32, 22.29s/it]

[Epoch 37] Validation accuracy:0.8800. bacc:0.7651. Loss:0.380
best_acc:0.884


 62%|██████▏   | 37/60 [14:05<08:32, 22.29s/it]

[Epoch 38] Training accuracy: 0.9637. Loss: 0.109. LR 0.000100


 63%|██████▎   | 38/60 [14:07<08:10, 22.28s/it]

[Epoch 38] Validation accuracy:0.8803. bacc:0.7696. Loss:0.381
best_acc:0.884


 63%|██████▎   | 38/60 [14:28<08:10, 22.28s/it]

[Epoch 39] Training accuracy: 0.9647. Loss: 0.104. LR 0.000100


 65%|██████▌   | 39/60 [14:30<07:47, 22.28s/it]

[Epoch 39] Validation accuracy:0.8812. bacc:0.7654. Loss:0.377
best_acc:0.884


 65%|██████▌   | 39/60 [14:50<07:47, 22.28s/it]

[Epoch 40] Training accuracy: 0.9653. Loss: 0.103. LR 0.000100


 67%|██████▋   | 40/60 [14:52<07:25, 22.29s/it]

[Epoch 40] Validation accuracy:0.8831. bacc:0.7759. Loss:0.377
best_acc:0.884


 67%|██████▋   | 40/60 [15:12<07:25, 22.29s/it]

[Epoch 41] Training accuracy: 0.9642. Loss: 0.103. LR 0.000010


 68%|██████▊   | 41/60 [15:14<07:03, 22.27s/it]

[Epoch 41] Validation accuracy:0.8803. bacc:0.7683. Loss:0.380
best_acc:0.884


 68%|██████▊   | 41/60 [15:35<07:03, 22.27s/it]

[Epoch 42] Training accuracy: 0.9661. Loss: 0.102. LR 0.000010


 70%|███████   | 42/60 [15:37<06:40, 22.28s/it]

[Epoch 42] Validation accuracy:0.8812. bacc:0.7688. Loss:0.377
best_acc:0.884


 70%|███████   | 42/60 [15:57<06:40, 22.28s/it]

[Epoch 43] Training accuracy: 0.9640. Loss: 0.105. LR 0.000010


 72%|███████▏  | 43/60 [15:59<06:18, 22.29s/it]

[Epoch 43] Validation accuracy:0.8812. bacc:0.7662. Loss:0.383
best_acc:0.884


 72%|███████▏  | 43/60 [16:19<06:18, 22.29s/it]

[Epoch 44] Training accuracy: 0.9635. Loss: 0.106. LR 0.000010


 73%|███████▎  | 44/60 [16:21<05:56, 22.29s/it]

[Epoch 44] Validation accuracy:0.8806. bacc:0.7672. Loss:0.375
best_acc:0.884


 73%|███████▎  | 44/60 [16:41<05:56, 22.29s/it]

[Epoch 45] Training accuracy: 0.9654. Loss: 0.103. LR 0.000010


 75%|███████▌  | 45/60 [16:43<05:34, 22.27s/it]

[Epoch 45] Validation accuracy:0.8815. bacc:0.7659. Loss:0.376
best_acc:0.884


 75%|███████▌  | 45/60 [17:04<05:34, 22.27s/it]

[Epoch 46] Training accuracy: 0.9646. Loss: 0.106. LR 0.000010


 77%|███████▋  | 46/60 [17:06<05:11, 22.28s/it]

[Epoch 46] Validation accuracy:0.8815. bacc:0.7687. Loss:0.380
best_acc:0.884


 77%|███████▋  | 46/60 [17:26<05:11, 22.28s/it]

[Epoch 47] Training accuracy: 0.9651. Loss: 0.105. LR 0.000010


 78%|███████▊  | 47/60 [17:28<04:49, 22.28s/it]

[Epoch 47] Validation accuracy:0.8831. bacc:0.7818. Loss:0.378
best_acc:0.884


 78%|███████▊  | 47/60 [17:48<04:49, 22.28s/it]

[Epoch 48] Training accuracy: 0.9644. Loss: 0.106. LR 0.000010


 80%|████████  | 48/60 [17:50<04:27, 22.27s/it]

[Epoch 48] Validation accuracy:0.8815. bacc:0.7761. Loss:0.379
best_acc:0.884


 80%|████████  | 48/60 [18:10<04:27, 22.27s/it]

[Epoch 49] Training accuracy: 0.9651. Loss: 0.105. LR 0.000010


 82%|████████▏ | 49/60 [18:13<04:04, 22.27s/it]

[Epoch 49] Validation accuracy:0.8803. bacc:0.7716. Loss:0.385
best_acc:0.884


 82%|████████▏ | 49/60 [18:33<04:04, 22.27s/it]

[Epoch 50] Training accuracy: 0.9637. Loss: 0.107. LR 0.000010


 83%|████████▎ | 50/60 [18:35<03:42, 22.28s/it]

[Epoch 50] Validation accuracy:0.8819. bacc:0.7757. Loss:0.381
best_acc:0.884


 83%|████████▎ | 50/60 [18:55<03:42, 22.28s/it]

[Epoch 51] Training accuracy: 0.9644. Loss: 0.107. LR 0.000001


 85%|████████▌ | 51/60 [18:57<03:20, 22.29s/it]

[Epoch 51] Validation accuracy:0.8806. bacc:0.7700. Loss:0.380
best_acc:0.884


 85%|████████▌ | 51/60 [19:17<03:20, 22.29s/it]

[Epoch 52] Training accuracy: 0.9640. Loss: 0.107. LR 0.000001


 87%|████████▋ | 52/60 [19:19<02:58, 22.28s/it]

[Epoch 52] Validation accuracy:0.8809. bacc:0.7685. Loss:0.383
best_acc:0.884


 87%|████████▋ | 52/60 [19:40<02:58, 22.28s/it]

[Epoch 53] Training accuracy: 0.9639. Loss: 0.108. LR 0.000001


 88%|████████▊ | 53/60 [19:42<02:35, 22.26s/it]

[Epoch 53] Validation accuracy:0.8847. bacc:0.7700. Loss:0.374
best_acc:0.8847


 88%|████████▊ | 53/60 [20:02<02:35, 22.26s/it]

[Epoch 54] Training accuracy: 0.9615. Loss: 0.112. LR 0.000001


 90%|█████████ | 54/60 [20:04<02:13, 22.27s/it]

[Epoch 54] Validation accuracy:0.8809. bacc:0.7648. Loss:0.379
best_acc:0.8847


 90%|█████████ | 54/60 [20:24<02:13, 22.27s/it]

[Epoch 55] Training accuracy: 0.9671. Loss: 0.098. LR 0.000001


 92%|█████████▏| 55/60 [20:26<01:51, 22.26s/it]

[Epoch 55] Validation accuracy:0.8806. bacc:0.7690. Loss:0.377
best_acc:0.8847


 92%|█████████▏| 55/60 [20:46<01:51, 22.26s/it]

[Epoch 56] Training accuracy: 0.9652. Loss: 0.106. LR 0.000001


 93%|█████████▎| 56/60 [20:48<01:29, 22.26s/it]

[Epoch 56] Validation accuracy:0.8847. bacc:0.7774. Loss:0.381
best_acc:0.8847


 93%|█████████▎| 56/60 [21:09<01:29, 22.26s/it]

[Epoch 57] Training accuracy: 0.9650. Loss: 0.106. LR 0.000001


 95%|█████████▌| 57/60 [21:11<01:06, 22.27s/it]

[Epoch 57] Validation accuracy:0.8809. bacc:0.7673. Loss:0.378
best_acc:0.8847


 95%|█████████▌| 57/60 [21:31<01:06, 22.27s/it]

[Epoch 58] Training accuracy: 0.9645. Loss: 0.108. LR 0.000001


 97%|█████████▋| 58/60 [21:33<00:44, 22.26s/it]

[Epoch 58] Validation accuracy:0.8790. bacc:0.7751. Loss:0.384
best_acc:0.8847


 97%|█████████▋| 58/60 [21:53<00:44, 22.26s/it]

[Epoch 59] Training accuracy: 0.9666. Loss: 0.103. LR 0.000001


 98%|█████████▊| 59/60 [21:55<00:22, 22.29s/it]

[Epoch 59] Validation accuracy:0.8819. bacc:0.7769. Loss:0.382
best_acc:0.8847


 98%|█████████▊| 59/60 [22:16<00:22, 22.29s/it]

[Epoch 60] Training accuracy: 0.9630. Loss: 0.107. LR 0.000001


100%|██████████| 60/60 [22:18<00:00, 22.30s/it]

[Epoch 60] Validation accuracy:0.8812. bacc:0.7754. Loss:0.386
best_acc:0.8847





In [26]:
def evaluate_model(model, test_loader, device):
    model.eval()  # Set model to evaluation mode
    correct = 0
    total = 0
    y_true = []
    y_pred = []

    print("Starting evaluation...")
    with torch.no_grad():  # Disable gradient computation
        for i, (imgs, targets) in enumerate(test_loader):

            imgs = imgs.to(device)
            targets = targets.to(device)

            # Forward pass
            outputs, _, _ = model(imgs, torch.zeros(imgs.size(0), 39936).to(device))  # Provide dummy latent input
            _, predictions = torch.max(outputs, 1)  # Get predicted class

            # Collect results
            correct += (predictions == targets).sum().item()
            total += targets.size(0)

            y_true.extend(targets.cpu().numpy())
            y_pred.extend(predictions.cpu().numpy())

    # Calculate accuracy
    accuracy = correct / total if total > 0 else 0  # Prevent division by zero
    print(f"Test Accuracy: {accuracy}")  
    print(f"Test Accuracy: {accuracy}")
    print(f"Test Accuracy: {accuracy}")

    return accuracy
data_transforms_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
])

test_dataset = FERPlusDataset(test_csv, phase='test', transform=data_transforms_test)
print('Test set size:', test_dataset.__len__())

test_loader = torch.utils.data.DataLoader(test_dataset,
                                            batch_size=batch_size,
                                            num_workers=workers,
                                            shuffle=False,
                                            pin_memory=True)

# Evaluate model
acc = evaluate_model(model, test_loader, device)
print(f"Final Test Accuracy: {acc:.4f}")


Unique labels in dataset after filtering: [0 1 2 3 4 5 6 7]
Test set size: 3136
Starting evaluation...
Test Accuracy: 0.8807397959183674
Test Accuracy: 0.8807397959183674
Test Accuracy: 0.8807397959183674
Final Test Accuracy: 0.8807


# AffectNet Latent-OFER Dlib

In [1]:
from torch import nn
from torch.nn import functional as F
import torch
import torch.nn.init as init
from torchvision import models
import time
from torchvision import transforms
from sklearn.metrics import balanced_accuracy_score

class OURS(nn.Module):
    def __init__(self, num_class=7, num_head=4, affectnet_weights_path=None):
        super(OURS, self).__init__()
        resnet = models.resnet50(pretrained=False)

        if affectnet_weights_path:
            checkpoint = torch.load(affectnet_weights_path, map_location=torch.device('cpu'))
            if 'state_dict' in checkpoint:
                resnet.load_state_dict(checkpoint['state_dict'], strict=False)
            elif 'model_state_dict' in checkpoint:
                resnet.load_state_dict(checkpoint['model_state_dict'], strict=False)
            else:
                resnet.load_state_dict(checkpoint, strict=False)
            print("AffectNet weights loaded successfully.")

        self.features = nn.Sequential(*list(resnet.children())[:-2])
        self.num_head = num_head
        for i in range(num_head):
            setattr(self, f"cat_head{i}", CrossAttentionHead())

        self.sig = nn.Sigmoid()
        self.hh_layer = nn.Linear(39936, 1024)  # latent 크기 조정
        self.fc = nn.Linear(2048 + 1024, 256)  # heads + latent 결합 크기 반영
        self.fc2 = nn.Linear(256, 128)
        self.batch_norm = nn.BatchNorm1d(128)
        self.fc4 = nn.Linear(128, num_class)
        self.bn = nn.BatchNorm1d(num_class)

    def forward(self, x, latent):
        x = self.features(x)
        heads = []
        for i in range(self.num_head):
            heads.append(getattr(self, f"cat_head{i}")(x))

        # heads 결합
        heads = torch.stack(heads).permute([1, 0, 2])  # (batch_size, num_heads, features)
        heads = heads.sum(dim=1)  # (batch_size, features)

        # latent 변환
        latent = self.hh_layer(latent)

        # heads와 latent 결합
        out = torch.cat([heads, latent], dim=1)  # (batch_size, 2048 + 1024)
        out = self.fc(out)
        out = self.fc2(out)
        out = self.batch_norm(out)
        out = self.fc4(out)
        out = self.bn(out)

        return out, x, heads


class CrossAttentionHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.sa = SpatialAttention()
        self.ca = ChannelAttention()
        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        ca = self.ca(x)
        sa = self.sa(ca)
        return sa

class SpatialAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1x1 = nn.Sequential(
            nn.Conv2d(2048, 1024, kernel_size=1),  # 2048 -> 1024로 축소
            nn.BatchNorm2d(1024),
        )
        self.conv_3x3 = nn.Sequential(
            nn.Conv2d(1024, 2048, kernel_size=3, padding=1),  # 1024 -> 2048로 복구
            nn.BatchNorm2d(2048),
        )
        self.conv_1x3 = nn.Sequential(
            nn.Conv2d(1024, 2048, kernel_size=(1, 3), padding=(0, 1)),
            nn.BatchNorm2d(2048),
        )
        self.conv_3x1 = nn.Sequential(
            nn.Conv2d(1024, 2048, kernel_size=(3, 1), padding=(1, 0)),
            nn.BatchNorm2d(2048),
        )
        self.relu = nn.ReLU()
        self.gap = nn.AdaptiveAvgPool2d(1)

    def forward(self, x):
        y = self.conv1x1(x)
        a = self.conv_3x3(y)
        b = self.conv_1x3(y)
        c = self.conv_3x1(y)

        y = self.relu(a + b + c)
        y = y.sum(dim=1, keepdim=True)

        out = x * y
        out = self.gap(out)
        out = out.view(out.size(0), -1)
        return out


class ChannelAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.attention = nn.Sequential(
            nn.Linear(2048, 512),  # 2048 -> 512로 축소
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 2048),  # 512 -> 2048로 복구
            nn.Sigmoid()
        )

    def forward(self, sa):
        sa2 = self.gap(sa)  # (batch_size, 2048, 1, 1)
        sa2 = sa2.view(sa2.size(0), -1)  # (batch_size, 2048)
        y = self.attention(sa2)  # (batch_size, 2048)
        y = y.unsqueeze(dim=-1).unsqueeze(dim=-1)  # (batch_size, 2048, 1, 1)

        out = sa * y
        return out



In [2]:
import dlib
import pandas as pd
import numpy as np
from PIL import Image
import cv2
from torch.utils.data import Dataset
from tqdm import tqdm
class FERPlusDataset(Dataset):
    def __init__(self, data_csv, phase, transform=None):
        self.phase = phase
        self.transform = transform

        # Read the dataset CSV file
        self.data = pd.read_csv(data_csv)
        self.data.iloc[:, 2:12] = self.data.iloc[:, 2:12].replace(1, 0)

        # Get file paths and labels
        self.file_paths = self.data.iloc[:, 0].values
        self.counts = self.data.iloc[:, 2:12].values  # 감정 점수들

        # Apply constraints to filter valid samples
        self._apply_constraints()

        # Use argmax to determine the emotion class
        self.labels = np.argmax(self.counts, axis=1)

        # Debugging: Check label range
        print("Unique labels in dataset after filtering:", np.unique(self.labels))

        # Dlib face detector and predictor initialization
        self.detector = dlib.get_frontal_face_detector()
        self.predictor = dlib.shape_predictor('/root/FER2013/shape_predictor_68_face_landmarks.dat')  # Dlib 모델 파일 경로 필요

    def _apply_constraints(self):
        # Constraint : 'unknown-face' 또는 'not-face' 레이블 제거
        max_counts = self.counts.max(axis=1)
        counts_eq_max = (self.counts == max_counts[:, None])
        constraint1_violation = counts_eq_max[:, [8, 9]].any(axis=1)

        # Constraint : 최대 투표 수를 가진 레이블이 3개 초과 제거
        num_max_labels = counts_eq_max.sum(axis=1)
        constraint2_violation = num_max_labels > 3

        # Constraint : 최대 투표 수가 전체 투표 수의 절반 이하인 경우 제거
        total_votes = self.counts.sum(axis=1)
        constraint3_violation = max_counts <= (total_votes / 2)

        # Combine constraints
        valid_samples = ~(
            constraint1_violation | constraint2_violation | constraint3_violation
        )

        # Apply valid samples filter
        self.file_paths = self.file_paths[valid_samples]
        self.counts = self.counts[valid_samples]

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        if self.phase == 'train':
            path = '/data/FER2013/FER2013Train/' + self.file_paths[idx]
        elif self.phase == 'val':
            path = '/data/FER2013/FER2013Valid/' + self.file_paths[idx]
        elif self.phase == 'test':
            path = '/data/FER2013/FER2013Test/' + self.file_paths[idx]
        
        # Open image
        image = cv2.imread(path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Align face using Dlib
        image = self._align_face(image)

        # Convert to PIL image for further processing
        image = Image.fromarray(image)

        label = self.labels[idx]

        if self.transform is not None:
            image = self.transform(image)

        return image, label

    def _align_face(self, image):
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        faces = self.detector(gray)

        if len(faces) == 0:
            return image  # 얼굴이 감지되지 않으면 원본 이미지를 반환

        for face in faces:
            landmarks = self.predictor(gray, face)

            # 좌우 눈의 중심 좌표 추출
            left_eye = (landmarks.part(36).x, landmarks.part(36).y)
            right_eye = (landmarks.part(45).x, landmarks.part(45).y)

            # 두 눈의 중심 계산
            eye_center = ((left_eye[0] + right_eye[0]) // 2, (left_eye[1] + right_eye[1]) // 2)

            # 눈 사이의 기울기 계산
            delta_x = right_eye[0] - left_eye[0]
            delta_y = right_eye[1] - left_eye[1]
            angle = np.degrees(np.arctan2(delta_y, delta_x))

            # 회전 행렬 계산
            rot_matrix = cv2.getRotationMatrix2D(eye_center, angle, 1.0)

            # 이미지 회전 및 정렬
            aligned_face = cv2.warpAffine(image, rot_matrix, (image.shape[1], image.shape[0]))

            return aligned_face  # 첫 번째 얼굴만 처리

        return image


In [3]:


# Define variables
train_csv = '/data/FER2013/train_label.csv'
val_csv = '/data/FER2013/valid_label.csv'
test_csv = '/data/FER2013/test_label.csv'
batch_size = 128
lr = 0.1
workers = 4
epochs = 60

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = True

model = OURS(num_class=8 ,num_head=4, affectnet_weights_path='/root/FER2013/FER_static_ResNet50_AffectNet.pt')# YOUR MODEL CLASS SHOULD BE DEFINED ELSEWHERE
model.to(device)

data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomApply([
            transforms.RandomRotation(20),
            transforms.RandomCrop(224, padding=32)
        ], p=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225]),
    transforms.RandomErasing(scale=(0.02, 0.25)),
])

train_dataset = FERPlusDataset(train_csv, phase='train', transform=data_transforms)
print('Whole train set size:', train_dataset.__len__())

train_loader = torch.utils.data.DataLoader(train_dataset,
                                            batch_size=batch_size,
                                            num_workers=workers,
                                            shuffle=True,
                                            pin_memory=True)

data_transforms_val = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
])

val_dataset = FERPlusDataset(val_csv, phase='val', transform=data_transforms_val)
print('Validation set size:', val_dataset.__len__())

val_loader = torch.utils.data.DataLoader(val_dataset,
                                            batch_size=batch_size,
                                            num_workers=workers,
                                            shuffle=False,
                                            pin_memory=True)

criterion_cls = torch.nn.CrossEntropyLoss()

params = list(model.parameters())
optimizer = torch.optim.SGD(params, lr=lr, weight_decay=1e-4, momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

best_acc = 0
for epoch in tqdm(range(1, epochs + 1)):
    running_loss = 0.0
    correct_sum = 0
    iter_cnt = 0
    model.train()

    for imgs, targets in train_loader:
        iter_cnt += 1
        optimizer.zero_grad()

        imgs = imgs.to(device)
        targets = targets.to(device)
        out, _, _ = model(imgs, torch.zeros(imgs.size(0), 39936).to(device)) # Pass None for latent

        loss = criterion_cls(out, targets)

        loss.backward()
        optimizer.step()

        running_loss += loss
        _, predicts = torch.max(out, 1)
        correct_num = torch.eq(predicts, targets).sum()
        correct_sum += correct_num

    acc = correct_sum.float() / float(train_dataset.__len__())
    running_loss = running_loss / iter_cnt
    tqdm.write('[Epoch %d] Training accuracy: %.4f. Loss: %.3f. LR %.6f' % (epoch, acc, running_loss, optimizer.param_groups[0]['lr']))

    with torch.no_grad():
        running_loss = 0.0
        iter_cnt = 0
        bingo_cnt = 0
        sample_cnt = 0
        baccs = []

        y_true = []
        y_pred = []

        model.eval()
        for imgs, targets in val_loader:
            imgs = imgs.to(device)
            targets = targets.to(device)

            out, _, _ = model(imgs, torch.zeros(imgs.size(0), 39936).to(device))   # Pass None for latent
            loss = criterion_cls(out, targets)
            running_loss += loss
            iter_cnt += 1
            _, predicts = torch.max(out, 1)
            correct_num = torch.eq(predicts, targets)
            bingo_cnt += correct_num.sum().cpu()
            sample_cnt += out.size(0)
            y_true.append(targets.cpu().numpy())
            y_pred.append(predicts.cpu().numpy())

            baccs.append(balanced_accuracy_score(targets.cpu().numpy(), predicts.cpu().numpy()))
        running_loss = running_loss / iter_cnt
        scheduler.step()

        acc = bingo_cnt.float() / float(sample_cnt)
        acc = np.around(acc.numpy(), 4)
        best_acc = max(acc, best_acc)

        y_true = np.concatenate(y_true)
        y_pred = np.concatenate(y_pred)

        bacc = np.around(np.mean(baccs), 4)
        tqdm.write("[Epoch %d] Validation accuracy:%.4f. bacc:%.4f. Loss:%.3f" % (epoch, acc, bacc, running_loss))
        tqdm.write("best_acc:" + str(best_acc))

  checkpoint = torch.load(affectnet_weights_path, map_location=torch.device('cpu'))


AffectNet weights loaded successfully.
Unique labels in dataset after filtering: [0 1 2 3 4 5 6 7]
Whole train set size: 25045
Unique labels in dataset after filtering: [0 1 2 3 4 5 6 7]
Validation set size: 3191


  0%|          | 0/60 [00:50<?, ?it/s]

[Epoch 1] Training accuracy: 0.3567. Loss: 1.585. LR 0.100000


  2%|▏         | 1/60 [00:53<52:26, 53.34s/it]

[Epoch 1] Validation accuracy:0.3660. bacc:0.1962. Loss:1.593
best_acc:0.366


  2%|▏         | 1/60 [01:38<52:26, 53.34s/it]

[Epoch 2] Training accuracy: 0.4115. Loss: 1.480. LR 0.100000


  3%|▎         | 2/60 [01:40<48:14, 49.91s/it]

[Epoch 2] Validation accuracy:0.4704. bacc:0.2534. Loss:1.426
best_acc:0.4704


  3%|▎         | 2/60 [02:25<48:14, 49.91s/it]

[Epoch 3] Training accuracy: 0.5262. Loss: 1.290. LR 0.100000


  5%|▌         | 3/60 [02:28<46:15, 48.70s/it]

[Epoch 3] Validation accuracy:0.5964. bacc:0.3504. Loss:1.182
best_acc:0.5964


  5%|▌         | 3/60 [03:12<46:15, 48.70s/it]

[Epoch 4] Training accuracy: 0.6181. Loss: 1.064. LR 0.100000


  7%|▋         | 4/60 [03:15<44:55, 48.14s/it]

[Epoch 4] Validation accuracy:0.6409. bacc:0.3841. Loss:1.024
best_acc:0.6409


  7%|▋         | 4/60 [04:00<44:55, 48.14s/it]

[Epoch 5] Training accuracy: 0.6695. Loss: 0.930. LR 0.100000


  8%|▊         | 5/60 [04:02<43:52, 47.86s/it]

[Epoch 5] Validation accuracy:0.7017. bacc:0.4104. Loss:0.851
best_acc:0.7017


  8%|▊         | 5/60 [04:47<43:52, 47.86s/it]

[Epoch 6] Training accuracy: 0.6979. Loss: 0.856. LR 0.100000


 10%|█         | 6/60 [04:50<42:55, 47.69s/it]

[Epoch 6] Validation accuracy:0.6164. bacc:0.3506. Loss:1.111
best_acc:0.7017


 10%|█         | 6/60 [05:34<42:55, 47.69s/it]

[Epoch 7] Training accuracy: 0.7186. Loss: 0.792. LR 0.100000


 12%|█▏        | 7/60 [05:37<41:59, 47.54s/it]

[Epoch 7] Validation accuracy:0.6775. bacc:0.4489. Loss:0.905
best_acc:0.7017


 12%|█▏        | 7/60 [06:22<41:59, 47.54s/it]

[Epoch 8] Training accuracy: 0.7381. Loss: 0.727. LR 0.100000


 13%|█▎        | 8/60 [06:24<41:08, 47.46s/it]

[Epoch 8] Validation accuracy:0.7020. bacc:0.4345. Loss:0.923
best_acc:0.702


 13%|█▎        | 8/60 [07:09<41:08, 47.46s/it]

[Epoch 9] Training accuracy: 0.7482. Loss: 0.696. LR 0.100000


 15%|█▌        | 9/60 [07:11<40:18, 47.42s/it]

[Epoch 9] Validation accuracy:0.6243. bacc:0.4894. Loss:1.059
best_acc:0.702


 15%|█▌        | 9/60 [07:56<40:18, 47.42s/it]

[Epoch 10] Training accuracy: 0.7609. Loss: 0.666. LR 0.100000


 17%|█▋        | 10/60 [07:59<39:29, 47.39s/it]

[Epoch 10] Validation accuracy:0.7694. bacc:0.5266. Loss:0.687
best_acc:0.7694


 17%|█▋        | 10/60 [08:44<39:29, 47.39s/it]

[Epoch 11] Training accuracy: 0.8026. Loss: 0.552. LR 0.010000


 18%|█▊        | 11/60 [08:46<38:40, 47.36s/it]

[Epoch 11] Validation accuracy:0.8139. bacc:0.6180. Loss:0.512
best_acc:0.8139


 18%|█▊        | 11/60 [09:31<38:40, 47.36s/it]

[Epoch 12] Training accuracy: 0.8166. Loss: 0.517. LR 0.010000


 20%|██        | 12/60 [09:33<37:50, 47.30s/it]

[Epoch 12] Validation accuracy:0.8154. bacc:0.6291. Loss:0.512
best_acc:0.8154


 20%|██        | 12/60 [10:18<37:50, 47.30s/it]

[Epoch 13] Training accuracy: 0.8253. Loss: 0.491. LR 0.010000


 22%|██▏       | 13/60 [10:20<37:02, 47.29s/it]

[Epoch 13] Validation accuracy:0.8220. bacc:0.6466. Loss:0.488
best_acc:0.822


 22%|██▏       | 13/60 [11:05<37:02, 47.29s/it]

[Epoch 14] Training accuracy: 0.8282. Loss: 0.483. LR 0.010000


 23%|██▎       | 14/60 [11:08<36:14, 47.28s/it]

[Epoch 14] Validation accuracy:0.8189. bacc:0.6288. Loss:0.518
best_acc:0.822


 23%|██▎       | 14/60 [11:52<36:14, 47.28s/it]

[Epoch 15] Training accuracy: 0.8301. Loss: 0.475. LR 0.010000


 25%|██▌       | 15/60 [11:55<35:27, 47.27s/it]

[Epoch 15] Validation accuracy:0.8248. bacc:0.6459. Loss:0.477
best_acc:0.8248


 25%|██▌       | 15/60 [12:40<35:27, 47.27s/it]

[Epoch 16] Training accuracy: 0.8376. Loss: 0.458. LR 0.010000


 27%|██▋       | 16/60 [12:42<34:40, 47.29s/it]

[Epoch 16] Validation accuracy:0.8207. bacc:0.6467. Loss:0.489
best_acc:0.8248


 27%|██▋       | 16/60 [13:27<34:40, 47.29s/it]

[Epoch 17] Training accuracy: 0.8376. Loss: 0.453. LR 0.010000


 28%|██▊       | 17/60 [13:30<33:53, 47.30s/it]

[Epoch 17] Validation accuracy:0.8207. bacc:0.6456. Loss:0.503
best_acc:0.8248


 28%|██▊       | 17/60 [14:14<33:53, 47.30s/it]

[Epoch 18] Training accuracy: 0.8393. Loss: 0.444. LR 0.010000


 30%|███       | 18/60 [14:17<33:06, 47.30s/it]

[Epoch 18] Validation accuracy:0.8295. bacc:0.6488. Loss:0.471
best_acc:0.8295


 30%|███       | 18/60 [15:02<33:06, 47.30s/it]

[Epoch 19] Training accuracy: 0.8471. Loss: 0.431. LR 0.010000


 32%|███▏      | 19/60 [15:04<32:19, 47.30s/it]

[Epoch 19] Validation accuracy:0.8264. bacc:0.6404. Loss:0.479
best_acc:0.8295


 32%|███▏      | 19/60 [15:49<32:19, 47.30s/it]

[Epoch 20] Training accuracy: 0.8470. Loss: 0.424. LR 0.010000


 33%|███▎      | 20/60 [15:52<31:32, 47.31s/it]

[Epoch 20] Validation accuracy:0.8342. bacc:0.6657. Loss:0.471
best_acc:0.8342


 33%|███▎      | 20/60 [16:36<31:32, 47.31s/it]

[Epoch 21] Training accuracy: 0.8590. Loss: 0.396. LR 0.001000


 35%|███▌      | 21/60 [16:39<30:44, 47.30s/it]

[Epoch 21] Validation accuracy:0.8323. bacc:0.6681. Loss:0.454
best_acc:0.8342


 35%|███▌      | 21/60 [17:24<30:44, 47.30s/it]

[Epoch 22] Training accuracy: 0.8612. Loss: 0.393. LR 0.001000


 37%|███▋      | 22/60 [17:26<29:57, 47.29s/it]

[Epoch 22] Validation accuracy:0.8348. bacc:0.6775. Loss:0.453
best_acc:0.8348


 37%|███▋      | 22/60 [18:11<29:57, 47.29s/it]

[Epoch 23] Training accuracy: 0.8596. Loss: 0.394. LR 0.001000


 38%|███▊      | 23/60 [18:13<29:10, 47.30s/it]

[Epoch 23] Validation accuracy:0.8358. bacc:0.6798. Loss:0.457
best_acc:0.8358


 38%|███▊      | 23/60 [18:58<29:10, 47.30s/it]

[Epoch 24] Training accuracy: 0.8615. Loss: 0.390. LR 0.001000


 40%|████      | 24/60 [19:01<28:23, 47.31s/it]

[Epoch 24] Validation accuracy:0.8380. bacc:0.6753. Loss:0.451
best_acc:0.838


 40%|████      | 24/60 [19:46<28:23, 47.31s/it]

[Epoch 25] Training accuracy: 0.8654. Loss: 0.381. LR 0.001000


 42%|████▏     | 25/60 [19:48<27:36, 47.32s/it]

[Epoch 25] Validation accuracy:0.8364. bacc:0.6775. Loss:0.452
best_acc:0.838


 42%|████▏     | 25/60 [20:33<27:36, 47.32s/it]

[Epoch 26] Training accuracy: 0.8636. Loss: 0.387. LR 0.001000


 43%|████▎     | 26/60 [20:35<26:49, 47.33s/it]

[Epoch 26] Validation accuracy:0.8405. bacc:0.6799. Loss:0.452
best_acc:0.8405


 43%|████▎     | 26/60 [21:20<26:49, 47.33s/it]

[Epoch 27] Training accuracy: 0.8626. Loss: 0.385. LR 0.001000


 45%|████▌     | 27/60 [21:23<26:01, 47.31s/it]

[Epoch 27] Validation accuracy:0.8345. bacc:0.6767. Loss:0.458
best_acc:0.8405


 45%|████▌     | 27/60 [22:07<26:01, 47.31s/it]

[Epoch 28] Training accuracy: 0.8644. Loss: 0.387. LR 0.001000


 47%|████▋     | 28/60 [22:10<25:13, 47.31s/it]

[Epoch 28] Validation accuracy:0.8427. bacc:0.6852. Loss:0.452
best_acc:0.8427


 47%|████▋     | 28/60 [22:55<25:13, 47.31s/it]

[Epoch 29] Training accuracy: 0.8618. Loss: 0.385. LR 0.001000


 48%|████▊     | 29/60 [22:57<24:26, 47.31s/it]

[Epoch 29] Validation accuracy:0.8395. bacc:0.6804. Loss:0.452
best_acc:0.8427


 48%|████▊     | 29/60 [23:42<24:26, 47.31s/it]

[Epoch 30] Training accuracy: 0.8641. Loss: 0.381. LR 0.001000


 50%|█████     | 30/60 [23:45<23:38, 47.29s/it]

[Epoch 30] Validation accuracy:0.8377. bacc:0.6866. Loss:0.457
best_acc:0.8427


 50%|█████     | 30/60 [24:29<23:38, 47.29s/it]

[Epoch 31] Training accuracy: 0.8666. Loss: 0.381. LR 0.000100


 52%|█████▏    | 31/60 [24:32<22:51, 47.29s/it]

[Epoch 31] Validation accuracy:0.8383. bacc:0.6877. Loss:0.451
best_acc:0.8427


 52%|█████▏    | 31/60 [25:17<22:51, 47.29s/it]

[Epoch 32] Training accuracy: 0.8678. Loss: 0.375. LR 0.000100


 53%|█████▎    | 32/60 [25:19<22:03, 47.27s/it]

[Epoch 32] Validation accuracy:0.8383. bacc:0.6796. Loss:0.453
best_acc:0.8427


 53%|█████▎    | 32/60 [26:04<22:03, 47.27s/it]

[Epoch 33] Training accuracy: 0.8673. Loss: 0.379. LR 0.000100


 55%|█████▌    | 33/60 [26:06<21:16, 47.29s/it]

[Epoch 33] Validation accuracy:0.8374. bacc:0.6886. Loss:0.449
best_acc:0.8427


 55%|█████▌    | 33/60 [26:51<21:16, 47.29s/it]

[Epoch 34] Training accuracy: 0.8660. Loss: 0.377. LR 0.000100


 57%|█████▋    | 34/60 [26:54<20:29, 47.31s/it]

[Epoch 34] Validation accuracy:0.8374. bacc:0.6825. Loss:0.453
best_acc:0.8427


 57%|█████▋    | 34/60 [27:38<20:29, 47.31s/it]

[Epoch 35] Training accuracy: 0.8669. Loss: 0.380. LR 0.000100


 58%|█████▊    | 35/60 [27:41<19:42, 47.31s/it]

[Epoch 35] Validation accuracy:0.8395. bacc:0.6796. Loss:0.451
best_acc:0.8427


 58%|█████▊    | 35/60 [28:26<19:42, 47.31s/it]

[Epoch 36] Training accuracy: 0.8685. Loss: 0.371. LR 0.000100


 60%|██████    | 36/60 [28:28<18:55, 47.31s/it]

[Epoch 36] Validation accuracy:0.8345. bacc:0.6775. Loss:0.456
best_acc:0.8427


 60%|██████    | 36/60 [29:13<18:55, 47.31s/it]

[Epoch 37] Training accuracy: 0.8676. Loss: 0.375. LR 0.000100


 62%|██████▏   | 37/60 [29:16<18:07, 47.29s/it]

[Epoch 37] Validation accuracy:0.8395. bacc:0.6943. Loss:0.454
best_acc:0.8427


 62%|██████▏   | 37/60 [30:00<18:07, 47.29s/it]

[Epoch 38] Training accuracy: 0.8681. Loss: 0.375. LR 0.000100


 63%|██████▎   | 38/60 [30:03<17:20, 47.30s/it]

[Epoch 38] Validation accuracy:0.8364. bacc:0.6762. Loss:0.456
best_acc:0.8427


 63%|██████▎   | 38/60 [30:48<17:20, 47.30s/it]

[Epoch 39] Training accuracy: 0.8698. Loss: 0.368. LR 0.000100


 65%|██████▌   | 39/60 [30:50<16:33, 47.31s/it]

[Epoch 39] Validation accuracy:0.8355. bacc:0.6735. Loss:0.453
best_acc:0.8427


 65%|██████▌   | 39/60 [31:35<16:33, 47.31s/it]

[Epoch 40] Training accuracy: 0.8690. Loss: 0.369. LR 0.000100


 67%|██████▋   | 40/60 [31:38<15:46, 47.31s/it]

[Epoch 40] Validation accuracy:0.8370. bacc:0.6845. Loss:0.454
best_acc:0.8427


 67%|██████▋   | 40/60 [32:22<15:46, 47.31s/it]

[Epoch 41] Training accuracy: 0.8679. Loss: 0.372. LR 0.000010


 68%|██████▊   | 41/60 [32:25<14:58, 47.29s/it]

[Epoch 41] Validation accuracy:0.8392. bacc:0.6885. Loss:0.449
best_acc:0.8427


 68%|██████▊   | 41/60 [33:10<14:58, 47.29s/it]

[Epoch 42] Training accuracy: 0.8675. Loss: 0.378. LR 0.000010


 70%|███████   | 42/60 [33:12<14:10, 47.27s/it]

[Epoch 42] Validation accuracy:0.8361. bacc:0.6802. Loss:0.451
best_acc:0.8427


 70%|███████   | 42/60 [33:57<14:10, 47.27s/it]

[Epoch 43] Training accuracy: 0.8683. Loss: 0.373. LR 0.000010


 72%|███████▏  | 43/60 [33:59<13:23, 47.29s/it]

[Epoch 43] Validation accuracy:0.8339. bacc:0.6865. Loss:0.451
best_acc:0.8427


 72%|███████▏  | 43/60 [34:44<13:23, 47.29s/it]

[Epoch 44] Training accuracy: 0.8672. Loss: 0.377. LR 0.000010


 73%|███████▎  | 44/60 [34:47<12:36, 47.30s/it]

[Epoch 44] Validation accuracy:0.8389. bacc:0.6876. Loss:0.452
best_acc:0.8427


 73%|███████▎  | 44/60 [35:31<12:36, 47.30s/it]

[Epoch 45] Training accuracy: 0.8676. Loss: 0.375. LR 0.000010


 75%|███████▌  | 45/60 [35:34<11:49, 47.29s/it]

[Epoch 45] Validation accuracy:0.8367. bacc:0.6715. Loss:0.454
best_acc:0.8427


 75%|███████▌  | 45/60 [36:19<11:49, 47.29s/it]

[Epoch 46] Training accuracy: 0.8690. Loss: 0.370. LR 0.000010


 77%|███████▋  | 46/60 [36:21<11:02, 47.33s/it]

[Epoch 46] Validation accuracy:0.8374. bacc:0.6833. Loss:0.451
best_acc:0.8427


 77%|███████▋  | 46/60 [37:06<11:02, 47.33s/it]

[Epoch 47] Training accuracy: 0.8660. Loss: 0.374. LR 0.000010


 78%|███████▊  | 47/60 [37:09<10:14, 47.30s/it]

[Epoch 47] Validation accuracy:0.8361. bacc:0.6752. Loss:0.450
best_acc:0.8427


 78%|███████▊  | 47/60 [37:53<10:14, 47.30s/it]

[Epoch 48] Training accuracy: 0.8689. Loss: 0.374. LR 0.000010


 80%|████████  | 48/60 [37:56<09:27, 47.29s/it]

[Epoch 48] Validation accuracy:0.8408. bacc:0.6866. Loss:0.450
best_acc:0.8427


 80%|████████  | 48/60 [38:41<09:27, 47.29s/it]

[Epoch 49] Training accuracy: 0.8670. Loss: 0.376. LR 0.000010


 82%|████████▏ | 49/60 [38:43<08:40, 47.27s/it]

[Epoch 49] Validation accuracy:0.8370. bacc:0.6852. Loss:0.450
best_acc:0.8427


 82%|████████▏ | 49/60 [39:28<08:40, 47.27s/it]

[Epoch 50] Training accuracy: 0.8681. Loss: 0.373. LR 0.000010


 83%|████████▎ | 50/60 [39:31<07:52, 47.29s/it]

[Epoch 50] Validation accuracy:0.8383. bacc:0.6749. Loss:0.452
best_acc:0.8427


 83%|████████▎ | 50/60 [40:15<07:52, 47.29s/it]

[Epoch 51] Training accuracy: 0.8673. Loss: 0.377. LR 0.000001


 85%|████████▌ | 51/60 [40:18<07:05, 47.29s/it]

[Epoch 51] Validation accuracy:0.8399. bacc:0.6856. Loss:0.455
best_acc:0.8427


 85%|████████▌ | 51/60 [41:03<07:05, 47.29s/it]

[Epoch 52] Training accuracy: 0.8708. Loss: 0.369. LR 0.000001


 87%|████████▋ | 52/60 [41:05<06:18, 47.30s/it]

[Epoch 52] Validation accuracy:0.8364. bacc:0.6758. Loss:0.452
best_acc:0.8427


 87%|████████▋ | 52/60 [41:50<06:18, 47.30s/it]

[Epoch 53] Training accuracy: 0.8675. Loss: 0.374. LR 0.000001


 88%|████████▊ | 53/60 [41:52<05:31, 47.29s/it]

[Epoch 53] Validation accuracy:0.8370. bacc:0.6825. Loss:0.453
best_acc:0.8427


 88%|████████▊ | 53/60 [42:37<05:31, 47.29s/it]

[Epoch 54] Training accuracy: 0.8663. Loss: 0.375. LR 0.000001


 90%|█████████ | 54/60 [42:40<04:43, 47.31s/it]

[Epoch 54] Validation accuracy:0.8405. bacc:0.6908. Loss:0.449
best_acc:0.8427


 90%|█████████ | 54/60 [43:25<04:43, 47.31s/it]

[Epoch 55] Training accuracy: 0.8703. Loss: 0.371. LR 0.000001


 92%|█████████▏| 55/60 [43:27<03:56, 47.32s/it]

[Epoch 55] Validation accuracy:0.8395. bacc:0.6872. Loss:0.448
best_acc:0.8427


 92%|█████████▏| 55/60 [44:12<03:56, 47.32s/it]

[Epoch 56] Training accuracy: 0.8696. Loss: 0.375. LR 0.000001


 93%|█████████▎| 56/60 [44:14<03:09, 47.31s/it]

[Epoch 56] Validation accuracy:0.8345. bacc:0.6774. Loss:0.452
best_acc:0.8427


 93%|█████████▎| 56/60 [44:59<03:09, 47.31s/it]

[Epoch 57] Training accuracy: 0.8702. Loss: 0.371. LR 0.000001


 95%|█████████▌| 57/60 [45:02<02:21, 47.29s/it]

[Epoch 57] Validation accuracy:0.8402. bacc:0.6923. Loss:0.452
best_acc:0.8427


 95%|█████████▌| 57/60 [45:46<02:21, 47.29s/it]

[Epoch 58] Training accuracy: 0.8713. Loss: 0.370. LR 0.000001


 97%|█████████▋| 58/60 [45:49<01:34, 47.29s/it]

[Epoch 58] Validation accuracy:0.8361. bacc:0.6746. Loss:0.451
best_acc:0.8427


 97%|█████████▋| 58/60 [46:34<01:34, 47.29s/it]

[Epoch 59] Training accuracy: 0.8684. Loss: 0.373. LR 0.000001


 98%|█████████▊| 59/60 [46:36<00:47, 47.30s/it]

[Epoch 59] Validation accuracy:0.8389. bacc:0.6863. Loss:0.451
best_acc:0.8427


 98%|█████████▊| 59/60 [47:21<00:47, 47.30s/it]

[Epoch 60] Training accuracy: 0.8686. Loss: 0.373. LR 0.000001


100%|██████████| 60/60 [47:24<00:00, 47.40s/it]

[Epoch 60] Validation accuracy:0.8355. bacc:0.6845. Loss:0.452
best_acc:0.8427





In [4]:
def evaluate_model(model, test_loader, device):
    model.eval()  # Set model to evaluation mode
    correct = 0
    total = 0
    y_true = []
    y_pred = []

    print("Starting evaluation...")
    with torch.no_grad():  # Disable gradient computation
        for i, (imgs, targets) in enumerate(test_loader):

            imgs = imgs.to(device)
            targets = targets.to(device)

            # Forward pass
            outputs, _, _ = model(imgs, torch.zeros(imgs.size(0), 39936).to(device))  # Provide dummy latent input
            _, predictions = torch.max(outputs, 1)  # Get predicted class

            # Collect results
            correct += (predictions == targets).sum().item()
            total += targets.size(0)

            y_true.extend(targets.cpu().numpy())
            y_pred.extend(predictions.cpu().numpy())

    # Calculate accuracy
    accuracy = correct / total if total > 0 else 0  # Prevent division by zero
    print(f"Test Accuracy: {accuracy}")  
    print(f"Test Accuracy: {accuracy}")
    print(f"Test Accuracy: {accuracy}")

    return accuracy
data_transforms_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
])

test_dataset = FERPlusDataset(test_csv, phase='test', transform=data_transforms_test)
print('Test set size:', test_dataset.__len__())

test_loader = torch.utils.data.DataLoader(test_dataset,
                                            batch_size=batch_size,
                                            num_workers=workers,
                                            shuffle=False,
                                            pin_memory=True)

# Evaluate model
acc = evaluate_model(model, test_loader, device)
print(f"Final Test Accuracy: {acc:.4f}")


Unique labels in dataset after filtering: [0 1 2 3 4 5 6 7]
Test set size: 3136
Starting evaluation...
Test Accuracy: 0.8447066326530612
Test Accuracy: 0.8447066326530612
Test Accuracy: 0.8447066326530612
Final Test Accuracy: 0.8447


# MobilNet Latent dlib attention

In [1]:
from torch import nn
from torch.nn import functional as F
import torch
import torch.nn.init as init
from torchvision.models import mobilenet_v2
from torchvision import transforms
from torchvision import models

from sklearn.metrics import balanced_accuracy_score
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="sklearn.metrics")



In [2]:
from torch import nn
from torch.nn import functional as F
import torch
import torch.nn.init as init
from torchvision import models

class ModifiedOURS(nn.Module):
    def __init__(self, num_class=7, num_head=4, pretrained=True):
        super(ModifiedOURS, self).__init__()

        # MobileNetV2 백본
        mobilenet = models.mobilenet_v2(pretrained=pretrained)
        self.features = mobilenet.features  # MobileNetV2의 특징 추출 레이어

        self.num_head = num_head
        for i in range(num_head):
            setattr(self, f"cat_head{i}", CrossAttentionHead(1280))  # 채널 크기를 1280으로 수정

        # Self-Attention
        self.self_attention = SelfAttention(1280)

        # 분류기
        self.fc = nn.Linear(1280, 256)
        self.fc2 = nn.Linear(256, 128)
        self.batch_norm = nn.BatchNorm1d(128)
        self.fc4 = nn.Linear(128, num_class)
        self.bn = nn.BatchNorm1d(num_class)

    def forward(self, x, latent):
        # MobileNetV2 특징 추출
        x = self.features(x)
        x = F.adaptive_avg_pool2d(x, (1, 1))  # GAP(Global Average Pooling)
        x = x.view(x.size(0), -1)  # (batch_size, 1280)

        # Cross Attention Heads
        heads = []
        for i in range(self.num_head):
            heads.append(getattr(self, f"cat_head{i}")(x.unsqueeze(-1).unsqueeze(-1)))  # 차원 추가

        heads = torch.stack(heads).permute([1, 0, 2])  # (batch_size, num_heads, features)
        heads = heads.sum(dim=1)  # (batch_size, features)

        # Self-Attention
        attention_out = self.self_attention(heads)

        # 분류기
        out = self.fc(attention_out)
        out = self.fc2(out)
        out = self.batch_norm(out)
        out = self.fc4(out)
        out = self.bn(out)

        return out, x, heads


class CrossAttentionHead(nn.Module):
    def __init__(self, channel_dim):
        super().__init__()
        self.sa = SpatialAttention(channel_dim)
        self.ca = ChannelAttention(channel_dim)
        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        ca = self.ca(x)
        sa = self.sa(ca)
        return sa


class SpatialAttention(nn.Module):
    def __init__(self, channel_dim):
        super().__init__()
        self.conv1x1 = nn.Sequential(
            nn.Conv2d(channel_dim, channel_dim // 2, kernel_size=1),
            nn.BatchNorm2d(channel_dim // 2),
        )
        self.conv_3x3 = nn.Sequential(
            nn.Conv2d(channel_dim // 2, channel_dim, kernel_size=3, padding=1),
            nn.BatchNorm2d(channel_dim),
        )
        self.relu = nn.ReLU()
        self.gap = nn.AdaptiveAvgPool2d(1)

    def forward(self, x):
        y = self.conv1x1(x)
        y = self.conv_3x3(y)
        y = self.relu(y)
        y = y.sum(dim=1, keepdim=True)

        out = x * y
        out = self.gap(out)
        out = out.view(out.size(0), -1)
        return out


class ChannelAttention(nn.Module):
    def __init__(self, channel_dim):
        super().__init__()
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.attention = nn.Sequential(
            nn.Linear(channel_dim, channel_dim // 4),
            nn.BatchNorm1d(channel_dim // 4),
            nn.ReLU(inplace=True),
            nn.Linear(channel_dim // 4, channel_dim),
            nn.Sigmoid()
        )

    def forward(self, sa):
        sa2 = self.gap(sa)
        sa2 = sa2.view(sa2.size(0), -1)
        y = self.attention(sa2)
        y = y.unsqueeze(dim=-1).unsqueeze(dim=-1)

        out = sa * y
        return out


class SelfAttention(nn.Module):
    def __init__(self, embed_dim):
        super(SelfAttention, self).__init__()
        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)

        attention_scores = torch.matmul(Q, K.transpose(-2, -1))
        attention_scores = self.softmax(attention_scores)

        attention_output = torch.matmul(attention_scores, V)
        return attention_output


In [3]:
import dlib
import pandas as pd
import numpy as np
from PIL import Image
import cv2
from torch.utils.data import Dataset
from tqdm import tqdm
class FERPlusDataset(Dataset):
    def __init__(self, data_csv, phase, transform=None):
        self.phase = phase
        self.transform = transform

        # Read the dataset CSV file
        self.data = pd.read_csv(data_csv)
        self.data.iloc[:, 2:12] = self.data.iloc[:, 2:12].replace(1, 0)

        # Get file paths and labels
        self.file_paths = self.data.iloc[:, 0].values
        self.counts = self.data.iloc[:, 2:12].values  # 감정 점수들

        # Apply constraints to filter valid samples
        self._apply_constraints()

        # Use argmax to determine the emotion class
        self.labels = np.argmax(self.counts, axis=1)

        # Debugging: Check label range
        print("Unique labels in dataset after filtering:", np.unique(self.labels))

        # Dlib face detector and predictor initialization
        self.detector = dlib.get_frontal_face_detector()
        self.predictor = dlib.shape_predictor('/root/FER2013/shape_predictor_68_face_landmarks.dat')  # Dlib 모델 파일 경로 필요

    def _apply_constraints(self):
        # Constraint : 'unknown-face' 또는 'not-face' 레이블 제거
        max_counts = self.counts.max(axis=1)
        counts_eq_max = (self.counts == max_counts[:, None])
        constraint1_violation = counts_eq_max[:, [8, 9]].any(axis=1)

        # Constraint : 최대 투표 수를 가진 레이블이 3개 초과 제거
        num_max_labels = counts_eq_max.sum(axis=1)
        constraint2_violation = num_max_labels > 3

        # Constraint : 최대 투표 수가 전체 투표 수의 절반 이하인 경우 제거
        total_votes = self.counts.sum(axis=1)
        constraint3_violation = max_counts <= (total_votes / 2)

        # Combine constraints
        valid_samples = ~(
            constraint1_violation | constraint2_violation | constraint3_violation
        )

        # Apply valid samples filter
        self.file_paths = self.file_paths[valid_samples]
        self.counts = self.counts[valid_samples]

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        if self.phase == 'train':
            path = '/data/FER2013/FER2013Train/' + self.file_paths[idx]
        elif self.phase == 'val':
            path = '/data/FER2013/FER2013Valid/' + self.file_paths[idx]
        elif self.phase == 'test':
            path = '/data/FER2013/FER2013Test/' + self.file_paths[idx]
        
        # Open image
        image = cv2.imread(path, cv2.IMREAD_GRAYSCALE)  # Read as grayscale
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)  # Convert grayscale to RGB by duplicating the channel
        
        # Align face using Dlib
        image = self._align_face(image)

        # Convert to PIL image for further processing
        image = Image.fromarray(image)

        label = self.labels[idx]

        if self.transform is not None:
            image = self.transform(image)

        return image, label

    def _align_face(self, image):
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        faces = self.detector(gray)

        if len(faces) == 0:
            return image  # 얼굴이 감지되지 않으면 원본 이미지를 반환

        for face in faces:
            landmarks = self.predictor(gray, face)

            # 좌우 눈의 중심 좌표 추출
            left_eye = (landmarks.part(36).x, landmarks.part(36).y)
            right_eye = (landmarks.part(45).x, landmarks.part(45).y)

            # 두 눈의 중심 계산
            eye_center = ((left_eye[0] + right_eye[0]) // 2, (left_eye[1] + right_eye[1]) // 2)

            # 눈 사이의 기울기 계산
            delta_x = right_eye[0] - left_eye[0]
            delta_y = right_eye[1] - left_eye[1]
            angle = np.degrees(np.arctan2(delta_y, delta_x))

            # 회전 행렬 계산
            rot_matrix = cv2.getRotationMatrix2D(eye_center, angle, 1.0)

            # 이미지 회전 및 정렬
            aligned_face = cv2.warpAffine(image, rot_matrix, (image.shape[1], image.shape[0]))

            return aligned_face  # 첫 번째 얼굴만 처리

        return image

In [4]:


# Define variables
train_csv = '/data/FER2013/train_label.csv'
val_csv = '/data/FER2013/valid_label.csv'
test_csv = '/data/FER2013/test_label.csv'
batch_size = 128
lr = 0.001
workers = 4
epochs = 60

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = True

latent_dim = 1280  # MobileNetV2의 출력 채널 크기
latent_input = latent_dim  # `latent` 입력 크기 수정

model = ModifiedOURS(num_class=8 ,num_head=4)# YOUR MODEL CLASS SHOULD BE DEFINED ELSEWHERE
model.to(device)

data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomApply([
            transforms.RandomRotation(20),
            transforms.RandomCrop(224, padding=32)
        ], p=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225]),
    transforms.RandomErasing(scale=(0.02, 0.25)),
])

train_dataset = FERPlusDataset(train_csv, phase='train', transform=data_transforms)
print('Whole train set size:', train_dataset.__len__())

train_loader = torch.utils.data.DataLoader(train_dataset,
                                            batch_size=batch_size,
                                            num_workers=workers,
                                            shuffle=True,
                                            pin_memory=True)

data_transforms_val = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
])

val_dataset = FERPlusDataset(val_csv, phase='val', transform=data_transforms_val)
print('Validation set size:', val_dataset.__len__())

val_loader = torch.utils.data.DataLoader(val_dataset,
                                            batch_size=batch_size,
                                            num_workers=workers,
                                            shuffle=False,
                                            pin_memory=True)

# Loss 함수 정의
def stable_cross_entropy_loss(output, targets, epsilon=1e-7):
    output = F.log_softmax(output, dim=1).clamp(min=epsilon, max=1 - epsilon)
    return F.nll_loss(output, targets)

params = list(model.parameters())
# 손실 함수 수정
criterion_cls = stable_cross_entropy_loss

# 학습률 및 옵티마이저 수정
optimizer = torch.optim.SGD(params, lr=0.01, weight_decay=1e-4, momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
best_acc = 0
latent_dim = 1280  # MobileNetV2 출력 채널 크기
for epoch in tqdm(range(1, epochs + 1)):
    running_loss = 0.0
    correct_sum = 0
    iter_cnt = 0
    model.train()

    for imgs, targets in train_loader:
        iter_cnt += 1
        optimizer.zero_grad()

        imgs, targets = imgs.to(device), targets.to(device)
        latent_input = torch.zeros(imgs.size(0), latent_dim).to(device)  # `latent` 초기화
        out, _, _ = model(imgs, latent_input)

        # 안정적인 CrossEntropyLoss 적용
        loss = stable_cross_entropy_loss(out, targets)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicts = torch.max(out, 1)
        correct_sum += (predicts == targets).sum().item()

    acc = correct_sum / len(train_dataset)
    running_loss /= iter_cnt
    tqdm.write(f'[Epoch {epoch}] Training accuracy: {acc:.4f}. Loss: {running_loss:.3f}. LR {optimizer.param_groups[0]["lr"]:.6f}')

    # Validation
    with torch.no_grad():
        running_loss = 0.0
        iter_cnt = 0
        correct_sum = 0
        total_samples = 0
        baccs = []

        y_true = []
        y_pred = []

        model.eval()
        for imgs, targets in val_loader:
            imgs, targets = imgs.to(device), targets.to(device)
            latent_input = torch.zeros(imgs.size(0), latent_dim).to(device)  # `latent` 초기화
            out, _, _ = model(imgs, latent_input)

            loss = stable_cross_entropy_loss(out, targets)
            running_loss += loss.item()
            iter_cnt += 1

            _, predicts = torch.max(out, 1)
            correct_sum += (predicts == targets).sum().item()
            total_samples += targets.size(0)

            y_true.append(targets.cpu().numpy())
            y_pred.append(predicts.cpu().numpy())

            baccs.append(balanced_accuracy_score(targets.cpu().numpy(), predicts.cpu().numpy()))

        val_acc = correct_sum / total_samples
        running_loss /= iter_cnt
        scheduler.step()

        y_true = np.concatenate(y_true)
        y_pred = np.concatenate(y_pred)
        bacc = np.mean(baccs)

        tqdm.write(f'[Epoch {epoch}] Validation accuracy: {val_acc:.4f}. bacc: {bacc:.4f}. Loss: {running_loss:.3f}')
        tqdm.write(f'Best validation accuracy: {best_acc:.4f}')




Unique labels in dataset after filtering: [0 1 2 3 4 5 6 7]
Whole train set size: 25045
Unique labels in dataset after filtering: [0 1 2 3 4 5 6 7]
Validation set size: 3191


TypeError: stable_cross_entropy_loss() missing 2 required positional arguments: 'output' and 'targets'

In [None]:
def evaluate_model(model, test_loader, device):
    model.eval()  # Set model to evaluation mode
    correct = 0
    total = 0
    y_true = []
    y_pred = []

    print("Starting evaluation...")
    with torch.no_grad():  # Disable gradient computation
        for i, (imgs, targets) in enumerate(test_loader):

            imgs = imgs.to(device)
            targets = targets.to(device)

            # Forward pass
            outputs, _, _ = model(imgs, torch.zeros(imgs.size(0), 39936).to(device))  # Provide dummy latent input
            _, predictions = torch.max(outputs, 1)  # Get predicted class

            # Collect results
            correct += (predictions == targets).sum().item()
            total += targets.size(0)

            y_true.extend(targets.cpu().numpy())
            y_pred.extend(predictions.cpu().numpy())

    # Calculate accuracy
    accuracy = correct / total if total > 0 else 0  # Prevent division by zero
    print(f"Test Accuracy: {accuracy}")  
    print(f"Test Accuracy: {accuracy}")
    print(f"Test Accuracy: {accuracy}")

    return accuracy
data_transforms_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
])

test_dataset = FERPlusDataset(test_csv, phase='test', transform=data_transforms_test)
print('Test set size:', test_dataset.__len__())

test_loader = torch.utils.data.DataLoader(test_dataset,
                                            batch_size=batch_size,
                                            num_workers=workers,
                                            shuffle=False,
                                            pin_memory=True)

# Evaluate model
acc = evaluate_model(model, test_loader, device)
print(f"Final Test Accuracy: {acc:.4f}")
