In [None]:
import os
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
from torchvision import datasets, transforms, models
import torch.optim as optim

from matplotlib import pyplot as plt
import pandas as pd
from PIL import Image
import numpy as np
import cv2
import shutil

from sklearn.model_selection import train_test_split

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

In [None]:
real_images_path = '/kaggle/input/fake-or-real-dataset/train/real_images'
fake_images_path = '/kaggle/input/fake-or-real-dataset/train/fake_images'
real_images = os.listdir(real_images_path)
fake_images = os.listdir(fake_images_path)
real_train, real_temp = train_test_split(real_images, train_size=50, random_state=42)
fake_train, fake_temp = train_test_split(fake_images, train_size=50, random_state=42)
real_val, real_test = train_test_split(real_temp, test_size=10, random_state=42)
real_val = real_val[:100]  # Limit to 100 samples

fake_val, fake_test = train_test_split(fake_temp, test_size=10, random_state=42)
fake_val = fake_val[:100]  # Limit to 100 samples

In [None]:
class PatchSelectionModule(nn.Module):
    def __init__(self, patch_size=8, num_patches=0.75):
        super(PatchSelectionModule, self).__init__()
        self.patch_size = patch_size
        self.num_patches = num_patches

    def calculate_patch_scores(self, activation_map):
        # Get dimensions
        batch_size, _, height, width = activation_map.size()
        scores = F.avg_pool2d(activation_map, self.patch_size, stride=2)
        return scores

    def non_max_suppression(self, coords, scores, threshold=0.5):
        sorted_indices = torch.argsort(scores, descending=True)
        sorted_scores = scores[sorted_indices]
        coords = torch.tensor(coords)
        sorted_coords = coords[sorted_indices]

        keep_indices = []
        while len(sorted_coords) > 0:
            current = sorted_coords[0]
            keep_indices.append(sorted_indices[0])
            sorted_coords = sorted_coords[1:]
            sorted_indices = sorted_indices[1:]
            if len(sorted_coords) == 0:
                break
            ious = self.iou(current, sorted_coords)
            keep = ious < threshold
            sorted_coords = sorted_coords[keep]
            sorted_indices = sorted_indices[keep]
        return coords[keep_indices], scores[keep_indices]

    def iou(self, boxA, boxesB):
        xA1, yA1 = boxA[0]
        xA2, yA2 = boxA[1]
        xB1, yB1 = boxesB[:, 0, 0], boxesB[:, 0, 1]
        xB2, yB2 = boxesB[:, 1, 0], boxesB[:, 1, 1]
        inter_x1 = torch.max(xA1, xB1)
        inter_y1 = torch.max(yA1, yB1)
        inter_x2 = torch.min(xA2, xB2)
        inter_y2 = torch.min(yA2, yB2)
        interArea = torch.max(torch.tensor(0), inter_x2 - inter_x1) * torch.max(torch.tensor(0), inter_y2 - inter_y1)
        boxAArea = (xA2 - xA1) * (yA2 - yA1)
        boxBArea = (xB2 - xB1) * (yB2 - yB1)
        iou = interArea / (boxAArea + boxBArea - interArea)
        return iou

    def calc_topk_coords(self, scores, threshold=0.5):
        flattened_scores = scores.view(scores.size(0), -1)
        num_patches = min(self.num_patches * flattened_scores.shape[1], flattened_scores.shape[1])
        topk_scores, topk_indices = torch.topk(flattened_scores, int(num_patches), dim=1)
        topk_coords = []
        for idx in range(scores.size(0)):
            coords = []
            for index in topk_indices[idx]:
                x, y = divmod(index.item(), scores.size(2))
                coords.append((x, y))
            topk_coords.append(coords)
        return topk_scores, topk_coords

    def get_patch_coordinates(self, activation_map, original_image_size):
        scores = self.calculate_patch_scores(activation_map)
        topk_scores, topk_coords = self.calc_topk_coords(scores)
        scale_h = original_image_size[0] / activation_map.size(2)
        scale_w = original_image_size[1] / activation_map.size(3)
        patch_coords = []
        for coords in topk_coords:
            image_coords = []
            for (x, y) in coords:
                top_left = (int(y * scale_w), int(x * scale_h))
                bottom_right = (int((y + self.patch_size) * scale_w), int((x + self.patch_size) * scale_h))
                image_coords.append((top_left, bottom_right))
            patch_coords.append(image_coords)
        return patch_coords

    def get_patches(self, patch_coordinates, original_image):
        patches = []
        for coords in patch_coordinates:
            image_patches = []
            for (top_left, bottom_right) in coords:
                patch = original_image[:, top_left[1]:bottom_right[1], top_left[0]:bottom_right[0]]
                image_patches.append(patch)
            patches.append(torch.stack(image_patches))
        patches = torch.stack(patches)
        return patches

    def forward(self, activation_map, original_image_size):
        activation_map = activation_map.sum(dim=1).unsqueeze(1)
        patch_coordinates = self.get_patch_coordinates(activation_map, original_image_size)
        return patch_coordinates

