In [6]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader
from tqdm import tqdm
from collections import Counter
import pickle  # For saving metrics

# Use GPU if available, otherwise fallback to CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define a Simplified Model (ResNet18)
class SimpleResNet(nn.Module):
    def __init__(self, num_classes=2):
        super(SimpleResNet, self).__init__()
        self.backbone = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        self.backbone.fc = nn.Linear(self.backbone.fc.in_features, num_classes)

    def forward(self, x):
        return self.backbone(x)

# Initialize Model
model = SimpleResNet(num_classes=2).to(device)

# Data Transformations with Augmentation
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Load Datasets
train_dataset = datasets.ImageFolder("C:/Users/Vikram/DFDC/data/final/train", transform=transform)
val_dataset = datasets.ImageFolder("C:/Users/Vikram/DFDC/data/final/val", transform=transform)

# Use DataLoader
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=2)

# Handle Class Imbalance with Weighted Loss
num_classes = 2  # Number of classes (real and fake)
train_counts = Counter([label for _, label in train_dataset])

# Ensure all classes are represented
class_sample_counts = [train_counts.get(cls, 0) for cls in range(num_classes)]

# Avoid division by zero
total_samples = sum(class_sample_counts)
class_weights = [total_samples / (count if count > 0 else 1) for count in class_sample_counts]

# Convert to tensor and move to the device
class_weights = torch.tensor(class_weights).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights)

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Training Loop with Metrics Logging
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10):
    # Initialize lists to store metrics
    train_losses, val_losses = [], []
    train_accuracies = []  # Correctly initialized
    val_accuracies = []    # Correctly initialized

    best_val_loss = float("inf")  # Initialize the best validation loss
    best_epoch = 0  # Initialize the best epoch

    for epoch in range(num_epochs):
        # Training Phase
        model.train()
        train_loss, train_acc = 0, 0
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Training]"):
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            train_acc += torch.sum(preds == labels).item()

        # Calculate average training loss and accuracy
        avg_train_loss = train_loss / len(train_loader)
        avg_train_acc = train_acc / len(train_loader.dataset)
        train_losses.append(avg_train_loss)
        train_accuracies.append(avg_train_acc)

        # Validation Phase
        model.eval()
        val_loss, val_acc = 0, 0
        with torch.no_grad():
            for images, labels in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Validation]"):
                images, labels = images.to(device), labels.to(device)

                outputs = model(images)
                loss = criterion(outputs, labels)

                val_loss += loss.item()
                _, preds = torch.max(outputs, 1)
                val_acc += torch.sum(preds == labels).item()

        # Calculate average validation loss and accuracy
        avg_val_loss = val_loss / len(val_loader)
        avg_val_acc = val_acc / len(val_loader.dataset)
        val_losses.append(avg_val_loss)
        val_accuracies.append(avg_val_acc)

        # Print Metrics
        print(f"Epoch {epoch+1}/{num_epochs}: "
              f"Train Loss: {avg_train_loss:.4f}, Train Acc: {avg_train_acc:.4f}, "
              f"Val Loss: {avg_val_loss:.4f}, Val Acc: {avg_val_acc:.4f}")

        # Save the best model checkpoint
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_epoch = epoch + 1
            torch.save(model.state_dict(), "m2tr_best_model.pth")
            print(f"New Best Model Saved at Epoch {best_epoch} with Val Loss: {best_val_loss:.4f}")

    # Save metrics to a file for plotting later
    with open("metrics.pkl", "wb") as f:
        pickle.dump((train_losses, train_accuracies, val_losses, val_accuracies), f)
        print("Metrics saved to 'metrics.pkl'")

    return train_losses, train_accuracies, val_losses, val_accuracies

# Train the model
train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10)


Using device: cpu


Epoch 1/10 [Training]: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:08<00:00,  8.58s/it]
Epoch 1/10 [Validation]: 100%|███████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.53s/it]


