In [46]:
from src.modules.analysis_modules import *
from src.modules.vid_classifier import *
from src.modules.img_classifier import *
from src.modules.img_dataset import *

In [47]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm.notebook import tqdm
from torch.nn.utils.rnn import pad_sequence

# **Training Loop**

In [38]:
# --- CONFIGURATION ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCHS = 20
BATCH_SIZE = 16
LR = 1e-3
CHUNK_SIZE = 32 

### **Image Classification Model Architecture: Multi-Stream Deepfake Detection Network**

The system employs a **Multi-Stream Late Fusion** approach, processing visual and frequency data in parallel before synthesizing a final verdict.

**1. Preprocessing**
- Input images are resized to **256×256**, normalized, and cloned into three parallel tensors.

**2. Specialized Feature Extraction (The Three Streams)**
- **Stream A: Visual Context (EfficientNet B0) $\rightarrow$ [512 Features]**
    - Extracts high-level spatial semantics (facial geometry, lighting, warping).
        
- **Stream B: Frequency Heatmap (HFRI) $\rightarrow$ [128 Features]**
    - Analyzes noise residuals to detect upsampling artifacts and "checkerboard" patterns typical of GANs.
        
- **Stream C: Texture Consistency (HFRFS) $\rightarrow$ [128 Features]**
    - Examines the radial frequency spectrum to identify unnatural smoothness or lack of skin micro-texture.
        

**3. Fusion & Classification Head**
- Concatenation: The three vectors are merged into a single global descriptor:$$[512] \oplus [128] \oplus [128] = \mathbf{[768]}$$
- Linear Sequential Network:
    A Multi-Layer Perceptron (MLP) maps the fused features to a probability score:
    - **Layer 1:** Linear (768 $\to$ 256) + ReLU + Dropout (0.4)
    - **Layer 2:** Linear (256 $\to$ 64) + ReLU
    - **Layer 3:** Linear (64 $\to$ 1) (Final Logit Output)

In [39]:
def train_image_epoch(model, dataloader, criterion, optimizer, device):
    model.train() #Sets model to training mode [Dropout enabled]
    #Statistics
    running_loss = 0.0 
    correct_preds = 0 
    total_preds = 0
    loop = tqdm(dataloader, leave=False, desc="Training")
    
    for images, labels in loop:
        images, labels = images.to(device), labels.to(device).float().unsqueeze(1)

        logits = model(images) 
        loss = criterion(logits, labels)
        
        # Backward Pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Loss Metric
        running_loss += loss.item()
        
        # Convert logits to 0 or 1 predictions (Sigmoid > 0.5)
        probs = torch.sigmoid(logits)
        preds = (probs > 0.5).float()
        correct_preds += (preds == labels).sum().item()
        total_preds += labels.size(0)

        loop.set_postfix(loss=loss.item())
        
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = correct_preds / total_preds
    
    return epoch_loss, epoch_acc

def validate_image_epoch(model, dataloader, criterion, device):
    model.eval()
    
    running_loss = 0.0
    correct_preds = 0
    total_preds = 0
    
    with torch.no_grad(): # No gradients needed for validation (saves RAM)
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device).float().unsqueeze(1)
            
            logits = model(images)
            loss = criterion(logits, labels)
            
            running_loss += loss.item()
            
            probs = torch.sigmoid(logits)
            preds = (probs > 0.5).float()
            correct_preds += (preds == labels).sum().item()
            total_preds += labels.size(0)
            
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = correct_preds / total_preds
    
    return epoch_loss, epoch_acc


In [51]:
from dotenv import load_dotenv
load_dotenv()
IMG_DATASET_PATH = os.environ.get("IMG_DATASET_PATH")
IMG_WEIGHTS_PATH = os.environ.get("IMG_WEIGHTS_PATH")


classifier = IMG_Classifier().to(DEVICE)
train_ds = DF_Dataset(IMG_DATASET_PATH, training = True)
val_ds = DF_Dataset(IMG_DATASET_PATH, training = False)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

