In [1]:
root_dir ="/content/drive/MyDrive/Zero_Shot_DeepFake_Image_Classification/"

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.models import vit_b_16
import torch.optim as optim
from tqdm import tqdm
import os
from sklearn.metrics import confusion_matrix
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import random
import torch.fft as fft

In [3]:
seed = 43
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [4]:
dataset_root_dir = root_dir + 'DeepfakeEmpiricalStudy/dataset/'
temp_dataset_root_dir = root_dir + 'dataset_small/'
train_dir = dataset_root_dir + 'CELEB/train'
val_dir = dataset_root_dir + 'CELEB/val'
test_dirs = [dataset_root_dir + 'CELEB-M/test', dataset_root_dir + 'DF/test', dataset_root_dir + 'DFD/test', \
             dataset_root_dir + 'F2F/test', dataset_root_dir + 'FS-I/test', dataset_root_dir + 'NT-I/test' ]

models_root_dir = root_dir + 'DeepfakeEmpiricalStudy_Models/'

In [5]:
batch_size = 64
num_epochs = 5
learning_rate = 1e-4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = datasets.ImageFolder(train_dir, transform=transform)
val_dataset = datasets.ImageFolder(val_dir, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# class TransformerBasedModel(nn.Module):
#     def __init__(self, num_classes=2):
#         super(TransformerBasedModel, self).__init__()
#         self.vit = vit_b_16(pretrained=True)
#         #self.vit.heads = nn.Linear(self.vit.heads.in_features, num_classes)

#     def forward(self, x):
#         return self.vit(x)

In [6]:
class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split embedding into self.heads different pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        return self.fc_out(out)

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size)
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)

        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

