In [12]:
from src.modules.analysis_modules import *
from src.modules.vid_classifier import *
from src.modules.img_classifier import *
from src.modules.img_dataset import *

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm.notebook import tqdm

# **Training Loop**

In [None]:
# --- CONFIGURATION ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCHS = 10
LR = 1e-4
CHUNK_SIZE = 32 

### **Image Classification Model Architecture: Multi-Stream Deepfake Detection Network**

The system employs a **Multi-Stream Late Fusion** approach, processing visual and frequency data in parallel before synthesizing a final verdict.

**1. Preprocessing**
- Input images are resized to **256Ã—256**, normalized, and cloned into three parallel tensors.

**2. Specialized Feature Extraction (The Three Streams)**
- **Stream A: Visual Context (EfficientNet B0) $\rightarrow$ [512 Features]**
    - Extracts high-level spatial semantics (facial geometry, lighting, warping).
        
- **Stream B: Frequency Heatmap (HFRI) $\rightarrow$ [128 Features]**
    - Analyzes noise residuals to detect upsampling artifacts and "checkerboard" patterns typical of GANs.
        
- **Stream C: Texture Consistency (HFRFS) $\rightarrow$ [128 Features]**
    - Examines the radial frequency spectrum to identify unnatural smoothness or lack of skin micro-texture.
        

**3. Fusion & Classification Head**
- Concatenation: The three vectors are merged into a single global descriptor:$$[512] \oplus [128] \oplus [128] = \mathbf{[768]}$$
- Linear Sequential Network:
    A Multi-Layer Perceptron (MLP) maps the fused features to a probability score:
    - **Layer 1:** Linear (768 $\to$ 256) + ReLU + Dropout (0.4)
    - **Layer 2:** Linear (256 $\to$ 64) + ReLU
    - **Layer 3:** Linear (64 $\to$ 1) (Final Logit Output)

In [None]:
def train_image_epoch(model, dataloader, criterion, optimizer, device):
    model.train() #Sets model to training mode [Dropout enabled]
    #Statistics
    running_loss = 0.0 
    correct_preds = 0 
    total_preds = 0
    loop = tqdm(dataloader, leave=False, desc="Training")
    
    for images, labels in loop:
        images, labels = images.to(device), labels.to(device).float().unsqueeze(1)

        logits = model(images) 
        loss = criterion(logits, labels)
        
        # Backward Pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Loss Metric
        running_loss += loss.item()
        
        # Convert logits to 0 or 1 predictions (Sigmoid > 0.5)
        probs = torch.sigmoid(logits)
        preds = (probs > 0.5).float()
        correct_preds += (preds == labels).sum().item()
        total_preds += labels.size(0)

        loop.set_postfix(loss=loss.item())
        
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = correct_preds / total_preds
    
    return epoch_loss, epoch_acc

def validate_image_epoch(model, dataloader, criterion, device):
    model.eval()
    
    running_loss = 0.0
    correct_preds = 0
    total_preds = 0
    
    with torch.no_grad(): # No gradients needed for validation (saves RAM)
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device).float().unsqueeze(1)
            
            logits = model(images)
            loss = criterion(logits, labels)
            
            running_loss += loss.item()
            
            probs = torch.sigmoid(logits)
            preds = (probs > 0.5).float()
            correct_preds += (preds == labels).sum().item()
            total_preds += labels.size(0)
            
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = correct_preds / total_preds
    
    return epoch_loss, epoch_acc


In [None]:
from dotenv import load_dotenv
load_dotenv()
IMG_DATASET_PATH = ""
IMG_WEIGHTS_PATH = ""

classifier = IMG_Classifier().to(DEVICE)
train_ds = DF_Dataset(IMG_DATASET_PATH, training = True)
val_ds = DF_Dataset(IMG_DATASET_PATH, training = False)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(classifier.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2)
best_val_loss = 0

for epoch in range(EPOCHS):
    print(f"\n--- Epoch {epoch+1}/{EPOCHS} ---")
    
    #Train
    train_loss, train_acc = train_image_epoch(classifier, train_loader, criterion, optimizer, DEVICE)
    
    #Validation
    val_loss, val_acc = validate_image_epoch(classifier, val_loader, criterion, DEVICE)
    
    #Logging
    print(f"Train Loss: {train_loss:.4f} | Acc: {train_acc*100:.2f}%")
    print(f"Val Loss:   {val_loss:.4f} | Acc: {val_acc*100:.2f}%")
    
    #Scheduler
    scheduler.step(val_loss)
    
    # Save Weights if Improved
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(classifier.state_dict(), IMG_WEIGHTS_PATH)
        print(f"Model Weights Saved at {IMG_WEIGHTS_PATH}")
