# ECDD Model Training Notebook (Attention Pooling)

This notebook trains the Teacher and Student models using **Attention Pooling**, strictly following the project's architecture document.

### 1. Environment Setup
This cell installs dependencies and creates all necessary scripts and configuration files.

In [None]:
!pip install torch torchvision pyyaml scikit-learn matplotlib tqdm opencv-python -q

import os
from pathlib import Path

def write_file(path, content):
    p = Path(path)
    p.parent.mkdir(parents=True, exist_ok=True)
    with open(p, 'w') as f:
        f.write(content)

print("Creating directory structure...")
os.makedirs('deepfake-patch-audit/config', exist_ok=True)
os.makedirs('deepfake-patch-audit/models/student', exist_ok=True)
os.makedirs('deepfake-patch-audit/datasets', exist_ok=True)
os.makedirs('deepfake-patch-audit/losses', exist_ok=True)
os.makedirs('deepfake-patch-audit/scripts', exist_ok=True)
os.makedirs('ECDD_Experimentation/Training/models', exist_ok=True)
os.makedirs('ECDD_Experimentation/Training/training', exist_ok=True)
print("Directory structure created.")

print("Writing Python scripts and configs...")

# --- Config Files --- 
base_yaml_content = """
model:
  teacher:
    architecture: "LaDeDaResNet50"
  student:
    architecture: "TinyAttentionLaDeDa"
dataset:
  resize_size: 256
  num_workers: 2
  pin_memory: true
training:
  weight_decay: 0.0001
  distillation:
    alpha_distill: 0.2
    alpha_task: 0.8
"""
write_file('deepfake-patch-audit/config/base.yaml', base_yaml_content)

# --- Model Files --- 
ladeda_resnet_content = """
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50, ResNet50_Weights
from typing import Tuple, Optional

class AttentionPooling(nn.Module):
    def __init__(self, in_channels: int, hidden_dim: int = 128):
        super().__init__()
        self.attention_fc = nn.Sequential(
            nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden_dim, 1, kernel_size=1)
        )
    def forward(self, features: torch.Tensor, patch_logits: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        attention_scores = self.attention_fc(features)
        B, _, H, W = attention_scores.shape
        attention_flat = attention_scores.view(B, -1)
        attention_weights_flat = F.softmax(attention_flat, dim=1)
        attention_weights = attention_weights_flat.view(B, 1, H, W)
        patch_logits_flat = patch_logits.view(B, -1)
        pooled_logit = (patch_logits_flat * attention_weights_flat).sum(dim=1, keepdim=True)
        return pooled_logit, attention_weights

class LaDeDaResNet50(nn.Module):
    def __init__(self, pretrained: bool = True, freeze_layers: Optional[list] = None, num_classes: int = 1):
        super().__init__()
        base_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2 if pretrained else None)
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        if pretrained:
            with torch.no_grad():
                self.conv1.weight.data = base_model.conv1.weight.data[:, :, 2:5, 2:5]
        self.bn1 = base_model.bn1
        self.relu = base_model.relu
        self.layer1 = base_model.layer1
        self.layer2 = base_model.layer2
        self.layer3 = base_model.layer3
        self.layer4 = base_model.layer4
        self.patch_classifier = nn.Conv2d(2048, num_classes, kernel_size=1)
        self.attention_pool = AttentionPooling(in_channels=2048)
        if freeze_layers:
            for name, param in self.named_parameters():
                for layer_name in freeze_layers:
                    if name.startswith(layer_name):
                        param.requires_grad = False

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        features = self.layer4(x)
        patch_logits = self.patch_classifier(features)
        pooled_logit, attention_map = self.attention_pool(features, patch_logits)
        return pooled_logit, patch_logits, attention_map

def create_ladeda_model(pretrained: bool = True, freeze_layers: Optional[list] = None) -> LaDeDaResNet50:
    return LaDeDaResNet50(pretrained=pretrained, freeze_layers=freeze_layers)
"""
write_file('ECDD_Experimentation/Training/models/ladeda_resnet.py', ladeda_resnet_content)

