# Installing required libraries

In [1]:
!pip install lightning

Collecting lightning
  Downloading lightning-2.6.1-py3-none-any.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
Downloading lightning-2.6.1-py3-none-any.whl (853 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m853.6/853.6 kB[0m [31m16.3 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: lightning
Successfully installed lightning-2.6.1


In [2]:
!pip install mlflow

Collecting mlflow
  Downloading mlflow-3.9.0-py3-none-any.whl.metadata (31 kB)
Collecting mlflow-skinny==3.9.0 (from mlflow)
  Downloading mlflow_skinny-3.9.0-py3-none-any.whl.metadata (32 kB)
Collecting mlflow-tracing==3.9.0 (from mlflow)
  Downloading mlflow_tracing-3.9.0-py3-none-any.whl.metadata (19 kB)
Collecting Flask-CORS<7 (from mlflow)
  Downloading flask_cors-6.0.2-py3-none-any.whl.metadata (5.3 kB)
Collecting graphene<4 (from mlflow)
  Downloading graphene-3.4.3-py2.py3-none-any.whl.metadata (6.9 kB)
Collecting gunicorn<24 (from mlflow)
  Downloading gunicorn-23.0.0-py3-none-any.whl.metadata (4.4 kB)
Collecting huey<3,>=2.5.4 (from mlflow)
  Downloading huey-2.6.0-py3-none-any.whl.metadata (4.3 kB)
Collecting skops<1 (from mlflow)
  Downloading skops-0.13.0-py3-none-any.whl.metadata (5.6 kB)
Collecting databricks-sdk<1,>=0.20.0 (from mlflow-skinny==3.9.0->mlflow)
  Downloading databricks_sdk-0.85.0-py3-none-any.whl.metadata (40 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

# Imports

In [3]:
import lightning.pytorch as pl
import torch
import torch.nn as nn
import torch.optim as optim
from lightning.pytorch.callbacks import EarlyStopping
from torch.utils.data import DataLoader
from torchmetrics.classification import (
    BinaryAccuracy,
    BinaryPrecision,
    BinaryRecall,
    BinaryF1Score
)

import pytorch_lightning as pl
from torchvision import datasets, transforms
import torchvision.models as models
from pytorch_lightning.callbacks import ModelCheckpoint
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image
import os

import seaborn as sns
from sklearn.metrics import (
    confusion_matrix,
    classification_report,
    roc_auc_score
)
import mlflow
from lightning.pytorch.loggers import MLFlowLogger
import random
from PIL import Image, ImageFile

torch.set_float32_matmul_precision('medium')
#Using it reduces precision thus makes iteration faster. 
#Training behaviour and code logic remains same.
#Good for experimentation and while building a working min end to end pipeline, can be removed for final training.

# Configuring MLFlow

In [6]:
mlf_logger = MLFlowLogger(
    experiment_name="AI generated image detector",
    run_name="resnet50_kaggle",
    tracking_uri="file:/kaggle/working/mlruns"
)


  return FileStore(store_uri, store_uri)


# Dataset Preparation

In [7]:
data_dir = "/kaggle/input/tiny-genimage"

## RGB to Frequency domain

In [8]:
def rgb_to_fft(x):
    # x: [B, 3, H, W]  (already normalized)
    fft = torch.fft.fft2(x, dim=(-2, -1))
    fft = torch.fft.fftshift(fft, dim=(-2, -1))
    mag = torch.abs(fft)
    mag = torch.log1p(mag)
    return mag


## Custom Dataset

In [9]:
class MultiGenDataset(Dataset):
    def __init__(self, root_dir, split, transform=None):
        self.samples = []
        self.transform = transform

        generators = os.listdir(root_dir)

        for gen in generators:
            gen_path = os.path.join(root_dir, gen, split)
            if not os.path.isdir(gen_path):
                continue

            for cls in ["ai", "nature"]:
                class_dir = os.path.join(gen_path, cls)
                if not os.path.isdir(class_dir):
                    continue

                #Mapping: nature=REAL(0) ai=FAKE(1) 
                label = 1 if cls == "ai" else 0  

                for img_name in os.listdir(class_dir):
                    img_path = os.path.join(class_dir, img_name)
                    self.samples.append((img_path, label))

        assert len(self.samples) > 0, "Dataset is EMPTY"
        print(f"[{split}] Loaded {len(self.samples)} samples")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)

        return img, label


## Transforms

In [10]:
TRAIN_TRANSFORM = transforms.Compose([
    transforms.Resize((224, 224)),          # ResNet input size                                      # mild, not aggressive
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],          # ImageNet stats
        std=[0.229, 0.224, 0.225]
    )
])


VAL_TRANSFORM = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])


