In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install timm

In [None]:
!pip install patchify

In [None]:
pip install lightning

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import matplotlib
from matplotlib import pyplot as plt
from tqdm.auto import tqdm
import time
import os
from patchify import patchify, unpatchify
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, random_split
import timm
import lightning as L

In [None]:
transform = transforms.Compose([transforms.Resize(384),
                                transforms.ToTensor()
                                ])

In [None]:
batch_size = 8
fake_dir = '../Datasets/FF++/C23/Manipulated/DeepFake'
real_dir = '../Datasets/FF++/C23/Original/Original_images/'

class CustomImageDataset(Dataset):
    def __init__(self, path, transform=None):
        self.transform = transform
        self.files = [os.path.join(path, f) for f in os.listdir(path)]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file = self.files[idx]
        image = Image.open(file)
        label = 0.0 if "Original" in file else 1.0
        if self.transform:
            image = self.transform(image)
        return image, label

custom_fake_dataset = CustomImageDataset(fake_dir, transform=transform)

total_samples = len(custom_fake_dataset)
train_size = 2938 #int(0.7 * total_samples)
val_size = 200 #int(0.2 * total_samples)
test_size = 100 #total_samples - train_size - val_size
rem_size =  total_samples - train_size - val_size - test_size

fake_train_subset, fake_val_subset, fake_test_subset, _ = random_split(
    custom_fake_dataset, [train_size, val_size, test_size, rem_size]
)

custom_real_dataset = CustomImageDataset(real_dir, transform=transform)

total_samples = len(custom_real_dataset)
train_size = 2938 # int(0.7 * total_samples)
val_size = 200 # int(0.2 * total_samples)
test_size = 100 # total_samples - train_size - val_size
rem_size =  total_samples - train_size - val_size - test_size

real_train_subset, real_val_subset, real_test_subset, _ = random_split(
    custom_real_dataset, [train_size, val_size, test_size, rem_size]
)

train_dataset = torch.utils.data.ConcatDataset([fake_train_subset, real_train_subset])
val_dataset = torch.utils.data.ConcatDataset([fake_val_subset, real_val_subset])
test_dataset = torch.utils.data.ConcatDataset([fake_test_subset, real_test_subset])

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                             shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size,
                                             shuffle=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,
                                             shuffle=False)

In [None]:
print("Train Loader Length", len(train_dataset))
print("Validation Loader Length", len(val_dataset))
print("Test Loader Length", len(test_dataset))

In [None]:
# Display image and label.
train_features, train_labels = next(iter(train_loader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].permute(1, 2, 0)
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.axis('off')
plt.tight_layout()
plt.show()
print(f"Label: {label}")

In [None]:
xception_weight_path = '../Pretrained Weights/xception-43020ad28.pth'
class Xception(nn.Module):
  def __init__(self):
        super(Xception, self).__init__()
        # self.xception_model = tf.keras.applications.xception.Xception(
        #     include_top=False,
        #     weights='imagenet',
        #     input_shape = (384, 384, 3)
        # )
        # self.global_avg_pool = tf.keras.layers.GlobalAveragePooling2D()
        self.global_avg_pool = nn.AvgPool2d(12, 12)
        self.xception_model = timm.create_model('xception', pretrained=False)
        state_dict = torch.load(xception_weight_path)
        self.xception_model.load_state_dict(state_dict)

  def forward(self, input):
        # x = input.permute(0, 2, 3, 1)
        # x = tf.keras.applications.xception.preprocess_input(x)
        # output = self.xception_model(x.numpy())
        # output = self.global_avg_pool(output.numpy())
        # return output.numpy()
        output = self.xception_model.forward_features(input)
        output = self.global_avg_pool(output)
        return output

In [None]:
class PatchModule(nn.Module):
  def __init__(self):
    super(PatchModule, self).__init__()
    self.conv = nn.ModuleList([
            nn.ModuleList([
                nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1)
                for j in range(24)
            ])
            for i in range(24)
        ])

  def extract_patches(self, img):
    patch_size = 16
    stride = 16
    patches = img.unfold(2, patch_size, stride).unfold(3, patch_size, stride)
    patches = patches.permute(2, 3, 0, 1, 4, 5).squeeze(axis=0)
    return patches

  def forward(self, x):
    patches = self.extract_patches(x)
    conv_patches = torch.zeros(24, 24, x.size(0), 3, 16, 16, device=x.device)
    for i in range(patches.shape[0]):
      for j in range(patches.shape[1]):
        for batch_idx in range(patches.shape[2]):
          conv_patches[i][j][batch_idx] = self.conv[i][j](patches[i][j][batch_idx])
    patches = patches / (torch.max(patches) + 1e-6)
    conv_patches = conv_patches / (torch.max(conv_patches) + 1e-6)
    output_image = patches * conv_patches
    output_image = output_image / (torch.max(output_image) + 1e-6)
    return output_image

