In [None]:
root_dir ="/content/drive/MyDrive/Zero_Shot_DeepFake_Image_Classification/"

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.models import vit_b_16
import torch.optim as optim
from tqdm import tqdm
import os
from sklearn.metrics import confusion_matrix
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import random
import torch.fft as fft

In [None]:
seed = 43
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
dataset_root_dir = root_dir + 'DeepfakeEmpiricalStudy/dataset/'
temp_dataset_root_dir = root_dir + 'dataset_small/'
train_dir = temp_dataset_root_dir + 'CELEB/train'
val_dir = temp_dataset_root_dir + 'CELEB/val'
test_dirs = [dataset_root_dir + 'CELEB-M/test', dataset_root_dir + 'DF/test', dataset_root_dir + 'DFD/test', \
             dataset_root_dir + 'F2F/test', dataset_root_dir + 'FS-I/test', dataset_root_dir + 'NT-I/test' ]

models_root_dir = root_dir + 'DeepfakeEmpiricalStudy_Models/'

In [None]:
batch_size = 64
num_epochs = 5
learning_rate = 1e-4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = datasets.ImageFolder(train_dir, transform=transform)
val_dataset = datasets.ImageFolder(val_dir, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# class TransformerBasedModel(nn.Module):
#     def __init__(self, num_classes=2):
#         super(TransformerBasedModel, self).__init__()
#         self.vit = vit_b_16(pretrained=True)
#         #self.vit.heads = nn.Linear(self.vit.heads.in_features, num_classes)

#     def forward(self, x):
#         return self.vit(x)

In [None]:
class FreqNetViT(nn.Module):
    def __init__(self, num_classes=2, patch_size=16, im_width=224, im_height=224):
        super(FreqNetViT, self).__init__()

        # Load pre-trained Vision Transformer (ViT) model
        self.vit = vit_b_16(pretrained=True)
        self.vit.heads = nn.Linear(self.vit.heads.head.in_features, num_classes)  # Update the final layer

        # Patch size (used to define high-pass filter size)
        self.patch_size = patch_size
        self.im_width = im_width
        self.im_height = im_height

        self.high_pass_filter1 = self.create_high_pass_filter(self.im_width)
        # High-pass filter for extracting high-frequency information
        self.high_pass_filter = self.create_high_pass_filter(self.patch_size)

        # Frequency convolutional layers for amplitude and phase
        self.freq_conv_amp = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, padding=1)
        self.freq_conv_phase = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, padding=1)

    def forward(self, x):
        # Step 1: Convert input images to the frequency domain and apply high-pass filter
        # x.shape:  torch.Size([1, 3, 224, 224])
        x_freq = self.apply_fft_highpass(x)
        # x_freq.shape: torch.Size([1, 3, 224, 224])

        # Step 2: Apply frequency convolution to the high-frequency components
        x_freq_convolved = self.frequency_convolution(x_freq)

        # Step 3: Pass the frequency-transformed images to the Vision Transformer (ViT)
        x_vit = self.vit(x_freq_convolved)

        return x_vit

    def apply_fft_highpass(self, x):
        """
        Convert image to frequency domain, apply high-pass filter, and convert back.
        """
        # x.shape is [1, 3, 224, 224]

        # FFT: Transform the input images to the frequency domain
        x_fft = fft.fftn(x, dim=(-2, -1))  # Apply FFT over spatial dimensions (height, width)
        # x_fft.shape: torch.Size([1, 3, 224, 224])


        # Shift zero frequency to the center
        x_fft_shift = fft.fftshift(x_fft)
        # x_fft_shift.shape: torch.Size([1, 3, 224, 224])

        # self.high_pass_filter1.to(x.device).shape: torch.Size([224, 224])

        # Apply high-pass filter to remove low-frequency components
        x_fft_high = x_fft_shift * self.high_pass_filter1.to(x.device)
        # x_fft_high.shape: torch.Size([1, 3, 224, 224])

        # Inverse FFT: Convert back to the spatial domain
        x_fft_high_shifted = fft.ifftshift(x_fft_high)  # Shift frequencies back
        # x_fft_high_shifted.shape:  torch.Size([1, 3, 224, 224])
        x_ifft = torch.real(fft.ifftn(x_fft_high_shifted, dim=(-2, -1)))  # Inverse FFT
        # x_ifft.shape: torch.Size([1, 3, 224, 224])

        return x_ifft

    def create_high_pass_filter(self, patch_size):
        """
        Create a high-pass filter to extract high-frequency components from patches.
        """
        # Initialize filter to ones (no filtering)
        filter = torch.ones(patch_size, patch_size)

        # Set a central region to zero (to remove low frequencies)
        center_x, center_y = patch_size // 2, patch_size // 2
        filter[center_x - patch_size//4 : center_x + patch_size//4,
               center_y - patch_size//4 : center_y + patch_size//4] = 0

        return filter

    def frequency_convolution(self, x):
        """
        Apply convolutional layers in the frequency domain on amplitude and phase spectra.
        """
        # FFT: Convert feature maps to the frequency domain
        x_fft = fft.fftn(x, dim=(-2, -1))  # FFT on spatial dimensions (height, width)

        # Separate amplitude and phase
        amp = torch.abs(x_fft)  # Amplitude spectrum
        phase = torch.angle(x_fft)  # Phase spectrum

        # Apply convolutions in the frequency space
        amp_conv = self.freq_conv_amp(amp)  # Convolution on amplitude
        phase_conv = self.freq_conv_phase(phase)  # Convolution on phase

        # Reconstruct the feature maps using the modified amplitude and phase
        x_fft_new = torch.polar(amp_conv, phase_conv)

        # Inverse FFT: Convert back to spatial domain
        x_ifft = torch.real(fft.ifftn(x_fft_new, dim=(-2, -1)))
        # x_ifft.shape:  torch.Size([1, 3, 224, 224])

        return x_ifft


model = FreqNetViT(num_classes=2).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)



