In [2]:
import os
import torch
import torch.nn as nn
import torch.fft
import math
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix, ConfusionMatrixDisplay
from PIL import Image, ImageOps
import matplotlib.pyplot as plt
import numpy as np
import shutil
import random
from torchvision.utils import make_grid
import torch.nn.functional as F
from tqdm.notebook import tqdm
import cv2

In [3]:
batch_size = 16
epochs = 10
learning_rate = 0.001
weight_decay = 1e-5
freeze_backbone = True
img_size = 256
depth = 4
dropout = 0.5

In [4]:
# Define data paths
root_path = '/kaggle/working'
data_dir = "/kaggle/input/deepfake/DFWILD"
train_fake_dir = os.path.join(data_dir, "train_fake", "fake")
train_real_dir = os.path.join(data_dir, "train_real")
test_fake_dir = os.path.join(data_dir, "valid_fake", "fake")
test_real_dir = os.path.join(data_dir, "valid_real", "real")

In [5]:
# Paths to the working directories
working_fake_dir = os.path.join(root_path, "fake")
working_real_dir = os.path.join(root_path, "real")

if not (os.path.exists(working_fake_dir) and os.path.exists(working_real_dir)):
    # Create the working directories if they don't exist
    os.makedirs(working_fake_dir, exist_ok=True)
    os.makedirs(working_real_dir, exist_ok=True)

    # Get the list of all images in train_fake and train_real directories
    fake_images = [f for f in os.listdir(train_fake_dir) if f.endswith((".png", ".jpg", ".jpeg"))]
    real_images = [f for f in os.listdir(train_real_dir) if f.endswith((".png", ".jpg", ".jpeg"))]

    # Determine the number of real images
    num_real_images = len(real_images)

    # Randomly sample a subset of fake images equal to the number of real images
    sampled_fake_images = random.sample(fake_images, num_real_images)

    # Copy the sampled fake images to the working_fake_dir
    for image in tqdm(sampled_fake_images, desc="Copying sampled fake images"):
        src_path = os.path.join(train_fake_dir, image)
        dst_path = os.path.join(working_fake_dir, image)
        shutil.copy(src_path, dst_path)

    # Copy all real images to the working_real_dir
    for image in tqdm(real_images, desc="Copying real images"):
        src_path = os.path.join(train_real_dir, image)
        dst_path = os.path.join(working_real_dir, image)
        shutil.copy(src_path, dst_path)

train_fake_dir = working_fake_dir
train_real_dir = working_real_dir

In [6]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

### Without face cropping

In [7]:
# Custom Dataset Loassertader for Real and Fake Images
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, image_dir, label, transform=None):
        self.image_paths = [os.path.join(image_dir, fname) for fname in os.listdir(image_dir) if fname.endswith(('.png', '.jpg', '.jpeg'))]
        self.label = label
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, self.label

### With face cropping

In [8]:
# class CustomDataset(torch.utils.data.Dataset):
#     def __init__(self, image_dir, label, transform=None, cascade_path="haarcascade_frontalface_default.xml"):
#         self.image_paths = [os.path.join(image_dir, fname) for fname in os.listdir(image_dir) if fname.endswith(('.png', '.jpg', '.jpeg'))]
#         self.label = label
#         self.transform = transform
#         self.face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + cascade_path)
#         self.default_image = Image.new('RGB', (224, 224), color='gray')  # Default placeholder image

#     def __len__(self):
#         return len(self.image_paths)

#     def __getitem__(self, idx):
#         image_path = self.image_paths[idx]
#         # image = Image.open(image_path).convert("RGB")
#         # if self.transform:
#         #     image = self.transform(image)
#         image = cv2.imread(image_path)
#         if image is None:
#             print(f"Error: Unable to read the image file at {image_path}. Returning default image.")
#             return self.default_image, self.label

#         # Remove the background and make it black
#         gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
#         _, mask = cv2.threshold(gray, 10, 255, cv2.THRESH_BINARY)
#         black_background = np.zeros_like(image)
#         image = cv2.bitwise_and(image, image, mask=mask)

#         # Convert the image to grayscale for face detection
#         gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

#         # Detect faces in the image
#         faces = self.face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))

#         if len(faces) == 0:
#             print(f"No faces detected in the image at {image_path}. Returning default image.")
#             if self.transform:
#                 return self.transform(self.default_image), self.label
#             return self.default_image, self.label