# Lightning DataModule

In [11]:
class MultiGenDatasetModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size=32, num_workers=2):
        super().__init__()
    
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.train_dataset = None
        self.test_dataset = None

    def setup(self, stage=None):

        self.train_ds = MultiGenDataset(
            self.data_dir, "train", TRAIN_TRANSFORM
        )
        self.val_ds = MultiGenDataset(
            self.data_dir, "val", VAL_TRANSFORM
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_ds,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_ds,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers, #for problem free sanity checking, num_workers = 0 as sanity checking uses val loader
            pin_memory=True
        )


## DataModule Initialization

In [12]:
data_module = MultiGenDatasetModule(
    data_dir="/kaggle/input/tiny-genimage",
    batch_size=64,
    num_workers=2
)

# Loading ResNet34 and Freezing

In [13]:
def load_resnet34_freq_model(
    num_classes=2,
    freeze_layer2=False
):
    """
    ResNet-34 for frequency-domain branch.
    - ImageNet pretrained
    - Early layers frozen
    - Deeper layers trainable

    Args:
        num_classes (int): number of output classes
        freeze_layer2 (bool): whether to freeze layer2 as well

    Returns:
        model (nn.Module)
    """

    # 1. Load pretrained ResNet-34
    model = models.resnet34(
        weights=models.ResNet34_Weights.IMAGENET1K_V1
    )

    # 2. Replace classifier
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    
    # 3. Freeze stem
    for p in model.conv1.parameters():
        p.requires_grad = False
    for p in model.bn1.parameters():
        p.requires_grad = False

    # 4. Freeze layer1
    for p in model.layer1.parameters():
        p.requires_grad = False

    # 5. Optionally freeze layer2
    if freeze_layer2:
        for p in model.layer2.parameters():
            p.requires_grad = False

    return model


# Schedulers and Optimizers

In [14]:
def define_optimizer_and_scheduler(model, learning_rate, weight_decay):
    """
    Defines the optimizer and learning rate scheduler for transfer learning.

    Args:
        model (nn.Module): Model with frozen backbone and trainable head.
        learning_rate (float): Learning rate for the optimizer.
        weight_decay (float): Weight decay (L2 regularization).

    Returns:
        tuple: (optimizer, scheduler)
    """
    # Optimize ONLY trainable parameters
    trainable_params = filter(lambda p: p.requires_grad, model.parameters())

    optimizer = optim.AdamW(
        trainable_params,
        lr=learning_rate,
        weight_decay=weight_decay
    )

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode="min",
        factor=0.1,
        patience=2
    )

    return optimizer, scheduler


# Lightning Module

In [15]:
class AIImageDetector(pl.LightningModule):
    """
    LightningModule for AI-generated vs Real image classification
    using transfer learning with ResNet-50.
    """

    def __init__(self, learning_rate=1e-3,weight_decay=1e-2):
        super().__init__()
        self.save_hyperparameters()
        

        self.model = load_resnet34_freq_model(num_classes=2)


        # Loss function
        self.loss_fn = nn.CrossEntropyLoss()

        # Metrics
        self.train_acc = BinaryAccuracy()
        self.val_acc   = BinaryAccuracy()
        self.val_precision = BinaryPrecision()
        self.val_recall = BinaryRecall()
        self.val_f1 = BinaryF1Score()
        
    def forward(self, x):
        x = rgb_to_fft(x)
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)

        probs = torch.softmax(logits, dim=1)[:, 1]
        preds = (probs > 0.5).int()

        acc = self.train_acc(preds, y)

        self.log("train_loss", loss, prog_bar=True, on_epoch=True)
        self.log("train_acc", acc, prog_bar=True, on_epoch=True)

        return loss


    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)

        probs = torch.softmax(logits, dim=1)[:, 1]
        preds = (probs > 0.5).int()

        acc = self.val_acc(preds, y)

        self.val_precision.update(preds, y)
        self.val_recall.update(preds, y)
        self.val_f1.update(preds, y)

        self.log("val_loss", loss, prog_bar=True, on_epoch=True)
        self.log("val_acc", acc, prog_bar=True, on_epoch=True)


    def configure_optimizers(self):
        optimizer, scheduler = define_optimizer_and_scheduler(
            self.model,
            self.hparams.learning_rate,
            self.hparams.weight_decay
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss"
            }
        }
    def on_validation_epoch_end(self):
        self.log("val_precision", self.val_precision.compute(), prog_bar=True)
        self.log("val_recall", self.val_recall.compute(), prog_bar=True)
        self.log("val_f1", self.val_f1.compute(), prog_bar=True)

        self.val_precision.reset()
        self.val_recall.reset()
        self.val_f1.reset()



