<a href="https://colab.research.google.com/github/AleksandreBakhtadze/ML-abakh22-facial-expression-recognition/blob/main/facial_expression_train3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [13]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [14]:
!pip install -q kaggle
!pip install -q wandb
!pip install -q torchmetrics
!pip install -q albumentations

In [15]:
import wandb
wandb.login()

!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!kaggle competitions download -c challenges-in-representation-learning-facial-expression-recognition-challenge
!unzip -q challenges-in-representation-learning-facial-expression-recognition-challenge.zip



challenges-in-representation-learning-facial-expression-recognition-challenge.zip: Skipping, found more recently modified local copy (use --force to force download)
replace example_submission.csv? [y]es, [n]o, [A]ll, [N]one, [r]ename: A


In [16]:
import os
import torch
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics.classification import MulticlassConfusionMatrix, MulticlassPrecision, MulticlassRecall, MulticlassF1Score
import matplotlib.pyplot as plt
import seaborn as sns
from torchsummary import summary
import numpy as np
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [17]:
import pandas as pd

# Load the CSV
df = pd.read_csv("train.csv")

# Example row
print(df.head())

# Convert the pixel values to image tensors
def process_row(row):
    pixels = np.array([int(p) for p in row['pixels'].split()], dtype=np.uint8).reshape(48, 48)
    img = Image.fromarray(pixels)
    return img, int(row['emotion'])

images, labels = zip(*[process_row(row) for _, row in df.iterrows()])

   emotion                                             pixels
0        0  70 80 82 72 58 58 60 63 54 58 60 48 89 115 121...
1        0  151 150 147 155 148 133 111 140 170 174 182 15...
2        2  231 212 156 164 174 138 161 173 182 200 106 38...
3        4  24 32 36 30 32 23 19 20 30 41 21 22 32 34 21 1...
4        6  4 0 0 0 0 0 0 0 0 0 0 0 3 15 23 28 48 50 58 84...


In [18]:
# Define separate transforms for training and validation
train_transform = A.Compose([
    A.Rotate(limit=15, p=0.5),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.3),
    A.Normalize(mean=[0.5], std=[0.5]),
    ToTensorV2()
])

val_transform = A.Compose([
    A.Normalize(mean=[0.5], std=[0.5]),
    ToTensorV2()
])

