In [6]:
!pip install kaggle --upgrade





In [7]:
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

In [8]:
!kaggle datasets download -d thienkhonghoc/affectnet -p /content

Dataset URL: https://www.kaggle.com/datasets/thienkhonghoc/affectnet
License(s): unknown


In [9]:
!unzip -q /content/affectnet.zip -d /content/affectnet > /dev/null 2>&1

In [10]:
!pip install torch torchvision timm matplotlib tqdm


Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
from collections import Counter
from PIL import Image
import os

# Set Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Function to Filter Corrupt Images
def filter_corrupt_images(dataset):
    valid_samples = []
    for path, label in dataset.samples:
        try:
            img = Image.open(path).convert("RGB")  # Ensure proper format
            valid_samples.append((path, label))
        except Exception as e:
            print(f"Corrupt image removed: {path} - {e}")
    dataset.samples = valid_samples

# Data Augmentation
transform = transforms.Compose([
    transforms.Lambda(lambda img: img.convert("RGB") if isinstance(img, Image.Image) else img),  # Ensure RGB format
    transforms.ToTensor(),
    transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.RandomGrayscale(p=0.1),
    transforms.RandomErasing(p=0.3),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load Dataset
train_data_path = "/content/affectnet/AffectNet/train"
val_data_path = "/content/affectnet/AffectNet/val"

train_dataset = datasets.ImageFolder(root=train_data_path, transform=transform)
val_dataset = datasets.ImageFolder(root=val_data_path, transform=transform)

# Filter Out Corrupt Images Before Training
filter_corrupt_images(train_dataset)
filter_corrupt_images(val_dataset)

# Compute Class Weights
class_counts = Counter(train_dataset.targets)
num_samples = sum(class_counts.values())
weights = [num_samples/class_counts[i] for i in range(len(class_counts))]
weights = torch.tensor(weights, dtype=torch.float).to(device)

# Load Data
batch_size = 16  # Reduced batch size to prevent OOM errors
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

# Load ConvNeXt-Large Model
model = models.convnext_large(weights=models.ConvNeXt_Large_Weights.IMAGENET1K_V1)

# Modify Classifier for 8 Classes
model.classifier = nn.Sequential(
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Flatten(),
    nn.LayerNorm(model.classifier[2].in_features),
    nn.Dropout(0.5),
    nn.Linear(model.classifier[2].in_features, 512),
    nn.ReLU(),
    nn.Dropout(0.4),
    nn.Linear(512, 8)
)

# Load Checkpoint
checkpoint_path = "/content/affectnet_convnext_large_epoch10.pt"
checkpoint = torch.load(checkpoint_path, map_location=device)

# Remove "_orig_mod." keys if necessary
new_checkpoint = {}
for key in checkpoint.keys():
    new_key = key.replace("_orig_mod.", "")
    new_checkpoint[new_key] = checkpoint[key]

# Load State Dict
model.load_state_dict(new_checkpoint, strict=False)

print("Checkpoint successfully loaded! Resuming training from Epoch 11.")

# Move Model to Device
model = model.to(device)

# Define Loss, Optimizer & Scheduler
criterion = nn.CrossEntropyLoss(weight=weights, label_smoothing=0.1)
optimizer = optim.AdamW(model.parameters(), lr=2e-5, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-7)

# Mixed Precision Training
scaler = GradScaler()

# Resume Training from 11th Epoch
print("\nContinuing Fine-tuning from Epoch 11...\n")

for epoch in range(11, 26):
    torch.cuda.empty_cache()  # Prevent memory fragmentation
    model.train()
    running_loss, correct_train, total_train = 0.0, 0, 0

    optimizer.zero_grad()
    batch_count = 0  # Track processed batches

    for images, labels in train_loader:
        try:
            if images is None or labels is None:
                continue  # Skip NoneType images

            images, labels = images.to(device), labels.to(device)

        except Exception as e:
            print(f"Skipping corrupt batch: {e}")
            continue

        with autocast():
            outputs = model(images)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        correct_train += (predicted == labels).sum().item()
        total_train += labels.size(0)

        batch_count += 1
        if batch_count % 50 == 0:  # Print every 50 batches
            print(f"Epoch {epoch}: Processed {batch_count} batches...")

    train_accuracy = 100 * correct_train / total_train
    scheduler.step()

    # Validation Phase
    model.eval()
    correct_val, total_val = 0, 0
    with torch.no_grad():
        for images, labels in val_loader:
            try:
                if images is None or labels is None:
                    continue  # Skip NoneType images

                images, labels = images.to(device), labels.to(device)

            except Exception as e:
                print(f"Skipping corrupt batch in validation: {e}")
                continue

            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            correct_val += (predicted == labels).sum().item()
            total_val += labels.size(0)

    val_accuracy = 100 * correct_val / total_val

    print(f"Epoch [{epoch}/25], Loss: {running_loss:.4f}, Train Acc: {train_accuracy:.2f}%, Val Acc: {val_accuracy:.2f}%")

    # Save Model Every 5 Epochs and every 500 batches
    if epoch % 5 == 0:
        torch.save(model.state_dict(), f"affectnet_convnext_large_epoch{epoch}.pt")
        print(f"Model saved: affectnet_convnext_large_epoch{epoch}.pt")

# Save Final Model
torch.save(model.state_dict(), "affectnet_convnext_large_final.pt")
print("\nTraining complete! Final model saved.")


Checkpoint successfully loaded! Resuming training from Epoch 11.

Continuing Fine-tuning from Epoch 11...



  scaler = GradScaler()
  with autocast():


Epoch 11: Processed 50 batches...
Epoch 11: Processed 100 batches...
Epoch 11: Processed 150 batches...
Epoch 11: Processed 200 batches...
Epoch 11: Processed 250 batches...
Epoch 11: Processed 300 batches...
Epoch 11: Processed 350 batches...
Epoch 11: Processed 400 batches...
Epoch 11: Processed 450 batches...
Epoch 11: Processed 500 batches...
Epoch 11: Processed 550 batches...
Epoch 11: Processed 600 batches...
Epoch 11: Processed 650 batches...
Epoch 11: Processed 700 batches...
Epoch 11: Processed 750 batches...
Epoch 11: Processed 800 batches...
Epoch 11: Processed 850 batches...
Epoch 11: Processed 900 batches...
Epoch 11: Processed 950 batches...
Epoch 11: Processed 1000 batches...
Epoch 11: Processed 1050 batches...
Epoch 11: Processed 1100 batches...
Epoch 11: Processed 1150 batches...
Epoch 11: Processed 1200 batches...
Epoch 11: Processed 1250 batches...
Epoch 11: Processed 1300 batches...
Epoch 11: Processed 1350 batches...
Epoch 11: Processed 1400 batches...
Epoch 11: Pr

KeyboardInterrupt: 

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
from collections import Counter
from PIL import Image
import os

# Set Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Function to Filter Corrupt Images
def filter_corrupt_images(dataset):
    valid_samples = []
    for path, label in dataset.samples:
        try:
            img = Image.open(path).convert("RGB")
            valid_samples.append((path, label))
        except Exception as e:
            print(f"Corrupt image removed: {path} - {e}")
    dataset.samples = valid_samples

# Data Augmentation
transform = transforms.Compose([
    transforms.Lambda(lambda img: img.convert("RGB") if isinstance(img, Image.Image) else img),
    transforms.ToTensor(),
    transforms.RandomResizedCrop(224, scale=(0.85, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
    transforms.RandomGrayscale(p=0.1),
    transforms.RandomErasing(p=0.3),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load Dataset
train_data_path = "/content/affectnet/AffectNet/train"
val_data_path = "/content/affectnet/AffectNet/val"

train_dataset = datasets.ImageFolder(root=train_data_path, transform=transform)
val_dataset = datasets.ImageFolder(root=val_data_path, transform=transform)

# Filter Out Corrupt Images
filter_corrupt_images(train_dataset)
filter_corrupt_images(val_dataset)

# Compute Class Weights
class_counts = Counter(train_dataset.targets)
num_samples = sum(class_counts.values())
weights = [num_samples/class_counts[i] for i in range(len(class_counts))]
weights = torch.tensor(weights, dtype=torch.float).to(device)

# Load Data
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

# Load ConvNeXt-Large Model
model = models.convnext_large(weights=models.ConvNeXt_Large_Weights.IMAGENET1K_V1)

# Modify Classifier for 8 Classes
model.classifier = nn.Sequential(
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Flatten(),
    nn.LayerNorm(model.classifier[2].in_features),
    nn.Dropout(0.6),  # Increased dropout for regularization
    nn.Linear(model.classifier[2].in_features, 512),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(512, 8)
)

# Load Checkpoint (Resume from Epoch 22)
checkpoint_path = "/content/affectnet_convnext_large_epoch20.pt"
checkpoint = torch.load(checkpoint_path, map_location=device)

# Remove "_orig_mod." keys if necessary
new_checkpoint = {key.replace("_orig_mod.", ""): val for key, val in checkpoint.items()}

# Load Weights
model.load_state_dict(new_checkpoint, strict=False)

print("Checkpoint successfully loaded! Resuming training from Epoch 22.")

# Move Model to Device
model = model.to(device)

# Define Loss, Optimizer & Scheduler
criterion = nn.CrossEntropyLoss(weight=weights, label_smoothing=0.05)
optimizer = optim.AdamW(model.parameters(), lr=5e-6, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=3, verbose=True)

# Mixed Precision Training
scaler = GradScaler()

# Early Stopping Parameters
best_val_acc = 0.0
epochs_without_improvement = 0
early_stopping_patience = 3

# Resume Training from Epoch 22
print("\nContinuing Fine-tuning from Epoch 22...\n")

for epoch in range(22, 31):
    torch.cuda.empty_cache()
    model.train()
    running_loss, correct_train, total_train = 0.0, 0, 0

    optimizer.zero_grad()
    batch_count = 0

    for images, labels in train_loader:
        try:
            images, labels = images.to(device), labels.to(device)
        except Exception as e:
            print(f"Skipping corrupt batch: {e}")
            continue

        with autocast():
            outputs = model(images)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Prevent exploding gradients
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        correct_train += (predicted == labels).sum().item()
        total_train += labels.size(0)

        batch_count += 1
        if batch_count % 50 == 0:
            print(f"Epoch {epoch}: Processed {batch_count} batches...")

    train_accuracy = 100 * correct_train / total_train
    scheduler.step(train_accuracy)

    # Validation Phase
    model.eval()
    correct_val, total_val = 0, 0
    with torch.no_grad():
        for images, labels in val_loader:
            try:
                images, labels = images.to(device), labels.to(device)
            except Exception as e:
                print(f"Skipping corrupt batch in validation: {e}")
                continue

            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            correct_val += (predicted == labels).sum().item()
            total_val += labels.size(0)

    val_accuracy = 100 * correct_val / total_val

    print(f"Epoch [{epoch}/30], Loss: {running_loss:.4f}, Train Acc: {train_accuracy:.2f}%, Val Acc: {val_accuracy:.2f}%")

    if epoch % 5 == 0:
        torch.save(model.state_dict(), f"affectnet_convnext_large_epoch{epoch}.pt")
        print(f"Model saved: affectnet_convnext_large_epoch{epoch}.pt")

    # Early Stopping Condition
    if val_accuracy > best_val_acc:
        best_val_acc = val_accuracy
        epochs_without_improvement = 0
    else:
        epochs_without_improvement += 1

    if epochs_without_improvement >= early_stopping_patience:
        print(f"Early stopping triggered. Best validation accuracy: {best_val_acc:.2f}%")
        break

# Save Final Model
torch.save(model.state_dict(), "affectnet_convnext_large_final.pt")
print("\nTraining complete! Final model saved.")


Checkpoint successfully loaded! Resuming training from Epoch 22.

Continuing Fine-tuning from Epoch 22...



  scaler = GradScaler()
  with autocast():


Epoch 22: Processed 50 batches...
Epoch 22: Processed 100 batches...
Epoch 22: Processed 150 batches...
Epoch 22: Processed 200 batches...
Epoch 22: Processed 250 batches...
Epoch 22: Processed 300 batches...
Epoch 22: Processed 350 batches...
Epoch 22: Processed 400 batches...
Epoch 22: Processed 450 batches...
Epoch 22: Processed 500 batches...
Epoch 22: Processed 550 batches...
Epoch 22: Processed 600 batches...
Epoch 22: Processed 650 batches...
Epoch 22: Processed 700 batches...
Epoch 22: Processed 750 batches...
Epoch 22: Processed 800 batches...
Epoch 22: Processed 850 batches...
Epoch 22: Processed 900 batches...
Epoch 22: Processed 950 batches...
Epoch 22: Processed 1000 batches...
Epoch 22: Processed 1050 batches...
Epoch 22: Processed 1100 batches...
Epoch 22: Processed 1150 batches...
Epoch 22: Processed 1200 batches...
Epoch 22: Processed 1250 batches...
Epoch 22: Processed 1300 batches...
Epoch 22: Processed 1350 batches...
Epoch 22: Processed 1400 batches...
Epoch 22: Pr