In [None]:
from torch import nn, optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
import os
from torch.utils.data import Dataset, ConcatDataset
import torch

use_gpu = True if torch.cuda.is_available() else False
print(use_gpu)

In [None]:
# Custom Dataset Classes
dataset_dir = os.path.join(os.getcwd(), "datasets")

transform = transforms.Compose([
    transforms.Resize((48, 48)),  # Resize images to 224x224
    transforms.ToTensor(),         # Convert images to PyTorch tensors
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize
])

category_mapping = {
    "not_happy": 0,
    "happy": 1
}

class RecategorizedDataset(Dataset):
    def __init__(self, original_dataset, target_class):
        """
        Args:
            original_dataset: The dataset to wrap (e.g., ImageFolder).
            target_class: The class index to map to 1. All other classes will be mapped to 0.
        """
        self.original_dataset = original_dataset
        self.target_class = target_class

    def __len__(self):
        # Return the length of the original dataset
        return len(self.original_dataset)

    def __getitem__(self, idx):
        """
        Args:
            idx: Index of the sample to retrieve.
        Returns:
            A tuple (image, new_label), where new_label is 1 if the original label matches
            the target_class, otherwise 0.
        """
        # Get the original image and label
        image, label = self.original_dataset[idx]
        
        # Map the label: 1 if it matches the target class, otherwise 0
        new_label = 1 if label == self.target_class else 0
        
        return image, new_label
    
import random

class BalancedDataset(Dataset):
    def __init__(self, original_dataset):
        self.original_dataset = original_dataset
        self.filtered_indices = self._balance_dataset()

    def _balance_dataset(self):
        # Separate indices for each class
        class_indices = {}
        for idx, (_, label) in enumerate(self.original_dataset):
            if label not in class_indices:
                class_indices[label] = []
            class_indices[label].append(idx)

        # Find the minimum class size
        min_class_size = min(len(indices) for indices in class_indices.values())

        # Sample indices to balance the dataset
        balanced_indices = []
        for indices in class_indices.values():
            balanced_indices.extend(random.sample(indices, min_class_size))

        random.shuffle(balanced_indices)  # Shuffle the balanced dataset
        return balanced_indices

    def __len__(self):
        return len(self.filtered_indices)

    def __getitem__(self, idx):
        actual_idx = self.filtered_indices[idx]
        return self.original_dataset[actual_idx]

In [None]:
# Load Datasets
# Load AffectNet dataset / 
# Recategorize AffectNet datasets
affectnet_dataset_train = RecategorizedDataset(
    datasets.ImageFolder(os.path.join(dataset_dir, "AffectNet/train"), transform=transform),
    target_class=3
)
affectnet_dataset_test = RecategorizedDataset(
    datasets.ImageFolder(os.path.join(dataset_dir, "AffectNet/test"), transform=transform),
    target_class=3
)

affectnet_dataset_val = RecategorizedDataset(
    datasets.ImageFolder(os.path.join(dataset_dir, "AffectNet/val"), transform=transform),
    target_class=3
)

# Recategorize FER2013 datasets
fer2013_dataset_train = RecategorizedDataset(
    datasets.ImageFolder(os.path.join(dataset_dir, "fer2013/train"), transform=transform),
    target_class=3
)
fer2013_dataset_test = RecategorizedDataset(
    datasets.ImageFolder(os.path.join(dataset_dir, "fer2013/test"), transform=transform),
    target_class=3
)

# Recategorize RAF-DB datasets
raf_db_dataset_train = RecategorizedDataset(
    datasets.ImageFolder(os.path.join(dataset_dir, "raf-db-dataset/train"), transform=transform),
    target_class=3
)
raf_db_dataset_test = RecategorizedDataset(
    datasets.ImageFolder(os.path.join(dataset_dir, "raf-db-dataset/test"), transform=transform),
    target_class=3
)

# Merge Datasets
# Merge training datasets
merged_dataset_train = BalancedDataset(ConcatDataset([affectnet_dataset_train, fer2013_dataset_train, raf_db_dataset_train]))

# Merge test datasets
merged_dataset_test = BalancedDataset(ConcatDataset([affectnet_dataset_test, fer2013_dataset_test, raf_db_dataset_test]))

In [None]:
# Step 1: Load the Pre-Trained EfficientNet Model
efficientnet = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_efficientnet_b0', pretrained=True)

# Step 2: Modify the Model
# Replace the final classification layer to match the number of classes in your dataset
num_classes = 2  # Example: Binary classification
efficientnet.classifier.fc = nn.Linear(efficientnet.classifier.fc.in_features, num_classes)

# Move the model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
efficientnet = efficientnet.to(device)

# Step 3: Prepare the Dataset and DataLoader
batch_size = 32
train_loader = DataLoader(merged_dataset_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(merged_dataset_test, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

# Step 4: Define the Optimizer, Loss Function, and Scheduler
lr = 0.001
optimizer = optim.Adam(efficientnet.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)  # Reduce LR every 5 epochs

# Step 5: Initialize TensorBoard for Logging
writer = SummaryWriter("runs/efficientnet_finetune")

# Step 6: Define Validation Function
def validate(model, test_loader, criterion):
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    average_loss = test_loss / len(test_loader)
    accuracy = 100 * correct / total
    print(f"Validation Loss: {average_loss:.4f}, Accuracy: {accuracy:.2f}%")
    return average_loss, accuracy

# Step 7: Fine-Tune the Model
num_epochs = 10
scaler = torch.cuda.amp.GradScaler("cuda")  # For mixed precision training
best_accuracy = 0.0
early_stopping_patience = 3
no_improvement_epochs = 0

for epoch in range(num_epochs):
    efficientnet.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)

        # Forward pass with mixed precision
        with torch.cuda.amp.autocast("cuda"):
            outputs = efficientnet(images)
            loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # Track loss and accuracy
        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        # Print training progress
        if (i + 1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], "
                  f"Loss: {loss.item():.4f}, Accuracy: {100 * correct / total:.2f}%")

    # Log training metrics
    train_loss = running_loss / len(train_loader)
    train_accuracy = 100 * correct / total
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}, Accuracy: {train_accuracy:.2f}%")
    writer.add_scalar("Loss/Train", train_loss, epoch)
    writer.add_scalar("Accuracy/Train", train_accuracy, epoch)

    # Validation
    val_loss, val_accuracy = validate(efficientnet, test_loader, criterion)
    writer.add_scalar("Loss/Validation", val_loss, epoch)
    writer.add_scalar("Accuracy/Validation", val_accuracy, epoch)

    # Save the best model
    if val_accuracy > best_accuracy:
        best_accuracy = val_accuracy
        no_improvement_epochs = 0
        os.makedirs("models", exist_ok=True)
        torch.save(efficientnet.state_dict(), "models/best_fine_tuned_efficientnet_b0.pth")
        print(f"Saved Best Model with Accuracy: {best_accuracy:.2f}%")
    else:
        no_improvement_epochs += 1

    # Early stopping
    if no_improvement_epochs >= early_stopping_patience:
        print("Early stopping triggered.")
        break

    # Step the scheduler
    scheduler.step()

# Close TensorBoard writer
writer.close()

# Final message
print("Training complete.")