#         # Assume the largest detected face is the main face
#         x, y, w, h = max(faces, key=lambda rect: rect[2] * rect[3])

#         # Crop the face from the image
#         cropped_face = image[y:y+h, x:x+w]

#         # Convert the cropped face to PIL Image
#         cropped_face = cv2.cvtColor(cropped_face, cv2.COLOR_BGR2RGB)
#         cropped_face_image = Image.fromarray(cropped_face)

#         # Apply transformations if any
#         if self.transform:
#             cropped_face_image = self.transform(cropped_face_image)

#         return cropped_face_image, self.label


In [9]:
# Load datasets using CustomDataset for both real and fake images
train_real_dataset = CustomDataset(image_dir=train_real_dir, label=1, transform=transform)
train_fake_dataset = CustomDataset(image_dir=train_fake_dir, label=0, transform=transform)

train_dataset = torch.utils.data.ConcatDataset([train_fake_dataset, train_real_dataset])

# Split train_dataset into training and validation datasets
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

# Create test dataset
test_real_dataset = CustomDataset(image_dir=test_real_dir, label=1, transform=transform)
test_fake_dataset = CustomDataset(image_dir=test_fake_dir, label=0, transform=transform)

test_dataset = torch.utils.data.ConcatDataset([test_fake_dataset, test_real_dataset])

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

In [10]:
# def show_images_from_dataloader(dataloader, grids=3, grid_size=3, title="Dataset Images"):
#     """
#     Display a specified number of grids, each showing a random selection of images from the dataloader.

#     Parameters:
#     - dataloader (DataLoader): The PyTorch DataLoader to sample images from.
#     - grids (int): Number of grids to display.
#     - grid_size (int): Size of the grid (grid_size x grid_size images per grid).
#     - title (str): Title for the grids.
#     """
#     for grid in range(grids):
#         # Sample a batch of images
#         data_iter = iter(dataloader)
#         images, labels = next(data_iter)

#         # Select random images for the grid
#         selected_indices = random.sample(range(len(images)), grid_size * grid_size)
#         selected_images = [images[idx] for idx in selected_indices]

#         # Create a grid of images
#         grid_images = make_grid(selected_images, nrow=grid_size, normalize=True, pad_value=1)

#         # Convert to numpy for display
#         np_grid_images = grid_images.permute(1, 2, 0).cpu().numpy()

#         # Display the grid
#         plt.figure(figsize=(8, 8))
#         plt.imshow(np_grid_images)
#         plt.axis('off')
#         plt.title(f"{title} - Grid {grid + 1}")
#         plt.show()

# # Display 9 random images in 3 grids for train, validation, and test loaders
# print("Train Loader Grids:")
# show_images_from_dataloader(train_loader, grids=3, grid_size=3, title="Train Loader")

# print("Validation Loader Grids:")
# show_images_from_dataloader(val_loader, grids=3, grid_size=3, title="Validation Loader")

# print("Test Loader Grids:")
# show_images_from_dataloader(test_loader, grids=3, grid_size=3, title="Test Loader")


In [11]:
class BaseNetwork(nn.Module):
    def __init__(self):
        super(BaseNetwork, self).__init__()

    def print_network(self):
        if isinstance(self, list):
            self = self[0]
        num_params = 0
        for param in self.parameters():
            num_params += param.numel()
        print(
            'Network [%s] was created. Total number of parameters: %.1f million. '
            'To see the architecture, do print(network).'
            % (type(self).__name__, num_params / 1000000)
        )

    def init_weights(self, init_type='normal', gain=0.02):
        '''
        initialize network's weights
        init_type: normal | xavier | kaiming | orthogonal
        https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
        '''

        def init_func(m):
            classname = m.__class__.__name__
            if classname.find('InstanceNorm2d') != -1:
                if hasattr(m, 'weight') and m.weight is not None:
                    nn.init.constant_(m.weight.data, 1.0)
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.constant_(m.bias.data, 0.0)
            elif hasattr(m, 'weight') and (
                classname.find('Conv') != -1 or classname.find('Linear') != -1
            ):
                if init_type == 'normal':
                    nn.init.normal_(m.weight.data, 0.0, gain)
                elif init_type == 'xavier':
                    nn.init.xavier_normal_(m.weight.data, gain=gain)
                elif init_type == 'xavier_uniform':
                    nn.init.xavier_uniform_(m.weight.data, gain=1.0)
                elif init_type == 'kaiming':
                    nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
                elif init_type == 'orthogonal':
                    nn.init.orthogonal_(m.weight.data, gain=gain)
                elif init_type == 'none':  # uses pytorch's default init method
                    m.reset_parameters()
                else:
                    raise NotImplementedError(
                        'initialization method [%s] is not implemented'
                        % init_type
                    )
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.constant_(m.bias.data, 0.0)

        self.apply(init_func)

        for m in self.children():
            if hasattr(m, 'init_weights'):
                m.init_weights(init_type, gain)