if IMG_WEIGHTS_PATH and os.path.exists(IMG_WEIGHTS_PATH):
    print(f"Pre-trained weights found at {IMG_WEIGHTS_PATH}. Loading...")
    
    classifier.load_state_dict(torch.load(IMG_WEIGHTS_PATH, map_location=DEVICE))
    print("Resuming from saved state.")
else:
    print("No pre-trained weights found. Starting from scratch.")

criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(classifier.parameters(), lr=LR, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
best_val_loss = 10.0



Initializing EfficientNet-B0...
Pre-trained weights found at /home/arun/Desktop/deep/deepfake-detection-model/src/models/img-weights/IMG-WEIGHT.pth. Loading...
Resuming from saved state.


In [None]:
for epoch in range(EPOCHS):
    print(f"\n--- Epoch {epoch+1}/{EPOCHS} ---")
    
    #Train
    train_loss, train_acc = train_image_epoch(classifier, train_loader, criterion, optimizer, DEVICE)
    
    #Validation
    val_loss, val_acc = validate_image_epoch(classifier, val_loader, criterion, DEVICE)
    
    #Logging
    print(f"Train Loss: {train_loss:.4f} | Acc: {train_acc*100:.2f}%")
    print(f"Val Loss:   {val_loss:.4f} | Acc: {val_acc*100:.2f}%")
    
    #Scheduler
    scheduler.step(val_loss)
    
    # Save Weights if Improved
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(classifier.state_dict(), IMG_WEIGHTS_PATH)
        print(f"Model Weights Saved at {IMG_WEIGHTS_PATH}")



--- Epoch 1/20 ---


Training:   0%|          | 0/438 [00:00<?, ?it/s]

Train Loss: 0.5070 | Acc: 75.87%
Val Loss:   0.3998 | Acc: 82.33%
Model Weights Saved at C:/Users/Krishna/My Projects/deepfake-detection-model/src/models/img-weights/IMG-WEIGHT.pth

--- Epoch 2/20 ---


Training:   0%|          | 0/438 [00:00<?, ?it/s]

Train Loss: 0.4398 | Acc: 79.51%
Val Loss:   0.4063 | Acc: 81.40%

--- Epoch 3/20 ---


Training:   0%|          | 0/438 [00:00<?, ?it/s]

Train Loss: 0.4116 | Acc: 81.79%
Val Loss:   0.3688 | Acc: 83.37%
Model Weights Saved at C:/Users/Krishna/My Projects/deepfake-detection-model/src/models/img-weights/IMG-WEIGHT.pth

--- Epoch 4/20 ---


Training:   0%|          | 0/438 [00:00<?, ?it/s]

In [53]:
def train_stateful_epoch(model, dataloader, optimizer, device, chunk_size=32):
    model.train()
    total_loss = 0 
    total_correct = 0
    total_frames = 0
    
    # --- RECALL VARS ---
    total_fakes = 0
    correct_fakes = 0

    loop = tqdm(dataloader, leave=False, desc="Training Video")

    for video_batch, label_batch, mask_batch in loop:
        video_batch = video_batch.to(device)
        label_batch = label_batch.to(device).float()
        mask_batch  = mask_batch.to(device).float()

        hidden_state = None
        b, t_total, c, h, w = video_batch.shape

        for t in range(0, t_total, chunk_size):
            end_t = min(t + chunk_size, t_total)
            
            x_chunk = video_batch[:, t:end_t] 
            y_chunk = label_batch[:, t:end_t]
            m_chunk = mask_batch[:, t:end_t]
            
            if x_chunk.shape[1] == 0: break

            logits, hidden_state = model(x_chunk, hidden_state) 
            
            loss = masked_bce_loss(logits.squeeze(2), y_chunk, m_chunk)
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0) 
            optimizer.step()
            
            hidden_state = (hidden_state[0].detach(), hidden_state[1].detach())
            
            total_loss += loss.item()
            
            # --- METRICS CALCULATION ---
            probs = torch.sigmoid(logits.squeeze(2))
            preds = (probs > 0.5).float()
            
            # 1. Standard Accuracy
            correct = ((preds == y_chunk) * m_chunk).sum().item()
            total_correct += correct
            total_frames += m_chunk.sum().item()
            
            # 2. RECALL CALCULATION (Focus on Fakes)
            # Find frames that are BOTH labeled 1 (Fake) AND Valid (mask=1)
            actual_fakes_mask = (y_chunk == 1) * m_chunk
            

            true_positives = ((preds == 1) * actual_fakes_mask).sum().item()
            
            correct_fakes += true_positives
            total_fakes += actual_fakes_mask.sum().item()
    
    avg_loss = total_loss / len(dataloader)
    avg_acc = total_correct / (total_frames + 1e-8)

    avg_recall = correct_fakes / (total_fakes + 1e-8) 
    
    return avg_loss, avg_acc, avg_recall