tiny_attention_ladeda_content = """
import torch
import torch.nn as nn
from ECDD_Experimentation.Training.models.ladeda_resnet import AttentionPooling

class TinyAttentionLaDeDa(nn.Module):
    def __init__(self, num_classes=1):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(16),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(32),
        )
        self.patch_classifier = nn.Conv2d(32, num_classes, kernel_size=1)
        # Attention pooling aligned with teacher, adapted for student's feature dimension
        self.attention_pool = AttentionPooling(in_channels=32, hidden_dim=64)

    def forward(self, x):
        features = self.features(x) 
        patch_logits = self.patch_classifier(features)
        pooled_logit, attention_map = self.attention_pool(features, patch_logits)
        return pooled_logit, patch_logits, attention_map
"""
write_file('deepfake-patch-audit/models/student/tiny_attention_ladeda.py', tiny_attention_ladeda_content)

# --- Dataset and Loss Files ---
base_dataset_content = """
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
from pathlib import Path

class BaseDataset(Dataset):
    def __init__(self, root_dir: str, resize_size: int = 256, normalize: bool = True):
        self.root_dir = Path(root_dir)
        self.resize_size = resize_size
        self.normalize = normalize
        self.normalize_mean = np.array([0.485, 0.456, 0.406])
        self.normalize_std = np.array([0.229, 0.224, 0.225])
        self.samples = []
        self._load_from_directory()

    def _load_from_directory(self):
        for label, class_name in enumerate(['real', 'fake']):
            class_dir = self.root_dir / class_name
            if class_dir.exists():
                for img_path in sorted(class_dir.glob(f"*.jpg")):
                    self.samples.append((str(img_path), label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        try:
            image = Image.open(img_path).convert("RGB")
            image = image.resize((self.resize_size, self.resize_size), Image.BICUBIC)
            image = np.array(image, dtype=np.float32) / 255.0
            if self.normalize:
                image = (image - self.normalize_mean) / self.normalize_std
            image = torch.from_numpy(image).permute(2, 0, 1)
            return {"image": image, "label": torch.tensor(label, dtype=torch.long)}
        except Exception:
            # Return a blank image on error
            return {"image": torch.zeros((3, self.resize_size, self.resize_size), dtype=torch.float32), "label": torch.tensor(label, dtype=torch.long)}
"""
write_file('deepfake-patch-audit/datasets/base_dataset.py', base_dataset_content)

distillation_loss_content = """
import torch
import torch.nn as nn
import torch.nn.functional as F

class DistillationLoss(nn.Module):
    def __init__(self, alpha_distill=0.2, alpha_task=0.8, temperature=4.0):
        super().__init__()
        self.alpha_distill = alpha_distill
        self.alpha_task = alpha_task
        self.temperature = temperature
        self.bce_loss = nn.BCEWithLogitsLoss()
        self.kl_loss = nn.KLDivLoss(reduction='batchmean')

    def forward(self, student_pooled_logit, student_patches, teacher_patches, labels):
        # Align patch grid sizes using adaptive pooling
        student_patches_aligned = F.adaptive_avg_pool2d(student_patches, teacher_patches.shape[2:])
        
        # Distillation loss on patch logits
        distill_loss = self.kl_loss(
            F.log_softmax(student_patches_aligned / self.temperature, dim=1),
            F.log_softmax(teacher_patches / self.temperature, dim=1)
        ) * (self.temperature ** 2)
        
        # Task loss on the student's pooled output
        task_loss = self.bce_loss(student_pooled_logit.squeeze(), labels.float())
        
        total_loss = self.alpha_distill * distill_loss + self.alpha_task * task_loss
        return total_loss, distill_loss, task_loss
"""
write_file('deepfake-patch-audit/losses/distillation_loss.py', distillation_loss_content)