In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs):
    model.train()
    best_acc = 0.0

    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        for inputs, labels in tqdm(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_acc = correct / total
        val_acc = evaluate_model(model, val_loader, criterion)[0]

        print("Epoch "+str(epoch+1)+", Loss: "+str(running_loss/total)+", Train Accuracy: "+str(train_acc)+", Val Accuracy: "+str(val_acc))

        #print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/total:.4f}, Train Accuracy: {train_acc:.4f}, Val Accuracy: {val_acc:.4f}")

        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), models_root_dir + 'best_vit_model.pth')
            print('Model saved!')

    print(f"Training complete. Best validation accuracy: {best_acc:.4f}")

def evaluate_model(model, loader, criterion):
    model.eval()
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    return correct / total, np.array(all_labels), np.array(all_preds)

In [None]:
def plot_confusion_matrix(cm, classes, title='Confusion Matrix'):
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
    plt.title(title)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.show()

In [None]:
train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs)

100%|██████████| 8/8 [03:09<00:00, 23.74s/it]


Epoch 1, Loss: 0.8520173964500427, Train Accuracy: 0.528, Val Accuracy: 0.5
Model saved!


100%|██████████| 8/8 [00:11<00:00,  1.42s/it]


Epoch 2, Loss: 0.7063859400749206, Train Accuracy: 0.524, Val Accuracy: 0.5


100%|██████████| 8/8 [00:10<00:00,  1.37s/it]


Epoch 3, Loss: 0.7067483739852906, Train Accuracy: 0.472, Val Accuracy: 0.5


100%|██████████| 8/8 [00:10<00:00,  1.36s/it]


Epoch 4, Loss: 0.6962648062705994, Train Accuracy: 0.468, Val Accuracy: 0.5


100%|██████████| 8/8 [00:10<00:00,  1.37s/it]


Epoch 5, Loss: 0.6913810930252076, Train Accuracy: 0.514, Val Accuracy: 0.53
Model saved!
Training complete. Best validation accuracy: 0.5300


In [None]:
model.load_state_dict(torch.load(models_root_dir + 'best_vit_model_freqnet1.pth'))

all_labels_combined = []
all_preds_combined = []

for test_dir in test_dirs:
    test_dataset = datasets.ImageFolder(test_dir, transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    test_acc, all_labels, all_preds = evaluate_model(model, test_loader, criterion)
    print(f"Test Accuracy for {test_dir}: {test_acc:.4f}")

    cm = confusion_matrix(all_labels, all_preds, labels=[0, 1])
    plot_confusion_matrix(cm, classes=['real', 'fake'], title=f'Confusion Matrix for {test_dir}')

    all_labels_combined.extend(all_labels)
    all_preds_combined.extend(all_preds)

cm_combined = confusion_matrix(all_labels_combined, all_preds_combined, labels=[0, 1])
print(f"Average Accuracy: {np.mean([evaluate_model(model, DataLoader(datasets.ImageFolder(test_dir, transform=transform), batch_size=batch_size, shuffle=False), criterion)[0] for test_dir in test_dirs]):.4f}")
plot_confusion_matrix(cm_combined, classes=['real', 'fake'], title='Combined Confusion Matrix')

  model.load_state_dict(torch.load(models_root_dir + 'best_vit_model.pth'))


KeyboardInterrupt: 