In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, WeightedRandomSampler
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import StratifiedGroupKFold
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import copy

# === Focal Loss ===
class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, weight=None):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.weight = weight

    def forward(self, input, target):
        logpt = nn.functional.log_softmax(input, dim=1)
        pt = torch.exp(logpt)
        logpt = (1 - pt) ** self.gamma * logpt
        loss = nn.functional.nll_loss(logpt, target, weight=self.weight)
        return loss

# === Config ===
data_dir = r"C:\Users\assen\Downloads\My_preposed\Preprocessed"
batch_size = 32
num_epochs = 50
patience = 7
learning_rate = 0.001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === Transforms ===
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# === Dataset & Split ===
dataset = datasets.ImageFolder(root=data_dir)
targets = dataset.targets

# Extract group ID from filename like 'P6_Sad_FLIR1001234_face1'
group_labels = [os.path.basename(path).split('_')[0] for path, _ in dataset.samples]  # e.g., 'P6'
groups = [int(label[1:]) for label in group_labels]  # remove 'P' and convert to int

# Stratified Group K-Fold
skf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=42)
train_idx, val_idx = next(skf.split(np.zeros(len(targets)), targets, groups))

train_dataset = copy.deepcopy(dataset)
val_dataset = copy.deepcopy(dataset)
train_dataset.samples = [dataset.samples[i] for i in train_idx]
val_dataset.samples = [dataset.samples[i] for i in val_idx]
train_dataset.targets = [targets[i] for i in train_idx]
val_dataset.targets = [targets[i] for i in val_idx]
train_dataset.transform = train_transform
val_dataset.transform = val_transform

# === Weighted Sampling ===
class_counts = np.bincount(train_dataset.targets)
sample_weights = [1.0 / class_counts[label] for label in train_dataset.targets]
sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

# === Dataloaders ===
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)

# === Model ===
model = models.googlenet(pretrained=True, aux_logits=True)  # use aux_logits=True for training

for param in model.parameters():
    param.requires_grad = False
for param in model.inception5b.parameters():
    param.requires_grad = True

model.fc = nn.Linear(model.fc.in_features, len(dataset.classes))
model.aux1.fc2 = nn.Linear(model.aux1.fc2.in_features, len(dataset.classes))
model.aux2.fc2 = nn.Linear(model.aux2.fc2.in_features, len(dataset.classes))
model = model.to(device)

# === Loss, Optimizer, Scheduler ===
criterion = FocalLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

# === Training Loop ===
best_model_wts = copy.deepcopy(model.state_dict())
best_loss = float('inf')
epochs_no_improve = 0
train_losses, val_losses = [], []
train_accuracies, val_accuracies = [], []

print("ðŸš€ Starting training...\n")
for epoch in range(num_epochs):
    print(f"\nEpoch [{epoch+1}/{num_epochs}]")
    model.train()
    running_loss, correct, total = 0.0, 0, 0

    for images, labels in tqdm(train_loader, desc="Training", leave=False):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs, aux1, aux2 = model(images)

        loss_main = criterion(outputs, labels)
        loss_aux1 = criterion(aux1, labels)
        loss_aux2 = criterion(aux2, labels)
        loss = loss_main + 0.3 * loss_aux1 + 0.3 * loss_aux2

        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    train_loss = running_loss / total
    train_acc = correct / total
    train_losses.append(train_loss)
    train_accuracies.append(train_acc)

    # === Validation ===
    model.eval()
    val_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            val_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    val_loss /= total
    val_acc = correct / total
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)

    print(f"ðŸ“Š Epoch [{epoch+1}/{num_epochs}] | Train Loss: {train_loss:.4f}, Acc: {train_acc*100:.2f}% | "
          f"Val Loss: {val_loss:.4f}, Acc: {val_acc*100:.2f}%")

    # === Checkpointing ===
    if val_loss < best_loss:
        best_loss = val_loss
        best_model_wts = copy.deepcopy(model.state_dict())
        torch.save(model.state_dict(), "best_googlenet_model.pth")
        print("âœ… Checkpoint saved.")
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print("ðŸ›‘ Early stopping triggered.")
            break

    scheduler.step()
    torch.cuda.empty_cache()

# === Load Best Model ===
model.load_state_dict(best_model_wts)

# === Final Evaluation ===
model.eval()
y_true, y_pred = [], []
with torch.no_grad():
    for images, labels in val_loader:
        images = images.to(device)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        y_true.extend(labels.tolist())
        y_pred.extend(preds.cpu().tolist())

final_acc = np.mean(np.array(y_true) == np.array(y_pred))
print(f"\nâœ… Final Validation Accuracy: {final_acc*100:.2f}%")
print("\nðŸ“Š Final Confusion Matrix:")
print(confusion_matrix(y_true, y_pred))
print(classification_report(y_true, y_pred, target_names=dataset.classes))

# === Plot Results ===
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.title('Loss per Epoch')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_accuracies, label='Train Acc')
plt.plot(val_accuracies, label='Val Acc')
plt.title('Accuracy per Epoch')
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()

plt.tight_layout()
plt.show()