# Imports

In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import random_split
from torch.autograd import Variable
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
import numpy as np
import torchvision
import random

In [2]:
from transformers import CLIPProcessor, CLIPModel

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
import kagglehub
path_to_clipped_freqnet = kagglehub.model_download('aayushpuri01/clipped-freqnet/PyTorch/default/1')
print(path_to_clipped_freqnet)
os.chdir(path_to_clipped_freqnet)
from clipped_freqnet import freqnet

/kaggle/input/clipped-freqnet/pytorch/default/1


In [5]:
path_to_dataloader = kagglehub.model_download('aayushpuri01/dataloader-pipeline-for-dds/PyTorch/default/2')
print(path_to_dataloader)
os.chdir(path_to_dataloader)
from dataloading_pipeline_fixleakage import DataPipeline

/kaggle/input/dataloader-pipeline-for-dds/pytorch/default/2


In [6]:
os.chdir('/kaggle/working')

# Wandb config

In [27]:
import wandb

try:
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    api_key = user_secrets.get_secret("wandb_api")
    wandb.login(key=api_key)
    anony = None
except:
    anony = "must"
    
    print('If you want to use your W&B account, go to Add-ons -> Secrets and provide your W&B access token. Use the Label name as wandb_api. \nGet your W&B access token from here: https://wandb.ai/authorize')

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33maayushpuri2486[0m ([33maayushpuri2486-pulchowk-campus[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [None]:
wandb.init(
    # set the wandb project where this run will be logged
    project="before shipping notebook test",

    # track hyperparameters and run metadata
    config={
    "learning_rate": 0.0001,
    "architecture": "Freqnet+CLIP",
    "dataset": "Custom Augmented Dataset",
    "epochs": 1,
    }
)

# Dataloader Configuration

The DataPipeline function takes the following arguments:

1. ***path_to_drive*** and ***data_dir*** are joined to make the complete relative directory path,
   so just split if you have a complete path
2. ***num_images*** (default value set to 20,000): create a subset of the dataset, i.e. 20,000 images will
   be used for training and validation set
3. ***val_split*** set for the portion of limited dataset to use for validation
4. ***batch_size***
5. ***test_size*** (default set to 3000): creates a separate test set of 3000 images. 

Values can be changed for experimentations. 


In [16]:
pipeline = DataPipeline(path_to_drive='/kaggle/input',
                        data_dir='deepfake-dataset/Deepfake_Dataset',
                        val_split=0.3,
                        batch_size=32,
                        num_images=1000,
                        test_size =100)
train_loader, val_loader, test_loader = pipeline.get_loaders()

In [None]:
#Visualising some images from the dataloaders, 
#try all the data loaders 
#comment the cell when running the session 

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

def denormalized(img_tensor, mean, std):
    img_tensor = img_tensor.clone()  # cloning the tensor to avoid changing it
    for t, m, s in zip(img_tensor, mean, std):
        t.mul_(s).add_(m)  # Denormalize each channel
    return img_tensor

data_iter = iter(val_loader)
images, labels = next(data_iter)
images = torch.stack([denormalized(img, mean, std) for img in images])
images = images.numpy().transpose((0, 2, 3, 1))

fig, axes = plt.subplots(3, 3, figsize = (8, 6))
for i in range(3):
  for j in range(3):
    axes[i][j].imshow(images[i*3+j])
    lbl = labels[i*3+j].item()
    if lbl == 0:
      axes[i][j].set_title("Fake")
    else:
      axes[i][j].set_title("Real")
    axes[i][j].axis('off')
plt.show()

In [7]:
# pipeline = DataPipeline(path_to_drive='/kaggle/input',
#                         data_dir='deepfake-dataset/Deepfake_Dataset',
#                         val_split=0.3,
#                         batch_size=32,
#                         num_images=20000)
# train_loader, val_loader, test_loader = pipeline.get_loaders()

# Model Architecture Config and Init

In [29]:
class HybridDeepfakeDetector(nn.Module):
    def __init__(self, freq_model, clip_model_name="openai/clip-vit-large-patch14", device="cuda"):
        super(HybridDeepfakeDetector, self).__init__()
        self.device = device

        # Load FreqNet
        self.freqnet = freq_model.to(self.device)

        # Load CLIP Model
        self.clip_model = CLIPModel.from_pretrained(clip_model_name).to(self.device)
        self.clip_processor = CLIPProcessor.from_pretrained(clip_model_name)

        # Define Fully Connected Classifier
        self.fc = nn.Linear(512 + 768, 1)  # FreqNet (512) + CLIP (768)

        # ImageNet Denormalization
        imagenet_mean = [0.485, 0.456, 0.406]
        imagenet_std = [0.229, 0.224, 0.225]
        self.denormalize = transforms.Compose([
            transforms.Normalize(mean=[-m / s for m, s in zip(imagenet_mean, imagenet_std)],
                                 std=[1 / s for s in imagenet_std]),
            # transforms.Lambda(lambda x: x.clamp(0, 1))  # Clamp to [0,1]
            transforms.Lambda(HybridDeepfakeDetector.clamp_image)
        ])

    @staticmethod
    def clamp_image(tensor):
        return tensor.clamp(0, 1)
    
    def extract_clip_features(self, images):
        """
        Extract CLIP embeddings from images.
        """
        inputs = self.clip_processor(images=images, return_tensors="pt", padding=True)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        with torch.no_grad():
            embeddings = self.clip_model.get_image_features(**inputs)
        return embeddings.to(self.device)

    def forward(self, freq_input, clip_input):
        """
        freq_input: Image tensors normalized with ImageNet stats for FreqNet
        clip_input: Same image tensors (but denormalized) for CLIP
        """
        # Get frequency-based features from FreqNet
        freq_features = self.freqnet(freq_input)

        # Get semantic embeddings from CLIP
        denormalized_images = torch.stack([self.denormalize(image) for image in clip_input])
        clip_features = self.extract_clip_features(denormalized_images)

        # Normalize features and concatenate
        freq_features = F.normalize(freq_features, dim=-1)
        clip_features = F.normalize(clip_features, dim=-1)
        combined_features = torch.cat((freq_features, clip_features), dim=-1)

        # Classifier output
        logits = self.fc(combined_features)
        return logits.squeeze()

In [30]:
#Initialize the models
freqnet_model = freqnet(num_classes=2)  # Assuming `freqnet` function initializes FreqNet
model = HybridDeepfakeDetector(freqnet_model, device=device).to(device)

In [31]:
# Optimizer & Loss
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001) #set lr = 0.001 optimally? 
criterion = nn.BCEWithLogitsLoss()

# Helper Functions

In [33]:
def train_one_epoch(epoch, model, optimizer, criterion, train_loader, device):
    model.train()
    running_loss = 0.0
    running_corrects = 0

    loop = tqdm(enumerate(train_loader), total=len(train_loader), leave=True)
    
    for batch_idx, (images, labels) in loop:
        images, labels = images.to(device), labels.to(device).float()

        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images, images)  # Passing same images to both FreqNet & CLIP
        loss = criterion(outputs, labels)

        # Backward pass
        loss.backward()
        optimizer.step()

        # Track loss and accuracy
        running_loss += loss.item() * images.size(0)
        predicted = (torch.sigmoid(outputs) > 0.5).float()
        running_corrects += (predicted == labels).sum().item()

        loop.set_description(f'Epoch [{epoch+1}]')
        loop.set_postfix(loss=loss.item())

    epoch_loss = running_loss / len(train_loader.dataset)
    epoch_acc = running_corrects / len(train_loader.dataset)
    print(f'Training Loss: {epoch_loss:.4f}, Training Accuracy: {epoch_acc:.4f}')
    return epoch_loss, epoch_acc

In [34]:
def validate_one_epoch(epoch, model, criterion, val_loader, device):
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        progress_bar = tqdm(val_loader, desc="Validating", leave=False)
        for images, labels in progress_bar:
            images, labels = images.to(device), labels.to(device).float()
            outputs = model(images, images)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * images.size(0)

            # Predictions
            predicted = (torch.sigmoid(outputs) > 0.5).float()
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            progress_bar.set_postfix({'Val Loss': loss.item()})

    val_acc = correct / total
    print(f'Validation Loss: {val_loss / len(val_loader.dataset):.4f}, Validation Accuracy: {val_acc:.4f}')
    return val_loss / len(val_loader.dataset), val_acc

In [35]:
def save_checkpoint(epoch, model, optimizer, checkpoint_dir = '/kaggle/working'):
    os.makedirs(checkpoint_dir, exist_ok = True)

    path_to_checkpoint = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch+1}.pth")

    #saving the checkpoint dictionary with the model and optimizer state dict only for resuming training if deemed necesary

    torch.save({
        "epoch": epoch + 1, 
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
    }, path_to_checkpoint)

    #might as well log to wandb as an artifact

    artifact = wandb.Artifact(f"checkpoint_epoch_{epoch+1}", type="model")
    artifact.add_file(path_to_checkpoint)
    wandb.log_artifact(artifact)

    print(f"Checkpoint saved at {path_to_checkpoint}")

In [36]:
def save_entire_model(epoch, model, optimizer, checkpoint_dir):
    """
    Saves the entire model (architecture + weights), optimizer, and epoch info.
    """
    
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    checkpoint_path = os.path.join(checkpoint_dir, f'hybrid_model_epoch_{epoch}.pth')

    # Save the entire model
    torch.save({
        'epoch': epoch,
        'model': model,  # Saving entire model (not just state_dict)
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, checkpoint_path)

    # Log to WandB
    artifact = wandb.Artifact(f'hybrid_model_epoch_{epoch+1}', type='model')
    artifact.add_file(checkpoint_path)
    wandb.log_artifact(artifact)

    print(f'Model saved at {checkpoint_path}')


In [37]:
def train_model(num_epochs, model, optimizer, criterion, train_loader, val_loader, device, checkpoint_dir="/kaggle/working"):
    for epoch in range(num_epochs):
        train_loss, train_acc = train_one_epoch(epoch,
                                                model,
                                                optimizer,
                                                criterion,
                                                train_loader,
                                                device)
        val_loss, val_acc = validate_one_epoch(epoch,
                                               model,
                                               criterion,
                                               val_loader,
                                               device)
        wandb.log({
            "epoch": epoch,
            "training_loss": train_loss,
            "training_accuracy": train_acc,
            "validation_loss": val_loss,
            "validation_accuracy": val_acc,
        })

        save_checkpoint(epoch, model, optimizer, checkpoint_dir)

# Train the model
set ***num_epochs*** to desired number of epochs

In [None]:
train_model(
    num_epochs=1, 
    model=model,
    optimizer=optimizer,
    criterion=criterion,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    checkpoint_dir="/kaggle/working")

# Evaluation on Test Set

In [39]:
from sklearn.metrics import accuracy_score

def evaluate_test_set(model, test_loader, device):
    """
    Evaluate the model on the test set and compute accuracy.
    """

    print('Evaluating the test set...')
    model.eval()# Set model to evaluation mode
    # model.train() #setting the model to train mode to see if the batchNorm are misbehaving

    total = 0
    correct = 0

    with torch.no_grad():
        progress_bar = tqdm(test_loader, desc='Testing', leave=False)
        for images, labels in progress_bar:
            images, labels = images.to(device), labels.to(device).float()

            # Forward pass through the unified model
            outputs = model(images, images)  # Pass the same images for both FreqNet & CLIP

            # Convert logits to binary predictions
            predicted = (torch.sigmoid(outputs) > 0.5).float()

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    # Compute accuracy
    test_acc = correct / total
    print(f'Test Accuracy: {test_acc:.4f}')

    wandb.log({"Test_Accuracy": test_acc})

    return test_acc


In [None]:
# Evaluate the trained model on the test set
test_accuracy = evaluate_test_set(model,
                                  test_loader, 
                                  device)

# Save Model

set ***epoch*** to number of epochs set above

In [42]:
checkpoint_dir = '/kaggle/working'
save_entire_model(epoch = 10,
                  model = model,
                  optimizer = optimizer,
                  checkpoint_dir = checkpoint_dir)

Model saved at /kaggle/working/hybrid_model_epoch_10.pth


In [13]:
# def load_entire_model(checkpoint_path, device):
#     """
#     Loads the entire model from a checkpoint file.

#     Parameters:
#     - checkpoint_path (str): Path to the saved model file.
#     - device (str): 'cuda' or 'cpu'.

#     Returns:
#     - model (nn.Module): Loaded HybridDeepfakeDetector model.
#     - optimizer (torch.optim.Optimizer): Optimizer state.
#     - epoch (int): Last saved epoch.
    
#     """
    
#     checkpoint = torch.load(checkpoint_path, map_location=device)
    
#     # Load the saved model architecture and weights
#     model = checkpoint['model'].to(device)
#     model.load_state_dict(checkpoint['model_state_dict'])

#     # Recreate optimizer and load its state
#     optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
#     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

#     epoch = checkpoint['epoch']

#     print(f'Model loaded from {checkpoint_path}, last trained epoch: {epoch+1}')
    
#     return model, optimizer, epoch

In [None]:
# loaded_model, optimizer, epoch = load_entire_model("/kaggle/input/hybridmodel20k10ep/pytorch/default/1/hybrid_model_epoch_10.pth", device)

def load_entire_model(checkpoint_path, device):
    """
    Loads the entire model from a checkpoint file.

    Parameters:
    - checkpoint_path (str): Path to the saved model file.
    - device (str): 'cuda' or 'cpu'.

    Returns:
    - model (nn.Module): Loaded HybridDeepfakeDetector model.
    - optimizer (torch.optim.Optimizer): Optimizer state.
    - epoch (int): Last saved epoch.
    
    """
    
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Load the saved model architecture and weights
    model = checkpoint['model'].to(device)
    model.load_state_dict(checkpoint['model_state_dict'])

    # Recreate optimizer and load its state
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    epoch = checkpoint['epoch']

    print(f'Model loaded from {checkpoint_path}, last trained epoch: {epoch+1}')
    
    return model, optimizer, epoch


# Visualization of embeddings

In [24]:
# from sklearn.decomposition import PCA
# from sklearn.manifold import TSNE

In [61]:
# def analyze_embeddings(model, train_loader, device):
#     """
#     Extracts and compares embeddings from FreqNet, CLIP, and Hybrid representation.
#     Returns embeddings and differences.
#     """

#     model.eval()  # Set model to evaluation mode

#     sampled_images = []
#     freq_embeddings = []
#     clip_embeddings = []
#     hybrid_embeddings = []

#     with torch.no_grad():
#         for images, _ in train_loader:
#             images = images.to(device)

#             # Forward pass to get embeddings
#             freq_features = model.freqnet(images)  # FreqNet embeddings

#             denormalized_images = torch.stack([model.denormalize(image) for image in images])
#             clip_features = model.extract_clip_features(denormalized_images)
            
#             # clip_features = model.extract_clip_features(images)  # CLIP embeddings

#             # Normalize and concatenate to get hybrid features
#             freq_features_norm = torch.nn.functional.normalize(freq_features, dim=-1)
#             clip_features_norm = torch.nn.functional.normalize(clip_features, dim=-1)
#             hybrid_features = torch.cat((freq_features_norm, clip_features_norm), dim=-1)

#             # Store embeddings
#             sampled_images.append(images.cpu())  # Store original images
#             freq_embeddings.append(freq_features.cpu().numpy())  # Store FreqNet embeddings
#             clip_embeddings.append(clip_features.cpu().numpy())  # Store CLIP embeddings
#             hybrid_embeddings.append(hybrid_features.cpu().numpy())  # Store Hybrid embeddings

#             # Stop after 6 images
#             if len(sampled_images) >= 6:
#                 break

#     # Convert lists to numpy arrays
#     freq_embeddings = np.vstack(freq_embeddings)[:6]
#     clip_embeddings = np.vstack(clip_embeddings)[:6]
#     hybrid_embeddings = np.vstack(hybrid_embeddings)[:6]

#     # Compute changes in embeddings
#     freq_hybrid_diff = np.linalg.norm(hybrid_embeddings[:, :512] - freq_embeddings, axis=1)
#     clip_hybrid_diff = np.linalg.norm(hybrid_embeddings[:, 512:] - clip_embeddings, axis=1)

#     return sampled_images, freq_embeddings, clip_embeddings, hybrid_embeddings, freq_hybrid_diff, clip_hybrid_diff


In [67]:
# # Define a function to reverse ImageNet normalization
# def denormalize_image(image):
#     mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 1, 3)  # Shape (1, 1, 3)
#     std = torch.tensor([0.229, 0.224, 0.225]).view(1, 1, 3)
#     image = image * std + mean  # Reverse normalization
#     return image.clamp(0, 1)  # Ensure values are in [0, 1] range


In [68]:
# def visualize_embeddings(freq_embeddings, clip_embeddings, hybrid_embeddings, sampled_images):
#     """
#     Uses PCA and t-SNE to visualize embeddings of FreqNet, CLIP, and Hybrid embeddings.
#     """

#     # Reduce dimensionality using PCA
#     pca = PCA(n_components=2)
#     freq_pca = pca.fit_transform(freq_embeddings)
#     clip_pca = pca.fit_transform(clip_embeddings)
#     hybrid_pca = pca.fit_transform(hybrid_embeddings)

#     # Further reduce with t-SNE
#     tsne = TSNE(n_components=2, perplexity=5, random_state=42)
#     freq_tsne = tsne.fit_transform(freq_embeddings)
#     clip_tsne = tsne.fit_transform(clip_embeddings)
#     hybrid_tsne = tsne.fit_transform(hybrid_embeddings)

#     fig, axs = plt.subplots(2, 3, figsize=(15, 10))  # Ensure correct layout

#     #Flatten axs for safe indexing
#     axs = axs.flatten()

#     #Extract individual images before permute
#     for i in range(len(sampled_images)):  # Ensure we do not exceed available indices
#         image = sampled_images[i][0]  # Extract first image from batch
#         image = image.permute(1, 2, 0)  # Convert from (C, H, W) to (H, W, C)

#         image = denormalize_image(image)
        
#         # Convert embeddings to rounded text for display
#         clip_text = f"CLIP: {np.round(clip_embeddings[i][:3], 3)}"  # Show first 5 values
#         freq_text = f"FreqNet: {np.round(freq_embeddings[i][:3], 3)}"
#         hybrid_text = f"Hybrid: {np.round(hybrid_embeddings[i][:3], 3)}"

#         axs[i].imshow(image.cpu().numpy())  # Convert from tensor
#         axs[i].axis("off")
#         axs[i].set_title(f"{clip_text}\n{freq_text}\n{hybrid_text}", fontsize=8)

#     # Plot PCA visualization
#     plt.figure(figsize=(8, 6))
#     plt.scatter(freq_pca[:, 0], freq_pca[:, 1], label="FreqNet", color="blue")
#     plt.scatter(clip_pca[:, 0], clip_pca[:, 1], label="CLIP", color="red")
#     plt.scatter(hybrid_pca[:, 0], hybrid_pca[:, 1], label="Hybrid", color="green")
#     plt.title("PCA Projection")
#     plt.legend()
#     plt.show()

#     # Plot t-SNE visualization
#     plt.figure(figsize=(8, 6))
#     plt.scatter(freq_tsne[:, 0], freq_tsne[:, 1], label="FreqNet", color="blue")
#     plt.scatter(clip_tsne[:, 0], clip_tsne[:, 1], label="CLIP", color="red")
#     plt.scatter(hybrid_tsne[:, 0], hybrid_tsne[:, 1], label="Hybrid", color="green")
#     plt.title("t-SNE Projection")
#     plt.legend()
#     plt.show()
    
#     # Plot embedding changes
#     plt.figure(figsize=(8, 6))
#     plt.bar(range(len(sampled_images)), np.linalg.norm(hybrid_embeddings[:, :512] - freq_embeddings, axis=1), color="blue", label="FreqNet → Hybrid")
#     plt.bar(range(len(sampled_images)), np.linalg.norm(hybrid_embeddings[:, 512:] - clip_embeddings, axis=1), color="red", label="CLIP → Hybrid")
#     plt.title("Embedding Change Magnitude")
#     plt.legend()
#     plt.show()

In [None]:
# # Extract embeddings
# sampled_images, freq_embeddings, clip_embeddings, hybrid_embeddings, freq_hybrid_diff, clip_hybrid_diff = analyze_embeddings(model, train_loader, device)

# # Visualize embeddings
# visualize_embeddings(freq_embeddings, clip_embeddings, hybrid_embeddings, sampled_images)


def load_checkpoint(checkpoint_path, model, optimizer, device):
    """
    
    Loads a checkpoint to resume training.
    Parameters:
    - checkpoint_path (str): Path to the saved checkpoint file.
    - model (HybridDeepfakeDetector): Model to load weights into.
    - optimizer (torch.optim.Optimizer): Optimizer to load state into.
    - device (str): 'cuda' or 'cpu'.
    
    Returns:
    - int: Last epoch number to resume training.
    
    """
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    last_epoch = checkpoint["epoch"]

    print(f"Resuming training from epoch {last_epoch}")
    return last_epoch

"""

checkpoint_path = "./checkpoints/checkpoint_epoch_5.pth"  # Change to the latest checkpoint
last_epoch = load_checkpoint(checkpoint_path, model, optimizer, device)
train_model(num_epochs=10 - last_epoch, model=model, optimizer=optimizer, criterion=criterion, train_loader=train_loader, val_loader=val_loader, device=device, checkpoint_dir="./checkpoints")

"""