class GlobalBranch(nn.Module):
    def __init__(self):
        super(GlobalBranch, self).__init__()
        self.resnet = models.resnet50(pretrained=True)
        self.resnet = nn.Sequential(*list(self.resnet.children())[:-2])
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(2048, 128)

    def forward(self, x):
        F = self.resnet(x)
        pooled_F = self.avg_pool(F)
        pooled_F = pooled_F.view(pooled_F.size(0), -1)
        global_embeddings = self.fc(pooled_F)
        return F, global_embeddings

class LocalEmbeddingExtractor(nn.Module):
    def __init__(self):
        super(LocalEmbeddingExtractor, self).__init__()
        self.resnet = models.resnet50(pretrained=True)
        self.resnet = nn.Sequential(*list(self.resnet.children())[:-2])
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(2048, 128)

    def forward(self, x, coordinates):
        batch_size, _, _, _ = x.size()
        local_embeddings = []
        for i in range(batch_size):
            patches = []
            for (top_left, bottom_right) in coordinates[i]:
                patch = x[i:i+1, :, top_left[1]:bottom_right[1], top_left[0]:bottom_right[0]]
                patches.append(patch)
            if len(patches) == 0:
                continue
            patches = torch.cat(patches, dim=0)
            features = self.resnet(patches)
            pooled_features = self.avg_pool(features)
            pooled_features = pooled_features.view(pooled_features.size(0), -1)
            embeddings = self.fc(pooled_features)
            local_embeddings.append(embeddings)
        local_embeddings = torch.stack(local_embeddings)
        return local_embeddings

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.Wq = nn.Linear(d_model, d_model)
        self.Wk = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)
        self.fc = nn.Linear(d_model, d_model)

    def forward(self, Q, K, V):
        batch_size = Q.size(0)
        Q = self.Wq(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.Wk(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.Wv(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))
        attention = F.softmax(scores, dim=-1)
        context = torch.matmul(attention, V).transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        output = self.fc(context)
        return output

class AFFM(nn.Module):
    def __init__(self, global_dim, local_dim, num_heads):
        super(AFFM, self).__init__()
        self.multihead_attention = MultiHeadAttention(global_dim + local_dim, num_heads)
        self.classifier = nn.Linear(global_dim + local_dim, 1)  # Assuming binary classification

    def forward(self, global_embeddings, local_embeddings):
        embeddings = torch.cat((global_embeddings, local_embeddings), dim=1)
        fused_features = self.multihead_attention(embeddings, embeddings, embeddings)
        output = self.classifier(fused_features)
        return output

class FullArchitecture(nn.Module):
    def __init__(self):
        super(FullArchitecture, self).__init__()
        self.global_branch = GlobalBranch()
        self.patch_selection_module = PatchSelectionModule()
        self.local_embedding_extractor = LocalEmbeddingExtractor()
        self.affm = AFFM(global_dim=128, local_dim=128, num_heads=8)

    def forward(self, x):
        original_image_size = x.shape[2:]
        F, global_embeddings = self.global_branch(x)
        patch_coordinates = self.patch_selection_module(F, original_image_size)
        local_embeddings = self.local_embedding_extractor(x, patch_coordinates)
        output = self.affm(global_embeddings, local_embeddings)
        return output


In [None]:
os.makedirs('/kaggle/working/dataset/train/real', exist_ok=True)
os.makedirs('/kaggle/working/dataset/train/fake', exist_ok=True)
os.makedirs('/kaggle/working/dataset/val/real', exist_ok=True)
os.makedirs('/kaggle/working/dataset/val/fake', exist_ok=True)
os.makedirs('/kaggle/working/dataset/test/real', exist_ok=True)
os.makedirs('/kaggle/working/dataset/test/fake', exist_ok=True)

In [None]:
for file in real_train:
    shutil.copy(os.path.join(real_images_path, file), '/kaggle/working/dataset/train/real')

for file in fake_train:
    shutil.copy(os.path.join(fake_images_path, file), '/kaggle/working/dataset/train/fake')

for file in real_val:
    shutil.copy(os.path.join(real_images_path, file), '/kaggle/working/dataset/val/real')

for file in fake_val:
    shutil.copy(os.path.join(fake_images_path, file), '/kaggle/working/dataset/val/fake')

for file in real_test:
    shutil.copy(os.path.join(real_images_path, file), '/kaggle/working/dataset/test/real')

for file in fake_test:
    shutil.copy(os.path.join(fake_images_path, file), '/kaggle/working/dataset/test/fake')

In [None]:
epochs = 100
lr = 0.003
batch_size = 32

In [None]:
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dir = '/kaggle/working/dataset/train'
val_dir = '/kaggle/working/dataset/val'
test_dir = '/kaggle/working/dataset/test'



train_dataset = datasets.ImageFolder(root=train_dir, transform=transform)
val_dataset = datasets.ImageFolder(root=val_dir, transform=transform)
test_dataset = datasets.ImageFolder(root=test_dir, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

print(f'Training samples: {len(train_dataset)}')
print(f'Testing samples: {len(test_dataset)}')

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class PatchSelectionModule(nn.Module):
    def __init__(self, patch_size=8, num_patches=0.75):
        super(PatchSelectionModule, self).__init__()
        self.patch_size = patch_size
        self.num_patches = num_patches

    def calculate_patch_scores(self, activation_map):
        # Get dimensions
        batch_size, _, height, width = activation_map.size()
        scores = F.avg_pool2d(activation_map, self.patch_size, stride=2)
        return scores

    def non_max_suppression(self, coords, scores, threshold=0.5):
        sorted_indices = torch.argsort(scores, descending=True)
        sorted_scores = scores[sorted_indices]
        coords = torch.tensor(coords)
        sorted_coords = coords[sorted_indices]

        keep_indices = []
        while len(sorted_coords) > 0:
            current = sorted_coords[0]
            keep_indices.append(sorted_indices[0])
            sorted_coords = sorted_coords[1:]
            sorted_indices = sorted_indices[1:]
            if len(sorted_coords) == 0:
                break
            ious = self.iou(current, sorted_coords)
            keep = ious < threshold
            sorted_coords = sorted_coords[keep]
            sorted_indices = sorted_indices[keep]
        return coords[keep_indices], scores[keep_indices]

    def iou(self, boxA, boxesB):
        xA1, yA1 = boxA[0]
        xA2, yA2 = boxA[1]
        xB1, yB1 = boxesB[:, 0, 0], boxesB[:, 0, 1]
        xB2, yB2 = boxesB[:, 1, 0], boxesB[:, 1, 1]
        inter_x1 = torch.max(xA1, xB1)
        inter_y1 = torch.max(yA1, yB1)
        inter_x2 = torch.min(xA2, xB2)
        inter_y2 = torch.min(yA2, yB2)
        interArea = torch.max(torch.tensor(0), inter_x2 - inter_x1) * torch.max(torch.tensor(0), inter_y2 - inter_y1)
        boxAArea = (xA2 - xA1) * (yA2 - yA1)
        boxBArea = (xB2 - xB1) * (yB2 - yB1)
        iou = interArea / (boxAArea + boxBArea - interArea)
        return iou

    def calc_topk_coords(self, scores, threshold=0.5):
        flattened_scores = scores.view(scores.size(0), -1)
        num_patches = min(self.num_patches * flattened_scores.shape[1], flattened_scores.shape[1])
        topk_scores, topk_indices = torch.topk(flattened_scores, int(num_patches), dim=1)
        topk_coords = []
        for idx in range(scores.size(0)):
            coords = []
            for index in topk_indices[idx]:
                x, y = divmod(index.item(), scores.size(2))
                coords.append((x, y))
            topk_coords.append(coords)
        return topk_scores, topk_coords

    def get_patch_coordinates(self, activation_map, original_image_size):
        scores = self.calculate_patch_scores(activation_map)
        topk_scores, topk_coords = self.calc_topk_coords(scores)
        scale_h = original_image_size[0] / activation_map.size(2)
        scale_w = original_image_size[1] / activation_map.size(3)
        patch_coords = []
        for coords in topk_coords:
            image_coords = []
            for (x, y) in coords:
                top_left = (int(y * scale_w), int(x * scale_h))
                bottom_right = (int((y + self.patch_size) * scale_w), int((x + self.patch_size) * scale_h))
                image_coords.append((top_left, bottom_right))
            patch_coords.append(image_coords)
        return patch_coords

    def get_patches(self, patch_coordinates, original_image):
        patches = []
        for coords in patch_coordinates:
            image_patches = []
            for (top_left, bottom_right) in coords:
                patch = original_image[:, top_left[1]:bottom_right[1], top_left[0]:bottom_right[0]]
                image_patches.append(patch)
            patches.append(torch.stack(image_patches))
        patches = torch.stack(patches)
        return patches

    def forward(self, activation_map, original_image_size):
        activation_map = activation_map.sum(dim=1).unsqueeze(1)
        patch_coordinates = self.get_patch_coordinates(activation_map, original_image_size)
        return patch_coordinates

class GlobalBranch(nn.Module):
    def __init__(self):
        super(GlobalBranch, self).__init__()
        self.resnet = models.resnet50(pretrained=True)
        self.resnet = nn.Sequential(*list(self.resnet.children())[:-2])
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(2048, 128)

    def forward(self, x):
        F = self.resnet(x)
        pooled_F = self.avg_pool(F)
        pooled_F = pooled_F.view(pooled_F.size(0), -1)
        global_embeddings = self.fc(pooled_F)
        return F, global_embeddings

class LocalEmbeddingExtractor(nn.Module):
    def __init__(self):
        super(LocalEmbeddingExtractor, self).__init__()
        self.resnet = models.resnet50(pretrained=True)
        self.resnet = nn.Sequential(*list(self.resnet.children())[:-2])
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(2048, 128)

    def forward(self, x, coordinates):
        batch_size, _, _, _ = x.size()
        local_embeddings = []
        for i in range(batch_size):
            patches = []
            for (top_left, bottom_right) in coordinates[i]:
                patch = x[i:i+1, :, top_left[1]:bottom_right[1], top_left[0]:bottom_right[0]]
                patches.append(patch)
            if len(patches) == 0:
                continue
            patches = torch.cat(patches, dim=0)
            features = self.resnet(patches)
            pooled_features = self.avg_pool(features)
            pooled_features = pooled_features.view(pooled_features.size(0), -1)
            embeddings = self.fc(pooled_features)
            local_embeddings.append(embeddings)
        local_embeddings = torch.stack(local_embeddings)
        return local_embeddings

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.Wq = nn.Linear(d_model, d_model)
        self.Wk = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)
        self.fc = nn.Linear(d_model, d_model)

    def forward(self, Q, K, V):
        batch_size = Q.size(0)
        Q = self.Wq(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.Wk(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.Wv(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32, device=Q.device))
        attention = F.softmax(scores, dim=-1)
        context = torch.matmul(attention, V).transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        output = self.fc(context)
        return output

class AFFM(nn.Module):
    def __init__(self, global_dim, local_dim, num_heads):
        super(AFFM, self).__init__()
        self.multihead_attention = MultiHeadAttention(global_dim + local_dim, num_heads)
        self.classifier = nn.Linear(global_dim + local_dim, 1)  # Assuming binary classification

    def forward(self, global_embeddings, local_embeddings):
        embeddings = torch.cat((global_embeddings, local_embeddings.mean(dim=1)), dim=1)
        fused_features = self.multihead_attention(embeddings, embeddings, embeddings)
        output = self.classifier(fused_features)
        return output

class FullArchitecture(nn.Module):
    def __init__(self):
        super(FullArchitecture, self).__init__()
        self.global_branch = GlobalBranch()
        self.patch_selection_module = PatchSelectionModule()
        self.local_embedding_extractor = LocalEmbeddingExtractor()
        self.affm = AFFM(global_dim=128, local_dim=128, num_heads=8)

    def forward(self, x):
        original_image_size = x.shape[2:]
        F, global_embeddings = self.global_branch(x)
        patch_coordinates = self.patch_selection_module(F, original_image_size)
        local_embeddings = self.local_embedding_extractor(x, patch_coordinates)
        output = self.affm(global_embeddings, local_embeddings)
        return output

# Example usage
# Assuming the input tensor 'images' has shape (B, C, H, W)
# images = torch.randn((3, 3, 512, 512)).to(device)  # Example batch of 2 images
# model = FullArchitecture().to(device)
# output = model(images)
# print(output)

In [None]:
# Initialize the model
model = FullArchitecture().to(device)

# Define loss function and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.003)

In [None]:
def train(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device).float()
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images)
        outputs = outputs.squeeze()  # Remove extra dimension if necessary
        
        # Calculate loss
        loss = criterion(outputs, labels)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    epoch_loss = running_loss / len(train_loader)
    print(f"Training Loss: {epoch_loss:.4f}")

def test(model, test_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device).float()
            
            # Forward pass
            outputs = model(images)
            outputs = outputs.squeeze()  # Remove extra dimension if necessary
            
            # Calculate loss
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            
            # Calculate accuracy
            predicted = torch.round(torch.sigmoid(outputs))
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    epoch_loss = running_loss / len(test_loader)
    accuracy = correct / total
    print(f"Test Loss: {epoch_loss:.4f}, Test Accuracy: {accuracy:.4f}")

# Example usage:
# Assuming `train_loader` and `test_loader` are DataLoader objects
num_epochs = 10

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    train(model, train_loader, criterion, optimizer, device)
    test(model, test_loader, criterion, device)

In [None]:
import datetime as dt
def save_model(model, model_name,model_evaluation_history):
    model_evaluation_loss, model_evaluation_accuracy = model_evaluation_history
    # Define the string date format.
    # Get the current Date and Time in a DateTime Object.
    # Convert the DateTime object to string according to the style mentioned in date_time_format string.
    date_time_format = '%Y_%m_%d__%H_%M_%S'
    current_date_time_dt = dt.datetime.now()
    current_date_time_string = dt.datetime.strftime(current_date_time_dt, date_time_format)

    # Define a useful name for our model to make it easy for us while navigating through multiple saved models.
    model_file_name = f'{model_name}___Date_Time_{current_date_time_string}___Loss_{model_evaluation_loss}___Accuracy_{model_evaluation_accuracy}.pth'

    # Save your Model.
    torch.save(model.state_dict(), f'{model_file_name}_entire_dict')
    torch.save(model, f'{model_file_name}_entire')

In [None]:
save_model(model, "Fusion", (0, 0))