In [12]:
class FeedForward2D(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(FeedForward2D, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channel, out_channel, kernel_size=3, padding=2, dilation=2
            ),
            nn.BatchNorm2d(out_channel),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channel),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, x):
        x = self.conv(x)
        return x

In [13]:
class GlobalFilter(nn.Module):
    def __init__(self, dim=32, h=80, w=41, fp32fft=True):
        super().__init__()
        self.complex_weight = nn.Parameter(
            torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02
        )
        self.w = w
        self.h = h
        self.fp32fft = fp32fft

    def forward(self, x):
        b, _, a, b = x.size()
        x = x.permute(0, 2, 3, 1).contiguous()

        if self.fp32fft:
            dtype = x.dtype
            x = x.to(torch.float32)

        x = torch.fft.rfft2(x, dim=(1, 2), norm="ortho")
        weight = torch.view_as_complex(self.complex_weight)
        x = x * weight
        x = torch.fft.irfft2(x, s=(a, b), dim=(1, 2), norm="ortho")

        if self.fp32fft:
            x = x.to(dtype)

        x = x.permute(0, 3, 1, 2).contiguous()

        return x

In [14]:
class FreqBlock(nn.Module):
    def __init__(self, dim, h=80, w=41, fp32fft=True):
        super().__init__()
        self.filter = GlobalFilter(dim, h=h, w=w, fp32fft=fp32fft)
        self.feed_forward = FeedForward2D(in_channel=dim, out_channel=dim)

    def forward(self, x):
        x = x + self.feed_forward(self.filter(x))
        return x

In [15]:
def attention(query, key, value):
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(
        query.size(-1)
    )
    p_attn = F.softmax(scores, dim=-1)
    p_val = torch.matmul(p_attn, value)
    return p_val, p_attn

In [16]:
class MultiHeadedAttention(nn.Module):
    """
    Take in model size and number of heads.
    """

    def __init__(self, patchsize, d_model):
        super().__init__()
        self.patchsize = patchsize
        self.query_embedding = nn.Conv2d(
            d_model, d_model, kernel_size=1, padding=0
        )
        self.value_embedding = nn.Conv2d(
            d_model, d_model, kernel_size=1, padding=0
        )
        self.key_embedding = nn.Conv2d(
            d_model, d_model, kernel_size=1, padding=0
        )
        self.output_linear = nn.Sequential(
            nn.Conv2d(d_model, d_model, kernel_size=3, padding=1),
            nn.BatchNorm2d(d_model),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, x):
        b, c, h, w = x.size()
        d_k = c // len(self.patchsize)
        output = []
        _query = self.query_embedding(x)
        _key = self.key_embedding(x)
        _value = self.value_embedding(x)
        attentions = []
        for (width, height), query, key, value in zip(
            self.patchsize,
            torch.chunk(_query, len(self.patchsize), dim=1),
            torch.chunk(_key, len(self.patchsize), dim=1),
            torch.chunk(_value, len(self.patchsize), dim=1),
        ):
            out_w, out_h = w // width, h // height

            # 1) embedding and reshape
            query = query.view(b, d_k, out_h, height, out_w, width)
            query = (
                query.permute(0, 2, 4, 1, 3, 5)
                .contiguous()
                .view(b, out_h * out_w, d_k * height * width)
            )
            key = key.view(b, d_k, out_h, height, out_w, width)
            key = (
                key.permute(0, 2, 4, 1, 3, 5)
                .contiguous()
                .view(b, out_h * out_w, d_k * height * width)
            )
            value = value.view(b, d_k, out_h, height, out_w, width)
            value = (
                value.permute(0, 2, 4, 1, 3, 5)
                .contiguous()
                .view(b, out_h * out_w, d_k * height * width)
            )

            y, _ = attention(query, key, value)

            # 3) "Concat" using a view and apply a final linear.
            y = y.view(b, out_h, out_w, d_k, height, width)
            y = y.permute(0, 3, 1, 4, 2, 5).contiguous().view(b, d_k, h, w)
            attentions.append(y)
            output.append(y)

        output = torch.cat(output, 1)
        self_attention = self.output_linear(output)

        return self_attention

