In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset
from sklearn.utils.class_weight import compute_class_weight
import numpy as np
from tqdm import tqdm
import pandas as pd
from PIL import Image

# Configuration
NUM_CLASSES = 8
IMG_SHAPE = (384, 384, 3)
BATCH_SIZE = 64
TRAIN_EPOCH = 100
TRAIN_LR = 1e-4
TRAIN_ES_PATIENCE = 5
TRAIN_LR_PATIENCE = 3
TRAIN_MIN_LR = 1e-6
TRAIN_DROPOUT = 0.1

FT_EPOCH = 500
FT_LR = 1e-5
FT_LR_DECAY_STEP = 10
FT_LR_DECAY_RATE = 0.25
FT_DROPOUT = 0.2

ES_LR_MIN_DELTA = 0.003

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset
class FERPlusDataset(Dataset):
    def __init__(self, data_csv, phase, transform=None):
        self.phase = phase
        self.transform = transform

        # Read the dataset CSV file
        self.data = pd.read_csv(data_csv)
        self.data.iloc[:, 2:12] = self.data.iloc[:, 2:12].replace(1, 0)
        # Get file paths and labels
        self.file_paths = self.data.iloc[:, 0].values
        self.counts = self.data.iloc[:, 2:12].values  # 감정 점수들

        # Apply constraints to filter valid samples
        self._apply_constraints()

        # Use argmax to determine the emotion class
        self.labels = np.argmax(self.counts, axis=1)

        # Debugging: Check label range
        print("Unique labels in dataset after filtering:", np.unique(self.labels))

    def _apply_constraints(self):
        # Constraint : 'unknown-face' 또는 'not-face' 레이블 제거
        max_counts = self.counts.max(axis=1)
        counts_eq_max = (self.counts == max_counts[:, None])
        constraint1_violation = counts_eq_max[:, [8, 9]].any(axis=1)

        # Constraint : 1인 라벨 0으로 만들기


        # Constraint : 최대 투표 수를 가진 레이블이 3개 초과 제거
        num_max_labels = counts_eq_max.sum(axis=1)
        constraint2_violation = num_max_labels > 3

        # Constraint : 최대 투표 수가 전체 투표 수의 절반 이하인 경우 제거
        total_votes = self.counts.sum(axis=1)
        constraint3_violation = max_counts <= (total_votes / 2)

        # Combine constraints
        valid_samples = ~(
            constraint1_violation | constraint2_violation | constraint3_violation
        )

        # Apply valid samples filter
        self.file_paths = self.file_paths[valid_samples]
        self.counts = self.counts[valid_samples]

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        if self.phase == 'train':
            path = '/data/FER2013/FER2013Train/' + self.file_paths[idx]
        elif self.phase == 'val':
            path = '/data/FER2013/FER2013Valid/' + self.file_paths[idx]
        elif self.phase == 'test':
            path = '/data/FER2013/FER2013Test/' + self.file_paths[idx]
        image = Image.open(path).convert('RGB')
        label = self.labels[idx]

        if self.transform is not None:
            image = self.transform(image)

        return image, label



# Data loading and preprocessing
transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(contrast=0.3),  # Added RandomContrast equivalent
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = FERPlusDataset('/data/FER2013/train_label.csv', phase='train', transform=transform_train)
valid_dataset = FERPlusDataset('/data/FER2013/valid_label.csv', phase='val', transform=transform_test)
test_dataset = FERPlusDataset('/data/FER2013/test_label.csv', phase='test', transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Compute class weights
class_weights = compute_class_weight('balanced', classes=np.unique(train_dataset.labels), y=train_dataset.labels)
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)


# Model definition
class CustomModel(nn.Module):
    def __init__(self, num_classes, dropout):
        super(CustomModel, self).__init__()
        backbone = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
        self.base_model = nn.Sequential(*list(backbone.features)[:-1])  # Matches TF backbone

        self.patch_extraction = nn.Sequential(
            nn.Conv2d(320, 256, kernel_size=3, stride=1, padding=1),  # Adjusted kernel/stride
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),  # Reduced stride
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0),
            nn.ReLU(),
        )
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.fc_pre_classification = nn.Sequential(
            nn.Linear(256, 32),
            nn.ReLU(),
            nn.BatchNorm1d(32)
        )
        self.attention = nn.MultiheadAttention(embed_dim=32, num_heads=1, batch_first=True)
        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(32, num_classes)
        )

    def forward(self, x):
        x = self.base_model(x)
        x = self.patch_extraction(x)
        x = self.gap(x).view(x.size(0), -1)
        x = self.fc_pre_classification(x).unsqueeze(1)  # Add sequence dim for attention
        x, _ = self.attention(x, x, x)
        x = x.squeeze(1)
        x = self.classifier(x)
        return x