# --- Training Scripts ---
finetune_script_content = """
import os, sys, json, random, argparse
from pathlib import Path
import numpy as np
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from PIL import Image, ImageOps
from tqdm import tqdm
from ECDD_Experimentation.Training.models.ladeda_resnet import create_ladeda_model

CONFIGS = {
    "finetune1": {"name": "finetune1_celeb_df", "dataset_path": Path("ECDD_Experimentation/ECDD_Training_Data/processed/splits/finetune1"), "epochs": 15, "batch_size": 16, "lr": 1e-4, "freeze_layers": ["conv1", "layer1"]},
    "finetune2": {"name": "finetune2_face_filtered", "dataset_path": Path("ECDD_Experimentation/ECDD_Training_Data/processed/splits/finetune2"), "epochs": 20, "batch_size": 8, "lr": 5e-5, "freeze_layers": ["conv1", "layer1"]}
}
TARGET_SIZE = (256, 256)
IMAGENET_MEAN, IMAGENET_STD = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]

class FinetuneDataset(Dataset):
    def __init__(self, data_dir: Path, split: str = "train"):
        self.data_dir = Path(data_dir) / split
        self.images, self.labels = [], []
        for label, class_name in enumerate(['real', 'fake']):
            class_dir = self.data_dir / class_name
            if class_dir.exists():
                for f in class_dir.glob("*.jpg"): self.images.append(f); self.labels.append(label)
    def __len__(self): return len(self.images)
    def __getitem__(self, idx):
        img = Image.open(self.images[idx]).convert('RGB')
        img = img.resize(TARGET_SIZE, Image.Resampling.LANCZOS)
        img_array = (np.array(img).astype(np.float32) / 255.0 - IMAGENET_MEAN) / IMAGENET_STD
        return torch.from_numpy(img_array).permute(2, 0, 1).float(), torch.tensor(self.labels[idx], dtype=torch.float32)

def run_finetuning(config_name: str):
    config = CONFIGS[config_name]
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    output_dir = Path("outputs") / config_name
    output_dir.mkdir(parents=True, exist_ok=True)
    train_loader = DataLoader(FinetuneDataset(config['dataset_path'], "train"), batch_size=config['batch_size'], shuffle=True, num_workers=2)
    val_loader = DataLoader(FinetuneDataset(config['dataset_path'], "val"), batch_size=config['batch_size'], shuffle=False, num_workers=2)
    model = create_ladeda_model(True, config['freeze_layers']).to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=config['lr'])
    for epoch in range(config['epochs']):
        model.train()
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            images, labels = images.to(device), labels.to(device)
            pooled_logit, _, _ = model(images)
            loss = criterion(pooled_logit.squeeze(), labels)
            optimizer.zero_grad(); loss.backward(); optimizer.step()
    torch.save(model.state_dict(), output_dir / 'finetuned_best.pth')

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, choices=['finetune1', 'finetune2'], required=True)
    args = parser.parse_args()
    run_finetuning(args.config)
"""
write_file('ECDD_Experimentation/Training/training/finetune_script.py', finetune_script_content)

train_teacher_content = """
import sys, yaml, torch, argparse
from pathlib import Path
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
from ECDD_Experimentation.Training.models.ladeda_resnet import create_ladeda_model
from deepfake-patch-audit.datasets.base_dataset import BaseDataset

def main():
    parser = argparse.ArgumentParser(description='Train Teacher with Attention Pooling')
    parser.add_argument('--epochs', type=int, default=15, help='Training epochs')
    parser.add_argument('--batch-size', type=int, default=32, help='Batch size')
    parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate')
    parser.add_argument('--device', type=str, default='cuda', help='Device')
    parser.add_argument('--output-dir', type=str, default='outputs/teacher_attention', help='Output directory')
    args = parser.parse_args()

    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    model = create_ladeda_model(pretrained=True, freeze_layers=['conv1', 'layer1']).to(device)
    
    train_dataset = BaseDataset(root_dir="dataset/train", resize_size=256)
    val_dataset = BaseDataset(root_dir="dataset/val", resize_size=256)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2)
    
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)
    
    best_val_auc = 0.0
    for epoch in range(args.epochs):
        model.train()
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            images = batch["image"].to(device)
            labels = batch["label"].to(device).float()
            
            pooled_logit, _, _ = model(images)
            loss = criterion(pooled_logit.squeeze(), labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        model.eval()
        all_preds, all_labels = [], []
        with torch.no_grad():
            for batch in val_loader:
                images = batch["image"].to(device)
                labels = batch["label"].to(device).float()
                pooled_logit, _, _ = model(images)
                all_preds.append(torch.sigmoid(pooled_logit.squeeze()).cpu())
                all_labels.append(labels.cpu())
        
        val_auc = roc_auc_score(torch.cat(all_labels), torch.cat(all_preds))
        print(f"Epoch {epoch+1}, Val AUC: {val_auc:.4f}")
        
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            torch.save(model.state_dict(), output_dir / "teacher_best.pth")
            print(f"Saved new best model with AUC: {best_val_auc:.4f}")

if __name__ == "__main__":
    main()
"""
write_file('deepfake-patch-audit/scripts/train_teacher.py', train_teacher_content)

