In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import matplotlib.pyplot as plt
import seaborn as sns
import os
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.amp import GradScaler, autocast  # For mixed precision training
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# Set device and configure workers
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_workers = 0 if torch.cuda.is_available() else 4  # Use 4 workers if on CPU

# Batch size adjustment for Kaggle GPU P100 (smaller batch size to avoid OOM)
batch_size = 32 if torch.cuda.is_available() else 4  # Reduce for GPU, adjust for CPU

# Resize image size and convert to float16 for reduced memory footprint
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),  # Horizontal flip for augmentation
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),  # Data augmentation
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Pretrained ImageNet stats
])

# Load dataset (train and validation directories)
train_data = datasets.ImageFolder(root='/input/train', transform=transform)
val_data = datasets.ImageFolder(root='/input/session_1/val', transform=transform)

# Convert images to float16
def convert_to_float16(image):
    return image.half()  # Convert tensor to half-precision (float16)

# Data Loaders with Image to float16 conversion
class MyDataLoader(DataLoader):
    def __init__(self, dataset, batch_size, shuffle, num_workers):
        super().__init__(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
    
    def collate_fn(self, batch):
        images, labels = zip(*batch)
        images = torch.stack(images)
        images = convert_to_float16(images)  # Convert the batch of images to float16
        labels = torch.tensor(labels, dtype=torch.long)
        return images, labels

train_loader = MyDataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = MyDataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=num_workers)

# Model with dropout
class ViTWithDropout(nn.Module):
    def __init__(self):
        super(ViTWithDropout, self).__init__()
        self.model = models.vit_l_16(weights="IMAGENET1K_V1")
        self.model.heads.head = nn.Sequential(
            nn.Dropout(0.5),  # Dropout before the final layer
            nn.Linear(self.model.heads.head.in_features, 2)  # 2 output classes
        )

    def forward(self, x):
        return self.model(x)

model = ViTWithDropout().to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.5)
scaler = GradScaler() # Mixed Precision Setup

# Paths for model saving and loading
previous_model_path = "input/deepfake_model_session1.pth"  # Path to the previously saved model
current_model_name = "vitL_deepfake_model_session2.pth"  # New name for the current session's model
current_model_path = os.path.join("/working/directory", current_model_name)  # Save new model in the working directory

# Load previous checkpoint if exists
if os.path.exists(previous_model_path):
    print(f"Loading model from {previous_model_path}...")
    checkpoint = torch.load(previous_model_path, weights_only=True)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    best_val_loss = checkpoint['val_loss']
    print(f"Resumed from epoch {start_epoch + 1}")
else:
    start_epoch = 0
    best_val_loss = float("inf")

# Training settings
total_epochs = 30
accumulation_steps = 2  # Accumulate gradients over 4 mini-batches
patience_counter = 0

# Track metrics
train_losses = []
val_losses = []
train_metrics = {'accuracy': [], 'precision': [], 'recall': [], 'f1': []}
val_metrics = {'accuracy': [], 'precision': [], 'recall': [], 'f1': []}

# Early stopping
early_stopping_patience = 3
best_val_loss = float("inf")
patience_counter = 0