class FERDataset(torch.utils.data.Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img = self.images[idx]
        label = self.labels[idx]
        img = np.array(img)  # Convert PIL image to numpy for albumentations
        if self.transform:
            augmented = self.transform(image=img)
            img = augmented['image']
        return img, label

# Split into train and val
from sklearn.model_selection import train_test_split
train_imgs, val_imgs, train_labels, val_labels = train_test_split(images, labels, test_size=0.1, stratify=labels)

train_dataset = FERDataset(train_imgs, train_labels, train_transform)
val_dataset = FERDataset(val_imgs, val_labels, val_transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)  # Increased batch size
val_loader = DataLoader(val_dataset, batch_size=128)

In [19]:
class ImprovedCNN(nn.Module):
    def __init__(self):
        super(ImprovedCNN, self).__init__()
        self.conv_block1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.conv_block2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.conv_block3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.dropout = nn.Dropout(0.4)  # Reduced dropout rate
        self.fc = nn.Sequential(
            nn.Linear(256 * 6 * 6, 1024),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(512, 7)  # 7 emotion classes
        )

    def forward(self, x):
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        x = self.conv_block3(x)
        x = self.dropout(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ImprovedCNN().to(device)
summary(model, (1, 48, 48))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 48, 48]             640
       BatchNorm2d-2           [-1, 64, 48, 48]             128
              ReLU-3           [-1, 64, 48, 48]               0
            Conv2d-4           [-1, 64, 48, 48]          36,928
       BatchNorm2d-5           [-1, 64, 48, 48]             128
              ReLU-6           [-1, 64, 48, 48]               0
         MaxPool2d-7           [-1, 64, 24, 24]               0
            Conv2d-8          [-1, 128, 24, 24]          73,856
       BatchNorm2d-9          [-1, 128, 24, 24]             256
             ReLU-10          [-1, 128, 24, 24]               0
           Conv2d-11          [-1, 128, 24, 24]         147,584
      BatchNorm2d-12          [-1, 128, 24, 24]             256
             ReLU-13          [-1, 128, 24, 24]               0
        MaxPool2d-14          [-1, 128,

In [20]:
# Sample 20 training examples
small_dataset, _ = torch.utils.data.random_split(train_dataset, [20, len(train_dataset) - 20])
small_loader = torch.utils.data.DataLoader(small_dataset, batch_size=4, shuffle=True)

# Re-initialize the model
model = ImprovedCNN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Start Wandb for overfitting test
wandb.init(project="FER-CNN", name="overfit_test")

print("Training on a tiny dataset to check overfitting...")
for epoch in range(20):
    model.train()
    total_loss, correct = 0, 0
    for imgs, labels in small_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        correct += (outputs.argmax(1) == labels).sum().item()

    acc = correct / len(small_loader.dataset)
    wandb.log({"overfit_epoch": epoch + 1, "overfit_loss": total_loss, "overfit_acc": acc})
    print(f"Epoch {epoch+1} - Loss: {total_loss:.4f}, Acc: {acc:.4f}")
    if acc >= 0.95:  # Relaxed threshold to account for deeper model
        print("✅ Model successfully overfit tiny dataset")
        break

wandb.finish()

Training on a tiny dataset to check overfitting...
Epoch 1 - Loss: 35.4556, Acc: 0.4000
Epoch 2 - Loss: 16.7016, Acc: 0.2500
Epoch 3 - Loss: 15.8338, Acc: 0.5500
Epoch 4 - Loss: 11.7996, Acc: 0.6000
Epoch 5 - Loss: 6.8095, Acc: 0.4000
Epoch 6 - Loss: 6.7574, Acc: 0.3500
Epoch 7 - Loss: 6.9952, Acc: 0.6000
Epoch 8 - Loss: 6.8201, Acc: 0.6500
Epoch 9 - Loss: 5.1638, Acc: 0.4500
Epoch 10 - Loss: 6.6056, Acc: 0.4000
Epoch 11 - Loss: 4.5113, Acc: 0.6500
Epoch 12 - Loss: 6.5261, Acc: 0.5000
Epoch 13 - Loss: 7.6441, Acc: 0.5500
Epoch 14 - Loss: 5.7156, Acc: 0.6000
Epoch 15 - Loss: 5.5162, Acc: 0.6500
Epoch 16 - Loss: 4.9310, Acc: 0.6500
Epoch 17 - Loss: 6.0052, Acc: 0.5500
Epoch 18 - Loss: 4.5700, Acc: 0.6500
Epoch 19 - Loss: 4.3062, Acc: 0.7000
Epoch 20 - Loss: 5.6484, Acc: 0.6500


0,1
overfit_acc,▃▁▆▆▃▃▆▇▄▃▇▅▆▆▇▇▆▇█▇
overfit_epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
overfit_loss,█▄▄▃▂▂▂▂▁▂▁▁▂▁▁▁▁▁▁▁

0,1
overfit_acc,0.65
overfit_epoch,20.0
overfit_loss,5.64839


In [21]:
import wandb
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

# Start Wandb
wandb.init(project="FER-CNN", name="improved_cnn_run", config={
    "learning_rate": 0.001,
    "batch_size": 128,
    "epochs": 30,  # Increased epochs
    "architecture": "ImprovedCNN",
    "optimizer": "Adam"
})

# Watch model gradients and parameters
wandb.watch(model, criterion=criterion, log="all", log_freq=100)

# Loss, optimizer, and scheduler
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

# Early stopping parameters
best_val_acc = 0.0
patience = 5
patience_counter = 0

# Metrics for logging
conf_matrix = MulticlassConfusionMatrix(num_classes=7).to(device)
precision = MulticlassPrecision(num_classes=7, average=None).to(device)
recall = MulticlassRecall(num_classes=7, average=None).to(device)
f1 = MulticlassF1Score(num_classes=7, average=None).to(device)

for epoch in range(30):
    model.train()
    train_loss, correct = 0, 0
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        correct += (outputs.argmax(1) == labels).sum().item()

    train_acc = correct / len(train_loader.dataset)

    # Validation phase
    model.eval()
    val_loss, val_correct = 0, 0
    all_preds, all_labels = [], []
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)

            val_loss += loss.item()
            preds = outputs.argmax(1)
            val_correct += (preds == labels).sum().item()
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

            # Update metrics
            conf_matrix.update(preds, labels)
            precision.update(preds, labels)
            recall.update(preds, labels)
            f1.update(preds, labels)

    val_acc = val_correct / len(val_loader.dataset)

    # Compute per-class metrics
    cm = conf_matrix.compute().cpu().numpy()
    prec_scores = precision.compute().cpu().numpy()
    rec_scores = recall.compute().cpu().numpy()
    f1_scores = f1.compute().cpu().numpy()

    # Log metrics to Wandb
    wandb.log({
        "epoch": epoch + 1,
        "train_loss": train_loss / len(train_loader),
        "train_acc": train_acc,
        "val_loss": val_loss / len(val_loader),
        "val_acc": val_acc,
        "learning_rate": optimizer.param_groups[0]['lr'],
        **{f"precision_class_{i}": prec_scores[i] for i in range(7)},
        **{f"recall_class_{i}": rec_scores[i] for i in range(7)},
        **{f"f1_class_{i}": f1_scores[i] for i in range(7)}
    })

    print(f"Epoch {epoch+1} - Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}, Val Loss: {val_loss / len(val_loader):.4f}")

    # Learning rate scheduling
    scheduler.step(val_loss / len(val_loader))

    # Early stopping
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        patience_counter = 0
        torch.save(model.state_dict(), "best_model.pth")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping triggered")
            break

    # Reset metrics for next epoch
    conf_matrix.reset()
    precision.reset()
    recall.reset()
    f1.reset()

# Load best model for evaluation
model.load_state_dict(torch.load("best_model.pth"))

# Final evaluation
model.eval()
all_preds, all_labels = [], []
with torch.no_grad():
    for imgs, labels in val_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        preds = outputs.argmax(1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Confusion matrix
cm = confusion_matrix(all_labels, all_preds)
fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax)
ax.set_xlabel("Predicted")
ax.set_ylabel("True")
ax.set_title("Confusion Matrix")
wandb.log({"confusion_matrix": wandb.Image(fig)})
plt.close(fig)

# Classification report
report = classification_report(all_labels, all_preds, digits=4, target_names=['Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral'])
print(report)
wandb.run.summary["classification_report"] = report

# Sample predictions
model.eval()
for i in range(5):
    img, label = val_dataset[i]
    with torch.no_grad():
        pred = model(img.unsqueeze(0).to(device)).argmax(1).item()
    img_np = img.squeeze().numpy()
    fig, ax = plt.subplots()
    ax.imshow(img_np, cmap="gray")
    ax.set_title(f"Predicted: {pred}, True: {label}")
    ax.axis('off')
    wandb.log({f"Example_{i}": wandb.Image(fig)})
    plt.close(fig)

wandb.finish()

Epoch 1 - Train Acc: 0.2437, Val Acc: 0.2557, Val Loss: 1.8002
Epoch 2 - Train Acc: 0.2631, Val Acc: 0.2870, Val Loss: 1.7648
Epoch 3 - Train Acc: 0.3234, Val Acc: 0.3762, Val Loss: 1.5355
Epoch 4 - Train Acc: 0.3820, Val Acc: 0.4204, Val Loss: 1.4782
Epoch 5 - Train Acc: 0.4134, Val Acc: 0.4326, Val Loss: 1.4167
Epoch 6 - Train Acc: 0.4218, Val Acc: 0.4417, Val Loss: 1.4241
Epoch 7 - Train Acc: 0.4279, Val Acc: 0.4535, Val Loss: 1.3630
Epoch 8 - Train Acc: 0.4367, Val Acc: 0.4580, Val Loss: 1.3522
Epoch 9 - Train Acc: 0.4407, Val Acc: 0.4521, Val Loss: 1.3834
Epoch 10 - Train Acc: 0.4446, Val Acc: 0.4615, Val Loss: 1.3395
Epoch 11 - Train Acc: 0.4497, Val Acc: 0.4678, Val Loss: 1.3217
Epoch 12 - Train Acc: 0.4537, Val Acc: 0.4535, Val Loss: 1.2968
Epoch 13 - Train Acc: 0.4553, Val Acc: 0.4869, Val Loss: 1.2966
Epoch 14 - Train Acc: 0.4618, Val Acc: 0.4845, Val Loss: 1.2759
Epoch 15 - Train Acc: 0.4716, Val Acc: 0.5037, Val Loss: 1.2549
Epoch 16 - Train Acc: 0.4741, Val Acc: 0.5026, Va

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


              precision    recall  f1-score   support

       Angry     0.5654    0.3684    0.4461       399
     Disgust     0.0000    0.0000    0.0000        44
        Fear     0.4340    0.3049    0.3582       410
       Happy     0.8869    0.8255    0.8551       722
         Sad     0.4076    0.6853    0.5112       483
    Surprise     0.8000    0.6562    0.7210       317
     Neutral     0.5561    0.6492    0.5991       496

    accuracy                         0.6022      2871
   macro avg     0.5214    0.4985    0.4987      2871
weighted avg     0.6166    0.6022    0.5973      2871



0,1
epoch,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇███
f1_class_0,▁▁▁▁▁▁▁▁▁▂▁▁▃▃▃▄▅▆▆▇▇▇▇██▇███▇
f1_class_1,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
f1_class_2,▁▁▃▂▄▅▅▄▄▄▆▅▅▅▆▄▆▆▇▆▆▆▇▇█▇████
f1_class_3,▁▂▄▄▆▆▆▇▇▇▇▇▇▇▇▇▇██▇██████████
f1_class_4,▁▂▅▆▆▄▇▄▇▄▇▁▄▇▇▇▇▇▇▇▇██▇█████▇
f1_class_5,▁▃▅▇▆▇▇▇▇▇▇▇▇██▇█▇▇▅█▄██▇██▇▇█
f1_class_6,▁▁▃▅▂▆▂▆▂▆▄▅▆▃▅▅▇▇▇▇█▇▇██▇█▇██
learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
precision_class_0,▁▁▁▁▁▁▁█▂▆▇▄▃▃▄▄▅▄▄▅▄▄▅▄▄▅▅▄▅▅

0,1
classification_report,precis...
epoch,30
f1_class_0,0.43087
f1_class_1,0
f1_class_2,0.34932
f1_class_3,0.86171
f1_class_4,0.46911
f1_class_5,0.73684
f1_class_6,0.58889
learning_rate,0.001