train_student_content = """
import sys, yaml, torch, argparse, random, numpy
from pathlib import Path
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from ECDD_Experimentation.Training.models.ladeda_resnet import create_ladeda_model
from deepfake-patch-audit.models.student.tiny_attention_ladeda import TinyAttentionLaDeDa
from deepfake-patch-audit.datasets.base_dataset import BaseDataset
from deepfake-patch-audit.losses.distillation_loss import DistillationLoss

def main():
    parser = argparse.ArgumentParser(description='Train Student with Attention Distillation')
    parser.add_argument('--epochs', type=int, default=30, help='Training epochs')
    parser.add_argument('--batch-size', type=int, default=32, help='Batch size')
    parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
    parser.add_argument('--device', type=str, default='cuda', help='Device')
    parser.add_argument('--output-dir', type=str, default='outputs/student_attention', help='Output directory')
    parser.add_argument('--teacher-weights', type=str, default='outputs/teacher_attention/teacher_best.pth', help='Path to teacher weights')
    args = parser.parse_args()

    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Load Teacher (Attention-based LaDeDaResNet50)
    teacher_model = create_ladeda_model(pretrained=False).to(device)
    teacher_model.load_state_dict(torch.load(args.teacher_weights))
    teacher_model.eval()
    for param in teacher_model.parameters():
        param.requires_grad = False
    
    # Load Student (TinyAttentionLaDeDa)
    student_model = TinyAttentionLaDeDa().to(device)
    
    train_dataset = BaseDataset(root_dir="dataset/train", resize_size=256)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2)
    
    criterion = DistillationLoss()
    optimizer = optim.Adam(student_model.parameters(), lr=args.lr)
    
    for epoch in range(args.epochs):
        student_model.train()
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            images = batch["image"].to(device)
            labels = batch["label"].to(device)
            
            # Get outputs from both models
            student_pooled_logit, student_patches, _ = student_model(images)
            with torch.no_grad():
                _, teacher_patches, _ = teacher_model(images)
            
            loss, _, _ = criterion(student_pooled_logit, student_patches, teacher_patches, labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
    torch.save(student_model.state_dict(), output_dir / "student_best.pth")
    print(f"Saved final student model to {output_dir / 'student_best.pth'}")

if __name__ == "__main__":
    main()
"""
write_file('deepfake-patch-audit/scripts/train_student_two_stage.py', train_student_content)

print("All files written successfully.")


### 2. Data and Pretrained Weights

For this notebook to work, you must upload your data with the following directory structure:

```
./dataset/
└── train/
    ├── real/
    │   ├── 0001.jpg
    │   └── ...
    └── fake/
        ├── 0001.jpg
        └── ...
└── val/
    ├── real/
    │   └── ...
    └── fake/
        └── ...

./ECDD_Experimentation/ECDD_Training_Data/processed/splits/
└── finetune1/
    ├── train/
    │   ├── real/
    │   └── fake/
    └── val/
        └── ...
└── finetune2/
    └── ...
```

You can zip these folders, upload the zip file to the Colab environment, and then run the following cell to unzip it.

In [None]:
# Example: !unzip data.zip

### 3. Training Commands

#### 3.1 Train the Teacher Model

This fine-tunes a pretrained `LaDeDaResNet50` on your custom dataset. The best model based on validation AUC will be saved.

In [None]:
!python deepfake-patch-audit/scripts/train_teacher.py --epochs 10 --batch-size 16

#### 3.2 Train the Student Model

This trains the new `TinyAttentionLaDeDa` student model to mimic the teacher you just trained, using the patch-distillation loss.

In [None]:
!python deepfake-patch-audit/scripts/train_student_two_stage.py --epochs 20 --batch-size 16

#### 3.3 Run Finetuning (ECDD Experimentation)

This runs the finetuning script from the `ECDD_Experimentation` folder, which also uses the attention-based `LaDeDaResNet50` model.

In [None]:
!python ECDD_Experimentation/Training/training/finetune_script.py --config finetune1