model = CustomModel(num_classes=NUM_CLASSES, dropout=TRAIN_DROPOUT).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = optim.Adam(model.parameters(), lr=TRAIN_LR)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=TRAIN_LR_PATIENCE, min_lr=TRAIN_MIN_LR)

# Early Stopping
early_stopping_patience = TRAIN_ES_PATIENCE
early_stopping_counter = 0
best_val_acc = 0


def train_model(model, train_loader, valid_loader, num_epochs):
    global best_val_acc, early_stopping_counter
    for epoch in range(num_epochs):
        model.train()
        train_loss, correct, total = 0, 0, 0

        for inputs, labels in tqdm(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * inputs.size(0)
            correct += (outputs.argmax(1) == labels).sum().item()
            total += labels.size(0)

        train_acc = correct / total

        # Validation
        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        with torch.no_grad():
            for inputs, labels in valid_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                val_loss += criterion(outputs, labels).item() * inputs.size(0)
                val_correct += (outputs.argmax(1) == labels).sum().item()
                val_total += labels.size(0)

        val_acc = val_correct / val_total
        scheduler.step(val_acc)

        print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss / total:.4f}, Train Acc: {train_acc:.4f}, Val Loss: {val_loss / val_total:.4f}, Val Acc: {val_acc:.4f}")

        # Early Stopping
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            early_stopping_counter = 0
            torch.save(model.state_dict(), "best_model.pth")
        else:
            early_stopping_counter += 1
            if early_stopping_counter >= early_stopping_patience:
                print("Early stopping triggered.")
                break


train_model(model, train_loader, valid_loader, TRAIN_EPOCH)


Unique labels in dataset after filtering: [0 1 2 3 4 5 6 7]
Unique labels in dataset after filtering: [0 1 2 3 4 5 6 7]
Unique labels in dataset after filtering: [0 1 2 3 4 5 6 7]


100%|██████████| 392/392 [01:14<00:00,  5.26it/s]


Epoch 1/100, Train Loss: 1.5666, Train Acc: 0.5341, Val Loss: 1.2571, Val Acc: 0.6575


100%|██████████| 392/392 [01:12<00:00,  5.42it/s]


Epoch 2/100, Train Loss: 1.0894, Train Acc: 0.6774, Val Loss: 0.9854, Val Acc: 0.7073


100%|██████████| 392/392 [01:12<00:00,  5.42it/s]


Epoch 3/100, Train Loss: 0.8205, Train Acc: 0.7239, Val Loss: 0.7968, Val Acc: 0.7606


100%|██████████| 392/392 [01:12<00:00,  5.40it/s]


Epoch 4/100, Train Loss: 0.6660, Train Acc: 0.7649, Val Loss: 0.9108, Val Acc: 0.7249


100%|██████████| 392/392 [01:12<00:00,  5.38it/s]


Epoch 5/100, Train Loss: 0.6076, Train Acc: 0.7790, Val Loss: 1.0431, Val Acc: 0.7245


100%|██████████| 392/392 [01:12<00:00,  5.37it/s]


Epoch 6/100, Train Loss: 0.4951, Train Acc: 0.8133, Val Loss: 0.7527, Val Acc: 0.7944


100%|██████████| 392/392 [01:13<00:00,  5.35it/s]


Epoch 7/100, Train Loss: 0.4449, Train Acc: 0.8278, Val Loss: 0.7922, Val Acc: 0.7888


100%|██████████| 392/392 [01:13<00:00,  5.35it/s]


Epoch 8/100, Train Loss: 0.3853, Train Acc: 0.8384, Val Loss: 0.6691, Val Acc: 0.7994


100%|██████████| 392/392 [01:12<00:00,  5.39it/s]


Epoch 9/100, Train Loss: 0.3448, Train Acc: 0.8556, Val Loss: 0.8114, Val Acc: 0.8160


100%|██████████| 392/392 [01:12<00:00,  5.40it/s]


Epoch 10/100, Train Loss: 0.2984, Train Acc: 0.8674, Val Loss: 0.7973, Val Acc: 0.8132


100%|██████████| 392/392 [01:12<00:00,  5.38it/s]


Epoch 11/100, Train Loss: 0.2544, Train Acc: 0.8850, Val Loss: 0.9142, Val Acc: 0.8179


100%|██████████| 392/392 [01:13<00:00,  5.35it/s]


Epoch 12/100, Train Loss: 0.2859, Train Acc: 0.8738, Val Loss: 0.9595, Val Acc: 0.7950


100%|██████████| 392/392 [01:13<00:00,  5.36it/s]


Epoch 13/100, Train Loss: 0.2546, Train Acc: 0.8856, Val Loss: 0.8808, Val Acc: 0.8229


100%|██████████| 392/392 [01:13<00:00,  5.36it/s]