# Configuring Early Stopping

In [16]:
from pytorch_lightning.callbacks import EarlyStopping

early_stop_cb = EarlyStopping(
    monitor="val_loss",
    patience=4,
    mode="min",
    verbose=True,
    min_delta = 0.003
)

# Trainer Helper Function

In [17]:
def initialize_trainer(
    num_epochs,
    early_stop_callback,
    progress_bar=True,
    dry_run=False
):
    """
    Runs Lightning training for the CIFake project.

    Returns:
        trainer (pl.Trainer)
        model (pl.LightningModule)
    """

    # Save the best model based on validation loss
    checkpoint_cb = ModelCheckpoint(
        monitor="val_loss",
        mode="min",
        save_top_k=1,
        dirpath="/kaggle/working/checkpoints",
        filename="best-model"
    )

    callbacks = [early_stop_callback, checkpoint_cb]

    trainer = pl.Trainer(
        max_epochs=num_epochs,
        accelerator="auto",
        devices=1,
        precision="16-mixed",
        callbacks=callbacks,
        logger=mlf_logger,
        enable_progress_bar=progress_bar,
        enable_model_summary=False,
        fast_dev_run=dry_run,
        num_sanity_val_steps = 0
    )


    return trainer


# Model Initialization 

In [18]:
model = AIImageDetector(
    learning_rate=1e-3,
    weight_decay=1e-2
)

Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth


100%|██████████| 83.3M/83.3M [00:00<00:00, 194MB/s] 


In [19]:
model

AIImageDetector(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, tra

# Trainer Configuration

In [20]:
trainer = initialize_trainer(
    num_epochs=12,
    early_stop_callback=early_stop_cb,
    progress_bar=True,
    dry_run=False
)


Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


# MLflow Logger

In [21]:
mlf_logger = MLFlowLogger(
    experiment_name="AI Generated Image Detector",
    run_name="resnet50_local_debug"
)
trainer.logger.log_hyperparams({
    "model": "resnet50",
    "batch_size": 64,
    "lr": 1e-3,
    "optimizer": "AdamW",
    "img_size": 224,
    "device": "RTX3050"
})


Experiment with name AI generated image detector not found. Creating it.


# Model Training

In [22]:
trainer.fit(model, datamodule=data_module)

[train] Loaded 28000 samples
[val] Loaded 7000 samples


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Output()

Metric val_loss improved. New best score: 0.593
Metric val_loss improved by 0.314 >= min_delta = 0.003. New best score: 0.279
Metric val_loss improved by 0.008 >= min_delta = 0.003. New best score: 0.271
Metric val_loss improved by 0.046 >= min_delta = 0.003. New best score: 0.225
Monitored metric val_loss did not improve in the last 4 records. Best score: 0.225. Signaling Trainer to stop.


# Saving Model

In [None]:
best_ckpt_path = trainer.checkpoint_callback.best_model_path

In [None]:
ckpt = torch.load(best_ckpt_path, map_location="cpu")

state_dict = ckpt["state_dict"]
state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}

model.load_state_dict(state_dict)
model.eval()

In [33]:
torch.save(
    model.state_dict(),
    "model_resnet34_frequency_best.pth"
)