In [None]:
import torch
import torchvision
from torchvision import transforms
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.models import efficientnet_b0
from tqdm import tqdm

# 1. Data augmentation (SimCLR required; Tiny version offers lighter augmentation)
simclr_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.3, 0.3, 0.3, 0.1),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])

# 2. Loading the dataset
dataset = torchvision.datasets.ImageFolder("YOLO_format_cls/train", transform=None)

def collate_fn(batch):
    imgs, labels = zip(*batch)
    return list(imgs), list(labels)

loader = DataLoader(
    dataset,
    batch_size=16,
    shuffle=True,
    num_workers=0,
    drop_last=True,
    collate_fn=collate_fn
)

# 3. Backbone + Projection head（EfficientNet-B0）
class TinySimCLR(nn.Module):
    def __init__(self):
        super().__init__()
        base = efficientnet_b0(weights='IMAGENET1K_V1')
        base.classifier = nn.Identity()  # Remove classification layer
        self.encoder = base

        # Tiny Projection Head (Compact dimensions, minimal graphics memory)
        self.projector = nn.Sequential(
            nn.Linear(1280, 256),
            nn.ReLU(),
            nn.Linear(256, 64)
        )

    def forward(self, x):
        h = self.encoder(x)
        z = self.projector(h)
        return nn.functional.normalize(z, dim=1)


device = "mps" if torch.backends.mps.is_available() else "cpu"
print("Using device:", device)
model = TinySimCLR().to(device)

# 4. Tiny Contrastive Loss Function (No NxN Similarity Matrix Required)
def tiny_contrastive_loss(z1, z2):
    # cosine similarity (negative sign becomes distance)
    return 1 - torch.mean(torch.sum(z1 * z2, dim=1))

# 5. Training Cycle
optimizer = optim.Adam(model.parameters(), lr=2e-4)

print("Start TinySimCLR training...\n")
epochs = 15

for epoch in range(epochs):
    model.train()
    epoch_loss = 0.0

    pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{epochs}", ncols=100)
    for imgs, _ in pbar:
        # Two randomly enhanced versions
        x1 = torch.stack([simclr_transform(img) for img in imgs]).to(device)
        x2 = torch.stack([simclr_transform(img) for img in imgs]).to(device)

        z1 = model(x1)
        z2 = model(x2)

        loss = tiny_contrastive_loss(z1, z2)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        pbar.set_postfix({"loss": f"{loss.item():.4f}"})

    print(f"Epoch {epoch+1} finished | avg loss: {epoch_loss / len(loader):.4f}\n")


torch.save(model.state_dict(), "tinysimclr_effb0_mac.pth")
print("Training complete! Saved as tinysimclr_effb0_mac.pth")


Using device: mps
Start TinySimCLR training...



Epoch 1/15: 100%|████████████████████████████████████| 669/669 [05:44<00:00,  1.94it/s, loss=0.0003]


Epoch 1 finished | avg loss: 0.0044



Epoch 2/15: 100%|████████████████████████████████████| 669/669 [05:41<00:00,  1.96it/s, loss=0.0001]


Epoch 2 finished | avg loss: 0.0002



Epoch 3/15: 100%|████████████████████████████████████| 669/669 [05:48<00:00,  1.92it/s, loss=0.0000]


Epoch 3 finished | avg loss: 0.0001



Epoch 4/15: 100%|████████████████████████████████████| 669/669 [05:47<00:00,  1.93it/s, loss=0.0000]


Epoch 4 finished | avg loss: 0.0000



Epoch 5/15: 100%|████████████████████████████████████| 669/669 [05:42<00:00,  1.95it/s, loss=0.0000]


Epoch 5 finished | avg loss: 0.0000



Epoch 6/15: 100%|████████████████████████████████████| 669/669 [05:43<00:00,  1.95it/s, loss=0.0000]


Epoch 6 finished | avg loss: 0.0000



Epoch 7/15: 100%|████████████████████████████████████| 669/669 [05:46<00:00,  1.93it/s, loss=0.0000]


Epoch 7 finished | avg loss: 0.0000



Epoch 8/15: 100%|████████████████████████████████████| 669/669 [05:41<00:00,  1.96it/s, loss=0.0000]


Epoch 8 finished | avg loss: 0.0000



Epoch 9/15: 100%|████████████████████████████████████| 669/669 [05:45<00:00,  1.93it/s, loss=0.0000]


Epoch 9 finished | avg loss: 0.0000



Epoch 10/15: 100%|███████████████████████████████████| 669/669 [05:46<00:00,  1.93it/s, loss=0.0000]


Epoch 10 finished | avg loss: 0.0000



Epoch 11/15: 100%|███████████████████████████████████| 669/669 [05:50<00:00,  1.91it/s, loss=0.0000]


Epoch 11 finished | avg loss: 0.0000



Epoch 12/15: 100%|███████████████████████████████████| 669/669 [05:47<00:00,  1.92it/s, loss=0.0000]


Epoch 12 finished | avg loss: 0.0000



Epoch 13/15: 100%|███████████████████████████████████| 669/669 [05:49<00:00,  1.91it/s, loss=0.0000]


Epoch 13 finished | avg loss: 0.0000



Epoch 14/15: 100%|███████████████████████████████████| 669/669 [05:45<00:00,  1.94it/s, loss=0.0000]


Epoch 14 finished | avg loss: 0.0000



Epoch 15/15: 100%|███████████████████████████████████| 669/669 [05:45<00:00,  1.94it/s, loss=0.0000]

Epoch 15 finished | avg loss: 0.0000

Training complete! Saved as tinysimclr_effb0_mac.pth





In [None]:
import math
import torch
import torchvision
from torchvision import transforms, datasets
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.models import efficientnet_b0
from tqdm import tqdm

# Collate function — Permit output in PIL or Tensor format
def collate_fn(batch):
    imgs, labels = zip(*batch)  # imgs: PIL.Image list, labels: int list
    return list(imgs), list(labels)


