In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, WeightedRandomSampler
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import f1_score, precision_score, recall_score
import torch.optim as optim
from torch.utils.data import Dataset
from torchvision import transforms, datasets
from PIL import Image
from tqdm import tqdm

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [6]:
# zip the files in colab on CPU
# !cd /content/drive/MyDrive/Colab_Notebooks/GAIDI/Deepfake/
# !zip -r /content/drive/MyDrive/Deepfake_data.zip data


In [3]:
#  unzip directly into local directory (on A100)
# !unzip /content/drive/MyDrive/Deepfake_data.zip -d /content/

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: /content/content/drive/MyDrive/Colab_Notebooks/GAIDI/Deepfake/data/fake/564_971.mp4_0243.jpg  
  inflating: /content/content/drive/MyDrive/Colab_Notebooks/GAIDI/Deepfake/data/fake/564_971.mp4_0244.jpg  
  inflating: /content/content/drive/MyDrive/Colab_Notebooks/GAIDI/Deepfake/data/fake/564_971.mp4_0245.jpg  
  inflating: /content/content/drive/MyDrive/Colab_Notebooks/GAIDI/Deepfake/data/fake/564_971.mp4_0246.jpg  
  inflating: /content/content/drive/MyDrive/Colab_Notebooks/GAIDI/Deepfake/data/fake/564_971.mp4_0247.jpg  
  inflating: /content/content/drive/MyDrive/Colab_Notebooks/GAIDI/Deepfake/data/fake/564_971.mp4_0248.jpg  
  inflating: /content/content/drive/MyDrive/Colab_Notebooks/GAIDI/Deepfake/data/fake/564_971.mp4_0249.jpg  
  inflating: /content/content/drive/MyDrive/Colab_Notebooks/GAIDI/Deepfake/data/fake/564_971.mp4_0250.jpg  
  inflating: /content/content/drive/MyDrive/Colab_Notebooks/GAIDI/Deepf

In [None]:
!mv /content/local_data/content/drive/MyDrive/Colab_Notebooks/GAIDI/Deepfake/data /content/fixed_data

# Torch Model

## Training modules

In [4]:
## After splitting, we need a way to load images + labels from our lists of X and y — that’s where FrameDataset comes in.

class FrameDataset(Dataset):
    def __init__(self, image_paths, labels, transform = None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self,idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)
        return image, label

In [5]:
# Model
class DeepF_CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            # Block 1
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),  # padding='same' → padding=1 when kernel=3
            nn.ReLU(),
            nn.BatchNorm2d(num_features=32),  # because output channels = 32
            nn.MaxPool2d(kernel_size=2),
            nn.Dropout(p=0.2),

            # Block 2
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(num_features=64),
            nn.MaxPool2d(kernel_size=2),
            nn.Dropout(p=0.3),

            # Block 3
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.AdaptiveAvgPool2d((1, 1)),  # → output shape [batch, 128, 1, 1]
            nn.Flatten(),                  # → [batch, 128]

            nn.Linear(128, 1),
            nn.Sigmoid()
            )
    def forward(self, x):
        return self.net(x)

In [6]:
#Create ES function
class EarlyStopping:
    def __init__(self, patience=4):
        self.patience = patience
        self.counter = 0
        self.best_loss = float('inf')
        self.best_model = None
        self.early_stop = False

    def __call__(self, val_loss, model):
        if val_loss < self.best_loss:
            self.best_loss = val_loss
            self.best_model = model.state_dict() # saves best weight
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

In [17]:
from torch.amp import autocast, GradScaler