def validate_stateful_epoch(model, dataloader, device, chunk_size=32):
    model.eval()
    total_loss = 0
    total_correct = 0
    total_frames = 0
    
    # --- RECALL COUNTERS ---
    total_fakes = 0
    correct_fakes = 0
    
    with torch.no_grad(): # Disable Gradient Calculation
        for video_batch, label_batch, mask_batch in dataloader:
            video_batch = video_batch.to(device)
            label_batch = label_batch.to(device).float()
            mask_batch  = mask_batch.to(device).float()
            
            hidden_state = None 
            b, t_total, c, h, w = video_batch.shape
            
            for t in range(0, t_total, chunk_size):
                end_t = min(t + chunk_size, t_total)
                
                x_chunk = video_batch[:, t:end_t]
                y_chunk = label_batch[:, t:end_t]
                m_chunk = mask_batch[:, t:end_t]
                
                if x_chunk.shape[1] == 0: break

                logits, hidden_state = model(x_chunk, hidden_state)
                
                # Squeeze logits [B, T, 1] -> [B, T]
                loss = masked_bce_loss(logits.squeeze(2), y_chunk, m_chunk)
                total_loss += loss.item()
                
                # --- METRICS ---
                probs = torch.sigmoid(logits.squeeze(2))
                preds = (probs > 0.5).float()
                
                # 1. Accuracy (Overall correctness)
                correct = ((preds == y_chunk) * m_chunk).sum().item()
                total_correct += correct
                total_frames += m_chunk.sum().item()
                
                # 2. Recall (Catching the Fakes)
                # Identify frames that are actually FAKE (1) and VALID (mask=1)
                actual_fakes_mask = (y_chunk == 1) * m_chunk
                
                # Count how many of those we correctly predicted as 1
                true_positives = ((preds == 1) * actual_fakes_mask).sum().item()
                
                correct_fakes += true_positives
                total_fakes += actual_fakes_mask.sum().item()
                
    avg_loss = total_loss / len(dataloader)
    avg_acc = total_correct / (total_frames + 1e-8)
    avg_recall = correct_fakes / (total_fakes + 1e-8) # Avoid DivByZero
    
    return avg_loss, avg_acc, avg_recall

def masked_bce_loss(logits, labels, mask):
    criterion = nn.BCEWithLogitsLoss(reduction='none')
    safe_labels = labels.clone()
    safe_labels[mask == -1] = 0
    # Calculate error everywhere (even on padding)
    raw_loss = criterion(logits, labels)

    masked_loss = raw_loss * mask# Zero out the error on padding frames
    
    loss = masked_loss.sum() / (mask.sum() + 1e-8)
    
    return loss

def variable_length_collate(batch):
    videos, labels = zip(*batch)
    videos = [v.permute(1, 0, 2, 3) for v in videos] 
    videos_padded = pad_sequence(videos, batch_first=True, padding_value=0)
    labels_padded = pad_sequence(labels, batch_first=True, padding_value=-1)
    padding_mask = (labels_padded != -1)
    
    return videos_padded, labels_padded, padding_mask


