In [2]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import KFold
from transformers import AutoImageProcessor, AutoModelForImageClassification, AdamW
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import numpy as np

model_name    = "trpakov/vit-face-expression"
num_epochs    = 20
batch_size    = 16
learning_rate = 2e-5
weight_decay  = 0.01
num_folds     = 5
random_seed   = 42

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(random_seed)
np.random.seed(random_seed)



In [9]:
def load_datasets(directory):
    image_paths, labels = [], []
    for lbl in os.listdir(directory):
        lbl_dir = os.path.join(directory, lbl)
        if not os.path.isdir(lbl_dir):
            continue
        for fname in os.listdir(lbl_dir):
            image_paths.append(os.path.join(lbl_dir, fname))
            labels.append(lbl)
    return image_paths, labels

train_dir = "../data/train"
test_dir  = "../data/test"

train_paths, train_labels = load_datasets(train_dir)
test_paths,  test_labels  = load_datasets(test_dir)

le      = LabelEncoder()
y_train = le.fit_transform(train_labels)
label_names = le.classes_

# WeightedRandomSampler
class_counts   = np.bincount(y_train)
class_weights  = 1.0 / class_counts
samples_weight = class_weights[y_train]
base_sampler   = WeightedRandomSampler(
    weights=samples_weight,
    num_samples=len(samples_weight),
    replacement=True
)




In [10]:
#resize + normalize
processor = AutoImageProcessor.from_pretrained(model_name)
img_size   = processor.size["height"]

# transform
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(img_size, scale=(0.8,1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
])
val_transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.CenterCrop(img_size),
])

class FERDataset(Dataset):
    def __init__(self, paths, labels, processor, transform=None):
        self.paths     = paths
        self.labels    = labels
        self.processor = processor
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        if self.transform:
            img = self.transform(img)
        pv = self.processor(images=img, return_tensors="pt")["pixel_values"].squeeze(0)
        return pv, self.labels[idx]



In [11]:
kf   = KFold(n_splits=num_folds, shuffle=True, random_state=random_seed)
fold = 0

for train_idx, val_idx in kf.split(train_paths):
    fold += 1
    print(f"Fold {fold}/{num_folds}")

    tr_paths  = [train_paths[i] for i in train_idx]
    tr_labels = [y_train[i]     for i in train_idx]
    vl_paths  = [train_paths[i] for i in val_idx]
    vl_labels = [y_train[i]     for i in val_idx]

    fold_counts  = np.bincount(tr_labels)
    fold_wts     = 1.0 / fold_counts
    fold_sw      = fold_wts[np.array(tr_labels)]
    fold_sampler = WeightedRandomSampler(
        weights=fold_sw,
        num_samples=len(fold_sw),
        replacement=True
    )

    train_ds = FERDataset(tr_paths, tr_labels, processor, transform=train_transform)
    val_ds   = FERDataset(vl_paths, vl_labels, processor, transform=val_transform)

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        sampler=fold_sampler,
        num_workers=0,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True
    )



=== Fold 1/5 ===
=== Fold 2/5 ===
=== Fold 3/5 ===
=== Fold 4/5 ===
=== Fold 5/5 ===


In [12]:
    # freeze the body param
    model = AutoModelForImageClassification.from_pretrained(model_name)
    for param in model.vit.parameters():
        param.requires_grad = False
    # dropout
    model.config.classifier_dropout = 0.3
    model.classifier = nn.Sequential(
        nn.Dropout(model.config.classifier_dropout),
        model.classifier
    )
    model.to(device)

    optimizer = AdamW(
        model.parameters(),
        lr=learning_rate,
        weight_decay=weight_decay
    )

    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)





In [8]:
    for epoch in range(1, num_epochs+1):
        # train
        model.train()
        train_corr = train_tot = 0
        for imgs, lbls in tqdm(train_loader, desc=f"[Fold {fold}] Epoch {epoch} Train"):
            imgs, lbls = imgs.to(device), lbls.to(device)
            outputs    = model(pixel_values=imgs).logits
            loss       = criterion(outputs, lbls)
            loss.backward()

            # gradient cut
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()
            optimizer.zero_grad()

            preds      = outputs.argmax(dim=1)
            train_corr += (preds == lbls).sum().item()
            train_tot  += lbls.size(0)

        train_acc = train_corr / train_tot

        # val
        model.eval()
        val_corr = val_tot = 0
        val_loss = 0.0
        with torch.no_grad():
            for imgs, lbls in tqdm(val_loader, desc=f"[Fold {fold}] Epoch {epoch} Val"):
                imgs, lbls  = imgs.to(device), lbls.to(device)
                outs        = model(pixel_values=imgs).logits
                loss_v      = criterion(outs, lbls)
                val_loss   += loss_v.item()

                preds_v     = outs.argmax(dim=1)
                val_corr   += (preds_v == lbls).sum().item()
                val_tot    += lbls.size(0)

        val_acc = val_corr / val_tot
        val_loss /= len(val_loader)

        print(f"[Fold {fold}] Epoch {epoch}"
              f"Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}")


    

[Fold 5] Epoch 1 Train: 100%|██████████| 1436/1436 [17:11<00:00,  1.39it/s]
[Fold 5] Epoch 1 Val: 100%|██████████| 359/359 [04:14<00:00,  1.41it/s]


[Fold 5] Epoch 1 → Train Acc: 0.7319, Val Acc: 0.9875


[Fold 5] Epoch 2 Train: 100%|██████████| 1436/1436 [16:34<00:00,  1.44it/s]
[Fold 5] Epoch 2 Val: 100%|██████████| 359/359 [04:10<00:00,  1.43it/s]


[Fold 5] Epoch 2 → Train Acc: 0.7506, Val Acc: 0.9866


[Fold 5] Epoch 3 Train: 100%|██████████| 1436/1436 [17:08<00:00,  1.40it/s]
[Fold 5] Epoch 3 Val: 100%|██████████| 359/359 [04:22<00:00,  1.37it/s]


[Fold 5] Epoch 3 → Train Acc: 0.7531, Val Acc: 0.9869


[Fold 5] Epoch 4 Train: 100%|██████████| 1436/1436 [17:43<00:00,  1.35it/s]
[Fold 5] Epoch 4 Val: 100%|██████████| 359/359 [03:58<00:00,  1.50it/s]


[Fold 5] Epoch 4 → Train Acc: 0.7544, Val Acc: 0.9869


[Fold 5] Epoch 5 Train: 100%|██████████| 1436/1436 [1:03:27<00:00,  2.65s/it] 
[Fold 5] Epoch 5 Val: 100%|██████████| 359/359 [04:15<00:00,  1.41it/s]


[Fold 5] Epoch 5 → Train Acc: 0.7601, Val Acc: 0.9869


[Fold 5] Epoch 6 Train: 100%|██████████| 1436/1436 [17:03<00:00,  1.40it/s]
[Fold 5] Epoch 6 Val: 100%|██████████| 359/359 [04:09<00:00,  1.44it/s]


[Fold 5] Epoch 6 → Train Acc: 0.7578, Val Acc: 0.9866


[Fold 5] Epoch 7 Train:  10%|█         | 148/1436 [01:53<15:17,  1.40it/s]