def train_one_fold(X_train, y_train, X_val, y_val, transform, device):
    scaler = GradScaler()

    # Datasets
    train_dataset = FrameDataset(X_train, y_train, transform)
    val_dataset   = FrameDataset(X_val, y_val, transform)


    # Weighted sampler
    class_counts = np.bincount(y[train_idx])
    weights = 1. / class_counts
    sample_weights = weights[y[train_idx]]
    sampler = WeightedRandomSampler(sample_weights, len(sample_weights), replacement=True)

    # DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=128, sampler=sampler,
                              num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False,
                            num_workers=2, pin_memory=True)

    # Model, loss, optimizer
    model = DeepF_CNN().to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
    early_stopper = EarlyStopping(patience=5)

    for epoch in range(20):
        model.train()
        running_loss, correct, total = 0.0, 0, 0

        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            images = images.to(device, non_blocking=True)
            labels = labels.float().unsqueeze(1).to(device, non_blocking=True)

            optimizer.zero_grad()
            with autocast(device_type='cuda'):
                outputs = model(images)
                loss = criterion(outputs, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item() * labels.size(0)
            preds = torch.sigmoid(outputs) >= 0.5
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        train_loss = running_loss / total
        train_acc = correct / total

        # Validation
        model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0
        all_preds, all_labels = [], []

        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(device, non_blocking=True)
                labels = labels.float().unsqueeze(1).to(device, non_blocking=True)

                with autocast(device_type='cuda'):
                    outputs = model(images)
                    loss = criterion(outputs, labels)

                val_loss += loss.item() * labels.size(0)
                preds = torch.sigmoid(outputs) >= 0.5
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)

                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        val_loss /= val_total
        val_acc = val_correct / val_total
        f1 = f1_score(all_labels, all_preds)
        precision = precision_score(all_labels, all_preds)
        recall = recall_score(all_labels, all_preds)

        print(f"Epoch {epoch+1} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | "
              f"Val Acc: {val_acc:.4f} | F1: {f1:.4f} | Precision: {precision:.4f} | Recall: {recall:.4f}")

        scheduler.step(val_loss)
        early_stopper(val_loss, model)

        if early_stopper.early_stop:
            print("Early stopping triggered.")
            break

    model.load_state_dict(early_stopper.best_model)
    torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'scaler_state_dict': scaler.state_dict(),  # AMP scaler
}, f'model_fold{fold_idx + 1}.pth')

    return model, all_preds, all_labels


# Prepare data

In [8]:
from collections import Counter

transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
# skf = StratifiedKFold(n_splits=3, shuffle=True, random_state=42) # 3 fold TOO SLOW ON MY LAPTOP

# data = datasets.ImageFolder(root='/content/drive/MyDrive/Colab_Notebooks/GAIDI/Deepfake/data', transform=transform)
data = datasets.ImageFolder(root='/content/local_data/content/drive/MyDrive/Colab_Notebooks/GAIDI/Deepfake/data', transform=transform) #  <== after copying the dataset to colab local disk with unzip
class_counts = Counter(data.targets)

print(f'data has two classes: {data.classes}, there are {len(data)} images(frames) in data, {class_counts[1]} real video frames, {class_counts[0]} fake video frames')

if ((class_counts[0] * 100) / class_counts[1]) < 45 or ((class_counts[0] * 100) / class_counts[1]) > 55:
    print('classes weights are imbalanced, WeightedRandomSampler is required')
else:
    print('classes weights are balanced, no WeightedRandomSampler required.')

Using device: cuda
data has two classes: ['fake', 'real'], there are 19061 images(frames) in data, 8332 real video frames, 10729 fake video frames
classes weights are imbalanced, WeightedRandomSampler is required


In [9]:
X = np.array([s[0] for s in data.samples])
y = np.array([s[1] for s in data.samples])

In [14]:
from sklearn.model_selection import train_test_split
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, stratify=y)

In [15]:
print(f'{X_train.shape}\n{ X_val.shape}\n{y_train.shape}\n{y_val.shape}')

(15248,)
(3813,)
(15248,)
(3813,)


# Train

In [18]:
model = train_one_fold(X_train, y_train, X_val, y_val, transform, device)

Epoch 1: 100%|██████████| 100/100 [01:11<00:00,  1.40it/s]


Epoch 1 | Train Loss: 0.7033 | Val Loss: 0.6777 | Val Acc: 0.4372 | F1: 0.6084 | Precision: 0.4372 | Recall: 1.0000


Epoch 2: 100%|██████████| 100/100 [01:11<00:00,  1.41it/s]


Epoch 2 | Train Loss: 0.6606 | Val Loss: 0.6541 | Val Acc: 0.4372 | F1: 0.6084 | Precision: 0.4372 | Recall: 1.0000


Epoch 3: 100%|██████████| 100/100 [01:12<00:00,  1.38it/s]


Epoch 3 | Train Loss: 0.6419 | Val Loss: 0.6441 | Val Acc: 0.4372 | F1: 0.6084 | Precision: 0.4372 | Recall: 1.0000


Epoch 4: 100%|██████████| 100/100 [01:10<00:00,  1.42it/s]


Epoch 4 | Train Loss: 0.6283 | Val Loss: 0.6488 | Val Acc: 0.4372 | F1: 0.6084 | Precision: 0.4372 | Recall: 1.0000


Epoch 5: 100%|██████████| 100/100 [01:11<00:00,  1.40it/s]


Epoch 5 | Train Loss: 0.6220 | Val Loss: 0.6418 | Val Acc: 0.4372 | F1: 0.6084 | Precision: 0.4372 | Recall: 1.0000


Epoch 6: 100%|██████████| 100/100 [01:10<00:00,  1.42it/s]


Epoch 6 | Train Loss: 0.6191 | Val Loss: 0.6282 | Val Acc: 0.4372 | F1: 0.6084 | Precision: 0.4372 | Recall: 1.0000


