In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, WeightedRandomSampler, Subset
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import copy

# === Focal Loss ===
class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, weight=None):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.weight = weight

    def forward(self, input, target):
        logpt = nn.functional.log_softmax(input, dim=1)
        pt = torch.exp(logpt)
        logpt = (1 - pt) ** self.gamma * logpt
        return nn.functional.nll_loss(logpt, target, weight=self.weight)

# === Config ===
data_dir = r"C:\Users\assen\Downloads\CROP\CROP"  # <<<--- Your RGB data directory
batch_size = 32
num_epochs = 50
patience = 7
learning_rate = 0.001
n_splits = 5 # Number of folds
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === Transforms ===
# Note: No data augmentation in val_transform
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
])
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# === Load Full Dataset Once ===
full_dataset = datasets.ImageFolder(root=data_dir)
targets = np.array(full_dataset.targets)
groups = np.array([int(os.path.basename(path).split('_')[0][1:]) for path, _ in full_dataset.samples])

# === K-Fold Cross-Validation Setup ===
skf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=42)
fold_accuracies = []

# === Main Loop for K-Fold Cross-Validation ===
for fold, (train_idx, val_idx) in enumerate(skf.split(np.zeros(len(targets)), targets, groups)):
    print(f"\n{'='*20} FOLD {fold+1}/{n_splits} {'='*20}\n")

    # --- Create datasets and dataloaders for the current fold ---
    train_subset = Subset(full_dataset, train_idx)
    val_subset = Subset(full_dataset, val_idx)

    # Apply the correct transforms
    train_subset.dataset.transform = train_transform
    val_subset.dataset.transform = val_transform

    # Weighted sampler for the training subset
    train_targets = targets[train_idx]
    class_counts = np.bincount(train_targets)
    sample_weights = np.array([1.0 / class_counts[label] for label in train_targets])
    sampler = WeightedRandomSampler(torch.from_numpy(sample_weights).double(), num_samples=len(sample_weights), replacement=True)

    train_loader = DataLoader(train_subset, batch_size=batch_size, sampler=sampler, pin_memory=True)
    val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, pin_memory=True)

    # --- Initialize a fresh model and optimizer for this fold ---
    model = models.efficientnet_b0(pretrained=True)
    for param in model.parameters():
        param.requires_grad = False
    
    num_ftrs = model.classifier[1].in_features
    model.classifier = nn.Sequential(
        nn.Dropout(p=0.4),
        nn.Linear(num_ftrs, len(full_dataset.classes))
    )
    model = model.to(device)

    optimizer = optim.Adam(model.classifier.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    criterion = FocalLoss()
    
    # --- Training Loop for the current fold ---
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    epochs_no_improve = 0

    for epoch in range(num_epochs):
        print(f"\nEpoch [{epoch+1}/{num_epochs}]")
        model.train()
        for images, labels in tqdm(train_loader, desc="Training", leave=False):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        model.eval()
        val_correct, val_total = 0, 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, preds = torch.max(outputs, 1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)
        
        val_acc = val_correct / val_total
        print(f"ðŸ“Š Validation Accuracy: {val_acc*100:.2f}%")

        if val_acc > best_acc:
            best_acc = val_acc
            best_model_wts = copy.deepcopy(model.state_dict())
            torch.save(model.state_dict(), f"best_model_fold_{fold+1}.pth")
            print("âœ… Checkpoint saved.")
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"ðŸ›‘ Early stopping triggered after {patience} epochs with no improvement.")
                break
        
        scheduler.step()

    fold_accuracies.append(best_acc)

# === Final Results ===
mean_accuracy = np.mean(fold_accuracies)
std_accuracy = np.std(fold_accuracies)
print(f"\n\n{'='*20} K-FOLD CROSS-VALIDATION COMPLETE {'='*20}")
print(f"Individual Fold Accuracies: {[f'{acc*100:.2f}%' for acc in fold_accuracies]}")
print(f"âœ… Average Validation Accuracy: {mean_accuracy*100:.2f}% Â± {std_accuracy*100:.2f}%")