Epoch 1/10: Train Loss: 0.7977, Train Acc: 0.6000, Val Loss: 0.4957, Val Acc: 0.8889
New Best Model Saved at Epoch 1 with Val Loss: 0.4957


Epoch 2/10 [Training]: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.72s/it]
Epoch 2/10 [Validation]: 100%|███████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.78s/it]


Epoch 2/10: Train Loss: 0.7285, Train Acc: 0.6000, Val Loss: 0.4892, Val Acc: 1.0000
New Best Model Saved at Epoch 2 with Val Loss: 0.4892


Epoch 3/10 [Training]: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.96s/it]
Epoch 3/10 [Validation]: 100%|███████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.61s/it]


Epoch 3/10: Train Loss: 0.6750, Train Acc: 0.7000, Val Loss: 0.5270, Val Acc: 1.0000


Epoch 4/10 [Training]: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.73s/it]
Epoch 4/10 [Validation]: 100%|███████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.55s/it]


Epoch 4/10: Train Loss: 0.6953, Train Acc: 0.5000, Val Loss: 0.5476, Val Acc: 0.8889


Epoch 5/10 [Training]: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.76s/it]
Epoch 5/10 [Validation]: 100%|███████████████████████████████████████████████████████████| 1/1 [00:06<00:00,  6.33s/it]


Epoch 5/10: Train Loss: 0.5658, Train Acc: 0.9000, Val Loss: 0.5804, Val Acc: 0.7778


Epoch 6/10 [Training]: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:06<00:00,  6.29s/it]
Epoch 6/10 [Validation]: 100%|███████████████████████████████████████████████████████████| 1/1 [00:06<00:00,  6.33s/it]


Epoch 6/10: Train Loss: 0.5626, Train Acc: 0.8000, Val Loss: 0.5975, Val Acc: 0.6667


Epoch 7/10 [Training]: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:06<00:00,  6.40s/it]
Epoch 7/10 [Validation]: 100%|███████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.62s/it]


Epoch 7/10: Train Loss: 0.4543, Train Acc: 0.8000, Val Loss: 0.7669, Val Acc: 0.2222


Epoch 8/10 [Training]: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.88s/it]
Epoch 8/10 [Validation]: 100%|███████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.63s/it]


Epoch 8/10: Train Loss: 0.4522, Train Acc: 0.8000, Val Loss: 0.7577, Val Acc: 0.3333


Epoch 9/10 [Training]: 100%|█████████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.66s/it]
Epoch 9/10 [Validation]: 100%|███████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.95s/it]


Epoch 9/10: Train Loss: 0.3772, Train Acc: 1.0000, Val Loss: 0.7434, Val Acc: 0.3333


Epoch 10/10 [Training]: 100%|████████████████████████████████████████████████████████████| 1/1 [00:06<00:00,  6.28s/it]
Epoch 10/10 [Validation]: 100%|██████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.65s/it]

Epoch 10/10: Train Loss: 0.3630, Train Acc: 1.0000, Val Loss: 0.8273, Val Acc: 0.4444
Metrics saved to 'metrics.pkl'





([0.7977234125137329,
  0.728481650352478,
  0.6749986410140991,
  0.6953437328338623,
  0.5657631754875183,
  0.5626345276832581,
  0.4542670249938965,
  0.4522484838962555,
  0.3771844208240509,
  0.3629794716835022],
 [0.6, 0.6, 0.7, 0.5, 0.9, 0.8, 0.8, 0.8, 1.0, 1.0],
 [0.49570000171661377,
  0.4892173409461975,
  0.5270419120788574,
  0.5475667119026184,
  0.5804303288459778,
  0.5975461006164551,
  0.7669001221656799,
  0.7577447295188904,
  0.7434298992156982,
  0.8272726535797119],
 [0.8888888888888888,
  1.0,
  1.0,
  0.8888888888888888,
  0.7777777777777778,
  0.6666666666666666,
  0.2222222222222222,
  0.3333333333333333,
  0.3333333333333333,
  0.4444444444444444])