In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

In [None]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import torch
import torchvision.models as models
from torchvision.models import efficientnet_b4, EfficientNet_B4_Weights

# Load EfficientNet-B4 pretrained on ImageNet
weights = EfficientNet_B4_Weights.IMAGENET1K_V1
model = efficientnet_b4(weights=weights)

# Modify EfficientNet-B4 to use as a feature extractor
class EfficientNetFeatureExtractor(torch.nn.Module):
    def __init__(self, model):
        super(EfficientNetFeatureExtractor, self).__init__()
        # Keep all layers except the classification head
        self.features = model.features  # Extract only the feature layers

    def forward(self, x):
        # Extract feature maps
        return self.features(x)

In [None]:
class TFEM_M(nn.Module):
    def __init__(self, input_channels=1792, output_channels=10):
        super(TFEM_M, self).__init__()
        # First stage convolutions - input: [B, 1792, H, W]
        self.input_channels = input_channels
        self.output_channels = output_channels
        self.conv_1_1 = nn.Conv2d(input_channels, output_channels, kernel_size=1)
        self.conv_1_2 = nn.Conv2d(input_channels, output_channels, kernel_size=3, padding=1)
        self.conv_1_3 = nn.Conv2d(input_channels, output_channels, kernel_size=3, padding=1)
        self.conv_1_4 = nn.Conv2d(input_channels, output_channels, kernel_size=3, padding=1)

        # After concatenation, channels will be 4 * output_channels = 40
        self.conv_2_1 = nn.Conv2d(4 * output_channels, output_channels, kernel_size=1)

        # Adaptive branch
        self.conv_adap = nn.Conv2d(input_channels, input_channels, kernel_size=3, padding=1)
        self.channel_matcher = nn.Conv2d(input_channels, output_channels, kernel_size=1)

        self.BN_2_1 = nn.BatchNorm2d(output_channels)

        # Enhancement blocks
        self.enhance = nn.Sequential(
            nn.Conv2d(output_channels, output_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(output_channels),
            nn.Conv2d(output_channels, output_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(output_channels)
        )

    def chanel_1(self, x):
        conv_1_1 = self.conv_1_1(x)
        conv_1_2 = self.conv_1_2(x)
        conv_1_3 = self.conv_1_3(x)
        conv_1_4 = self.conv_1_4(x)
        F_1 = torch.cat([conv_1_1, conv_1_2, conv_1_3, conv_1_4], dim=1)  # [B, 40, H, W]
        return F_1

    def adaptive_pooling(self, x):
        logit = self.conv_adap(x)
        f_d = F.adaptive_avg_pool2d(logit, output_size=(x.size(2) // 2, x.size(3) // 2))
        F_d = F.interpolate(f_d, size=x.size()[2:], mode='bilinear', align_corners=True)
        F_R = x - F_d
        F_R = self.channel_matcher(F_R)  # Match channels to output_channels
        return F_R

    def forward(self, x):
        F_1 = self.chanel_1(x)
        F_1 = self.conv_2_1(F_1)
        F_1 = self.BN_2_1(F_1)

        F_D = self.adaptive_pooling(x)
        F_E = self.enhance(F_1)
        F_D = self.enhance(F_D)

        f_1 = F_E + F_D
        return f_1


In [None]:
class FasterAttentionModule(nn.Module):
    def __init__(self, input_channels, output_channels):
        super(FasterAttentionModule, self).__init__()
        self.key_value = nn.Linear(input_channels, 2 * output_channels)
        self.fc = nn.Linear(output_channels, input_channels)

    def forward(self, x):
        original_shape = x.shape
        x = x.view(original_shape[0], -1, original_shape[1])

        kv = self.key_value(x)
        k, v = torch.chunk(kv, 2, dim=-1)
        attention = torch.softmax(k, dim=1)
        weighted = attention * v
        out = self.fc(weighted)
        out=out.permute(0,2,1)
        out = out.view(original_shape)
        return out

In [None]:
fam = FasterAttentionModule(input_channels=1792,output_channels=1792).to(device)
outputa=fam(random_tensor_m)
outputa.shape

torch.Size([4, 1792, 224, 224])

In [None]:
class TFEM_A(nn.Module):
    def __init__(self, input_channels, output_channels, enhancement_coefficient=1.0):
        super(TFEM_A, self).__init__()
        self.enhancement_coefficient = enhancement_coefficient
        self.output_channels = output_channels

        # Convolution layers
        self.conv1 = nn.Conv2d(input_channels, input_channels, kernel_size=3, stride=1, padding=1, bias=False)
        # self.attention_conv = nn.Conv2d(input_channels, input_channels, kernel_size=1, stride=1, padding=0, bias=False)
        self.attention_conv = FasterAttentionModule(input_channels, input_channels)
        # Normalization layers
        self.mean = nn.AdaptiveAvgPool2d(1)
        self.std = nn.AdaptiveAvgPool2d(1)

        # Channel matcher for concatenation
        self.channel_matcher = nn.Conv2d(2 * input_channels, input_channels, kernel_size=1)

        # Final channel downsampling layer
        self.channel_downsampler = nn.Conv2d(input_channels, output_channels, kernel_size=1)

    def forward(self, F):
        # Transformation feature map
        T_F = self.conv1(F)

        # Compute normalization parameters
        std_F = torch.std(F, dim=[2, 3], keepdim=True)
        mean_T_F = torch.mean(T_F, dim=[2, 3], keepdim=True)
        std_T_F = torch.std(T_F, dim=[2, 3], keepdim=True)

        # Enhanced feature map
        Fa = F + self.enhancement_coefficient * std_F * (T_F - mean_T_F) / std_T_F

        # Attention mechanism
        A = torch.sigmoid(self.attention_conv(F))
        Fr = F - F.mean(dim=[2, 3], keepdim=True)
        Fr = Fr * A

        # Reshape F2 to match spatial dimensions of Fr
        F2 = F.mean(dim=[2, 3], keepdim=True).expand_as(F)
        Fa = torch.cat([Fr, F2], dim=1)  # [B, 2C, H, W]
        Fa = self.channel_matcher(Fa)  # [B, C, H, W]

        # Final output
        F3 = self.mean(Fa)
        f2 = Fa + F3.expand_as(Fa)

        # Downsample number of channels to output_channels
        f2 = self.channel_downsampler(f2)  # [B, output_channels, H, W]

        return f2


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [None]:
random_tensor_m = torch.randn(1, 1792, 224, 224).to(device)

In [None]:
random_tensor_m.shape[0]

2

In [None]:
tfema = TFEM_A(input_channels=1792,output_channels=512, enhancement_coefficient=1.0).to(device)
outputa=tfema(random_tensor_m)
outputa.shape

NameError: name 'device' is not defined

In [None]:
tfemm=TFEM_M(input_channels=1792,output_channels=512).to(device)
output=tfemm(random_tensor_m)
output.shape

torch.Size([4, 512, 224, 224])

In [None]:
class TFEM(nn.Module):
    def __init__(self,output_channels,backbone_model,enhancement_coefficient=0.1):
        super(TFEM,self).__init__()
        self.backbone=EfficientNetFeatureExtractor(backbone_model)
        self.input_channels = self.backbone.features[-1].out_channels
        self.enhancement_coefficient=enhancement_coefficient
        self.TFEM_A=TFEM_A(self.input_channels,output_channels,enhancement_coefficient)
        self.TFEM_M=TFEM_M(self.input_channels,output_channels)


    def forward(self,x):

        l=self.backbone(x)
        print("back_bone",l.shape)
        f_1=self.TFEM_A(l)
        print("tfem_a",f_1.shape)
        f_2=self.TFEM_M(l)
        print("tfem_b",f_1.shape)


        return f_1 + f_2

In [None]:
batch_size = 4
input_channels = 3  # for RGB image
height = 224
width = 224
random_tensor = torch.randn(batch_size, input_channels, height, width).to(device)

In [None]:
random_tensor.shape

torch.Size([4, 3, 224, 224])

In [None]:
tfem=TFEM(output_channels=1024,backbone_model=model,enhancement_coefficient=0.1).to(device)

In [None]:
output=tfem(random_tensor)

back_bone torch.Size([4, 1792, 7, 7])
tfem_a torch.Size([4, 1024, 7, 7])
tfem_b torch.Size([4, 1024, 7, 7])


In [None]:
output.shape

torch.Size([4, 1024, 7, 7])

## Multiscale Feature Extraction

In [None]:
class multi_scale_feature_enhancement_module(nn.Module):
    def __init__(self,patch_size,embed_dim):
        super(multi_scale_feature_enhancement_module,self).__init__()
        self.patch_size=patch_size
        self.embed_dim=embed_dim
        self.f1 = None
        self.f2 = None
        self.f3 = None


    def _initialize_layers(self, input_channels, device):
        if self.f1 is None:
            patch_dim = input_channels * self.patch_size * self.patch_size
            self.f1 = nn.Linear(patch_dim, self.embed_dim).to(device)
            self.f2 = nn.Linear(patch_dim, self.embed_dim).to(device)
            self.f3 = nn.Linear(patch_dim, self.embed_dim).to(device)

    def _split_to_patches(self,x,patch_size):
        B,C,H,W=x.size()

        assert H % patch_size == 0 and W % patch_size == 0
        patches=x.unfold(2,patch_size,patch_size).unfold(3,patch_size,patch_size)
        patches=patches.contiguous().view(B,C,-1,patch_size,patch_size)
        patches = patches.permute(0, 2, 1, 3, 4)

        return patches

    def _combine_patches(self, patches, height, width, patch_size, channels):
        # Reshape the patches before unpacking
        B, N, E = patches.size() # E is the embed_dim
        patches = patches.view(B, N, channels, patch_size, patch_size) # Reshape to (B, N, C, patch_h, patch_w)

        H = height // patch_size
        W = width // patch_size
        patches = patches.view(B, H, W, channels, patch_size, patch_size)
        patches = patches.permute(0, 3, 1, 4, 2, 5).contiguous() # (B, C, H, patch_h, W, patch_w)
        feature_map = patches.view(B, channels, height, width)

        return feature_map

    def forward(self,x):
        batch_size, channels, height, width = x.size()
        device = x.device
        self._initialize_layers(channels, device)
        patches = self._split_to_patches(x, self.patch_size)
        num_patches = patches.size(1)

        patches=patches.view(batch_size,num_patches,-1)

        patches = patches.to(device)

        f1_out = self.f1(patches)
        f2_out = self.f2(patches)

        weight_matrix = torch.matmul(f1_out, f2_out.transpose(1, 2))
        weight_matrix = F.softmax(weight_matrix / (self.embed_dim ** 0.5), dim=-1)

        f3_out = self.f3(patches)

        enhanced_patches = torch.matmul(weight_matrix, f3_out)

        # Reshape enhanced_patches to match the expected shape in _combine_patches
        # enhanced_patches = enhanced_patches.view(batch_size, num_patches, channels, self.patch_size, self.patch_size)

        enhanced_feature_map = self._combine_patches(enhanced_patches, height, width, self.patch_size, channels)

        output = x + enhanced_feature_map

        return output

In [None]:
random_tensor = torch.randn(4, 1024, 7, 7).to(device)

In [None]:
mfem=multi_scale_feature_enhancement_module(patch_size=1,embed_dim=1024).to(device)

In [None]:
output=mfem(random_tensor)

In [None]:
output.shape

torch.Size([4, 1024, 7, 7])

# FDM Block
## Extraction_layers

In [None]:
import torch
import torch.nn as nn

class ExtractionLayer(nn.Module):
    def __init__(self, input_channels, output_channels):
        super(ExtractionLayer, self).__init__()

        # Start Block
        self.conv_S_1 = nn.Conv2d(input_channels, output_channels, kernel_size=3, padding=1)
        self.BN_S_1 = nn.BatchNorm2d(output_channels)
        self.ReLU_1 = nn.ReLU()
        self.conv_S_2 = nn.Conv2d(output_channels, output_channels, kernel_size=3, padding=1)
        self.BN_S_2 = nn.BatchNorm2d(output_channels)

        # End Block
        self.conv_E_1 = nn.Conv2d(output_channels, output_channels, kernel_size=3, padding=1)
        self.ReLU_2 = nn.ReLU()
        self.conv_E_2 = nn.Conv2d(output_channels, output_channels, kernel_size=1, padding=0)

        # Downsample
        self.downsample = nn.Sequential(
            nn.Conv2d(input_channels, output_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(output_channels)
        )

    def start_block(self, x):
        l = self.conv_S_1(x)
        l = self.BN_S_1(l)
        l = self.ReLU_1(l)
        l = self.conv_S_2(l)
        l = self.BN_S_2(l)

        # Ensure residual connection has the same shape
        if l.shape != x.shape:
            raise ValueError(f"Shape mismatch in start_block: l.shape={l.shape}, x.shape={x.shape}")

        l = l + x
        return l

    def end_block(self, x):
        l = self.conv_E_1(x)
        l = self.ReLU_2(l)
        l = self.conv_E_2(l)

        # Downsample the input to match the transformed feature map
        x_residual = self.downsample(x)

        # Ensure residual connection has the same shape
        if l.shape != x_residual.shape:
            raise ValueError(f"Shape mismatch in end_block: l.shape={l.shape}, x_residual.shape={x_residual.shape}")

        l = l + x_residual
        return l

    def forward(self, x):
        out = self.start_block(x)
        out = self.end_block(out)
        return out





In [None]:
el=ExtractionLayer(input_channels=1024,output_channels=1024).to(device)

In [None]:
output=el(random_tensor)
output.shape

torch.Size([4, 1024, 7, 7])

## Additional Layers

In [None]:
class additional_layers(nn.Module):
    def __init__(self,input_channels,output_channels):
        super(additional_layers,self).__init__()
        self.additional_layer_1=nn.Sequential(
            nn.Conv2d(input_channels, output_channels, kernel_size=7, padding=1),
            nn.ReLU()

        )
        self.additional_layer_2=nn.Sequential(
            nn.Conv2d(output_channels, output_channels, kernel_size=5, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(output_channels)
        )
        self.additional_layer_3=nn.Sequential(
            nn.Conv2d(output_channels, output_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(output_channels)
        )
        self.additional_layer_4=nn.Sequential(
            nn.Conv2d(output_channels, output_channels, kernel_size=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(output_channels)
        )

    def forward(self,x):

        out=self.additional_layer_1(x)
        out=self.additional_layer_2(out)
        out=self.additional_layer_3(out)
        out=self.additional_layer_4(out)

        return out

In [None]:
al=additional_layers(input_channels=1024,output_channels=1024).to(device)

In [None]:
al(output).shape

torch.Size([4, 1024, 3, 3])

In [None]:
class FDM(nn.Module):
  def __init__(self,output_channels):
    super(FDM,self).__init__()
    self.output_channels=output_channels
    self.extraction_layer = None
    self.additional_layers = None

  def _initialize_layers(self, input_channels, device):
    self.extraction_layer=ExtractionLayer(input_channels,self.output_channels).to(device)
    self.additional_layers=additional_layers(input_channels,self.output_channels).to(device)

  def forward(self,x):
    batch_size, channels, height, width = x.size()
    device = x.device
    self._initialize_layers(channels, device)
    l=self.extraction_layer(x)
    l=self.additional_layers(l)

    return l

In [None]:
random_tensor = torch.randn(4, 1024, 7, 7).to(device)

In [None]:
fdm=FDM(output_channels=1024).to(device)

In [None]:
out=fdm(random_tensor)

In [None]:
out.shape

torch.Size([4, 1024, 3, 3])

In [None]:
import torch
import torch.nn as nn

class SimpleBinaryClassifier(nn.Module):
    def __init__(self, input_channels=1024):
        super(SimpleBinaryClassifier, self).__init__()

        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))  # Reduce spatial dimensions to 1x1
        self.classifier = nn.Sequential(
            nn.Flatten(),  # Flatten to [batch_size, 1024]
            nn.Linear(input_channels, 256),  # First fully connected layer
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 64),  # Second fully connected layer
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 2),  # Output layer for binary classification
             # Output probabilities
        )

    def forward(self, x):
        x = self.global_pool(x)  # Apply global average pooling
        return self.classifier(x)

# Example usage:
# model = SimpleBinaryClassifier(input_channels=1024)

# # Simulating deepfake feature output
# dummy_input = torch.randn(2, 1024, 3, 3)  # Batch size 2
# output = model(dummy_input)
# print(output.shape)  # Should output: torch.Size([2, 1])


In [None]:
from torchvision.models import efficientnet_b0
class deep_fake_detection(nn.Module):
  def __init__(self,output_channels,backbone_model,enhancement_coefficient=0.1,patch_size=1,embed_dim=1024):
    super(deep_fake_detection,self).__init__()
    self.base_model = efficientnet_b0(pretrained=True)
    self.tfem=TFEM(output_channels,backbone_model,enhancement_coefficient)
    self.multi_scale_feature_enhancement_module=multi_scale_feature_enhancement_module(patch_size,embed_dim)
    self.fdm=FDM(output_channels)
    self.classifier=SimpleBinaryClassifier(input_channels=output_channels)



  def forward(self,x):

    l=self.tfem(x)
    print("tfem")
    l=self.multi_scale_feature_enhancement_module(l)
    print("msfe")
    l=self.fdm(l)
    print("fdm")
    l=self.classifier(l)
    print("classifier")

    return l


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 2
input_channels = 3  # for RGB image
height = 224
width = 224
random_tensor = torch.randn(batch_size, input_channels, height, width).to(device)

In [None]:
dfd=deep_fake_detection(output_channels=1024,backbone_model=model,enhancement_coefficient=0.1,patch_size=1,embed_dim=1024).to(device)



In [None]:
out=dfd(random_tensor)

back_bone torch.Size([2, 1792, 7, 7])
tfem_a torch.Size([2, 1024, 7, 7])
tfem_b torch.Size([2, 1024, 7, 7])
tfem
msfe
fdm
classifier


In [None]:
out.shape

torch.Size([2, 2])

In [None]:
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score

In [None]:

!unzip /content/drive/MyDrive/FaceForensic.zip
def txt_file(data_dir, output_file):
    with open(output_file, 'w') as f:
        for filename in os.listdir(data_dir):
            if filename.endswith(".png"):
                label = 1 if "fake" in filename else 0
                f.write(f"{os.path.join(data_dir, filename)} {label}\n")

data_dir_train = "/content/FaceForensic/train"
data_dir_test = "/content/FaceForensic/test"
output_file_train = "dataset.txt"
output_file_test = "dataset_test.txt"

txt_file(data_dir_train, output_file_train)
txt_file(data_dir_test, output_file_test)

class MyDataset(Dataset):
    def __init__(self, txt_path, transform=None):
        self.imgs = []
        with open(txt_path, 'r') as fh:
            for line in fh:
                line = line.rstrip()
                words = line.split()
                self.imgs.append((words[0], int(words[1])))
        self.transform = transform

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = Image.open(fn).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)

        return img, label

    def __len__(self):
        return len(self.imgs)

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = MyDataset(txt_path=output_file_train, transform=transform)
test_dataset = MyDataset(txt_path=output_file_test, transform=transform)

Archive:  /content/drive/MyDrive/FaceForensic.zip
replace FaceForensic/test/fake_004_982_frame154.png? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

In [None]:
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)

In [None]:
len(train_dataset)

5876

In [None]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = deep_fake_detection(output_channels=1024,backbone_model=model,enhancement_coefficient=0.1,patch_size=1,embed_dim=1024).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 1
for epoch in range(num_epochs):
    model.train()
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device).long()

        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

accuracy = accuracy_score(all_labels, all_preds)
print(f"Test Accuracy: {accuracy * 100:.2f}%")



[1;30;43mStreaming output truncated to the last 5000 lines.[0m
msfe
fdm
classifier
back_bone torch.Size([2, 1792, 7, 7])
tfem_a torch.Size([2, 1024, 7, 7])
tfem_b torch.Size([2, 1024, 7, 7])
tfem
msfe
fdm
classifier
back_bone torch.Size([2, 1792, 7, 7])
tfem_a torch.Size([2, 1024, 7, 7])
tfem_b torch.Size([2, 1024, 7, 7])
tfem
msfe
fdm
classifier
back_bone torch.Size([2, 1792, 7, 7])
tfem_a torch.Size([2, 1024, 7, 7])
tfem_b torch.Size([2, 1024, 7, 7])
tfem
msfe
fdm
classifier
back_bone torch.Size([2, 1792, 7, 7])
tfem_a torch.Size([2, 1024, 7, 7])
tfem_b torch.Size([2, 1024, 7, 7])
tfem
msfe
fdm
classifier
back_bone torch.Size([2, 1792, 7, 7])
tfem_a torch.Size([2, 1024, 7, 7])
tfem_b torch.Size([2, 1024, 7, 7])
tfem
msfe
fdm
classifier
back_bone torch.Size([2, 1792, 7, 7])
tfem_a torch.Size([2, 1024, 7, 7])
tfem_b torch.Size([2, 1024, 7, 7])
tfem
msfe
fdm
classifier
back_bone torch.Size([2, 1792, 7, 7])
tfem_a torch.Size([2, 1024, 7, 7])
tfem_b torch.Size([2, 1024, 7, 7])
tfem
msfe

ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 1024, 1, 1])