In [17]:
class TransformerBlock(nn.Module):
    """
    Transformer = MultiHead_Attention + Feed_Forward with sublayer connection
    """

    def __init__(self, patchsize, in_channel=256):
        super().__init__()
        self.attention = MultiHeadedAttention(patchsize, d_model=in_channel)
        self.feed_forward = FeedForward2D(
            in_channel=in_channel, out_channel=in_channel
        )

    def forward(self, rgb):
        self_attention = self.attention(rgb)
        output = rgb + self_attention
        output = output + self.feed_forward(output)
        return output

In [18]:
class CMA_Block(nn.Module):
    def __init__(self, in_channel, hidden_channel, out_channel):
        super(CMA_Block, self).__init__()

        self.conv1 = nn.Conv2d(
            in_channel, hidden_channel, kernel_size=1, stride=1, padding=0
        )
        self.conv2 = nn.Conv2d(
            in_channel, hidden_channel, kernel_size=1, stride=1, padding=0
        )
        self.conv3 = nn.Conv2d(
            in_channel, hidden_channel, kernel_size=1, stride=1, padding=0
        )

        self.scale = hidden_channel ** -0.5

        self.conv4 = nn.Sequential(
            nn.Conv2d(
                hidden_channel, out_channel, kernel_size=1, stride=1, padding=0
            ),
            nn.BatchNorm2d(out_channel),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, rgb, freq):
        _, _, h, w = rgb.size()

        q = self.conv1(rgb)
        k = self.conv2(freq)
        v = self.conv3(freq)

        q = q.view(q.size(0), q.size(1), q.size(2) * q.size(3)).transpose(
            -2, -1
        )
        k = k.view(k.size(0), k.size(1), k.size(2) * k.size(3))

        attn = torch.matmul(q, k) * self.scale
        m = attn.softmax(dim=-1)

        v = v.view(v.size(0), v.size(1), v.size(2) * v.size(3)).transpose(
            -2, -1
        )
        z = torch.matmul(m, v)
        z = z.view(z.size(0), h, w, -1)
        z = z.permute(0, 3, 1, 2).contiguous()

        output = rgb + self.conv4(z)

        return output