# Training loop
for epoch in range(start_epoch, total_epochs):
    print(f"Epoch {epoch + 1}/{total_epochs}")
    train_loss = 0.0
    val_loss = 0.0
    all_train_preds = []
    all_train_labels = []
    all_val_preds = []
    all_val_labels = []

    # Training phase
    model.train()
    optimizer.zero_grad()  # Reset gradients before each training epoch
    train_progress = tqdm(train_loader, desc="Training", leave=False)
    
    for step, (images, labels) in enumerate(train_progress):
        images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)

        # Mixed precision: forward pass and loss computation
        with autocast(device_type='cuda'):
            outputs = model(images)
            loss = criterion(outputs, labels)

        # Scale the loss and backpropagate
        scaler.scale(loss).backward()

        # Accumulate gradients
        if (step + 1) % accumulation_steps == 0 or (step + 1) == len(train_loader):
            scaler.step(optimizer)  # Update model parameters
            scaler.update()  # Update the scaler
            optimizer.zero_grad()  # Reset gradients after the update

        train_loss += loss.item() * images.size(0)
        preds = torch.argmax(outputs, dim=1)
        all_train_preds.extend(preds.cpu().numpy())
        all_train_labels.extend(labels.cpu().numpy())

        train_progress.set_postfix({"Loss": loss.item()})

    train_loss /= len(train_loader.dataset)
    train_losses.append(train_loss)

    # Calculate training metrics
    train_accuracy = accuracy_score(all_train_labels, all_train_preds)
    train_precision = precision_score(all_train_labels, all_train_preds)
    train_recall = recall_score(all_train_labels, all_train_preds)
    train_f1 = f1_score(all_train_labels, all_train_preds)
    train_metrics['accuracy'].append(train_accuracy)
    train_metrics['precision'].append(train_precision)
    train_metrics['recall'].append(train_recall)
    train_metrics['f1'].append(train_f1)

    # Validation phase
    model.eval()
    with torch.no_grad():
        val_progress = tqdm(val_loader, desc="Validating", leave=False)
        for images, labels in val_progress:
            images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)

            # Mixed precision: forward pass and loss computation
            with autocast(device_type='cuda'):
                outputs = model(images)
                loss = criterion(outputs, labels)

            val_loss += loss.item() * images.size(0)

            preds = torch.argmax(outputs, dim=1)
            all_val_preds.extend(preds.cpu().numpy())
            all_val_labels.extend(labels.cpu().numpy())
            val_progress.set_postfix({"Loss": loss.item()})

    val_loss /= len(val_loader.dataset)
    val_losses.append(val_loss)
    scheduler.step(val_loss)

    # Calculate validation metrics
    val_accuracy = accuracy_score(all_val_labels, all_val_preds)
    val_precision = precision_score(all_val_labels, all_val_preds)
    val_recall = recall_score(all_val_labels, all_val_preds)
    val_f1 = f1_score(all_val_labels, all_val_preds)
    val_metrics['accuracy'].append(val_accuracy)
    val_metrics['precision'].append(val_precision)
    val_metrics['recall'].append(val_recall)
    val_metrics['f1'].append(val_f1)

   # Early stopping check
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        # Save the updated model with the current session's name
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss
        }, current_model_path)
        print(f"Model saved as {current_model_name}")
    else:
        patience_counter += 1
        if patience_counter >= early_stopping_patience:
            print("Early stopping triggered.")
            break
            
    # Print epoch summary with metrics
    print(f"Epoch [{epoch+1}/{total_epochs}], Train Loss: {train_loss:.4f}, "
          f"Train Accuracy: {train_accuracy:.4f}, Train F1: {train_f1:.4f}, "
          f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}, Val F1: {val_f1:.4f}")

# Plot Loss
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label="Train Loss")
plt.plot(val_losses, label="Val Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()

# Plot metrics
metrics = ['accuracy', 'precision', 'recall', 'f1']
for metric in metrics:
    plt.figure(figsize=(10, 5))
    plt.plot(range(1, len(train_metrics[metric]) + 1), train_metrics[metric], label=f"Train {metric}")
    plt.plot(range(1, len(val_metrics[metric]) + 1), val_metrics[metric], label=f"Val {metric}")
    plt.xlabel("Epochs")
    plt.ylabel(f"{metric.capitalize()}")
    plt.legend()
    plt.show()
    
# Plot Loss and Metrics
plt.figure(figsize=(15, 8))
epochs_range = range(1, len(train_losses) + 1)

# Loss Plot
plt.subplot(2, 1, 1)
plt.plot(epochs_range, train_losses, label="Train Loss", marker='o')
plt.plot(epochs_range, val_losses, label="Validation Loss", marker='o')
plt.title("Loss vs Epochs")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)

# Metrics Plot
plt.subplot(2, 1, 2)
for metric in ['accuracy', 'precision', 'recall', 'f1']:
    plt.plot(epochs_range, train_metrics[metric], label=f"Train {metric.capitalize()}", linestyle='--')
    plt.plot(epochs_range, val_metrics[metric], label=f"Val {metric.capitalize()}", linestyle='-')
plt.title("Metrics vs Epochs")
plt.xlabel("Epochs")
plt.ylabel("Metrics")
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

# Confusion Matrix for Validation Set
model.eval()  # Ensure model is in evaluation mode
all_val_preds = []
all_val_labels = []

with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
        with autocast(device_type='cuda'):
            outputs = model(images)
        preds = torch.argmax(outputs, dim=1)
        all_val_preds.extend(preds.cpu().numpy())
        all_val_labels.extend(labels.cpu().numpy())

# Compute confusion matrix
cm = confusion_matrix(all_val_labels, all_val_preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=train_data.classes)

# Plot confusion matrix
plt.figure(figsize=(10, 8))
disp.plot(cmap="Blues", values_format='d', ax=plt.gca())
plt.title("Confusion Matrix - Validation Set")
plt.show()