class FreqNetSimpleTransformer(nn.Module):
    def __init__(self, num_classes=2, patch_size=16, im_width=224, im_height=224, d_model=768, num_heads=8, num_layers=6, dropout=0.1):
        super(FreqNetSimpleTransformer, self).__init__()

        # Parameters
        self.patch_size = patch_size
        self.im_width = im_width
        self.im_height = im_height

        # Create high-pass filter
        self.high_pass_filter = self.create_high_pass_filter(self.patch_size)
        self.high_pass_filter1 = self.create_high_pass_filter(self.im_width)

        # Frequency convolutional layers for amplitude and phase
        self.freq_conv_amp = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, padding=1)
        self.freq_conv_phase = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, padding=1)

        # Positional embedding and patch embedding
        self.pos_embedding = nn.Parameter(torch.randn((im_width // patch_size) ** 2 + 1, 1, d_model))
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
        self.patch_to_embedding = nn.Linear(patch_size * patch_size * 3, d_model)

        # Transformer layers
        self.transformer_blocks = nn.ModuleList(
            [TransformerBlock(d_model, num_heads, dropout=dropout, forward_expansion=4) for _ in range(num_layers)]
        )

        # Final classification head
        self.fc = nn.Linear(d_model, num_classes)

    def forward(self, x):

        # x.shape:  torch.Size([1, 3, 224, 224])
        # Step 1: Convert input images to the frequency domain and apply high-pass filter
        x_freq = self.apply_fft_highpass(x)

        # Step 2: Apply frequency convolution to the high-frequency components
        x_freq_convolved = self.frequency_convolution(x_freq)
        # x_freq_convolved.shape:  torch.Size([1, 3, 224, 224])

        # Step 3: Convert the image into patches and embed
        x_patches = self.create_patches(x_freq_convolved)
        # x_patches.shape:  torch.Size([196, 1, 768]) # here 1 is the batch size...if batch size is 2 , then x_patches.shape: torch.Size([196, 2, 768])


        # Step 4: Add positional encoding
        n_patches,batch_size,   _ = x_patches.shape # n_patches = 196, batch_size = 1
        # self.cls_token.shape: torch.Size([1, 1, 768])...when batch size is 2, then also self.cls_token.shape: torch.Size([1, 1, 768])
        cls_tokens = self.cls_token.expand(-1, batch_size, -1) #to repeat the class token for each batch.

        # cls_tokens.shape: torch.Size([1, 1, 768]) ...if batch size was 2, then cls_tokens.shape: torch.Size([1, 2, 768])
        x_patches = torch.cat((cls_tokens, x_patches), dim=0)
        # x_patches.shape: torch.Size([197, 1, 768])... if batch size is 2, then x_patches.shape: torch.Size([197, 2, 768])
        # self.pos_embedding.shape: torch.Size([197, 1, 768])...if batch size is 2 then also self.pos_embedding.shape: torch.Size([197, 1, 768])

        ###################################################
        # if batch size is 2
        # x_patches.shape: torch.Size([197, 2, 768])
        # self.pos_embedding.shape: torch.Size([197, 1, 768])
        ###################################################

        # self.pos_embedding[:n_patches + 1,:, :].shape: torch.Size([197, 1, 768])...if batch size is 2, then also self.pos_embedding[:n_patches + 1,:, :].shape: torch.Size([197, 1, 768])
        x_patches += self.pos_embedding[:n_patches + 1,:, :]
        # x_patches.shape: torch.Size([197, 1, 768])...if batch size is 2, then x_patches.shape: torch.Size([197, 2, 768])


        # Step 5: Pass through transformer layers
        for transformer_block in self.transformer_blocks:
            x_patches = transformer_block(x_patches, x_patches, x_patches, mask=None)

        # x_patches.shape: torch.Size([197, 1, 768]) ... if batch_size is 2 then x_patches.shape: torch.Size([197, 2, 768])


        # Step 6: Classification using the cls_token output
        out = self.fc(x_patches[0])

        return out

    def create_patches(self, x):
        """
        Convert input images to patches and flatten them for transformer input.
        """
        batch_size, channels, height, width = x.shape
        patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        # patches.shape: torch.Size([1, 3, 14, 14, 16, 16])
        patches = patches.contiguous().view(batch_size, channels, -1, self.patch_size * self.patch_size)  # Flatten patches
        # patches.shape: torch.Size([1, 3, 196, 256])
        patches = patches.permute(2, 0, 1, 3).contiguous().view(-1, batch_size, self.patch_size * self.patch_size * channels)  # Rearrange for transformer
        # patches.shape :  torch.Size([196, 1, 768])
        patches = self.patch_to_embedding(patches)
        # patches.shape: torch.Size([196, 1, 768])
        return patches

    def apply_fft_highpass(self, x):
        """
        Convert image to frequency domain, apply high-pass filter, and convert back.
        """
        x_fft = fft.fftn(x, dim=(-2, -1))  # Apply FFT over spatial dimensions (height, width)
        x_fft_shift = fft.fftshift(x_fft)  # Shift zero frequency to the center

        # Apply high-pass filter to remove low-frequency components
        x_fft_high = x_fft_shift * self.high_pass_filter1.to(x.device)

        # Inverse FFT: Convert back to the spatial domain
        x_fft_high_shifted = fft.ifftshift(x_fft_high)  # Shift frequencies back
        x_ifft = torch.real(fft.ifftn(x_fft_high_shifted, dim=(-2, -1)))  # Inverse FFT

        return x_ifft

    def create_high_pass_filter(self, patch_size):
        """
        Create a high-pass filter to extract high-frequency components from patches.
        """
        filter = torch.ones(patch_size, patch_size)
        center_x, center_y = patch_size // 2, patch_size // 2
        filter[center_x - patch_size // 4:center_x + patch_size // 4,
               center_y - patch_size // 4:center_y + patch_size // 4] = 0  # Zero central region to keep high-frequencies
        return filter

    def frequency_convolution(self, x):
        """
        Apply convolutional layers in the frequency domain on amplitude and phase spectra.
        """
        x_fft = fft.fftn(x, dim=(-2, -1))  # FFT on spatial dimensions (height, width)

        amp = torch.abs(x_fft)  # Amplitude spectrum
        phase = torch.angle(x_fft)  # Phase spectrum

        # Apply convolutions in the frequency space
        amp_conv = self.freq_conv_amp(amp)  # Convolution on amplitude
        phase_conv = self.freq_conv_phase(phase)  # Convolution on phase

        # Reconstruct the feature maps using the modified amplitude and phase
        x_fft_new = torch.polar(amp_conv, phase_conv)

        # Inverse FFT: Convert back to spatial domain
        x_ifft = torch.real(fft.ifftn(x_fft_new, dim=(-2, -1)))

        return x_ifft


model = FreqNetSimpleTransformer(num_classes=2, patch_size=16, im_width=224, im_height=224, d_model=768, num_heads=8, num_layers=6, dropout=0.1).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [7]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs):
    model.train()
    best_acc = 0.0

    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        for inputs, labels in tqdm(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_acc = correct / total
        val_acc = evaluate_model(model, val_loader, criterion)[0]

        print("Epoch "+str(epoch+1)+", Loss: "+str(running_loss/total)+", Train Accuracy: "+str(train_acc)+", Val Accuracy: "+str(val_acc))

        #print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/total:.4f}, Train Accuracy: {train_acc:.4f}, Val Accuracy: {val_acc:.4f}")

        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), models_root_dir + 'best_vit_model_freqnet2.pth')
            print('Model saved!')

    print(f"Training complete. Best validation accuracy: {best_acc:.4f}")

def evaluate_model(model, loader, criterion):
    model.eval()
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    return correct / total, np.array(all_labels), np.array(all_preds)

In [8]:
def plot_confusion_matrix(cm, classes, title='Confusion Matrix'):
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
    plt.title(title)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.show()

In [9]:
train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs)

  0%|          | 1/250 [01:01<4:14:39, 61.37s/it]


KeyboardInterrupt: 

In [None]:
model.load_state_dict(torch.load(models_root_dir + 'best_vit_model.pth'))

all_labels_combined = []
all_preds_combined = []

for test_dir in test_dirs:
    test_dataset = datasets.ImageFolder(test_dir, transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    test_acc, all_labels, all_preds = evaluate_model(model, test_loader, criterion)
    print(f"Test Accuracy for {test_dir}: {test_acc:.4f}")

    cm = confusion_matrix(all_labels, all_preds, labels=[0, 1])
    plot_confusion_matrix(cm, classes=['real', 'fake'], title=f'Confusion Matrix for {test_dir}')

    all_labels_combined.extend(all_labels)
    all_preds_combined.extend(all_preds)

cm_combined = confusion_matrix(all_labels_combined, all_preds_combined, labels=[0, 1])
print(f"Average Accuracy: {np.mean([evaluate_model(model, DataLoader(datasets.ImageFolder(test_dir, transform=transform), batch_size=batch_size, shuffle=False), criterion)[0] for test_dir in test_dirs]):.4f}")
plot_confusion_matrix(cm_combined, classes=['real', 'fake'], title='Combined Confusion Matrix')

  model.load_state_dict(torch.load(models_root_dir + 'best_vit_model.pth'))


KeyboardInterrupt: 