In [1]:
%%writefile ddp.py

import warnings
warnings.simplefilter("ignore", UserWarning)

import os
import sys
import tempfile
import torch
import json
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
import pandas as pd
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
from PIL import Image, ImageDraw
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import matplotlib.pyplot as plt
import numpy as np
import random
import torch.nn.functional as F
from tqdm import tqdm
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler

# #-------------------------------------------------------------------------------------------
# Initialize the distributed process
def setup_ddp(rank, world_size):
    dist.init_process_group("gloo", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

# Cleanup after training
def cleanup_ddp():
    dist.destroy_process_group()

def cleanup_ddp():
    dist.destroy_process_group()

# #-------------------------------------------------------------------------------------------
def apply_gridmask(image, grid_unit_size_range=(96, 140), keep_ratio=0.6):
    """
    Applies GridMask to the image.
    Args:
        image: A PIL Image to augment.
        grid_unit_size_range: Tuple (min, max) for the size of the grid units in pixels.
        keep_ratio: Ratio of the grid unit area to retain.
    Returns:
        PIL Image with GridMask applied.
    """
    width, height = image.size

    # Choose grid unit size randomly within the range
    grid_unit_size = random.randint(*grid_unit_size_range)

    # Create a binary grid mask
    mask = np.ones((height, width), dtype=np.uint8)

    d = grid_unit_size
    l = int(d * keep_ratio)  # Length of retained square

    # Random offsets to shift the grid
    delta_x = random.randint(0, d - 1)
    delta_y = random.randint(0, d - 1)

    # Iterate through the grid and apply the mask
    for y in range(delta_y, height, d):
        for x in range(delta_x, width, d):
            mask[y:y + l, x:x + l] = 0

    # Convert mask to 3 channels to match the image
    mask = np.stack([mask] * 3, axis=-1)

    # Convert PIL Image to NumPy array for masking
    image_np = np.array(image, dtype=np.uint8)
    masked_image_np = image_np * mask

    # Convert back to PIL Image
    return Image.fromarray(masked_image_np)

def apply_black_patches(image, facial_data, patch_size_range=(30, 50)):
    """
    Applies black patches on randomly selected facial features.
    Args:
        image: A PIL Image to augment.
        facial_data: Dictionary containing facial feature coordinates.
        patch_size_range: Tuple (min, max) for the size of the patch.
    Returns:
        PIL Image with black patches applied.
    """
    if not facial_data:
        return image

    draw = ImageDraw.Draw(image)

    # Randomly select one or more features to blackout
    features = random.sample(list(facial_data['keypoints'].keys()), random.randint(1, len(facial_data['keypoints'])))

    for feature in features:
        x, y = facial_data['keypoints'][feature]

        # Determine patch size
        patch_size = random.randint(*patch_size_range)

        # Calculate patch boundaries
        left = max(0, x - patch_size // 2)
        top = max(0, y - patch_size // 2)
        right = min(image.width, x + patch_size // 2)
        bottom = min(image.height, y + patch_size // 2)

        # Draw black patch
        draw.rectangle([left, top, right, bottom], fill=(0, 0, 0))

    return image

def get_grid_dropout_prob(epoch, max_epochs, initial_prob=0.2, final_prob=0.8):
    return initial_prob + (final_prob - initial_prob) * (epoch / max_epochs)

# Define Custom Dataset class
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe, transform=None, grid_dropout_prob=0.2, face_data_train_fake=None, face_data_train_real=None):
        """
        Args:
            dataframe: Pandas dataframe with 'image_path' and 'label' columns.
            albumentations_transform: Albumentations transformations to be applied.
            grid_dropout_prob: Probability of applying GridMask or black patches to an image.
            grid_unit_size_range: Tuple (min, max) for the size of the grid units in pixels.
            keep_ratio: Ratio of the grid unit area to retain.
            face_data_train_fake: Dictionary containing facial feature data for fake images.
            face_data_train_real: Dictionary containing facial feature data for real images.
        """
        self.dataframe = dataframe
        self.transform = transform
        self.grid_dropout_prob = grid_dropout_prob
        self.face_data_train_fake = face_data_train_fake
        self.face_data_train_real = face_data_train_real

    def update_grid_dropout_prob(self, new_prob):
        self.grid_dropout_prob = new_prob

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        image_path = row['image_path']
        label = row['label']
        image = Image.open(image_path).convert("RGB")

        if self.face_data_train_fake:
            # Select corresponding facial data dictionary based on label
            facial_data = None
            if label == 0 and image_path in self.face_data_train_fake:
                facial_data = self.face_data_train_fake[image_path]
            elif label == 1 and image_path in self.face_data_train_real:
                facial_data = self.face_data_train_real[image_path]
    
            # Apply augmentation with probability
            if random.random() < self.grid_dropout_prob:
                if random.random() < 0.5:
                    image = apply_gridmask(image)
                else:
                    image = apply_black_patches(image, facial_data)

        if self.transform:
            # Convert PIL Image to NumPy array for Albumentations
            image = np.array(image)
    
            # Apply Albumentations transformations
            augmented = self.transform(image=image)
            image = augmented['image']

        return image, label

# #-------------------------------------------------------------------------------------------
class FeatureExtractor(nn.Module):
    def __init__(self, freeze_backbone = True):
        super(FeatureExtractor, self).__init__()
        efficientnet = models.efficientnet_b4(pretrained=True)
        if freeze_backbone:
            for param in efficientnet.parameters():
                param.requires_grad = False
        self.features = efficientnet.features

    def forward(self, x, target_block=None):
        if target_block == 8:
            return self.features(x)
        layers_output = {}
        for idx, layer in enumerate(self.features):
            x = layer(x)
            layers_output[idx] = x
            # Stop processing if the target block is reached
            if target_block is not None and idx == target_block:
                break
        if target_block is None:
            target_block = 8
        return layers_output[target_block]

class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        # First Convolutional Block
        self.conv1 = nn.Conv2d(1792, 1024, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(1024)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.dropout1 = nn.Dropout(0.25)

        # Second Convolutional Block
        self.conv2 = nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(512)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.dropout2 = nn.Dropout(0.25)

        # Third Convolutional Block
        self.conv3 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.dropout3 = nn.Dropout(0.25)

        # Fully Connected Layers
        self.fc1 = nn.Linear(256, 128)
        self.dropout_fc = nn.Dropout(0.5)
        self.fc2 = nn.Linear(128, 1)  # Binary classification

    def forward(self, x):
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        x = self.dropout1(x)
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        x = self.dropout2(x)
        x = self.pool3(F.relu(self.bn3(self.conv3(x))))
        x = self.dropout3(x)
        # Flatten
        x = x.view(x.size(0), -1)
        # Fully Connected Layers
        x = F.relu(self.fc1(x))
        x = self.dropout_fc(x)
        x = torch.sigmoid(self.fc2(x))  # Sigmoid for binary classification
        return x

class DeepFakeClassifier(nn.Module):
    def __init__(self, texture_layer, freeze_backbone):
        super(DeepFakeClassifier, self).__init__()
        self.feature_extractor = FeatureExtractor(freeze_backbone)
        self.classifier = Classifier()
        self.texture_layer = texture_layer

    def forward(self, x):
        x = self.feature_extractor(x, self.texture_layer)
        x = self.classifier(x)
        return x

# #-------------------------------------------------------------------------------------------
class CustomCriterion:
    def __init__(self):
        self.bce_loss = nn.BCELoss()

    def compute_loss(self, outputs, labels):
        return self.bce_loss(outputs, labels)

def train_one_epoch(model, dataloader, optimizer, criterion, device, epoch):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for images, labels in tqdm(dataloader, desc=f"Training Epoch {epoch}"):
        images, labels = images.to(device), labels.to(device, dtype=torch.float32)

        optimizer.zero_grad()
        outputs = model(images).squeeze()
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * images.size(0)
        preds = (outputs > 0.5).float()
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    return total_loss / total, correct / total


def validate_one_epoch(model, dataloader, criterion, device, epoch):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc=f"Validation Epoch {epoch}"):
            images, labels = images.to(device), labels.to(device, dtype=torch.float32)

            outputs = model(images).squeeze()
            loss = criterion(outputs, labels)

            total_loss += loss.item() * images.size(0)
            preds = (outputs > 0.5).float()
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    return total_loss / total, correct / total

# #-------------------------------------------------------------------------------------------
def create_df_from_dir(directory, label, train_val_test, split):
    image_paths = [os.path.join(directory, img) for img in os.listdir(directory)]
    data = {
        "image_path": image_paths,
        "label": [label] * len(image_paths),
        "train_val_test": [train_val_test] * len(image_paths),
        "split": [split] * len(image_paths),
    }
    return pd.DataFrame(data)

def save_to_csv(df, filename):
    file_path = os.path.join(root_path, filename)
    df.to_csv(file_path, index=False)
    print(f"DataFrame saved to {file_path}")

# #-------------------------------------------------------------------------------------------
# Main function for distributed training
def main(rank, world_size, train_dataset, val_dataset, split):
    batch_size = 32
    epochs = 3
    texture_layer = 8
    learning_rate = 0.001
    weight_decay = 1e-5
    freeze_backbone = False

    setup_ddp(rank, world_size)

    device = torch.device(f"cuda:{rank}")

    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
    val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, num_replicas=world_size, rank=rank)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, sampler=val_sampler)

    model = DeepFakeClassifier(texture_layer, freeze_backbone).to(device)

    prefix = f'model_checkpoint_{split}'
    model_checkpoints = [f for f in os.listdir('/kaggle/working/') if f.startswith(prefix)]
    model_checkpoints.sort()
    if len(model_checkpoints) > 0:
        checkpoint = model_checkpoints[-1]
        print(f"Checkpoint found at {checkpoint}. Loading checkpoint: {rank}...")
        model.load_state_dict(torch.load(checkpoint, map_location=device))
        print("Checkpoint loaded successfully.")
    else:
        print(f"No checkpoint found. Starting training from scratch: {rank}")
    
    model = DDP(model, device_ids=[rank])

    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    train_losses = []
    train_accuracies = []
    val_losses = []
    val_accuracies = []
    
    for epoch in range(epochs):
        new_prob = get_grid_dropout_prob(epoch, epochs)
        if train_dataset.grid_dropout_prob != new_prob:
            train_dataset.update_grid_dropout_prob(new_prob)
        train_sampler.set_epoch(epoch)
        train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device, epoch)
        val_loss, val_acc = validate_one_epoch(model, val_loader, criterion, device, epoch)

        if rank == 0:  # Log only from the main process
            print(f"Epoch {epoch + 1}/{epochs}:")
            print(f"\tTrain Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
            print(f"\tVal Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

            train_losses.append(train_loss)
            train_accuracies.append(train_acc)
            val_losses.append(val_loss)
            val_accuracies.append(val_acc)

            final_model_path = os.path.join('/kaggle/working/', prefix+f'_epoch_{"{:02d}".format(epoch)}_acc_{str(val_acc*10000)[:4]}.pt')
            torch.save(model.module.state_dict(), final_model_path)

    if rank == 0:  # Save the final best model at the end of training
        # Plot and save the individual diagrams
        epochs_range = range(1, epochs + 1)

        # Training Loss
        plt.figure()
        plt.plot(epochs_range, train_losses, label='Train Loss')
        plt.title('Training Loss Over Epochs')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.legend()
        train_loss_path = os.path.join('/kaggle/working/', f'training_loss_plot-{split}.png')
        plt.savefig(train_loss_path)
        print(f"Training loss diagram saved to {train_loss_path}")

        # Training Accuracy
        plt.figure()
        plt.plot(epochs_range, train_accuracies, label='Train Accuracy')
        plt.title('Training Accuracy Over Epochs')
        plt.xlabel('Epochs')
        plt.ylabel('Accuracy')
        plt.legend()
        train_accuracy_path = os.path.join('/kaggle/working/', f'training_accuracy_plot-{split}.png')
        plt.savefig(train_accuracy_path)
        print(f"Training accuracy diagram saved to {train_accuracy_path}")

        # Validation Loss
        plt.figure()
        plt.plot(epochs_range, val_losses, label='Validation Loss')
        plt.title('Validation Loss Over Epochs')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.legend()
        val_loss_path = os.path.join('/kaggle/working/', f'validation_loss_plot-{split}.png')
        plt.savefig(val_loss_path)
        print(f"Validation loss diagram saved to {val_loss_path}")

        # Validation Accuracy
        plt.figure()
        plt.plot(epochs_range, val_accuracies, label='Validation Accuracy')
        plt.title('Validation Accuracy Over Epochs')
        plt.xlabel('Epochs')
        plt.ylabel('Accuracy')
        plt.legend()
        val_accuracy_path = os.path.join('/kaggle/working/', f'validation_accuracy_plot-{split}.png')
        plt.savefig(val_accuracy_path)
        print(f"Validation accuracy diagram saved to {val_accuracy_path}")

    cleanup_ddp()

# #-------------------------------------------------------------------------------------------
if __name__ == "__main__":
    argv = sys.argv[1:]
    world_size = torch.cuda.device_count()
    print("\nTraining model on split", argv[0])
    
    transform = A.Compose([
        A.ImageCompression(quality_lower=60, quality_upper=100, p=0.5),
        A.GaussNoise(var_limit=(10.0, 50.0), p=0.1),
        A.GaussianBlur(blur_limit=3, p=0.05),
        A.HorizontalFlip(p=0.5),
        A.OneOf([
            A.Resize(256, 256, interpolation=cv2.INTER_AREA),
            A.Resize(256, 256, interpolation=cv2.INTER_LINEAR),
        ], p=1),
        A.PadIfNeeded(min_height=256, min_width=256, border_mode=cv2.BORDER_CONSTANT, value=0),  # Added `value=0`
        A.OneOf([
            A.RandomBrightnessContrast(),
            A.HueSaturationValue(),
        ], p=0.7),
        A.ToGray(p=0.1),
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=15, border_mode=cv2.BORDER_CONSTANT, value=0),  # Added `value=0`
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])

    transform_val = A.Compose([
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])

    with open('/kaggle/input/face-data/faces_data_train_fake_new.json', 'r') as file:
        face_data_train_fake = json.load(file)
    
    with open('/kaggle/input/face-data/faces_data_train_real_new.json', 'r') as file:
        face_data_train_real = json.load(file)
    
    df_path = '/kaggle/input/face-data/image_paths_dataframe.csv'
    final_df = pd.read_csv(df_path)
    train_datasets = []
    for split in range(1, 6):
        fake_split_df = final_df[(final_df['train_val_test'] == 'train') & (final_df['split'] == split)]
        real_split_df = final_df[(final_df['train_val_test'] == 'train') & (final_df['split'] == 0)]
        combined_df = pd.concat([fake_split_df, real_split_df]).sample(frac=1).reset_index(drop=True)
        split_dataset = CustomDataset(combined_df, transform=transform, face_data_train_fake=face_data_train_fake, face_data_train_real=face_data_train_real)
        train_datasets.append(split_dataset)
    
    val_df = final_df[(final_df['train_val_test'] == 'val')]
    val_dataset = CustomDataset(val_df, transform=transform_val)

    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "3242"

    torch.multiprocessing.spawn(
        main,
        args=(
            world_size,
            train_datasets[int(argv[0])],
            val_dataset,
            argv[0]
        ),
        nprocs=world_size,
        join=True,
    )


Writing ddp.py


In [2]:
!python ddp.py 0


Training model on split 0
Downloading: "https://download.pytorch.org/models/efficientnet_b4_rwightman-23ab8bcd.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b4_rwightman-23ab8bcd.pth
Downloading: "https://download.pytorch.org/models/efficientnet_b4_rwightman-23ab8bcd.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b4_rwightman-23ab8bcd.pth
100%|███████████████████████████████████████| 74.5M/74.5M [00:00<00:00, 148MB/s]
100%|███████████████████████████████████████| 74.5M/74.5M [00:00<00:00, 131MB/s]
No checkpoint found. Starting training from scratch: 0
No checkpoint found. Starting training from scratch: 1
Training Epoch 0: 100%|█████████████████████| 1335/1335 [34:08<00:00,  1.53s/it]
Training Epoch 0: 100%|█████████████████████| 1335/1335 [34:08<00:00,  1.53s/it]
Validation Epoch 0: 100%|███████████████████████| 32/32 [00:13<00:00,  2.45it/s]
Validation Epoch 0: 100%|███████████████████████| 32/32 [00:13<00:00,  2.36it/s]
Epoch 1/3:
	Train Loss: 0.3689, Train Acc: 