In [55]:

# Caution Cojoined Twins in the Making
def load_pretrained_modules(video_model, image_weights_path):
    full_state_dict = torch.load(image_weights_path, map_location='cpu')

    visual_bucket = {}
    hfri_bucket = {}

    for key, value in full_state_dict.items():
        if key.startswith("efficientnet."):
            clean_key = key.replace("efficientnet.", "")
            visual_bucket[clean_key] = value

        elif key.startswith("hfri."):
            clean_key = key.replace("hfri.", "")
            hfri_bucket[clean_key] = value
        
    try:
        # Load into 'eff_net_module' (The name you used in VideoClass)
        video_model.eff_net_module.load_state_dict(visual_bucket, strict=True)
        print(f"✅ Successfully transferred {len(visual_bucket)} weights to Visual Stream.")
        
        # Load into 'hfri_module'
        video_model.hfri_module.load_state_dict(hfri_bucket, strict=True)
        print(f"✅ Successfully transferred {len(hfri_bucket)} weights to Frequency Stream.")

    except RuntimeError as e:
        print(f"Error during Weight Join Making")
    
    return video_model

In [61]:
# --- CONFIG ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 4 # Smaller batch size because videos are huge!
CHUNK_SIZE = 32
EPOCHS = 20
LR = 1e-4

VIDEO_DATASET_PATH = os.environ.get("VID_DATASET_PATH")
VIDEO_WEIGHTS_PATH = os.environ.get("VID_WEIGHTS_PATH")

In [63]:
# 1. Load Data (Using the DeepfakeVideoDataset class)
import src.modules.vid_dataset as VID
train_ds = VID.DF_Dataset(VIDEO_DATASET_PATH,epoch_size=CHUNK_SIZE, training=True)
val_ds   = VID.DF_Dataset(VIDEO_DATASET_PATH,epoch_size=CHUNK_SIZE, training=False)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, collate_fn=variable_length_collate)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, collate_fn=variable_length_collate)

# 2. Init Model
video_model = VID_Classifier(LSTM_hidden_size=256, num_layers=2).to(DEVICE)

# Cojoin the HFRI and EFF_NET MODULES with Image Classifier
video_model = load_pretrained_modules(video_model, IMG_WEIGHTS_PATH)

optimizer = torch.optim.AdamW(video_model.parameters(), lr=LR, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3)

best_val_loss = float('inf')

print("Starting Video Training...")

Initializing EfficientNet-B0...
✅ Successfully transferred 360 weights to Visual Stream.
✅ Successfully transferred 30 weights to Frequency Stream.
Starting Video Training...


In [64]:
for epoch in range(EPOCHS):
    print(f"\n--- Epoch {epoch+1}/{EPOCHS} ---")
    
    # Train
    t_loss, t_acc, t_recall = train_stateful_epoch(video_model, train_loader, optimizer, DEVICE, CHUNK_SIZE)
    print(f"Train Loss: {t_loss:.4f} | Acc: {t_acc*100:.2f}%, Recall: {v_recall*100:.2f}%")
    
    # Validate
    v_loss, v_acc, v_recall = validate_stateful_epoch(video_model, val_loader, DEVICE, CHUNK_SIZE)
    print(f"Val Loss: {v_loss:.4f} | Acc: {v_acc*100:.2f}% | Recall: {v_recall*100:.2f}%")

    
    
    scheduler.step(v_loss)
    
    if v_loss < best_val_loss:
        best_val_loss = v_loss
        torch.save(video_model.state_dict(), "best_video_model.pth")
        print("✅ Video Model Saved!")



--- Epoch 1/20 ---


Training Video:   0%|          | 0/8 [00:00<?, ?it/s]

Creating VideoCreating VideoCreating Video


Creating Video


KeyboardInterrupt: 