In [19]:
class PatchTrans(BaseNetwork):
    def __init__(self, in_channel, in_size):
        super(PatchTrans, self).__init__()
        self.in_size = in_size

        patchsize = [
            (in_size, in_size),
            (in_size // 2, in_size // 2),
            (in_size // 4, in_size // 4),
            (in_size // 8, in_size // 8),
        ]

        self.t = TransformerBlock(patchsize, in_channel=in_channel)

    def forward(self, enc_feat):
        output = self.t(enc_feat)
        return output

In [20]:
class Classifier2D(nn.Module):
    def __init__(
        self,
        dim_in,
        num_classes,
        dropout_rate=0.0,
        act_func="softmax",
    ):
        super(Classifier2D, self).__init__()
        if dropout_rate > 0.0:
            self.dropout = nn.Dropout(dropout_rate)
        self.projection = nn.Linear(dim_in, num_classes, bias=True)

        self.act = nn.Sigmoid()

    def forward(self, x):
        if hasattr(self, "dropout"):
            x = self.dropout(x)
        x = self.projection(x)
        x = self.act(x)
        return x

In [21]:
class M2TR(BaseNetwork):
    def __init__(self, img_size, depth, drop_ratio):
        super(M2TR, self).__init__()
        img_size = img_size
        depth = depth
        drop_ratio = drop_ratio
        num_classes = 1

        freq_h = img_size // 4
        freq_w = freq_h // 2 + 1

        self.model = models.efficientnet_b4(pretrained=True)

        texture_dim = 32
        feature_dim = 1792

        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(
                nn.ModuleList(
                    [
                        PatchTrans(in_channel=texture_dim, in_size=freq_h),
                        FreqBlock(dim=texture_dim, h=freq_h, w=freq_w),
                        CMA_Block(
                            in_channel=texture_dim,
                            hidden_channel=texture_dim,
                            out_channel=texture_dim,
                        ),
                    ]
                )
            )

        self.classifier = Classifier2D(
            feature_dim, num_classes, drop_ratio, "sigmoid"
        )

    def forward(self, x):
        rgb = x
        B = rgb.size(0)

        # Set rgb as the output of the second layer of self.model
        rgb = self.model.features[:3](rgb)

        for attn, filter, cma in self.layers:
            rgb = attn(rgb)
            freq = filter(rgb)
            rgb = cma(rgb, freq)

        # Get the last layer number dynamically
        last_layer = len(self.model.features)

        # Feed the current rgb value to the 3rd layer of self.model and get the last layer's output
        for layer_idx in range(3, last_layer):
            rgb = self.model.features[layer_idx](rgb)

        # Adaptive average pooling and reshaping
        features = F.adaptive_avg_pool2d(rgb, (1, 1))
        features = features.view(B, features.size(1))

        # Classification
        output = self.classifier(features)
        return output


In [22]:
class DeepFakeClassifier(nn.Module):
    def __init__(self, img_size, depth, drop_ratio):
        super(DeepFakeClassifier, self).__init__()
        self.classifier = M2TR(img_size, depth, drop_ratio)

    def forward(self, x):
        x = self.classifier(x)
        return x

In [23]:
class CustomCriterion:
    def __init__(self):
        self.bce_loss = nn.BCELoss()

    def compute_loss(self, outputs, labels):
        return self.bce_loss(outputs, labels)

In [24]:
class DeepFakeClassifierTrainer(pl.LightningModule):
    def __init__(self, learning_rate, weight_decay, img_size, depth, drop_ratio):
        super(DeepFakeClassifierTrainer, self).__init__()
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.model = DeepFakeClassifier(img_size, depth, drop_ratio)
        self.criterion = CustomCriterion()
        self.training_losses = []
        self.training_accuracies = []
        self.validation_losses = []
        self.validation_accuracies = []
        self.test_labels = []
        self.test_outputs = []

    def forward(self, x):
        return self.model(x)

    def display_metrics(self):
        plt.figure(figsize=(18, 12))

        plt.subplot(2, 2, 1)
        plt.plot(range(1, len(self.training_losses) + 1), self.training_losses, label="Training Loss")
        plt.title("Training Loss Over Epochs")
        plt.xlabel("Epochs")
        plt.ylabel("Loss")
        plt.legend()

        plt.subplot(2, 2, 2)
        plt.plot(range(1, len(self.validation_losses) + 1), self.validation_losses, label="Validation Loss")
        plt.title("Validation Loss Over Epochs")
        plt.xlabel("Epochs")
        plt.ylabel("Loss")
        plt.legend()

        plt.subplot(2, 2, 3)
        plt.plot(range(1, len(self.training_losses) + 1), self.training_accuracies, label="Training Accuracy")
        plt.title("Training Accuracy Over Epochs")
        plt.xlabel("Epochs")
        plt.ylabel("Accuracy")
        plt.legend()

        plt.subplot(2, 2, 4)
        plt.plot(range(1, len(self.validation_losses) + 1), self.validation_accuracies, label="Validation Accuracy")
        plt.title("Validation Accuracy Over Epochs")
        plt.xlabel("Epochs")
        plt.ylabel("Accuracy")
        plt.legend()

        plt.tight_layout()
        plt.show()

    def get_confusion_matrix(self):
        all_preds = np.concatenate(self.test_outputs)
        all_labels = np.concatenate(self.test_labels)
        cm = confusion_matrix(all_labels, all_preds > 0.5)
        disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Fake", "Real"])
        disp.plot(cmap=plt.cm.Blues)
        plt.title("Confusion Matrix")
        plt.show()
        return cm

    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        inputs = inputs.to(self.device)  # Ensure inputs are on the correct device
        labels = labels.float().unsqueeze(1).to(self.device)  # Ensure labels are on the correct device
        outputs = self.model(inputs)
        loss = self.criterion.compute_loss(outputs, labels)
        preds = (outputs > 0.5).float()
        acc = accuracy_score(labels.cpu(), preds.cpu())
        self.training_losses.append(loss.item())
        self.training_accuracies.append(acc)
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        inputs = inputs.to(self.device)  # Ensure inputs are on the correct device
        labels = labels.float().unsqueeze(1).to(self.device)  # Ensure labels are on the correct device
        outputs = self(inputs)
        loss = self.criterion.compute_loss(outputs, labels)
        preds = (outputs > 0.5).float()
        acc = accuracy_score(labels.cpu(), preds.cpu())
        self.validation_losses.append(loss.item())
        self.validation_accuracies.append(acc)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        if batch_idx == 0:
            self.test_labels = []
            self.test_outputs = []

        inputs, labels = batch
        inputs = inputs.to(self.device)  # Ensure inputs are on the correct device
        labels = labels.float().unsqueeze(1).to(self.device)  # Ensure labels are on the correct device
        outputs = self(inputs)
        loss = self.criterion.compute_loss(outputs, labels)
        preds = (outputs > 0.5).float()
        acc = accuracy_score(labels.cpu(), preds.cpu())

        self.test_labels.append(labels.cpu().numpy())
        self.test_outputs.append(outputs.cpu().numpy())

        self.log("test_loss", loss, prog_bar=True)
        self.log("test_acc", acc, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return optim.Adam(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)  # Add L2 regularization

    # def on_fit_start(self):
    #     """Ensure compatibility with DDP setup."""
    #     if self.trainer.global_rank == 0:
    #         self.display_model_summary()

    def display_model_summary(self):
        print("Model Summary:")
        print(self.model)
        self.display_trainable_parameters()

    def display_trainable_parameters(self):
        total_params = sum(p.numel() for p in self.model.parameters())
        trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        non_trainable_params = total_params - trainable_params
        print("\nParameter Summary Table:")
        print(f"{'Parameter Type':<25}{'Count':<15}")
        print(f"{'Trainable Parameters':<25}{trainable_params:<15}")
        print(f"{'Non-Trainable Parameters':<25}{non_trainable_params:<15}")
        print(f"{'Total Parameters':<25}{total_params:<15}\n")

    def save_weights(self, path):
        torch.save(self.model.state_dict(), path)

    def load_weights(self, path):
        self.model.load_state_dict(torch.load(path))

    def classify_image(self, image_path):
        self.model.eval()
        transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        image = Image.open(image_path).convert("RGB")
        image = transform(image).unsqueeze(0).to(self.device)  # Ensure image is on the correct device
        with torch.no_grad():
            output = self.model(image)
            prediction = "Real" if output.item() > 0.5 else "Fake"
        print(f"Prediction: {prediction}, Confidence: {output.item():.4f}")

In [25]:
# Instantiate the model
model = DeepFakeClassifierTrainer(learning_rate, weight_decay, img_size, depth, dropout)



In [26]:
# Define trainer
trainer = pl.Trainer(
    callbacks=[
        pl.callbacks.ModelCheckpoint(
            monitor="val_loss",
            save_top_k=1,
            mode="min",
            dirpath=root_path,
            filename="best-checkpoint"
        )
    ],
    max_epochs=epochs,
    accelerator="gpu",
    devices=2,
    strategy='ddp_notebook'
)

In [27]:
# Path to the last checkpoint
checkpoint_path = os.path.join(root_path, "best-checkpoint.ckpt")

# Check if the checkpoint exists
if os.path.exists(checkpoint_path):
    print(f"Resuming training from checkpoint: {checkpoint_path}")
    trainer.fit(model, train_loader, val_loader, ckpt_path=checkpoint_path)
else:
    print("Starting training from scratch.")
    trainer.fit(model, train_loader, val_loader)


Starting training from scratch.


/usr/local/lib/python3.10/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /kaggle/working exists and is not empty.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:431: It is recommended to use `self.log('val_loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:431: It is recommended to use `self.log('val_acc', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
  self.pid = os.fork()


NameError: name 'exit' is not defined

In [None]:
# trainer.fit(model, train_loader, val_loader)

In [None]:
model.display_metrics()

In [None]:
trainer.test(model, test_loader)

In [None]:
model.get_confusion_matrix()