In [None]:
pm = PatchModule()
outp = pm(train_features)

In [None]:
# Display the Self Attention Patched Images
# Permute the tensor to bring batch and channel to the front
tensor = outp.permute(2, 3, 0, 4, 1, 5)  # [batch, channel, patch_dim1, height, patch_dim2, width]

# Reshape to combine patches into full images
batch_size, channels, patch_dim1, height, patch_dim2, width = tensor.shape
combined_images = tensor.reshape(batch_size, channels, patch_dim1 * height, patch_dim2 * width)

# Display the images
def show_images(images, custom_width=2, custom_height=2):
    batch_size, channels, height, width = images.shape
    rows, cols = 2, 4  # Two rows, four columns
    fig, axes = plt.subplots(rows, cols, figsize=(8, 6))  # Set custom figure size

    for i in range(rows):
        for j in range(cols):
            img = images[i * cols + j].permute(1, 2, 0).detach().numpy()  # [C, H, W] to [H, W, C]
            axes[i, j].imshow(img)
            axes[i, j].set_title(f"Image {i * cols + j + 1}")
            axes[i, j].axis('off')

    plt.tight_layout()  # Adjust spacing between subplots
    plt.show()

# Show the combined images
show_images(combined_images)


In [None]:
vit_weight_path = '../Pretrained Weights/vit_small_patch16_384.augreg_in1k'
class VIT_Encoder(nn.Module):
  def __init__(self):
    super(VIT_Encoder, self).__init__()
    self.vit_model = timm.create_model('vit_small_patch16_384.augreg_in1k', pretrained=False)
    state_dict = torch.load(vit_weight_path)
    self.vit_model.load_state_dict(state_dict)

  def forward(self, x):
    output = self.vit_model.forward_features(x)
    return output

In [None]:
class Classification_Module(nn.Module):
  def __init__(self):
    super(Classification_Module, self).__init__()
    self.model = nn.Sequential(
        nn.Linear(223616, 512),
        nn.ReLU(),
        nn.Linear(512, 256),
        nn.ReLU(),
        nn.Linear(256, 128),
        nn.ReLU(),
        nn.Linear(128, 1),
        nn.Softmax()
    )

  def forward(self, input):
    return self.model(input)

In [None]:
class ViXNet(nn.Module):
    def __init__(self):
      super(ViXNet, self).__init__()
      self.patch_module = PatchModule()
      self.vit_encoder = VIT_Encoder()
      self.xception = Xception()
      self.classification = Classification_Module()

    def combine_patches(self, patches):
      tensor = patches.permute(2, 3, 0, 4, 1, 5)
      batch_size, channels, patch_dim1, height, patch_dim2, width = tensor.shape
      combined_images = tensor.reshape(batch_size, channels, patch_dim1 * height, patch_dim2 * width)
      return combined_images

    def forward(self, x):
      patch_module_output = self.combine_patches(self.patch_module(x))
      vit_output = torch.flatten(self.vit_encoder(patch_module_output), 1, 2)
      xception_output = torch.flatten(self.xception(x), 1, 3)
      output = torch.cat((
          vit_output,
          xception_output
          ), dim=1)

      output = self.classification(output)
      return output

In [None]:
class LitViXNet(L.LightningModule):
    def __init__(self, vixnet):
        super().__init__()
        self.vixnet = vixnet

    def training_step(self, batch, batch_idx):
        images, labels = batch
        labels = torch.unsqueeze(labels, 1)
        outputs = self.vixnet(images)
        loss =  nn.CrossEntropyLoss()(outputs, labels)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=0.0001)
        return optimizer

In [None]:
vixnet = ViXNet()
print(vixnet)
total_params = sum(p.numel() for p in vixnet.parameters())
print(f"Number of parameters: {total_params}")

In [None]:
model = LitViXNet(vixnet)

In [None]:
trainer = L.Trainer(limit_train_batches=2, max_epochs=2, devices=1)
trainer.fit(model=model, train_dataloaders=train_loader)

In [None]:
plt.figure(figsize=(7, 7))
plt.plot(
    train_acc, color='green', linestyle='-',
    label='train accuracy'
)
plt.plot(
    valid_acc, color='blue', linestyle='-',
    label='validataion accuracy'
)
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.savefig('/content/drive/MyDrive/accuracytl.png')
plt.show()
# loss plots
plt.figure(figsize=(5, 5))
plt.plot(
    train_loss, color='orange', linestyle='-',
    label='train loss'
)
plt.plot(
    valid_loss, color='red', linestyle='-',
    label='validataion loss'
)
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.savefig('/content/drive/MyDrive/losstl.png')
plt.show()
# save the final model
save_path = 'model_res.pth'
torch.save(model.state_dict(), save_path)
print('MODEL SAVED...')