# 1. Data augmentation (version tuned for real-time processing)
# Enhanced Standard Training: Slight Rotation + Flip
train_aug = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(8),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# Weaker strong_aug: Simulates only the slight variations of a real camera.
strong_aug = transforms.Compose([
    transforms.ColorJitter(0.15, 0.15, 0.15, 0.05),
    transforms.RandomGrayscale(p=0.05),
    transforms.RandomRotation(8),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# Verification enhancement: Perform only resizing and standardisation.
val_aug = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# 2. dataset & DataLoader
# Train without applying the transformation, so that we can generate both weak and strong versions ourselves.
train_set = datasets.ImageFolder("YOLO_format_cls/train", transform=None)
val_set   = datasets.ImageFolder("YOLO_format_cls/valid", transform=val_aug)

BATCH_SIZE = 16 
EPOCHS     = 25

train_loader = DataLoader(
    train_set,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_set,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    collate_fn=collate_fn
)

print(f"Train size: {len(train_set)}, Val size: {len(val_set)}")

# 3. Load TinySimCLR EfficientNet-B0 weights
base = efficientnet_b0(weights=None)
base.classifier = nn.Identity()

simclr_weights = torch.load("tinysimclr_effb0_mac.pth", map_location="cpu")
missing, unexpected = base.load_state_dict(simclr_weights, strict=False)
print("Loaded SimCLR backbone. Missing keys:", missing)
print("Unexpected keys:", unexpected)

# 4. Add category headers
NUM_CLASSES = 7

model = nn.Sequential(
    base,
    nn.Linear(1280, 256),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(256, NUM_CLASSES)
)

device = "mps" if torch.backends.mps.is_available() else "cpu"
print("Using device:", device)
model = model.to(device)

# Label smoothing renders the output smoother and more conducive to real-time applications.
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

BASE_LR = 3e-4
optimizer = optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=1e-4)

# consistency loss reduces the weight slightly for greater stability.
LAMBDA_CONS = 0.1


# 5. Consistency loss
def consistency_loss(logits1, logits2):
    p1 = torch.softmax(logits1, dim=1)
    p2 = torch.softmax(logits2, dim=1)
    return torch.mean((p1 - p2) ** 2)

# 6. lr: warmup + cosine decay
total_steps  = EPOCHS * len(train_loader)
warmup_ratio = 0.1          # top 10% step warmup
warmup_steps = int(total_steps * warmup_ratio)

def get_lr(step):
    if step < warmup_steps:
        # Linear warmup: from 0 to BASE_LR
        return BASE_LR * float(step + 1) / float(warmup_steps + 1)
    # Subsequently, the cosine decays to BASE_LR multiplied by 0.1.
    progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
    cosine = 0.5 * (1.0 + math.cos(math.pi * progress))
    min_lr = BASE_LR * 0.1
    return min_lr + (BASE_LR - min_lr) * cosine


# 7. Training cycle (real-time tuned)
global_step = 0

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0.0
    ce_loss_sum = 0.0
    cons_loss_sum = 0.0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", ncols=100)
    for imgs, labels in pbar:
        # imgs: PIL list, labels: Python int list
        # weak aug (used for supervised CE loss)
        x = torch.stack([train_aug(img) for img in imgs]).to(device)
        y = torch.tensor(labels, dtype=torch.long).to(device)

        # strong aug (used for consistency)
        x_strong = torch.stack([strong_aug(img) for img in imgs]).to(device)

        # Dynamically adjusting the learning rate（warmup + cosine）
        lr = get_lr(global_step)
        for g in optimizer.param_groups:
            g['lr'] = lr

        # Forward
        logits = model(x)
        logits_strong = model(x_strong)

        ce = criterion(logits, y)
        cons = consistency_loss(logits, logits_strong)

        loss = ce + LAMBDA_CONS * cons

        optimizer.zero_grad()
        loss.backward()
        # Optional: Gradient cropping for further stabilisation.
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        optimizer.step()

        train_loss += loss.item()
        ce_loss_sum += ce.item()
        cons_loss_sum += cons.item()
        global_step += 1

        pbar.set_postfix({
            "loss": f"{loss.item():.4f}",
            "ce":   f"{ce.item():.4f}",
            "cons": f"{cons.item():.4f}",
            "lr":   f"{lr:.1e}"
        })

    avg_loss = train_loss / len(train_loader)
    avg_ce   = ce_loss_sum / len(train_loader)
    avg_cons = cons_loss_sum / len(train_loader)
    print(f"[Train] Epoch {epoch+1} | loss={avg_loss:.4f} | ce={avg_ce:.4f} | cons={avg_cons:.4f}")

    # Simple Verification
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for imgs, labels in val_loader:
            if isinstance(imgs[0], torch.Tensor):
                x_val = torch.stack(imgs).to(device)
            else:
                x_val = torch.stack([val_aug(img) for img in imgs]).to(device)

            y_val = torch.tensor(labels, dtype=torch.long).to(device)

            logits_val = model(x_val)
            loss_val = criterion(logits_val, y_val)
            val_loss += loss_val.item()

            preds = logits_val.argmax(dim=1)
            correct += (preds == y_val).sum().item()
            total += len(y_val)

    val_loss /= len(val_loader)
    val_acc = correct / total if total > 0 else 0.0
    print(f"[Val]   Epoch {epoch+1} | loss={val_loss:.4f} | acc={val_acc*100:.2f}%\n")

torch.save(model.state_dict(), "affectnet_finetuned_consistency_mac_real_time.pth")
print("Saved: affectnet_finetuned_consistency_mac_real_time.pth")


Train size: 10714, Val size: 3129
Loaded SimCLR backbone. Missing keys: ['features.0.0.weight', 'features.0.1.weight', 'features.0.1.bias', 'features.0.1.running_mean', 'features.0.1.running_var', 'features.1.0.block.0.0.weight', 'features.1.0.block.0.1.weight', 'features.1.0.block.0.1.bias', 'features.1.0.block.0.1.running_mean', 'features.1.0.block.0.1.running_var', 'features.1.0.block.1.fc1.weight', 'features.1.0.block.1.fc1.bias', 'features.1.0.block.1.fc2.weight', 'features.1.0.block.1.fc2.bias', 'features.1.0.block.2.0.weight', 'features.1.0.block.2.1.weight', 'features.1.0.block.2.1.bias', 'features.1.0.block.2.1.running_mean', 'features.1.0.block.2.1.running_var', 'features.2.0.block.0.0.weight', 'features.2.0.block.0.1.weight', 'features.2.0.block.0.1.bias', 'features.2.0.block.0.1.running_mean', 'features.2.0.block.0.1.running_var', 'features.2.0.block.1.0.weight', 'features.2.0.block.1.1.weight', 'features.2.0.block.1.1.bias', 'features.2.0.block.1.1.running_mean', 'features

Epoch 1/25: 100%|█| 670/670 [05:45<00:00,  1.94it/s, loss=1.8441, ce=1.8440, cons=0.0011, lr=1.2e-04


[Train] Epoch 1 | loss=1.8826 | ce=1.8825 | cons=0.0006
[Val]   Epoch 1 | loss=1.8635 | acc=21.06%



Epoch 2/25: 100%|█| 670/670 [05:28<00:00,  2.04it/s, loss=1.7727, ce=1.7726, cons=0.0011, lr=2.4e-04


[Train] Epoch 2 | loss=1.8644 | ce=1.8643 | cons=0.0008
[Val]   Epoch 2 | loss=1.9276 | acc=20.55%



Epoch 3/25: 100%|█| 670/670 [05:23<00:00,  2.07it/s, loss=1.7422, ce=1.7421, cons=0.0009, lr=3.0e-04


[Train] Epoch 3 | loss=1.8489 | ce=1.8488 | cons=0.0012
[Val]   Epoch 3 | loss=1.8949 | acc=28.60%



Epoch 4/25: 100%|█| 670/670 [06:17<00:00,  1.77it/s, loss=1.7421, ce=1.7418, cons=0.0025, lr=3.0e-04


[Train] Epoch 4 | loss=1.7658 | ce=1.7655 | cons=0.0027
[Val]   Epoch 4 | loss=1.6767 | acc=34.13%



Epoch 5/25: 100%|█| 670/670 [06:55<00:00,  1.61it/s, loss=1.5939, ce=1.5935, cons=0.0048, lr=2.9e-04


[Train] Epoch 5 | loss=1.6771 | ce=1.6768 | cons=0.0033
[Val]   Epoch 5 | loss=1.6209 | acc=37.87%



Epoch 6/25: 100%|█| 670/670 [05:31<00:00,  2.02it/s, loss=1.5622, ce=1.5617, cons=0.0058, lr=2.8e-04


[Train] Epoch 6 | loss=1.6244 | ce=1.6240 | cons=0.0039
[Val]   Epoch 6 | loss=1.5410 | acc=42.60%



Epoch 7/25: 100%|█| 670/670 [05:28<00:00,  2.04it/s, loss=1.4158, ce=1.4155, cons=0.0038, lr=2.7e-04


[Train] Epoch 7 | loss=1.5639 | ce=1.5635 | cons=0.0045
[Val]   Epoch 7 | loss=1.4935 | acc=45.19%



Epoch 8/25: 100%|█| 670/670 [05:25<00:00,  2.06it/s, loss=1.1338, ce=1.1328, cons=0.0095, lr=2.6e-04


[Train] Epoch 8 | loss=1.5171 | ce=1.5166 | cons=0.0053
[Val]   Epoch 8 | loss=1.4520 | acc=48.83%



Epoch 9/25: 100%|█| 670/670 [05:24<00:00,  2.06it/s, loss=1.6102, ce=1.6095, cons=0.0065, lr=2.5e-04


[Train] Epoch 9 | loss=1.4593 | ce=1.4586 | cons=0.0067
[Val]   Epoch 9 | loss=1.3800 | acc=53.21%



Epoch 10/25: 100%|█| 670/670 [05:27<00:00,  2.05it/s, loss=1.2689, ce=1.2686, cons=0.0027, lr=2.3e-0


[Train] Epoch 10 | loss=1.4031 | ce=1.4023 | cons=0.0079
[Val]   Epoch 10 | loss=1.3355 | acc=55.32%



Epoch 11/25: 100%|█| 670/670 [05:30<00:00,  2.03it/s, loss=1.4511, ce=1.4501, cons=0.0104, lr=2.2e-0


[Train] Epoch 11 | loss=1.3410 | ce=1.3401 | cons=0.0091
[Val]   Epoch 11 | loss=1.3341 | acc=55.70%



Epoch 12/25: 100%|█| 670/670 [05:35<00:00,  2.00it/s, loss=1.1549, ce=1.1540, cons=0.0098, lr=2.0e-0


[Train] Epoch 12 | loss=1.2810 | ce=1.2799 | cons=0.0105
[Val]   Epoch 12 | loss=1.2976 | acc=58.42%



Epoch 13/25: 100%|█| 670/670 [05:32<00:00,  2.02it/s, loss=1.0355, ce=1.0346, cons=0.0089, lr=1.8e-0


[Train] Epoch 13 | loss=1.2304 | ce=1.2293 | cons=0.0117
[Val]   Epoch 13 | loss=1.2415 | acc=60.24%



Epoch 14/25: 100%|█| 670/670 [05:35<00:00,  2.00it/s, loss=1.6070, ce=1.6063, cons=0.0066, lr=1.6e-0


[Train] Epoch 14 | loss=1.1831 | ce=1.1819 | cons=0.0122
[Val]   Epoch 14 | loss=1.2347 | acc=60.72%



Epoch 15/25: 100%|█| 670/670 [05:35<00:00,  2.00it/s, loss=1.0023, ce=1.0014, cons=0.0090, lr=1.4e-0


[Train] Epoch 15 | loss=1.1313 | ce=1.1299 | cons=0.0135
[Val]   Epoch 15 | loss=1.2146 | acc=62.48%



Epoch 16/25: 100%|█| 670/670 [05:30<00:00,  2.03it/s, loss=0.9972, ce=0.9957, cons=0.0149, lr=1.2e-0


[Train] Epoch 16 | loss=1.0914 | ce=1.0900 | cons=0.0138
[Val]   Epoch 16 | loss=1.2101 | acc=62.96%



Epoch 17/25: 100%|█| 670/670 [05:22<00:00,  2.08it/s, loss=1.2916, ce=1.2895, cons=0.0207, lr=1.1e-0


[Train] Epoch 17 | loss=1.0532 | ce=1.0518 | cons=0.0147
[Val]   Epoch 17 | loss=1.1908 | acc=63.63%



Epoch 18/25: 100%|█| 670/670 [05:25<00:00,  2.06it/s, loss=1.1685, ce=1.1680, cons=0.0057, lr=9.0e-0


[Train] Epoch 18 | loss=1.0163 | ce=1.0147 | cons=0.0156
[Val]   Epoch 18 | loss=1.1812 | acc=66.03%



Epoch 19/25: 100%|█| 670/670 [05:21<00:00,  2.08it/s, loss=0.8239, ce=0.8216, cons=0.0234, lr=7.5e-0


[Train] Epoch 19 | loss=0.9748 | ce=0.9732 | cons=0.0162
[Val]   Epoch 19 | loss=1.1938 | acc=65.77%



Epoch 20/25: 100%|█| 670/670 [05:20<00:00,  2.09it/s, loss=0.9485, ce=0.9477, cons=0.0084, lr=6.2e-0


[Train] Epoch 20 | loss=0.9489 | ce=0.9472 | cons=0.0166
[Val]   Epoch 20 | loss=1.2000 | acc=65.58%



Epoch 21/25: 100%|█| 670/670 [05:19<00:00,  2.10it/s, loss=0.6293, ce=0.6285, cons=0.0080, lr=5.1e-0


[Train] Epoch 21 | loss=0.9238 | ce=0.9221 | cons=0.0174
[Val]   Epoch 21 | loss=1.1981 | acc=65.77%



Epoch 22/25: 100%|█| 670/670 [05:19<00:00,  2.10it/s, loss=1.0718, ce=1.0691, cons=0.0271, lr=4.2e-0


[Train] Epoch 22 | loss=0.8906 | ce=0.8888 | cons=0.0178
[Val]   Epoch 22 | loss=1.1975 | acc=66.83%



Epoch 23/25: 100%|█| 670/670 [05:18<00:00,  2.10it/s, loss=0.8362, ce=0.8352, cons=0.0101, lr=3.5e-0


[Train] Epoch 23 | loss=0.8732 | ce=0.8715 | cons=0.0177
[Val]   Epoch 23 | loss=1.1979 | acc=66.19%



Epoch 24/25: 100%|█| 670/670 [05:19<00:00,  2.10it/s, loss=1.0682, ce=1.0672, cons=0.0097, lr=3.1e-0


[Train] Epoch 24 | loss=0.8566 | ce=0.8548 | cons=0.0177
[Val]   Epoch 24 | loss=1.2092 | acc=66.22%



Epoch 25/25: 100%|█| 670/670 [05:19<00:00,  2.10it/s, loss=0.8065, ce=0.8053, cons=0.0117, lr=3.0e-0


[Train] Epoch 25 | loss=0.8365 | ce=0.8347 | cons=0.0175
[Val]   Epoch 25 | loss=1.2198 | acc=66.12%

Saved: affectnet_finetuned_consistency_mac_real_time.pth


In [None]:
# fine tuning v2(class weight + mouth-based auxiliary head + 5 classes)
import os
import math
from collections import Counter

import torch
import torchvision
from torchvision import transforms, datasets
from torch import nn, optim
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision.models import efficientnet_b0
from tqdm import tqdm

DATA_ROOT   = "YOLO_format_cls"
TRAIN_DIR   = os.path.join(DATA_ROOT, "train")
VAL_DIR     = os.path.join(DATA_ROOT, "valid")
SIMCLR_PATH = "tinysimclr_effb0_mac.pth"
SAVE_PATH   = "affectnet_fer5_mouth.pth"

BATCH_SIZE  = 16
EPOCHS      = 25
BASE_LR     = 3e-4
NUM_CLASSES = 5

device = "mps" if torch.backends.mps.is_available() else "cpu"
print("Using device:", device)


# 1. collate_fn: Preserve PIL images
def collate_fn(batch):
    imgs, labels = zip(*batch)    # imgs: list[PIL.Image], labels: list[int]
    return list(imgs), list(labels)


# 2. Data Augmentation (Lightweight Version for Real-Time Scenarios)
# Primary Branch: Micro-Rotation + Mirroring
train_aug = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(8),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# Mouth branch: Apply a similar enhancement to the mouth crop.
mouth_aug = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(8),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# Verification: Resize only + normalisation
val_aug = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# 7 to 5 category mapping: original index to new 5 categories
# 0: angry     -> 0 (Angry)
# 1: disgusted -> 0 (Angry)
# 2: fearful   -> 3 (Surprise)
# 3: happy     -> 1 (Happy)
# 4: neutral   -> 4 (Neutral)
# 5: sad       -> 2 (Sad)
# 6: surprised -> 3 (Surprise)
mapping_5 = torch.tensor([0, 0, 3, 1, 4, 2, 3], dtype=torch.long)


# 3. Mouth Crop function (based on lower half of face ROI)
def crop_mouth_pil(img):
    """
    Simple Mouth ROI: Select the lower 45% of the face
    img: PIL.Image
    """
    w, h = img.size
    top = int(h * 0.55)
    # (left, upper, right, lower)
    return img.crop((0, top, w, h))


# 4. dataset + Balanced Sampler + class weight
train_set = datasets.ImageFolder(TRAIN_DIR, transform=None)
val_set   = datasets.ImageFolder(VAL_DIR, transform=None)

print(f"Train size (7-class): {len(train_set)}, Val size: {len(val_set)}")

# First, calculate the distribution of labels after the 7 to 5 mapping for use in the sampler/class weighting.
new_targets = []
for _, orig_label in train_set.samples:
    new_label = int(mapping_5[orig_label])
    new_targets.append(new_label)

counter_5 = Counter(new_targets)
print("5-class counts in train:", counter_5)

# class weight: 1 / count（(Renormalisation)
class_weights = []
for c in range(NUM_CLASSES):
    class_weights.append(1.0 / counter_5[c])

# Normalise (optional)
sum_w = sum(class_weights)
class_weights = [w / sum_w * NUM_CLASSES for w in class_weights]

class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(device)
print("Class weights (for CE):", class_weights_tensor.tolist())

# WeightedRandomSampler
sample_weights = [class_weights[label] for label in new_targets]
sample_weights = torch.tensor(sample_weights, dtype=torch.float32)

sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True
)

train_loader = DataLoader(
    train_set,
    batch_size=BATCH_SIZE,
    sampler=sampler,
    num_workers=0,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_set,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    collate_fn=collate_fn
)


# 5. Backbone + Mouth auxiliary head model
class FER5WithMouth(nn.Module):
    def __init__(self, backbone, num_classes=5):
        super().__init__()
        self.backbone = backbone  # EfficientNet-B0, classifier=Identity (output 1280)

        self.head_main = nn.Sequential(
            nn.Linear(1280, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

        self.head_mouth = nn.Sequential(
            nn.Linear(1280, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )

    def forward(self, x_full, x_mouth):
        feat_full = self.backbone(x_full)   # [B, 1280]
        feat_mouth = self.backbone(x_mouth) # [B, 1280]

        logits_main = self.head_main(feat_full)
        logits_mouth = self.head_mouth(feat_mouth)
        return logits_main, logits_mouth

    def forward_main(self, x_full):
        feat_full = self.backbone(x_full)
        return self.head_main(feat_full)


# Load pre-trained backbone（SimCLR）
base = efficientnet_b0(weights=None)
base.classifier = nn.Identity()

simclr_weights = torch.load(SIMCLR_PATH, map_location="cpu")
missing, unexpected = base.load_state_dict(simclr_weights, strict=False)
print("Loaded SimCLR backbone. Missing keys:", missing)
print("Unexpected keys:", unexpected)

model = FER5WithMouth(base, num_classes=NUM_CLASSES).to(device)

# CE with class weight + label smoothing
criterion = nn.CrossEntropyLoss(
    weight=class_weights_tensor,
    label_smoothing=0.1
)

# Loss weights for the auxiliary head
LAMBDA_MOUTH = 0.5

optimizer = optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=1e-4)

# 6. LR warmup + cosine decay
total_steps  = EPOCHS * len(train_loader)
warmup_ratio = 0.1
warmup_steps = int(total_steps * warmup_ratio)

def get_lr(step):
    if step < warmup_steps:
        return BASE_LR * float(step + 1) / float(warmup_steps + 1)
    progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
    cosine = 0.5 * (1.0 + math.cos(math.pi * progress))
    min_lr = BASE_LR * 0.1
    return min_lr + (BASE_LR - min_lr) * cosine

# 7. Training cycle
global_step = 0

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0.0
    ce_main_sum = 0.0
    ce_mouth_sum = 0.0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", ncols=100)
    for imgs, labels in pbar:
        # Original 7-class label -> Tensor
        y_orig = torch.tensor(labels, dtype=torch.long)  # [B]
        # Mapped to the new 5 categories
        y_5 = mapping_5[y_orig]   # [B]
        y_5 = y_5.to(device)

        # Main branch input
        x_full = torch.stack([train_aug(img) for img in imgs]).to(device)

        # Mouth branch input (first crop the mouth, then apply mouth augmentation)
        mouth_imgs = [crop_mouth_pil(img) for img in imgs]
        x_mouth = torch.stack([mouth_aug(m) for m in mouth_imgs]).to(device)

        # Update learning rate
        lr = get_lr(global_step)
        for g in optimizer.param_groups:
            g['lr'] = lr

        # Forward
        logits_main, logits_mouth = model(x_full, x_mouth)

        ce_main = criterion(logits_main, y_5)
        ce_mouth = criterion(logits_mouth, y_5)

        loss = ce_main + LAMBDA_MOUTH * ce_mouth

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        optimizer.step()

        train_loss += loss.item()
        ce_main_sum += ce_main.item()
        ce_mouth_sum += ce_mouth.item()
        global_step += 1

        pbar.set_postfix({
            "loss": f"{loss.item():.4f}",
            "main": f"{ce_main.item():.4f}",
            "mouth": f"{ce_mouth.item():.4f}",
            "lr":   f"{lr:.1e}"
        })

    avg_loss = train_loss / len(train_loader)
    avg_main = ce_main_sum / len(train_loader)
    avg_mouth = ce_mouth_sum / len(train_loader)
    print(f"[Train] Epoch {epoch+1} | loss={avg_loss:.4f} | main={avg_main:.4f} | mouth={avg_mouth:.4f}")

    # Verification (main header only)
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for imgs, labels in val_loader:
            y_orig = torch.tensor(labels, dtype=torch.long)
            y_5 = mapping_5[y_orig].to(device)

            x_val = torch.stack([val_aug(img) for img in imgs]).to(device)

            logits_val = model.forward_main(x_val)
            loss_val = criterion(logits_val, y_5)
            val_loss += loss_val.item()

            preds = logits_val.argmax(dim=1)
            correct += (preds == y_5).sum().item()
            total += len(y_5)

    val_loss /= len(val_loader)
    val_acc = correct / total if total > 0 else 0.0
    print(f"[Val]   Epoch {epoch+1} | loss={val_loss:.4f} | acc={val_acc*100:.2f}%\n")

torch.save(model.state_dict(), SAVE_PATH)
print(f"Saved model to {SAVE_PATH}")


Using device: mps
Train size (7-class): 10714, Val size: 3129
5-class counts in train: Counter({3: 4170, 0: 3891, 2: 1582, 4: 744, 1: 327})
Class weights (for CE): [0.23232516646385193, 2.764456272125244, 0.5714141726493835, 0.2167811095714569, 1.215023159980774]
Loaded SimCLR backbone. Missing keys: ['features.0.0.weight', 'features.0.1.weight', 'features.0.1.bias', 'features.0.1.running_mean', 'features.0.1.running_var', 'features.1.0.block.0.0.weight', 'features.1.0.block.0.1.weight', 'features.1.0.block.0.1.bias', 'features.1.0.block.0.1.running_mean', 'features.1.0.block.0.1.running_var', 'features.1.0.block.1.fc1.weight', 'features.1.0.block.1.fc1.bias', 'features.1.0.block.1.fc2.weight', 'features.1.0.block.1.fc2.bias', 'features.1.0.block.2.0.weight', 'features.1.0.block.2.1.weight', 'features.1.0.block.2.1.bias', 'features.1.0.block.2.1.running_mean', 'features.1.0.block.2.1.running_var', 'features.2.0.block.0.0.weight', 'features.2.0.block.0.1.weight', 'features.2.0.block.0.1

Epoch 1/25: 100%|█| 670/670 [05:38<00:00,  1.98it/s, loss=1.2741, main=0.8555, mouth=0.8373, lr=1.2e


[Train] Epoch 1 | loss=1.9957 | main=1.3257 | mouth=1.3401
[Val]   Epoch 1 | loss=3.1075 | acc=3.10%



Epoch 2/25: 100%|█| 670/670 [05:23<00:00,  2.07it/s, loss=2.2458, main=1.4744, mouth=1.5428, lr=2.4e


[Train] Epoch 2 | loss=1.8396 | main=1.2229 | mouth=1.2335
[Val]   Epoch 2 | loss=2.7359 | acc=5.94%



Epoch 3/25: 100%|█| 670/670 [05:25<00:00,  2.06it/s, loss=1.4137, main=0.8732, mouth=1.0810, lr=3.0e


[Train] Epoch 3 | loss=1.5659 | main=1.0332 | mouth=1.0653
[Val]   Epoch 3 | loss=3.2052 | acc=8.25%



Epoch 4/25: 100%|█| 670/670 [05:26<00:00,  2.05it/s, loss=1.6353, main=1.0019, mouth=1.2668, lr=3.0e


[Train] Epoch 4 | loss=1.3591 | main=0.8874 | mouth=0.9434
[Val]   Epoch 4 | loss=2.9958 | acc=12.18%



Epoch 5/25: 100%|█| 670/670 [05:20<00:00,  2.09it/s, loss=1.0680, main=0.6804, mouth=0.7752, lr=2.9e


[Train] Epoch 5 | loss=1.2728 | main=0.8224 | mouth=0.9007
[Val]   Epoch 5 | loss=2.6494 | acc=12.59%



Epoch 6/25: 100%|█| 670/670 [05:26<00:00,  2.06it/s, loss=0.9665, main=0.6410, mouth=0.6510, lr=2.8e


[Train] Epoch 6 | loss=1.1742 | main=0.7506 | mouth=0.8471
[Val]   Epoch 6 | loss=2.5499 | acc=19.14%



Epoch 7/25: 100%|█| 670/670 [05:19<00:00,  2.09it/s, loss=0.8011, main=0.5118, mouth=0.5787, lr=2.7e


[Train] Epoch 7 | loss=1.1077 | main=0.7056 | mouth=0.8042
[Val]   Epoch 7 | loss=2.7075 | acc=25.44%



Epoch 8/25: 100%|█| 670/670 [05:22<00:00,  2.08it/s, loss=1.7768, main=1.1417, mouth=1.2701, lr=2.6e


[Train] Epoch 8 | loss=1.0556 | main=0.6644 | mouth=0.7824
[Val]   Epoch 8 | loss=2.5010 | acc=33.37%



Epoch 9/25: 100%|█| 670/670 [05:23<00:00,  2.07it/s, loss=0.6689, main=0.4222, mouth=0.4934, lr=2.5e


[Train] Epoch 9 | loss=1.0401 | main=0.6521 | mouth=0.7760
[Val]   Epoch 9 | loss=2.4204 | acc=38.99%



Epoch 10/25: 100%|█| 670/670 [05:23<00:00,  2.07it/s, loss=0.6260, main=0.4373, mouth=0.3775, lr=2.3


[Train] Epoch 10 | loss=0.9948 | main=0.6195 | mouth=0.7506
[Val]   Epoch 10 | loss=2.1748 | acc=46.31%



Epoch 11/25: 100%|█| 670/670 [05:23<00:00,  2.07it/s, loss=1.1281, main=0.5776, mouth=1.1010, lr=2.2


[Train] Epoch 11 | loss=0.9364 | main=0.5772 | mouth=0.7185
[Val]   Epoch 11 | loss=2.3479 | acc=40.91%



Epoch 12/25: 100%|█| 670/670 [05:30<00:00,  2.03it/s, loss=1.7144, main=1.0189, mouth=1.3911, lr=2.0


[Train] Epoch 12 | loss=0.9350 | main=0.5739 | mouth=0.7220
[Val]   Epoch 12 | loss=2.0433 | acc=51.74%



Epoch 13/25: 100%|█| 670/670 [05:31<00:00,  2.02it/s, loss=0.6658, main=0.3551, mouth=0.6213, lr=1.8


[Train] Epoch 13 | loss=0.8955 | main=0.5501 | mouth=0.6907
[Val]   Epoch 13 | loss=2.1880 | acc=47.01%



Epoch 14/25: 100%|█| 670/670 [05:23<00:00,  2.07it/s, loss=0.7295, main=0.4092, mouth=0.6406, lr=1.6


[Train] Epoch 14 | loss=0.8521 | main=0.5224 | mouth=0.6594
[Val]   Epoch 14 | loss=2.0875 | acc=51.36%



Epoch 15/25: 100%|█| 670/670 [05:24<00:00,  2.06it/s, loss=0.8950, main=0.5725, mouth=0.6449, lr=1.4


[Train] Epoch 15 | loss=0.8428 | main=0.5164 | mouth=0.6528
[Val]   Epoch 15 | loss=2.0318 | acc=56.92%



Epoch 16/25: 100%|█| 670/670 [05:27<00:00,  2.04it/s, loss=0.5221, main=0.3292, mouth=0.3858, lr=1.2


[Train] Epoch 16 | loss=0.8022 | main=0.4890 | mouth=0.6263
[Val]   Epoch 16 | loss=1.8374 | acc=61.43%



Epoch 17/25: 100%|█| 670/670 [05:34<00:00,  2.00it/s, loss=0.6582, main=0.3894, mouth=0.5375, lr=1.1


[Train] Epoch 17 | loss=0.7837 | main=0.4735 | mouth=0.6205
[Val]   Epoch 17 | loss=1.9848 | acc=57.88%



Epoch 18/25: 100%|█| 670/670 [05:25<00:00,  2.06it/s, loss=0.6174, main=0.4186, mouth=0.3976, lr=9.0


[Train] Epoch 18 | loss=0.7667 | main=0.4669 | mouth=0.5995
[Val]   Epoch 18 | loss=1.9119 | acc=60.95%



Epoch 19/25: 100%|█| 670/670 [05:24<00:00,  2.07it/s, loss=1.0313, main=0.6433, mouth=0.7760, lr=7.5


[Train] Epoch 19 | loss=0.7194 | main=0.4365 | mouth=0.5659
[Val]   Epoch 19 | loss=1.8559 | acc=63.37%



Epoch 20/25: 100%|█| 670/670 [05:26<00:00,  2.05it/s, loss=1.0193, main=0.6017, mouth=0.8352, lr=6.2


[Train] Epoch 20 | loss=0.7272 | main=0.4442 | mouth=0.5661
[Val]   Epoch 20 | loss=1.8626 | acc=62.83%



Epoch 21/25: 100%|█| 670/670 [05:25<00:00,  2.06it/s, loss=0.5901, main=0.3841, mouth=0.4121, lr=5.1


[Train] Epoch 21 | loss=0.7072 | main=0.4310 | mouth=0.5523
[Val]   Epoch 21 | loss=1.8155 | acc=65.58%



Epoch 22/25: 100%|█| 670/670 [05:24<00:00,  2.07it/s, loss=0.4583, main=0.2941, mouth=0.3284, lr=4.2


[Train] Epoch 22 | loss=0.7119 | main=0.4349 | mouth=0.5540
[Val]   Epoch 22 | loss=1.8592 | acc=63.37%



Epoch 23/25: 100%|█| 670/670 [05:24<00:00,  2.06it/s, loss=0.4451, main=0.2781, mouth=0.3340, lr=3.5


[Train] Epoch 23 | loss=0.7177 | main=0.4359 | mouth=0.5635
[Val]   Epoch 23 | loss=1.7837 | acc=66.63%



Epoch 24/25: 100%|█| 670/670 [05:26<00:00,  2.05it/s, loss=0.6317, main=0.3921, mouth=0.4793, lr=3.1


[Train] Epoch 24 | loss=0.6926 | main=0.4242 | mouth=0.5369
[Val]   Epoch 24 | loss=1.8484 | acc=65.26%



Epoch 25/25: 100%|█| 670/670 [05:24<00:00,  2.06it/s, loss=0.7146, main=0.4796, mouth=0.4700, lr=3.0


[Train] Epoch 25 | loss=0.6869 | main=0.4188 | mouth=0.5363
[Val]   Epoch 25 | loss=1.7563 | acc=68.20%

Saved model to affectnet_fer5_mouth.pth


In [None]:
# fine tuning v3(back to 7 classes)
# fine_tune_7cls_mouth.py
# 7 classes + SimCLR backbone + class weight + balanced sampler + mouth auxiliary head

import os
import math
from collections import Counter

import torch
import torchvision
from torchvision import transforms, datasets
from torch import nn, optim
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision.models import efficientnet_b0
from tqdm import tqdm

DATA_ROOT   = "YOLO_format_cls"
TRAIN_DIR   = os.path.join(DATA_ROOT, "train")
VAL_DIR     = os.path.join(DATA_ROOT, "valid")

SIMCLR_PATH = "tinysimclr_effb0_mac.pth"   # Previously trained TinySimCLR weights
SAVE_PATH   = "affectnet_7cls_mouth.pth"   # Output model filename

BATCH_SIZE  = 16
EPOCHS      = 25
BASE_LR     = 3e-4
NUM_CLASSES = 7

device = "mps" if torch.backends.mps.is_available() else "cpu"
print("Using device:", device)


# 1. collate_fn: Preserve PIL image
def collate_fn(batch):
    imgs, labels = zip(*batch)   # imgs: list[PIL.Image], labels: list[int]
    return list(imgs), list(labels)


# 2. Data Augmentation (Real-time Scenario-Friendly Edition)
# Primary branch: Mild rotation + Mirroring
train_aug = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(8),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# Mouth branch: Apply a similar enhancement to the mouth crop.
mouth_aug = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(8),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# Verification: Resize only + normalisation
val_aug = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# 3. Simple 'mouth crop' function
def crop_mouth_pil(img):
    """
    Non-critical point version: Directly take the lower half of the face and treat it as the "mouth region".
    img: PIL.Image
    """
    w, h = img.size
    top = int(h * 0.55)  # Discard the top 55%, retaining the bottom 45%.
    return img.crop((0, top, w, h))  # (left, upper, right, lower)


# 4. Dataset + Category Statistics + Balanced Sampler + Class Weight
train_set = datasets.ImageFolder(TRAIN_DIR, transform=None)
val_set   = datasets.ImageFolder(VAL_DIR, transform=None)

print(f"Train size (7-class): {len(train_set)}, Val size: {len(val_set)}")
print("Class index mapping:", train_set.class_to_idx)  # Observe the sequence of angry/disgust/...

# ImageFolder will store all labels train_set.targets（list[int]）
orig_targets = train_set.targets

# Count the number of samples in each category
counter_7 = Counter(orig_targets)
print("7-class counts in train:", counter_7)

# Calculate class weight: 1 / count Then normalise
class_weights = []
for c in range(NUM_CLASSES):
    count_c = counter_7.get(c, 1)  # Theoretically, it cannot be zero.
    class_weights.append(1.0 / count_c)

sum_w = sum(class_weights)
class_weights = [w / sum_w * NUM_CLASSES for w in class_weights]

class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(device)
print("Class weights (for CE):", class_weights_tensor.tolist())

# Balanced Sampler（WeightedRandomSampler
sample_weights = [class_weights[label] for label in orig_targets]
sample_weights = torch.tensor(sample_weights, dtype=torch.float32)

sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True
)

train_loader = DataLoader(
    train_set,
    batch_size=BATCH_SIZE,
    sampler=sampler,
    num_workers=0,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_set,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    collate_fn=collate_fn
)


# 5. Backbone + Mouth auxiliary head model
class FER7WithMouth(nn.Module):
    """
    EfficientNet-B0 backbone + Primary Head (Full Face) + Secondary Head (Mouth)
    """
    def __init__(self, backbone, num_classes=7):
        super().__init__()
        self.backbone = backbone  # EfficientNet-B0, classifier = Identity（output 1280）

        # Main Head: Full FaceMain Head: Full Face
        self.head_main = nn.Sequential(
            nn.Linear(1280, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

        # Auxiliary head: Mouth
        self.head_mouth = nn.Sequential(
            nn.Linear(1280, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )

    def forward(self, x_full, x_mouth):
        feat_full = self.backbone(x_full)    # [B, 1280]
        feat_mouth = self.backbone(x_mouth)  # [B, 1280]（Weight sharing

        logits_main = self.head_main(feat_full)
        logits_mouth = self.head_mouth(feat_mouth)
        return logits_main, logits_mouth

    def forward_main(self, x_full):
        feat_full = self.backbone(x_full)
        return self.head_main(feat_full)


# Loading SimCLR pre-trained backbone
base = efficientnet_b0(weights=None)
base.classifier = nn.Identity()

print(f"Loading SimCLR weights from {SIMCLR_PATH} ...")
simclr_weights = torch.load(SIMCLR_PATH, map_location="cpu")
missing, unexpected = base.load_state_dict(simclr_weights, strict=False)
print("Loaded SimCLR backbone.")
print("Missing keys:", missing)
print("Unexpected keys:", unexpected)

model = FER7WithMouth(base, num_classes=NUM_CLASSES).to(device)

# CE with class weight + label smoothing
criterion = nn.CrossEntropyLoss(
    weight=class_weights_tensor,
    label_smoothing=0.1
)

# Weight for auxiliary head loss
LAMBDA_MOUTH = 0.5

optimizer = optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=1e-4)

# 6. LR warmup + cosine decay
total_steps  = EPOCHS * len(train_loader)
warmup_ratio = 0.1
warmup_steps = int(total_steps * warmup_ratio)

def get_lr(step):
    if step < warmup_steps:
        # Linear warmup: 0 to BASE_LR
        return BASE_LR * float(step + 1) / float(warmup_steps + 1)
    # Subsequently, the cosine decays to BASE_LR multiplied by 0.1.
    progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
    cosine = 0.5 * (1.0 + math.cos(math.pi * progress))
    min_lr = BASE_LR * 0.1
    return min_lr + (BASE_LR - min_lr) * cosine

# 7. Training cycle
global_step = 0

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0.0
    ce_main_sum = 0.0
    ce_mouth_sum = 0.0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", ncols=100)
    for imgs, labels in pbar:
        # labels: list[int]，0~6
        y = torch.tensor(labels, dtype=torch.long).to(device)  # [B]

        # Main branch (full face)
        x_full = torch.stack([train_aug(img) for img in imgs]).to(device)

        # mouth branch（mouth dection + mouth_aug）
        mouth_imgs = [crop_mouth_pil(img) for img in imgs]
        x_mouth = torch.stack([mouth_aug(m) for m in mouth_imgs]).to(device)

        # Update learning rate
        lr = get_lr(global_step)
        for g in optimizer.param_groups:
            g["lr"] = lr

        # Forward
        logits_main, logits_mouth = model(x_full, x_mouth)

        ce_main = criterion(logits_main, y)
        ce_mouth = criterion(logits_mouth, y)

        loss = ce_main + LAMBDA_MOUTH * ce_mouth

        optimizer.zero_grad()
        loss.backward()
        # Gradient cropping, steady now
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        optimizer.step()

        train_loss += loss.item()
        ce_main_sum += ce_main.item()
        ce_mouth_sum += ce_mouth.item()
        global_step += 1

        pbar.set_postfix({
            "loss":  f"{loss.item():.4f}",
            "main":  f"{ce_main.item():.4f}",
            "mouth": f"{ce_mouth.item():.4f}",
            "lr":    f"{lr:.1e}"
        })

    avg_loss = train_loss / len(train_loader)
    avg_main = ce_main_sum / len(train_loader)
    avg_mouth = ce_mouth_sum / len(train_loader)
    print(f"[Train] Epoch {epoch+1} | loss={avg_loss:.4f} | main={avg_main:.4f} | mouth={avg_mouth:.4f}")

    # 8. Verification (main header only)
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for imgs, labels in val_loader:
            y_val = torch.tensor(labels, dtype=torch.long).to(device)

            x_val = torch.stack([val_aug(img) for img in imgs]).to(device)

            logits_val = model.forward_main(x_val)
            loss_val = criterion(logits_val, y_val)
            val_loss += loss_val.item()

            preds = logits_val.argmax(dim=1)
            correct += (preds == y_val).sum().item()
            total += len(y_val)

    val_loss /= len(val_loader)
    val_acc = correct / total if total > 0 else 0.0
    print(f"[Val]   Epoch {epoch+1} | loss={val_loss:.4f} | acc={val_acc*100:.2f}%\n")

torch.save(model.state_dict(), SAVE_PATH)
print(f"Saved model to {SAVE_PATH}")


Using device: mps
Train size (7-class): 10714, Val size: 3129
Class index mapping: {'angry': 0, 'disgusted': 1, 'fearful': 2, 'happy': 3, 'neutral': 4, 'sad': 5, 'surprised': 6}
7-class counts in train: Counter({6: 2248, 0: 2026, 2: 1922, 1: 1865, 5: 1582, 4: 744, 3: 327})
Class weights (for CE): [0.4915323853492737, 0.5339649319648743, 0.5181293487548828, 3.045396327972412, 1.3385008573532104, 0.6294845938682556, 0.44299137592315674]
Loading SimCLR weights from tinysimclr_effb0_mac.pth ...
Loaded SimCLR backbone.
Missing keys: ['features.0.0.weight', 'features.0.1.weight', 'features.0.1.bias', 'features.0.1.running_mean', 'features.0.1.running_var', 'features.1.0.block.0.0.weight', 'features.1.0.block.0.1.weight', 'features.1.0.block.0.1.bias', 'features.1.0.block.0.1.running_mean', 'features.1.0.block.0.1.running_var', 'features.1.0.block.1.fc1.weight', 'features.1.0.block.1.fc1.bias', 'features.1.0.block.1.fc2.weight', 'features.1.0.block.1.fc2.bias', 'features.1.0.block.2.0.weight'

Epoch 1/25: 100%|█| 670/670 [05:54<00:00,  1.89it/s, loss=2.5194, main=1.6410, mouth=1.7567, lr=1.2e


[Train] Epoch 1 | loss=2.6546 | main=1.7679 | mouth=1.7733
[Val]   Epoch 1 | loss=2.6095 | acc=3.29%



Epoch 2/25: 100%|█| 670/670 [05:26<00:00,  2.05it/s, loss=2.4057, main=1.5473, mouth=1.7167, lr=2.4e


[Train] Epoch 2 | loss=2.5659 | main=1.7102 | mouth=1.7114
[Val]   Epoch 2 | loss=3.2005 | acc=3.00%



Epoch 3/25: 100%|█| 670/670 [05:20<00:00,  2.09it/s, loss=1.5711, main=1.0384, mouth=1.0656, lr=3.0e


[Train] Epoch 3 | loss=2.4401 | main=1.6169 | mouth=1.6463
[Val]   Epoch 3 | loss=2.4019 | acc=10.83%



Epoch 4/25: 100%|█| 670/670 [05:24<00:00,  2.07it/s, loss=1.2669, main=0.7880, mouth=0.9578, lr=3.0e


[Train] Epoch 4 | loss=2.2365 | main=1.4712 | mouth=1.5307
[Val]   Epoch 4 | loss=2.3826 | acc=11.38%



Epoch 5/25: 100%|█| 670/670 [05:35<00:00,  2.00it/s, loss=2.2532, main=1.5069, mouth=1.4927, lr=2.9e


[Train] Epoch 5 | loss=2.0434 | main=1.3310 | mouth=1.4247
[Val]   Epoch 5 | loss=2.2386 | acc=22.31%



Epoch 6/25: 100%|█| 670/670 [05:33<00:00,  2.01it/s, loss=1.7423, main=1.1333, mouth=1.2180, lr=2.8e


[Train] Epoch 6 | loss=1.9054 | main=1.2270 | mouth=1.3568
[Val]   Epoch 6 | loss=2.1749 | acc=29.56%



Epoch 7/25: 100%|█| 670/670 [05:25<00:00,  2.06it/s, loss=1.9278, main=1.2412, mouth=1.3731, lr=2.7e


[Train] Epoch 7 | loss=1.8507 | main=1.1800 | mouth=1.3414
[Val]   Epoch 7 | loss=2.0285 | acc=31.32%



Epoch 8/25: 100%|█| 670/670 [05:32<00:00,  2.01it/s, loss=2.0048, main=1.2626, mouth=1.4845, lr=2.6e


[Train] Epoch 8 | loss=1.7812 | main=1.1286 | mouth=1.3052
[Val]   Epoch 8 | loss=2.1659 | acc=31.42%



Epoch 9/25: 100%|█| 670/670 [05:26<00:00,  2.05it/s, loss=1.8605, main=1.2075, mouth=1.3059, lr=2.5e


[Train] Epoch 9 | loss=1.6461 | main=1.0354 | mouth=1.2213
[Val]   Epoch 9 | loss=1.9659 | acc=38.80%



Epoch 10/25: 100%|█| 670/670 [05:22<00:00,  2.08it/s, loss=1.0644, main=0.6607, mouth=0.8074, lr=2.3


[Train] Epoch 10 | loss=1.6335 | main=1.0187 | mouth=1.2295
[Val]   Epoch 10 | loss=1.8887 | acc=41.77%



Epoch 11/25: 100%|█| 670/670 [05:31<00:00,  2.02it/s, loss=0.9832, main=0.6379, mouth=0.6906, lr=2.2


[Train] Epoch 11 | loss=1.5610 | main=0.9680 | mouth=1.1861
[Val]   Epoch 11 | loss=1.8435 | acc=45.51%



Epoch 12/25: 100%|█| 670/670 [05:22<00:00,  2.08it/s, loss=1.5813, main=0.9896, mouth=1.1833, lr=2.0


[Train] Epoch 12 | loss=1.5407 | main=0.9523 | mouth=1.1767
[Val]   Epoch 12 | loss=1.8043 | acc=48.23%



Epoch 13/25: 100%|█| 670/670 [05:25<00:00,  2.06it/s, loss=2.5062, main=1.5159, mouth=1.9805, lr=1.8


[Train] Epoch 13 | loss=1.4655 | main=0.8929 | mouth=1.1451
[Val]   Epoch 13 | loss=1.7486 | acc=50.21%



Epoch 14/25: 100%|█| 670/670 [05:32<00:00,  2.02it/s, loss=2.1882, main=1.3564, mouth=1.6635, lr=1.6


[Train] Epoch 14 | loss=1.4425 | main=0.8763 | mouth=1.1324
[Val]   Epoch 14 | loss=1.7682 | acc=51.55%



Epoch 15/25: 100%|█| 670/670 [05:22<00:00,  2.08it/s, loss=1.0975, main=0.6789, mouth=0.8372, lr=1.4


[Train] Epoch 15 | loss=1.3624 | main=0.8229 | mouth=1.0790
[Val]   Epoch 15 | loss=1.7573 | acc=51.87%



Epoch 16/25: 100%|█| 670/670 [05:21<00:00,  2.08it/s, loss=0.8268, main=0.4247, mouth=0.8042, lr=1.2


[Train] Epoch 16 | loss=1.3509 | main=0.8108 | mouth=1.0803
[Val]   Epoch 16 | loss=1.6413 | acc=55.61%



Epoch 17/25: 100%|█| 670/670 [05:21<00:00,  2.08it/s, loss=0.7347, main=0.3320, mouth=0.8054, lr=1.1


[Train] Epoch 17 | loss=1.2931 | main=0.7679 | mouth=1.0505
[Val]   Epoch 17 | loss=1.6776 | acc=55.54%



Epoch 18/25: 100%|█| 670/670 [05:44<00:00,  1.95it/s, loss=1.8305, main=1.2329, mouth=1.1952, lr=9.0


[Train] Epoch 18 | loss=1.2612 | main=0.7497 | mouth=1.0230
[Val]   Epoch 18 | loss=1.6417 | acc=55.42%



Epoch 19/25: 100%|█| 670/670 [07:38<00:00,  1.46it/s, loss=1.4571, main=0.8050, mouth=1.3042, lr=7.5


[Train] Epoch 19 | loss=1.2527 | main=0.7397 | mouth=1.0260
[Val]   Epoch 19 | loss=1.6179 | acc=57.72%



Epoch 20/25: 100%|█| 670/670 [05:38<00:00,  1.98it/s, loss=1.0161, main=0.6445, mouth=0.7432, lr=6.2


[Train] Epoch 20 | loss=1.1929 | main=0.7005 | mouth=0.9848
[Val]   Epoch 20 | loss=1.6342 | acc=57.91%



Epoch 21/25: 100%|█| 670/670 [05:46<00:00,  1.93it/s, loss=1.7307, main=1.1403, mouth=1.1809, lr=5.1


[Train] Epoch 21 | loss=1.1764 | main=0.6894 | mouth=0.9739
[Val]   Epoch 21 | loss=1.6027 | acc=59.57%



Epoch 22/25: 100%|█| 670/670 [05:48<00:00,  1.92it/s, loss=0.9307, main=0.3804, mouth=1.1007, lr=4.2


[Train] Epoch 22 | loss=1.1700 | main=0.6857 | mouth=0.9686
[Val]   Epoch 22 | loss=1.6086 | acc=58.23%



Epoch 23/25: 100%|█| 670/670 [05:47<00:00,  1.93it/s, loss=1.1268, main=0.6671, mouth=0.9193, lr=3.5


[Train] Epoch 23 | loss=1.1929 | main=0.6981 | mouth=0.9896
[Val]   Epoch 23 | loss=1.5946 | acc=58.77%



Epoch 24/25: 100%|█| 670/670 [05:43<00:00,  1.95it/s, loss=1.0567, main=0.6026, mouth=0.9082, lr=3.1


[Train] Epoch 24 | loss=1.1559 | main=0.6752 | mouth=0.9614
[Val]   Epoch 24 | loss=1.6238 | acc=58.77%



Epoch 25/25: 100%|█| 670/670 [05:33<00:00,  2.01it/s, loss=0.8793, main=0.4335, mouth=0.8915, lr=3.0


[Train] Epoch 25 | loss=1.1193 | main=0.6544 | mouth=0.9299
[Val]   Epoch 25 | loss=1.6099 | acc=59.83%

Saved model to affectnet_7cls_mouth.pth


In [2]:
!pip install onnxscript

Collecting onnxscript
  Downloading onnxscript-0.5.6-py3-none-any.whl.metadata (13 kB)
Collecting onnx_ir<2,>=0.1.12 (from onnxscript)
  Downloading onnx_ir-0.1.12-py3-none-any.whl.metadata (3.2 kB)
Collecting onnx>=1.16 (from onnxscript)
  Downloading onnx-1.19.1-cp310-cp310-macosx_12_0_universal2.whl.metadata (7.0 kB)
Downloading onnxscript-0.5.6-py3-none-any.whl (683 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m683.0/683.0 kB[0m [31m15.0 MB/s[0m  [33m0:00:00[0m
[?25hDownloading onnx_ir-0.1.12-py3-none-any.whl (129 kB)
Downloading onnx-1.19.1-cp310-cp310-macosx_12_0_universal2.whl (18.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.3/18.3 MB[0m [31m20.6 MB/s[0m  [33m0:00:00[0m eta [36m0:00:01[0m
[?25hInstalling collected packages: onnx, onnx_ir, onnxscript
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3/3[0m [onnxscript]3[0m [onnxscript]
[1A[2KSuccessfully installed onnx-1.19.1 onnx_ir-0.1.12 onnxscript-0.5

In [None]:
import torch
import torch.nn as nn
from torchvision.models import efficientnet_b0

class FER7WithMouth(nn.Module):
    def __init__(self, backbone, num_classes=7):
        super().__init__()
        self.backbone = backbone

        self.head_main = nn.Sequential(
            nn.Linear(1280, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

        self.head_mouth = nn.Sequential(
            nn.Linear(1280, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )

    def forward(self, x_full, x_mouth):
        feat_full = self.backbone(x_full)
        feat_mouth = self.backbone(x_mouth)
        return self.head_main(feat_full), self.head_mouth(feat_mouth)

MODEL_PATH = "affectnet_7cls_mouth.pth"
ONNX_PATH  = "affectnet_7cls_mouth_full.onnx"

base = efficientnet_b0(weights=None)
base.classifier = nn.Identity()

model = FER7WithMouth(base, num_classes=7)
model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
model.eval()

print("Loaded model from", MODEL_PATH)

dummy_full  = torch.randn(1, 3, 224, 224)
dummy_mouth = torch.randn(1, 3, 224, 224)

torch.onnx.export(
    model,
    (dummy_full, dummy_mouth),
    ONNX_PATH,
    input_names=["full_input", "mouth_input"],
    output_names=["main_logits", "mouth_logits"],
    opset_version=12,
    export_params=True,
    do_constant_folding=True,
    dynamic_axes={
        "full_input": {0: "batch"},
        "mouth_input": {0: "batch"},
        "main_logits": {0: "batch"},
        "mouth_logits": {0: "batch"}
    }
)

print("Exported full ONNX model to", ONNX_PATH)


Loaded model from affectnet_7cls_mouth.pth


  torch.onnx.export(
W1117 14:48:59.735000 24146 site-packages/torch/onnx/_internal/exporter/_compat.py:114] Setting ONNX exporter to use operator set version 18 because the requested opset_version 12 is a lower version than we have implementations for. Automatic version conversion will be performed, which may not be successful at converting to the requested version. If version conversion is unsuccessful, the opset version of the exported model will be kept at 18. Please consider setting opset_version >=18 to leverage latest ONNX features


[torch.onnx] Obtain model graph for `FER7WithMouth([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `FER7WithMouth([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...


The model version conversion is not supported by the onnxscript version converter and fallback is enabled. The model will be converted using the onnx C API (target version: 12).
Failed to convert the model to the target version 12 using the ONNX C API. The model was not modified
Traceback (most recent call last):
  File "/Users/zhangyizhou/miniconda3/envs/tf-env/lib/python3.10/site-packages/onnxscript/version_converter/__init__.py", line 127, in call
    converted_proto = _c_api_utils.call_onnx_api(
  File "/Users/zhangyizhou/miniconda3/envs/tf-env/lib/python3.10/site-packages/onnxscript/version_converter/_c_api_utils.py", line 65, in call_onnx_api
    result = func(proto)
  File "/Users/zhangyizhou/miniconda3/envs/tf-env/lib/python3.10/site-packages/onnxscript/version_converter/__init__.py", line 122, in _partial_convert_version
    return onnx.version_converter.convert_version(
  File "/Users/zhangyizhou/miniconda3/envs/tf-env/lib/python3.10/site-packages/onnx/version_converter.py", 

[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 147 of general pattern rewrite rules.
Exported full ONNX model to affectnet_7cls_mouth_full.onnx


In [None]:
# fine_tune_7cls_mouth_v4.py
# 7 classes + SimCLR backbone + class weight（Angry/Disgust 加权）+ balanced sampler + mouth auxiliary head (de-emphasise)

import os
import math
from collections import Counter

import torch
import torchvision
from torchvision import transforms, datasets
from torch import nn, optim
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision.models import efficientnet_b0
from tqdm import tqdm

DATA_ROOT   = "YOLO_format_cls"
TRAIN_DIR   = os.path.join(DATA_ROOT, "train")
VAL_DIR     = os.path.join(DATA_ROOT, "valid")

SIMCLR_PATH = "tinysimclr_effb0_mac.pth"
SAVE_PATH   = "affectnet_7cls_mouth_v4_angry_boost.pth"

BATCH_SIZE  = 16
EPOCHS      = 25
BASE_LR     = 3e-4
NUM_CLASSES = 7

device = "mps" if torch.backends.mps.is_available() else "cpu"
print("Using device:", device)

# 1. collate_fn: Preserve PIL images
def collate_fn(batch):
    imgs, labels = zip(*batch)   # imgs: list[PIL.Image], labels: list[int]
    return list(imgs), list(labels)


# 2. Data augmentation (real-time compatible)
train_aug = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(8),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

mouth_aug = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(8),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

val_aug = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])


# 3. Simple 'mouth crop'
def crop_mouth_pil(img):
    w, h = img.size
    top = int(h * 0.55)
    return img.crop((0, top, w, h))


# 4. dataset + Balanced Sampler + class weight（enhance Angry/Disgust）
train_set = datasets.ImageFolder(TRAIN_DIR, transform=None)
val_set   = datasets.ImageFolder(VAL_DIR, transform=None)

print(f"Train size (7-class): {len(train_set)}, Val size: {len(val_set)}")
print("Class index mapping:", train_set.class_to_idx)

orig_targets = train_set.targets
counter_7 = Counter(orig_targets)
print("7-class counts in train:", counter_7)

# basic class weight: 1/count
base_class_weights = []
for c in range(NUM_CLASSES):
    count_c = counter_7.get(c, 1)
    base_class_weights.append(1.0 / count_c)

# find the index of angry / disgusted 
angry_idx     = train_set.class_to_idx.get("angry", 0)
disgust_idx   = train_set.class_to_idx.get("disgusted", 1)

# Apply amplified weighting specifically to Angry & Disgust (e.g. x3)
ANGRY_BOOST   = 3.0
DISGUST_BOOST = 2.0

base_class_weights[angry_idx]   *= ANGRY_BOOST
base_class_weights[disgust_idx] *= DISGUST_BOOST

# Normalised to the mean ~1
sum_w = sum(base_class_weights)
class_weights = [w / sum_w * NUM_CLASSES for w in base_class_weights]

class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(device)
print("Class weights (for CE, after angry boost):", class_weights_tensor.tolist())

# Balanced Sampler
sample_weights = [class_weights[label] for label in orig_targets]
sample_weights = torch.tensor(sample_weights, dtype=torch.float32)

sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True
)

train_loader = DataLoader(
    train_set,
    batch_size=BATCH_SIZE,
    sampler=sampler,
    num_workers=0,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_set,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    collate_fn=collate_fn
)


# 5.Backbone + Mouth Auxiliary Head
class FER7WithMouth(nn.Module):
    def __init__(self, backbone, num_classes=7):
        super().__init__()
        self.backbone = backbone

        self.head_main = nn.Sequential(
            nn.Linear(1280, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

        self.head_mouth = nn.Sequential(
            nn.Linear(1280, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )

    def forward(self, x_full, x_mouth):
        feat_full  = self.backbone(x_full)
        feat_mouth = self.backbone(x_mouth)

        logits_main  = self.head_main(feat_full)
        logits_mouth = self.head_mouth(feat_mouth)
        return logits_main, logits_mouth

    def forward_main(self, x_full):
        feat_full = self.backbone(x_full)
        return self.head_main(feat_full)


base = efficientnet_b0(weights=None)
base.classifier = nn.Identity()

print(f"Loading SimCLR weights from {SIMCLR_PATH} ...")
simclr_weights = torch.load(SIMCLR_PATH, map_location="cpu")
missing, unexpected = base.load_state_dict(simclr_weights, strict=False)
print("Loaded SimCLR backbone.")
print("Missing keys:", missing)
print("Unexpected keys:", unexpected)

model = FER7WithMouth(base, num_classes=NUM_CLASSES).to(device)

# CE with class weight + label smoothing
criterion = nn.CrossEntropyLoss(
    weight=class_weights_tensor,
    label_smoothing=0.1
)

# Reduce the weight of the mouth auxiliary head to prevent it from overshadowing the eyebrow features
LAMBDA_MOUTH = 0.3

optimizer = optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=1e-4)

# 6. LR warmup + cosine decay
total_steps  = EPOCHS * len(train_loader)
warmup_ratio = 0.1
warmup_steps = int(total_steps * warmup_ratio)

def get_lr(step):
    if step < warmup_steps:
        return BASE_LR * float(step + 1) / float(warmup_steps + 1)
    progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
    cosine = 0.5 * (1.0 + math.cos(math.pi * progress))
    min_lr = BASE_LR * 0.1
    return min_lr + (BASE_LR - min_lr) * cosine

# 7. Training cycle
global_step = 0

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0.0
    ce_main_sum = 0.0
    ce_mouth_sum = 0.0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", ncols=100)
    for imgs, labels in pbar:
        y = torch.tensor(labels, dtype=torch.long).to(device)  # [B]

        x_full = torch.stack([train_aug(img) for img in imgs]).to(device)

        mouth_imgs = [crop_mouth_pil(img) for img in imgs]
        x_mouth = torch.stack([mouth_aug(m) for m in mouth_imgs]).to(device)

        lr = get_lr(global_step)
        for g in optimizer.param_groups:
            g["lr"] = lr

        logits_main, logits_mouth = model(x_full, x_mouth)

        ce_main  = criterion(logits_main, y)
        ce_mouth = criterion(logits_mouth, y)

        loss = ce_main + LAMBDA_MOUTH * ce_mouth

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        optimizer.step()

        train_loss  += loss.item()
        ce_main_sum += ce_main.item()
        ce_mouth_sum += ce_mouth.item()
        global_step += 1

        pbar.set_postfix({
            "loss":  f"{loss.item():.4f}",
            "main":  f"{ce_main.item():.4f}",
            "mouth": f"{ce_mouth.item():.4f}",
            "lr":    f"{lr:.1e}"
        })

    avg_loss  = train_loss / len(train_loader)
    avg_main  = ce_main_sum / len(train_loader)
    avg_mouth = ce_mouth_sum / len(train_loader)
    print(f"[Train] Epoch {epoch+1} | loss={avg_loss:.4f} | main={avg_main:.4f} | mouth={avg_mouth:.4f}")

    # Verification (Main Header Only)
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for imgs, labels in val_loader:
            y_val = torch.tensor(labels, dtype=torch.long).to(device)
            x_val = torch.stack([val_aug(img) for img in imgs]).to(device)

            logits_val = model.forward_main(x_val)
            loss_val = criterion(logits_val, y_val)
            val_loss += loss_val.item()

            preds = logits_val.argmax(dim=1)
            correct += (preds == y_val).sum().item()
            total += len(y_val)

    val_loss /= len(val_loader)
    val_acc   = correct / total if total > 0 else 0.0
    print(f"[Val]   Epoch {epoch+1} | loss={val_loss:.4f} | acc={val_acc*100:.2f}%\n")

torch.save(model.state_dict(), SAVE_PATH)
print(f"Saved model to {SAVE_PATH}")


Using device: mps
Train size (7-class): 10714, Val size: 3129
Class index mapping: {'angry': 0, 'disgusted': 1, 'fearful': 2, 'happy': 3, 'neutral': 4, 'sad': 5, 'surprised': 6}
7-class counts in train: Counter({6: 2248, 0: 2026, 2: 1922, 1: 1865, 5: 1582, 4: 744, 3: 327})
Class weights (for CE, after angry boost): [1.2119460105895996, 0.8777132034301758, 0.4258415997028351, 2.5029587745666504, 1.100090742111206, 0.5173625349998474, 0.36408698558807373]
Loading SimCLR weights from tinysimclr_effb0_mac.pth ...
Loaded SimCLR backbone.
Missing keys: ['features.0.0.weight', 'features.0.1.weight', 'features.0.1.bias', 'features.0.1.running_mean', 'features.0.1.running_var', 'features.1.0.block.0.0.weight', 'features.1.0.block.0.1.weight', 'features.1.0.block.0.1.bias', 'features.1.0.block.0.1.running_mean', 'features.1.0.block.0.1.running_var', 'features.1.0.block.1.fc1.weight', 'features.1.0.block.1.fc1.bias', 'features.1.0.block.1.fc2.weight', 'features.1.0.block.1.fc2.bias', 'features.1.

Epoch 1/25: 100%|█| 670/670 [06:26<00:00,  1.73it/s, loss=2.2602, main=1.7507, mouth=1.6984, lr=1.2e


[Train] Epoch 1 | loss=2.2742 | main=1.7446 | mouth=1.7653
[Val]   Epoch 1 | loss=3.0905 | acc=19.27%



Epoch 2/25: 100%|█| 670/670 [05:59<00:00,  1.86it/s, loss=2.5395, main=1.9797, mouth=1.8660, lr=2.4e


[Train] Epoch 2 | loss=2.1875 | main=1.6821 | mouth=1.6848
[Val]   Epoch 2 | loss=2.6559 | acc=13.97%



Epoch 3/25: 100%|█| 670/670 [05:36<00:00,  1.99it/s, loss=2.3538, main=1.7717, mouth=1.9402, lr=3.0e


[Train] Epoch 3 | loss=2.1675 | main=1.6622 | mouth=1.6844
[Val]   Epoch 3 | loss=2.4776 | acc=17.48%



Epoch 4/25: 100%|█| 670/670 [05:35<00:00,  2.00it/s, loss=1.5274, main=1.1319, mouth=1.3185, lr=3.0e


[Train] Epoch 4 | loss=2.0562 | main=1.5650 | mouth=1.6373
[Val]   Epoch 4 | loss=2.2726 | acc=20.74%



Epoch 5/25: 100%|█| 670/670 [05:33<00:00,  2.01it/s, loss=2.5429, main=1.9021, mouth=2.1360, lr=2.9e


[Train] Epoch 5 | loss=1.8633 | main=1.3997 | mouth=1.5455
[Val]   Epoch 5 | loss=2.2474 | acc=23.33%



Epoch 6/25: 100%|█| 670/670 [05:30<00:00,  2.03it/s, loss=1.6601, main=1.1925, mouth=1.5586, lr=2.8e


[Train] Epoch 6 | loss=1.7254 | main=1.2846 | mouth=1.4695
[Val]   Epoch 6 | loss=2.2257 | acc=27.90%



Epoch 7/25: 100%|█| 670/670 [12:10<00:00,  1.09s/it, loss=1.5363, main=1.0632, mouth=1.5773, lr=2.7e


[Train] Epoch 7 | loss=1.6358 | main=1.2076 | mouth=1.4277
[Val]   Epoch 7 | loss=2.1448 | acc=31.67%



Epoch 8/25: 100%|█| 670/670 [05:47<00:00,  1.93it/s, loss=1.6059, main=1.2091, mouth=1.3224, lr=2.6e


[Train] Epoch 8 | loss=1.5587 | main=1.1444 | mouth=1.3809
[Val]   Epoch 8 | loss=2.0952 | acc=34.10%



Epoch 9/25: 100%|█| 670/670 [11:14<00:00,  1.01s/it, loss=1.3554, main=0.9519, mouth=1.3450, lr=2.5e


[Train] Epoch 9 | loss=1.5074 | main=1.1020 | mouth=1.3512
[Val]   Epoch 9 | loss=1.9694 | acc=37.52%



Epoch 10/25: 100%|█| 670/670 [05:23<00:00,  2.07it/s, loss=1.6476, main=1.2935, mouth=1.1804, lr=2.3


[Train] Epoch 10 | loss=1.4315 | main=1.0386 | mouth=1.3098
[Val]   Epoch 10 | loss=2.0920 | acc=37.14%



Epoch 11/25: 100%|█| 670/670 [05:25<00:00,  2.06it/s, loss=1.8090, main=1.2546, mouth=1.8480, lr=2.2


[Train] Epoch 11 | loss=1.4014 | main=1.0111 | mouth=1.3007
[Val]   Epoch 11 | loss=2.0557 | acc=38.89%



Epoch 12/25: 100%|█| 670/670 [05:22<00:00,  2.08it/s, loss=1.9196, main=1.2796, mouth=2.1333, lr=2.0


[Train] Epoch 12 | loss=1.3523 | main=0.9693 | mouth=1.2767
[Val]   Epoch 12 | loss=1.9069 | acc=41.87%



Epoch 13/25: 100%|█| 670/670 [07:01<00:00,  1.59it/s, loss=1.3648, main=0.9538, mouth=1.3702, lr=1.8


[Train] Epoch 13 | loss=1.2967 | main=0.9240 | mouth=1.2422
[Val]   Epoch 13 | loss=1.8844 | acc=43.88%



Epoch 14/25: 100%|█| 670/670 [15:03<00:00,  1.35s/it, loss=1.3152, main=0.9791, mouth=1.1205, lr=1.6


[Train] Epoch 14 | loss=1.2485 | main=0.8850 | mouth=1.2115
[Val]   Epoch 14 | loss=1.7918 | acc=45.73%



Epoch 15/25: 100%|█| 670/670 [21:55<00:00,  1.96s/it, loss=1.6675, main=1.2600, mouth=1.3585, lr=1.4


[Train] Epoch 15 | loss=1.2130 | main=0.8550 | mouth=1.1934
[Val]   Epoch 15 | loss=1.8438 | acc=45.93%



Epoch 16/25: 100%|█| 670/670 [05:26<00:00,  2.05it/s, loss=1.2647, main=0.7964, mouth=1.5611, lr=1.2


[Train] Epoch 16 | loss=1.1731 | main=0.8227 | mouth=1.1682
[Val]   Epoch 16 | loss=1.7858 | acc=49.15%



Epoch 17/25: 100%|█| 670/670 [05:20<00:00,  2.09it/s, loss=0.8631, main=0.5221, mouth=1.1368, lr=1.1


[Train] Epoch 17 | loss=1.1319 | main=0.7869 | mouth=1.1500
[Val]   Epoch 17 | loss=1.7948 | acc=48.87%



Epoch 18/25: 100%|█| 670/670 [05:20<00:00,  2.09it/s, loss=0.8215, main=0.5214, mouth=1.0002, lr=9.0


[Train] Epoch 18 | loss=1.1077 | main=0.7641 | mouth=1.1456
[Val]   Epoch 18 | loss=1.8106 | acc=49.31%



Epoch 19/25: 100%|█| 670/670 [05:19<00:00,  2.10it/s, loss=1.0230, main=0.6350, mouth=1.2935, lr=7.5


[Train] Epoch 19 | loss=1.0664 | main=0.7322 | mouth=1.1140
[Val]   Epoch 19 | loss=1.8072 | acc=50.30%



Epoch 20/25: 100%|█| 670/670 [05:19<00:00,  2.10it/s, loss=1.4669, main=0.9529, mouth=1.7132, lr=6.2


[Train] Epoch 20 | loss=1.0511 | main=0.7210 | mouth=1.1002
[Val]   Epoch 20 | loss=1.7378 | acc=52.32%



Epoch 21/25: 100%|█| 670/670 [05:19<00:00,  2.10it/s, loss=1.3527, main=0.9266, mouth=1.4204, lr=5.1


[Train] Epoch 21 | loss=1.0367 | main=0.7102 | mouth=1.0881
[Val]   Epoch 21 | loss=1.7187 | acc=53.08%



Epoch 22/25: 100%|█| 670/670 [05:19<00:00,  2.10it/s, loss=0.8636, main=0.5541, mouth=1.0316, lr=4.2


[Train] Epoch 22 | loss=1.0034 | main=0.6824 | mouth=1.0699
[Val]   Epoch 22 | loss=1.7528 | acc=52.29%



Epoch 23/25: 100%|█| 670/670 [05:20<00:00,  2.09it/s, loss=1.2417, main=0.7842, mouth=1.5249, lr=3.5


[Train] Epoch 23 | loss=0.9960 | main=0.6727 | mouth=1.0778
[Val]   Epoch 23 | loss=1.7082 | acc=53.40%



Epoch 24/25: 100%|█| 670/670 [05:19<00:00,  2.10it/s, loss=1.0739, main=0.6766, mouth=1.3243, lr=3.1


[Train] Epoch 24 | loss=0.9841 | main=0.6646 | mouth=1.0650
[Val]   Epoch 24 | loss=1.7482 | acc=52.96%



Epoch 25/25: 100%|█| 670/670 [05:19<00:00,  2.10it/s, loss=0.6576, main=0.4376, mouth=0.7334, lr=3.0


[Train] Epoch 25 | loss=0.9772 | main=0.6605 | mouth=1.0557
[Val]   Epoch 25 | loss=1.7296 | acc=53.60%

Saved model to affectnet_7cls_mouth_v4_angry_boost.pth


In [None]:
# fine_tune_7cls_mouth_mobilenet.py
# 7 classes + MobileNetV2 backbone + class weight（Angry/Disgust 加权）
# + balanced sampler + mouth auxiliary head (de-emphasised)

import os
import math
from collections import Counter

import torch
import torchvision
from torchvision import transforms, datasets
from torch import nn, optim
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision.models import mobilenet_v2
try:
    from torchvision.models import MobileNet_V2_Weights
    HAS_WEIGHTS_ENUM = True
except Exception:
    HAS_WEIGHTS_ENUM = False

from tqdm import tqdm

DATA_ROOT   = "YOLO_format_cls"
TRAIN_DIR   = os.path.join(DATA_ROOT, "train")
VAL_DIR     = os.path.join(DATA_ROOT, "valid")

SAVE_PATH   = "affectnet_7cls_mouth_mobilenet.pth"

BATCH_SIZE  = 16
EPOCHS      = 25
BASE_LR     = 3e-4
NUM_CLASSES = 7

device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)


# =========================
# 1. collate_fn: Preserve PIL image
# =========================
def collate_fn(batch):
    imgs, labels = zip(*batch)   # imgs: list[PIL.Image], labels: list[int]
    return list(imgs), list(labels)


# =========================
# 2. Data Augmentation (Real-Time Friendly)
# =========================
train_aug = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(8),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

mouth_aug = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(8),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

val_aug = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])


# =========================
# 3. 简单的“嘴部 crop”
# =========================
def crop_mouth_pil(img):
    """
    Directly remove 45% as the mouth area.
    """
    w, h = img.size
    top = int(h * 0.55)
    return img.crop((0, top, w, h))


# =========================
# 4. dataset + Balanced Sampler + class weight
# =========================
train_set = datasets.ImageFolder(TRAIN_DIR, transform=None)
val_set   = datasets.ImageFolder(VAL_DIR, transform=None)

print(f"Train size (7-class): {len(train_set)}, Val size: {len(val_set)}")
print("Class index mapping:", train_set.class_to_idx)

orig_targets = train_set.targets
counter_7 = Counter(orig_targets)
print("7-class counts in train:", counter_7)

# ----- class weight: 1/count -----
base_class_weights = []
for c in range(NUM_CLASSES):
    count_c = counter_7.get(c, 1)
    base_class_weights.append(1.0 / count_c)

# angry / disgusted index
angry_idx   = train_set.class_to_idx.get("angry", 0)
disgust_idx = train_set.class_to_idx.get("disgusted", 1)

# Increase the weight for Angry & Disgust (e.g., ×3, ×2)
ANGRY_BOOST   = 3.0
DISGUST_BOOST = 2.0

base_class_weights[angry_idx]   *= ANGRY_BOOST
base_class_weights[disgust_idx] *= DISGUST_BOOST

# Normalized to "approximately equal to 1"
sum_w = sum(base_class_weights)
class_weights = [w / sum_w * NUM_CLASSES for w in base_class_weights]

class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(device)
print("Class weights (for CE, after angry/disgust boost):", class_weights_tensor.tolist())

# ----- Balanced Sampler -----
sample_weights = [class_weights[label] for label in orig_targets]
sample_weights = torch.tensor(sample_weights, dtype=torch.float32)

sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True
)

train_loader = DataLoader(
    train_set,
    batch_size=BATCH_SIZE,
    sampler=sampler,
    num_workers=0,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_set,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    collate_fn=collate_fn
)


# =========================
# 5. MobileNetV2 Backbone + mouth with head
# =========================
class FER7WithMouth(nn.Module):
    """
    MobileNetV2 backbone + primary head (full face) + auxiliary head (mouth)
    Since MobileNetV2's final feature dimension is also 1280, the original head structure can be reused.
    """
    def __init__(self, backbone, num_classes=7, feat_dim=1280):
        super().__init__()
        self.backbone = backbone

        # Main Head: Full Face
        self.head_main = nn.Sequential(
            nn.Linear(feat_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

        # Mouth
        self.head_mouth = nn.Sequential(
            nn.Linear(feat_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )

    def forward(self, x_full, x_mouth):
        feat_full  = self.backbone(x_full)   # [B, 1280]
        feat_mouth = self.backbone(x_mouth)  # [B, 1280]

        logits_main  = self.head_main(feat_full)
        logits_mouth = self.head_mouth(feat_mouth)
        return logits_main, logits_mouth

    def forward_main(self, x_full):
        feat_full = self.backbone(x_full)
        return self.head_main(feat_full)

print("Building MobileNetV2 backbone...")

if HAS_WEIGHTS_ENUM:
    # base = mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1)
    base = mobilenet_v2(weights=None)
else:
    base = mobilenet_v2(weights=None)

# MobileNetV2: features -> avgpool -> flatten -> classifier
base.classifier = nn.Identity()

model = FER7WithMouth(base, num_classes=NUM_CLASSES, feat_dim=1280).to(device)
print("Model built with MobileNetV2 backbone.")

# with class weight + label smoothing
criterion = nn.CrossEntropyLoss(
    weight=class_weights_tensor,
    label_smoothing=0.1
)

# Weighting for the auxiliary head loss (reduced to prevent masking upper facial features like eyebrows)
LAMBDA_MOUTH = 0.3

optimizer = optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=1e-4)


# =========================
# 6. LR warmup + cosine decay
# =========================
total_steps  = EPOCHS * len(train_loader)
warmup_ratio = 0.1
warmup_steps = int(total_steps * warmup_ratio)

def get_lr(step):
    if step < warmup_steps:
        return BASE_LR * float(step + 1) / float(warmup_steps + 1)
    progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
    cosine = 0.5 * (1.0 + math.cos(math.pi * progress))
    min_lr = BASE_LR * 0.1
    return min_lr + (BASE_LR - min_lr) * cosine


# =========================
# 7. Training Cycle
# =========================
global_step = 0

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0.0
    ce_main_sum = 0.0
    ce_mouth_sum = 0.0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", ncols=100)
    for imgs, labels in pbar:
        y = torch.tensor(labels, dtype=torch.long).to(device)  # [B]

        # Full Face
        x_full = torch.stack([train_aug(img) for img in imgs]).to(device)

        # Mouth
        mouth_imgs = [crop_mouth_pil(img) for img in imgs]
        x_mouth = torch.stack([mouth_aug(m) for m in mouth_imgs]).to(device)

        # Update learning rate
        lr = get_lr(global_step)
        for g in optimizer.param_groups:
            g["lr"] = lr

        # forward
        logits_main, logits_mouth = model(x_full, x_mouth)

        ce_main  = criterion(logits_main, y)
        ce_mouth = criterion(logits_mouth, y)

        loss = ce_main + LAMBDA_MOUTH * ce_mouth

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        optimizer.step()

        train_loss  += loss.item()
        ce_main_sum += ce_main.item()
        ce_mouth_sum += ce_mouth.item()
        global_step += 1

        pbar.set_postfix({
            "loss":  f"{loss.item():.4f}",
            "main":  f"{ce_main.item():.4f}",
            "mouth": f"{ce_mouth.item():.4f}",
            "lr":    f"{lr:.1e}"
        })

    avg_loss  = train_loss / len(train_loader)
    avg_main  = ce_main_sum / len(train_loader)
    avg_mouth = ce_mouth_sum / len(train_loader)
    print(f"[Train] Epoch {epoch+1} | loss={avg_loss:.4f} | main={avg_main:.4f} | mouth={avg_mouth:.4f}")

    # =========================
    # 8. Verification (Main Header Only)
    # =========================
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for imgs, labels in val_loader:
            y_val = torch.tensor(labels, dtype=torch.long).to(device)
            x_val = torch.stack([val_aug(img) for img in imgs]).to(device)

            logits_val = model.forward_main(x_val)
            loss_val = criterion(logits_val, y_val)
            val_loss += loss_val.item()

            preds = logits_val.argmax(dim=1)
            correct += (preds == y_val).sum().item()
            total += len(y_val)

    val_loss /= len(val_loader)
    val_acc   = correct / total if total > 0 else 0.0
    print(f"[Val]   Epoch {epoch+1} | loss={val_loss:.4f} | acc={val_acc*100:.2f}%\n")

# =========================
# 9. 保存模型
# =========================
torch.save(model.state_dict(), SAVE_PATH)
print(f"Saved MobileNetV2-based model to {SAVE_PATH}")


Using device: mps
Train size (7-class): 10714, Val size: 3129
Class index mapping: {'angry': 0, 'disgusted': 1, 'fearful': 2, 'happy': 3, 'neutral': 4, 'sad': 5, 'surprised': 6}
7-class counts in train: Counter({6: 2248, 0: 2026, 2: 1922, 1: 1865, 5: 1582, 4: 744, 3: 327})
Class weights (for CE, after angry/disgust boost): [1.2119460105895996, 0.8777132034301758, 0.4258415997028351, 2.5029587745666504, 1.100090742111206, 0.5173625349998474, 0.36408698558807373]
Building MobileNetV2 backbone...
Model built with MobileNetV2 backbone.


Epoch 1/25: 100%|█| 670/670 [03:52<00:00,  2.89it/s, loss=1.8538, main=1.4371, mouth=1.3890, lr=1.2e


[Train] Epoch 1 | loss=2.2459 | main=1.7258 | mouth=1.7338
[Val]   Epoch 1 | loss=2.4114 | acc=19.27%



Epoch 2/25: 100%|█| 670/670 [03:18<00:00,  3.37it/s, loss=2.1261, main=1.6438, mouth=1.6077, lr=2.4e


[Train] Epoch 2 | loss=2.2138 | main=1.7033 | mouth=1.7016
[Val]   Epoch 2 | loss=2.4714 | acc=19.27%



Epoch 3/25: 100%|█| 670/670 [03:20<00:00,  3.33it/s, loss=2.8448, main=2.1931, mouth=2.1725, lr=3.0e


[Train] Epoch 3 | loss=2.1755 | main=1.6728 | mouth=1.6754
[Val]   Epoch 3 | loss=2.4518 | acc=19.27%



Epoch 4/25: 100%|█| 670/670 [03:19<00:00,  3.36it/s, loss=2.1182, main=1.5575, mouth=1.8690, lr=3.0e


[Train] Epoch 4 | loss=2.0915 | main=1.6011 | mouth=1.6345
[Val]   Epoch 4 | loss=2.5299 | acc=18.98%



Epoch 5/25: 100%|█| 670/670 [03:19<00:00,  3.36it/s, loss=2.0827, main=1.5773, mouth=1.6847, lr=2.9e


[Train] Epoch 5 | loss=1.9352 | main=1.4697 | mouth=1.5514
[Val]   Epoch 5 | loss=2.5812 | acc=21.06%



Epoch 6/25: 100%|█| 670/670 [03:19<00:00,  3.36it/s, loss=2.0630, main=1.5981, mouth=1.5495, lr=2.8e


[Train] Epoch 6 | loss=1.8482 | main=1.3944 | mouth=1.5128
[Val]   Epoch 6 | loss=2.3072 | acc=23.01%



Epoch 7/25: 100%|█| 670/670 [03:13<00:00,  3.46it/s, loss=1.6361, main=1.2287, mouth=1.3581, lr=2.7e


[Train] Epoch 7 | loss=1.7346 | main=1.2995 | mouth=1.4501
[Val]   Epoch 7 | loss=2.3255 | acc=22.98%



Epoch 8/25: 100%|█| 670/670 [03:14<00:00,  3.44it/s, loss=1.3337, main=1.0493, mouth=0.9480, lr=2.6e


[Train] Epoch 8 | loss=1.6642 | main=1.2337 | mouth=1.4349
[Val]   Epoch 8 | loss=2.1592 | acc=29.15%



Epoch 9/25: 100%|█| 670/670 [03:14<00:00,  3.45it/s, loss=1.9590, main=1.4724, mouth=1.6221, lr=2.5e


[Train] Epoch 9 | loss=1.6147 | main=1.1947 | mouth=1.3999
[Val]   Epoch 9 | loss=2.2266 | acc=30.36%



Epoch 10/25: 100%|█| 670/670 [03:15<00:00,  3.43it/s, loss=0.8917, main=0.6126, mouth=0.9305, lr=2.3


[Train] Epoch 10 | loss=1.5419 | main=1.1316 | mouth=1.3675
[Val]   Epoch 10 | loss=2.1195 | acc=33.72%



Epoch 11/25: 100%|█| 670/670 [03:14<00:00,  3.45it/s, loss=1.8939, main=1.5141, mouth=1.2658, lr=2.2


[Train] Epoch 11 | loss=1.4856 | main=1.0847 | mouth=1.3362
[Val]   Epoch 11 | loss=2.0611 | acc=34.61%



Epoch 12/25: 100%|█| 670/670 [03:14<00:00,  3.45it/s, loss=1.7246, main=1.2488, mouth=1.5857, lr=2.0


[Train] Epoch 12 | loss=1.4344 | main=1.0426 | mouth=1.3058
[Val]   Epoch 12 | loss=2.1539 | acc=34.48%



Epoch 13/25: 100%|█| 670/670 [03:12<00:00,  3.48it/s, loss=1.1004, main=0.7799, mouth=1.0685, lr=1.8


[Train] Epoch 13 | loss=1.3928 | main=1.0059 | mouth=1.2896
[Val]   Epoch 13 | loss=2.0392 | acc=36.05%



Epoch 14/25: 100%|█| 670/670 [03:12<00:00,  3.48it/s, loss=1.2559, main=0.8797, mouth=1.2539, lr=1.6


[Train] Epoch 14 | loss=1.3578 | main=0.9743 | mouth=1.2783
[Val]   Epoch 14 | loss=1.8897 | acc=38.77%



Epoch 15/25: 100%|█| 670/670 [03:13<00:00,  3.46it/s, loss=1.5726, main=1.1192, mouth=1.5113, lr=1.4


[Train] Epoch 15 | loss=1.3063 | main=0.9340 | mouth=1.2411
[Val]   Epoch 15 | loss=2.0053 | acc=38.48%



Epoch 16/25: 100%|█| 670/670 [03:15<00:00,  3.43it/s, loss=1.1437, main=0.8625, mouth=0.9373, lr=1.2


[Train] Epoch 16 | loss=1.2882 | main=0.9183 | mouth=1.2330
[Val]   Epoch 16 | loss=1.9515 | acc=40.75%



Epoch 17/25: 100%|█| 670/670 [03:15<00:00,  3.43it/s, loss=2.1009, main=1.5995, mouth=1.6713, lr=1.1


[Train] Epoch 17 | loss=1.2503 | main=0.8833 | mouth=1.2230
[Val]   Epoch 17 | loss=1.9831 | acc=39.25%



Epoch 18/25: 100%|█| 670/670 [03:16<00:00,  3.40it/s, loss=0.7776, main=0.5358, mouth=0.8061, lr=9.0


[Train] Epoch 18 | loss=1.2085 | main=0.8481 | mouth=1.2014
[Val]   Epoch 18 | loss=1.9134 | acc=39.66%



Epoch 19/25: 100%|█| 670/670 [03:16<00:00,  3.41it/s, loss=0.9949, main=0.6548, mouth=1.1336, lr=7.5


[Train] Epoch 19 | loss=1.1508 | main=0.8072 | mouth=1.1452
[Val]   Epoch 19 | loss=1.9401 | acc=41.16%



Epoch 20/25: 100%|█| 670/670 [03:17<00:00,  3.39it/s, loss=1.4786, main=0.9644, mouth=1.7143, lr=6.2


[Train] Epoch 20 | loss=1.1401 | main=0.7940 | mouth=1.1536
[Val]   Epoch 20 | loss=1.8812 | acc=42.89%



Epoch 21/25: 100%|█| 670/670 [03:20<00:00,  3.35it/s, loss=1.2344, main=0.7820, mouth=1.5081, lr=5.1


[Train] Epoch 21 | loss=1.1203 | main=0.7758 | mouth=1.1483
[Val]   Epoch 21 | loss=1.9218 | acc=42.03%



Epoch 22/25: 100%|█| 670/670 [03:27<00:00,  3.22it/s, loss=0.8680, main=0.6317, mouth=0.7876, lr=4.2


[Train] Epoch 22 | loss=1.0945 | main=0.7556 | mouth=1.1296
[Val]   Epoch 22 | loss=1.9073 | acc=42.35%



Epoch 23/25: 100%|█| 670/670 [03:19<00:00,  3.36it/s, loss=1.2236, main=0.8284, mouth=1.3174, lr=3.5


[Train] Epoch 23 | loss=1.0888 | main=0.7515 | mouth=1.1244
[Val]   Epoch 23 | loss=1.8694 | acc=43.88%



Epoch 24/25: 100%|█| 670/670 [03:28<00:00,  3.22it/s, loss=0.9866, main=0.6708, mouth=1.0525, lr=3.1


[Train] Epoch 24 | loss=1.0717 | main=0.7361 | mouth=1.1186
[Val]   Epoch 24 | loss=1.9091 | acc=43.27%



Epoch 25/25: 100%|█| 670/670 [03:25<00:00,  3.26it/s, loss=1.4298, main=0.9390, mouth=1.6360, lr=3.0


[Train] Epoch 25 | loss=1.0629 | main=0.7309 | mouth=1.1066
[Val]   Epoch 25 | loss=1.8951 | acc=43.98%

Saved MobileNetV2-based model to affectnet_7cls_mouth_mobilenet.pth


In [None]:
import torch
from torch import nn
from torchvision.models import mobilenet_v2

class FER7WithMouth(nn.Module):
    """
    MobileNetV2 backbone + primary head (full face) + auxiliary head (mouth)
    Note: feat_dim=1280, consistent with the training script
    """
    def __init__(self, backbone, num_classes=7, feat_dim=1280):
        super().__init__()
        self.backbone = backbone

        self.head_main = nn.Sequential(
            nn.Linear(feat_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

        self.head_mouth = nn.Sequential(
            nn.Linear(feat_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )

    def forward(self, x_full, x_mouth):
        # x_full, x_mouth: [B,3,224,224]
        feat_full  = self.backbone(x_full)   # [B,1280]
        feat_mouth = self.backbone(x_mouth)  # [B,1280]

        logits_main  = self.head_main(feat_full)
        logits_mouth = self.head_mouth(feat_mouth)
        return logits_main, logits_mouth

    def forward_main(self, x_full):
        feat_full = self.backbone(x_full)
        return self.head_main(feat_full)


def build_model(checkpoint_path, device="cpu"):
    base = mobilenet_v2(weights=None)
    base.classifier = nn.Identity()

    model = FER7WithMouth(base, num_classes=7, feat_dim=1280)
    state = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(state, strict=True)

    model.to(device)
    model.eval()
    return model


if __name__ == "__main__":
    CHECKPOINT = "affectnet_7cls_mouth_mobilenet.pth"
    ONNX_PATH  = "fer7_mobilenet_mouth.onnx"

    device = "cpu"
    model = build_model(CHECKPOINT, device=device)

    dummy_full  = torch.randn(1, 3, 224, 224, dtype=torch.float32, device=device)
    dummy_mouth = torch.randn(1, 3, 224, 224, dtype=torch.float32, device=device)

    torch.onnx.export(
        model,
        (dummy_full, dummy_mouth), 
        ONNX_PATH,
        input_names   = ["full_input", "mouth_input"],
        output_names  = ["main_logits", "mouth_logits"],
        opset_version = 11, 
        dynamic_axes  = {
            "full_input":  {0: "batch"},
            "mouth_input": {0: "batch"},
            "main_logits": {0: "batch"},
            "mouth_logits":{0: "batch"},
        },
        do_constant_folding = True,
        verbose = False
    )

    print("Exported ONNX model to:", ONNX_PATH)

  torch.onnx.export(
W1129 22:10:51.158000 66496 site-packages/torch/onnx/_internal/exporter/_compat.py:114] Setting ONNX exporter to use operator set version 18 because the requested opset_version 11 is a lower version than we have implementations for. Automatic version conversion will be performed, which may not be successful at converting to the requested version. If version conversion is unsuccessful, the opset version of the exported model will be kept at 18. Please consider setting opset_version >=18 to leverage latest ONNX features
The model version conversion is not supported by the onnxscript version converter and fallback is enabled. The model will be converted using the onnx C API (target version: 11).
Failed to convert the model to the target version 11 using the ONNX C API. The model was not modified
Traceback (most recent call last):
  File "/Users/zhangyizhou/miniconda3/envs/tf-env/lib/python3.10/site-packages/onnxscript/version_converter/__init__.py", line 127, in call


Applied 156 of general pattern rewrite rules.
Exported ONNX model to: fer7_mobilenet_mouth.onnx