Epoch 7: 100%|██████████| 100/100 [01:09<00:00,  1.44it/s]


Epoch 7 | Train Loss: 0.6153 | Val Loss: 0.6294 | Val Acc: 0.4372 | F1: 0.6084 | Precision: 0.4372 | Recall: 1.0000


Epoch 8: 100%|██████████| 100/100 [01:10<00:00,  1.43it/s]


Epoch 8 | Train Loss: 0.6145 | Val Loss: 0.6141 | Val Acc: 0.4372 | F1: 0.6084 | Precision: 0.4372 | Recall: 1.0000


Epoch 9: 100%|██████████| 100/100 [01:10<00:00,  1.42it/s]


Epoch 9 | Train Loss: 0.6030 | Val Loss: 0.6379 | Val Acc: 0.4372 | F1: 0.6084 | Precision: 0.4372 | Recall: 1.0000


Epoch 10: 100%|██████████| 100/100 [01:09<00:00,  1.43it/s]


Epoch 10 | Train Loss: 0.5955 | Val Loss: 0.6025 | Val Acc: 0.4372 | F1: 0.6084 | Precision: 0.4372 | Recall: 1.0000


Epoch 11: 100%|██████████| 100/100 [01:10<00:00,  1.42it/s]


Epoch 11 | Train Loss: 0.5897 | Val Loss: 0.5956 | Val Acc: 0.4372 | F1: 0.6084 | Precision: 0.4372 | Recall: 1.0000


Epoch 12: 100%|██████████| 100/100 [01:10<00:00,  1.43it/s]


Epoch 12 | Train Loss: 0.5876 | Val Loss: 0.6029 | Val Acc: 0.4372 | F1: 0.6084 | Precision: 0.4372 | Recall: 1.0000


Epoch 13: 100%|██████████| 100/100 [01:09<00:00,  1.44it/s]


Epoch 13 | Train Loss: 0.5826 | Val Loss: 0.6012 | Val Acc: 0.4372 | F1: 0.6084 | Precision: 0.4372 | Recall: 1.0000


Epoch 14: 100%|██████████| 100/100 [01:09<00:00,  1.44it/s]


Epoch 14 | Train Loss: 0.5754 | Val Loss: 0.6195 | Val Acc: 0.4372 | F1: 0.6084 | Precision: 0.4372 | Recall: 1.0000


Epoch 15: 100%|██████████| 100/100 [01:10<00:00,  1.43it/s]


Epoch 15 | Train Loss: 0.5660 | Val Loss: 0.5958 | Val Acc: 0.4372 | F1: 0.6084 | Precision: 0.4372 | Recall: 1.0000


Epoch 16: 100%|██████████| 100/100 [01:10<00:00,  1.42it/s]


Epoch 16 | Train Loss: 0.5650 | Val Loss: 0.5736 | Val Acc: 0.4372 | F1: 0.6084 | Precision: 0.4372 | Recall: 1.0000


Epoch 17: 100%|██████████| 100/100 [01:09<00:00,  1.44it/s]


Epoch 17 | Train Loss: 0.5604 | Val Loss: 0.5904 | Val Acc: 0.4372 | F1: 0.6084 | Precision: 0.4372 | Recall: 1.0000


Epoch 18: 100%|██████████| 100/100 [01:09<00:00,  1.43it/s]


Epoch 18 | Train Loss: 0.5585 | Val Loss: 0.5719 | Val Acc: 0.4372 | F1: 0.6084 | Precision: 0.4372 | Recall: 1.0000


Epoch 19: 100%|██████████| 100/100 [01:09<00:00,  1.44it/s]


Epoch 19 | Train Loss: 0.5549 | Val Loss: 0.5768 | Val Acc: 0.4372 | F1: 0.6084 | Precision: 0.4372 | Recall: 1.0000


Epoch 20: 100%|██████████| 100/100 [01:09<00:00,  1.43it/s]


Epoch 20 | Train Loss: 0.5569 | Val Loss: 0.5761 | Val Acc: 0.4372 | F1: 0.6084 | Precision: 0.4372 | Recall: 1.0000


In [20]:
print("Val labels distribution:", Counter(all_labels))
print("Val preds distribution:", Counter(all_preds))
print("Unique preds:", np.unique(all_preds)) #is it only predicting 0, or 1?

NameError: name 'all_labels' is not defined

## save model

In [21]:
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
}, 'model_fold1.pth')         # inside training function


NameError: name 'optimizer' is not defined

## Load model

In [None]:
checkpoint = torch.load('model_fold1.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
scaler.load_state_dict(checkpoint['scaler_state_dict'])
start_epoch = checkpoint['epoch'] + 1
## for epoch in range(start_epoch, start_epoch + n_more_epochs):