Epoch 14/100, Train Loss: 0.2160, Train Acc: 0.9002, Val Loss: 0.8897, Val Acc: 0.8355


100%|██████████| 392/392 [01:13<00:00,  5.35it/s]


Epoch 15/100, Train Loss: 0.1829, Train Acc: 0.9112, Val Loss: 0.8786, Val Acc: 0.8242


100%|██████████| 392/392 [01:13<00:00,  5.37it/s]


Epoch 16/100, Train Loss: 0.2044, Train Acc: 0.9094, Val Loss: 1.0033, Val Acc: 0.8248


100%|██████████| 392/392 [01:12<00:00,  5.39it/s]


Epoch 17/100, Train Loss: 0.1953, Train Acc: 0.9108, Val Loss: 0.8448, Val Acc: 0.8057


100%|██████████| 392/392 [01:13<00:00,  5.36it/s]


Epoch 18/100, Train Loss: 0.1833, Train Acc: 0.9161, Val Loss: 0.9338, Val Acc: 0.8276


100%|██████████| 392/392 [01:13<00:00,  5.36it/s]


Epoch 19/100, Train Loss: 0.1143, Train Acc: 0.9439, Val Loss: 0.9485, Val Acc: 0.8430


100%|██████████| 392/392 [01:13<00:00,  5.37it/s]


Epoch 20/100, Train Loss: 0.1055, Train Acc: 0.9500, Val Loss: 0.9274, Val Acc: 0.8361


100%|██████████| 392/392 [01:13<00:00,  5.36it/s]


Epoch 21/100, Train Loss: 0.0919, Train Acc: 0.9563, Val Loss: 1.0223, Val Acc: 0.8490


100%|██████████| 392/392 [01:13<00:00,  5.36it/s]


Epoch 22/100, Train Loss: 0.0842, Train Acc: 0.9598, Val Loss: 0.9958, Val Acc: 0.8417


100%|██████████| 392/392 [01:13<00:00,  5.36it/s]


Epoch 23/100, Train Loss: 0.0792, Train Acc: 0.9629, Val Loss: 1.0390, Val Acc: 0.8449


100%|██████████| 392/392 [01:13<00:00,  5.36it/s]


Epoch 24/100, Train Loss: 0.0749, Train Acc: 0.9650, Val Loss: 1.0831, Val Acc: 0.8511


100%|██████████| 392/392 [01:13<00:00,  5.34it/s]


Epoch 25/100, Train Loss: 0.0660, Train Acc: 0.9691, Val Loss: 1.1083, Val Acc: 0.8483


100%|██████████| 392/392 [01:13<00:00,  5.34it/s]


Epoch 26/100, Train Loss: 0.0644, Train Acc: 0.9698, Val Loss: 1.1349, Val Acc: 0.8515


100%|██████████| 392/392 [01:13<00:00,  5.36it/s]


Epoch 27/100, Train Loss: 0.0600, Train Acc: 0.9723, Val Loss: 1.1654, Val Acc: 0.8464


100%|██████████| 392/392 [01:13<00:00,  5.35it/s]


Epoch 28/100, Train Loss: 0.0592, Train Acc: 0.9745, Val Loss: 1.1701, Val Acc: 0.8515


100%|██████████| 392/392 [01:13<00:00,  5.37it/s]


Epoch 29/100, Train Loss: 0.0526, Train Acc: 0.9762, Val Loss: 1.1759, Val Acc: 0.8490


100%|██████████| 392/392 [01:13<00:00,  5.35it/s]


Epoch 30/100, Train Loss: 0.0509, Train Acc: 0.9778, Val Loss: 1.2056, Val Acc: 0.8562


100%|██████████| 392/392 [01:13<00:00,  5.37it/s]


Epoch 31/100, Train Loss: 0.0480, Train Acc: 0.9808, Val Loss: 1.2528, Val Acc: 0.8515


100%|██████████| 392/392 [01:13<00:00,  5.37it/s]


Epoch 32/100, Train Loss: 0.0471, Train Acc: 0.9794, Val Loss: 1.2619, Val Acc: 0.8493


100%|██████████| 392/392 [01:13<00:00,  5.36it/s]


Epoch 33/100, Train Loss: 0.0481, Train Acc: 0.9812, Val Loss: 1.2699, Val Acc: 0.8499


100%|██████████| 392/392 [01:13<00:00,  5.36it/s]


Epoch 34/100, Train Loss: 0.0437, Train Acc: 0.9818, Val Loss: 1.2918, Val Acc: 0.8515


100%|██████████| 392/392 [01:13<00:00,  5.36it/s]


Epoch 35/100, Train Loss: 0.0370, Train Acc: 0.9853, Val Loss: 1.3183, Val Acc: 0.8508
Early stopping triggered.
