# Load Dataset

In [None]:
import torch
from PIL import Image
import os
from torchvision import transforms
from torch.utils.data import DataLoader, random_split

# Custom PyTorch Dataset class for Steganography Image Classification
class StegoDataset(torch.utils.data.Dataset):
    def __init__(self, base_dir, transform=None):
        """
        Initializes the dataset by loading paths to all images and assigning labels.

        Args:
            base_dir (string): Root directory containing:
                - 'cover/' folder for clean images (label 0)
                - 'stego/LSB/', 'stego/WOW/', 'stego/HILL/' for stego images (labels 1, 2, 3)
            transform (callable, optional): Transformations to apply to each image (e.g., resizing, normalization).
        """
        self.transform = transform
        self.data = []  # Will store tuples of (image_path, label)

        # Define class names, directory paths, and associated labels
        self.class_dict = {
            'cover': (os.path.join(base_dir, 'cover'), 0),             # Label 0 for clean images
            'lsb': (os.path.join(base_dir, 'stego', 'LSB'), 1),        # Label 1 for LSB stego images
            'wow': (os.path.join(base_dir, 'stego', 'WOW'), 2),        # Label 2 for WOW stego images
            'hill': (os.path.join(base_dir, 'stego', 'HILL'), 3)       # Label 3 for HILL stego images
        }

        # Iterate over each class directory to collect image paths and corresponding labels
        for class_name, (class_dir, label) in self.class_dict.items():
            if not os.path.exists(class_dir):
                # Print a warning if the expected subdirectory is missing
                print(f"[WARNING] Missing directory: {class_dir}")
                continue
            for img_file in sorted(os.listdir(class_dir)):
                # Include only supported image formats
                if img_file.endswith(('.pgm', '.png', '.jpg', '.jpeg', '.bmp')):
                    img_path = os.path.join(class_dir, img_file)
                    self.data.append((img_path, label))  # Append (image_path, label) to dataset

    def __len__(self):
        """
        Returns the total number of image samples in the dataset.
        """
        return len(self.data)

    def __getitem__(self, idx):
        """
        Loads and returns a sample (image, label) at the specified index.

        Args:
            idx (int): Index of the sample to retrieve.

        Returns:
            tuple: (image_tensor, label) where image_tensor is a transformed image.
        """
        try:
            img_path, label = self.data[idx]  # Get image path and class label
            img = Image.open(img_path)        # Open image using PIL

            # Convert to grayscale if image is not already in 'L' mode
            if img.mode != 'L':
                img = img.convert('L')

            # Apply transformations (e.g., resizing, tensor conversion, normalization)
            if self.transform:
                img = self.transform(img)

            return img, label  # Return processed image and label
        except Exception as e:
            # Print detailed error if image fails to load
            print(f"Error loading idx={idx}: {e}")
            raise e  # Re-raise exception to halt DataLoader or debugging


In [None]:
from google.colab import drive
# Mount Google Drive to access datasets stored in your Drive
drive.mount('/content/drive')

from torchvision import transforms
from torch.utils.data import DataLoader, random_split

# -------------------------------
# Define Image Transform Pipelines
# -------------------------------

# Transformations for training set (with data augmentation)
train_transforms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Transformations for validation and test sets (no augmentation)
val_test_transforms = transforms.Compose([
    transforms.Resize((256, 256)),              # Resize to 256x256
    transforms.ToTensor(),                      # Convert to tensor
    transforms.Normalize(mean=[0.5], std=[0.5]) # Normalize to [-1, 1]
])

# -------------------------------
# Load Dataset
# -------------------------------

# Root directory where dataset folders ('cover', 'stego/LSB', etc.) are stored
base_path = '/content/drive/MyDrive/Stego-Images-Dataset'

# Initialize the custom dataset without transform
dataset = StegoDataset(base_path)

# Define split proportions
train_size = int(0.7 * len(dataset))  # 70% for training
val_size = int(0.15 * len(dataset))   # 15% for validation
test_size = len(dataset) - train_size - val_size  # Remaining 15% for testing

# Perform random split with fixed seed for reproducibility
train_dataset, val_dataset, test_dataset = random_split(
    dataset,
    [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42)
)


# Set the appropriate transform for each subset
train_dataset.dataset.transform = train_transforms
val_dataset.dataset.transform = val_test_transforms
test_dataset.dataset.transform = val_test_transforms

# -------------------------------
# Prepare Data Loaders
# -------------------------------

batch_size = 32  # Number of samples per batch

# DataLoader for training
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,               # Shuffle data for better generalization
    num_workers=2,              # Number of parallel data loading workers
    pin_memory=True,            # Speed up data transfer to GPU
    persistent_workers=True     # Keep workers alive between epochs
)

# DataLoader for validation (no shuffling)
val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

# DataLoader for test (no shuffling)
test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)


# Use CUDA (GPU) if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# ResNet34 Model

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load pretrained ResNet34 (ImageNet weights)
model = models.resnet34(weights='ResNet34_Weights.IMAGENET1K_V1')

# Adapt the first convolution layer to accept grayscale input (1 channel)
original_first_conv = model.conv1
new_first_conv = nn.Conv2d(
    in_channels=1,              # Grayscale input
    out_channels=64,            # Match original ResNet34
    kernel_size=7,
    stride=2,
    padding=3,
    bias=False
)

# Initialize the grayscale conv weights by averaging across RGB channels
if original_first_conv.weight is not None:
    new_first_conv.weight.data = original_first_conv.weight.data.mean(dim=1, keepdim=True)
    # Apply Kaiming normalization (better for ReLU activation)
    nn.init.kaiming_normal_(new_first_conv.weight, mode='fan_out', nonlinearity='relu')

# Replace the original conv1 with the new grayscale version
model.conv1 = new_first_conv

# -----------------------------------------------
#  Modify the Final Classification Head
# -----------------------------------------------
num_classes = 4  # Steganography detection: Cover, LSB, WOW, HILL

# Replace the fully connected layer with a custom head
model.fc = nn.Sequential(
    nn.Dropout(0.5),                            # High dropout to reduce overfitting
    nn.Linear(model.fc.in_features, 256),      # Reduce to intermediate dimension
    nn.BatchNorm1d(256),                        # Normalize for better training
    nn.LeakyReLU(0.1, inplace=True),            # More robust than ReLU for sparse gradients
    nn.Dropout(0.3),                            # Additional regularization
    nn.Linear(256, num_classes, bias=False)     # No bias needed (handled by BatchNorm)
)

# -----------------------------------------------
#  Initialize Classifier Weights
# -----------------------------------------------
def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.normal_(m.weight, mean=0.0, std=0.01)  # Small std for stable learning
    elif isinstance(m, nn.BatchNorm1d):
        nn.init.constant_(m.weight, 1)                 # Start with identity scaling
        nn.init.constant_(m.bias, 0)                   # No initial shift

# Apply the custom initialization to the classifier
model.fc.apply(init_weights)

# -----------------------------------------------
# Wrap Model with LogitScaler
# -----------------------------------------------

# LogitScaler scales the output logits to stabilize training, especially helpful in steganalysis with subtle feature differences.
class LogitScaler(nn.Module):
    def __init__(self, model, scale=0.5):
        super().__init__()
        self.model = model
        self.scale = scale

    def forward(self, x):
        return self.model(x) * self.scale  # Scale the logits (before softmax or loss)

model = LogitScaler(model).to(device)


# Training  and Evaluating Functions

In [None]:
import torch
import torch.nn as nn
import numpy as np
from torch.cuda.amp import GradScaler, autocast
from sklearn.metrics import (
    accuracy_score, f1_score, roc_auc_score, cohen_kappa_score,
    confusion_matrix, classification_report, precision_score, recall_score
)
import psutil
import time

# ---------------------------------------
# Utility: Format Seconds into MM:SS
# ---------------------------------------
def format_time(seconds):
    mins = int(seconds // 60)
    secs = int(seconds % 60)
    return f"{mins}m {secs}s"

# ---------------------------------------
# Model Training Function
# ---------------------------------------
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler,
                num_epochs=10, checkpoint_path='best_model.pth', device='cuda', patience=5):
    print(f"Training device: {device}")
    model = model.to(device)

    # Training history tracking
    history = {
        'train_loss': [],
        'val_metrics': [],
        'best_epoch': 0,
        'class_names': ['Cover', 'LSB', 'WOW', 'HILL']
    }

    best_val_acc = 0.0
    best_metrics = {}
    epochs_without_improvement = 0
    scaler = GradScaler()

    total_batches = len(train_loader)
    total_samples = len(train_loader.dataset)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct_train = 0
        total_train = 0
        epoch_start_time = time.time()

        for batch_idx, (inputs, labels) in enumerate(train_loader):
            batch_start_time = time.time()
            inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True)

            # Warning if data is not transferred to the correct device
            if inputs.device.type == 'cpu' or labels.device.type == 'cpu':
                print(f"⚠️ Warning: Inputs or labels still on CPU — inputs: {inputs.device}, labels: {labels.device}")

            optimizer.zero_grad()

            # AMP for faster and memory-efficient training
            with autocast(device_type=device):
                outputs = model(inputs)
                loss = criterion(outputs, labels)

            # Gradient scaling + clipping for stability
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()

            # Progress monitoring
            if batch_idx % 10 == 0 or batch_idx == total_batches - 1:
                batch_time = time.time() - batch_start_time
                elapsed_time = time.time() - epoch_start_time
                percent_complete = (batch_idx + 1) / total_batches * 100
                eta = (elapsed_time / (batch_idx + 1)) * (total_batches - batch_idx - 1)

                gpu_mem_alloc = torch.cuda.memory_allocated(device) / 1024**2
                gpu_mem_max = torch.cuda.max_memory_allocated(device) / 1024**2
                cpu_mem = psutil.virtual_memory().used / 1024**2
                cpu_percent = psutil.cpu_percent(interval=None)

                print("\n" + "-" * 60)
                print(f"Epoch [{epoch+1}/{num_epochs}] | Batch [{batch_idx+1}/{total_batches}] "
                      f"({percent_complete:.1f}%)")
                print(f"Loss        : {loss.item():.4f}")
                print(f"Batch Time  : {batch_time:.2f}s | ETA: {format_time(eta)}")
                print(f"GPU Memory  : {gpu_mem_alloc:.1f} MB (Max: {gpu_mem_max:.1f} MB)")
                print(f"CPU Memory  : {cpu_mem:.1f} MB | CPU Usage: {cpu_percent:.1f}%")
                print("-" * 60)

        # Epoch-level metrics
        epoch_loss = running_loss / total_samples
        train_acc = correct_train / total_train
        history['train_loss'].append(epoch_loss)

        # Evaluate on validation set
        val_metrics = evaluate(model, val_loader)
        history['val_metrics'].append(val_metrics)

        # Scheduler step
        if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            scheduler.step(val_metrics['accuracy'])
        else:
            scheduler.step()

        # Checkpoint the best model
        if val_metrics['accuracy'] > best_val_acc:
            best_val_acc = val_metrics['accuracy']
            best_metrics = val_metrics
            history['best_epoch'] = epoch + 1
            epochs_without_improvement = 0

            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_metrics': val_metrics,
                'train_loss': epoch_loss,
                'train_acc': train_acc,
            }, checkpoint_path)
        else:
            epochs_without_improvement += 1

        # Print epoch summary
        print("\n" + "=" * 60)
        print(f"Epoch {epoch+1}/{num_epochs} Summary")
        print("-" * 60)
        print(f"Train Loss     : {epoch_loss:.4f}")
        print(f"Train Accuracy : {train_acc:.4f}")
        print(f"Val Accuracy   : {val_metrics['accuracy']:.4f}")
        print(f"Val F1-Macro   : {val_metrics['f1_macro']:.4f}")
        print(f"Val AUC (OVO)  : {val_metrics['auc_ovo']:.4f}")
        print(f"Cohen’s Kappa  : {val_metrics['kappa']:.4f}")
        print("\nClassification Report:\n")
        print(val_metrics['classification_report'])
        print("=" * 60)

        # Early stopping
        if epochs_without_improvement >= patience:
            print(f"\nEarly stopping triggered at epoch {epoch + 1}. "
                  f"No improvement for {patience} consecutive epochs.")
            break

    print("\n" + "#" * 60)
    print(f"Training complete. Best validation at epoch {history['best_epoch']}")
    print(f"Best Val Accuracy: {best_val_acc:.4f} | F1-Macro: {best_metrics['f1_macro']:.4f}")
    print("Confusion Matrix:")
    print(best_metrics['confusion_matrix'])
    print("#" * 60)

    return best_metrics, history


# Evaluation Function: Compute performance metrics on validation/test set
def evaluate(model, loader, num_classes=4):
    model.eval()
    all_preds, all_labels, all_probs = [], [], []

    with torch.no_grad():
        for patches, labels in loader:
            patches, labels = patches.to(model.device), labels.to(model.device)
            outputs = model(patches)
            all_probs.extend(torch.softmax(outputs, dim=1).cpu().numpy())
            all_preds.extend(outputs.argmax(1).cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Convert lists to numpy arrays
    y_true = np.array(all_labels)
    y_pred = np.array(all_preds)
    y_probs = np.array(all_probs)

    # Compute key metrics
    acc = accuracy_score(y_true, y_pred)
    f1_macro = f1_score(y_true, y_pred, average='macro')
    f1_per_class = f1_score(y_true, y_pred, average=None)
    precision_macro = precision_score(y_true, y_pred, average='macro')
    recall_macro = recall_score(y_true, y_pred, average='macro')
    cm = confusion_matrix(y_true, y_pred)

    try:
        auc = roc_auc_score(y_true, y_probs, multi_class='ovo', average='macro')
    except ValueError:
        auc = float('nan')  # In case of invalid inputs for AUC

    kappa = cohen_kappa_score(y_true, y_pred)
    clf_report = classification_report(y_true, y_pred, target_names=['Cover', 'LSB', 'WOW', 'HILL'])

    return {
        'accuracy': acc,
        'f1_macro': f1_macro,
        'f1_per_class': f1_per_class,
        'precision_macro': precision_macro,
        'recall_macro': recall_macro,
        'auc_ovo': auc,
        'kappa': kappa,
        'confusion_matrix': cm,
        'classification_report': clf_report
    }


# Visualization Functions


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Plot Training History Metrics
def plot_history_metrics(history_dict):
    """
    Plot training loss, validation accuracy, and F1-Macro over epochs.

    Parameters:
    - history_dict (dict): Contains 'train_loss' and 'val_metrics' with validation scores.
    """
    train_loss = history_dict.get("train_loss", [])
    val_metrics = history_dict.get("val_metrics", [])
    val_acc = [m['accuracy'] for m in val_metrics]
    val_f1 = [m['f1_macro'] for m in val_metrics]

    plt.figure(figsize=(15, 4))

    # Training Loss
    plt.subplot(1, 3, 1)
    plt.plot(train_loss, label='Train Loss', color='blue')
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True)
    plt.legend()

    # Validation Accuracy
    plt.subplot(1, 3, 2)
    plt.plot(val_acc, label='Val Accuracy', color='green')
    plt.title('Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.grid(True)
    plt.legend()

    # Validation F1-Macro
    plt.subplot(1, 3, 3)
    plt.plot(val_f1, label='Val F1-Macro', color='orange')
    plt.title('Validation F1-Macro')
    plt.xlabel('Epoch')
    plt.ylabel('F1-Macro')
    plt.grid(True)
    plt.legend()

    plt.tight_layout()
    plt.show()


# Confusion Matrix Heatmap
def plot_confusion_matrix(cm, class_names, title="Confusion Matrix", output_path=None):
    """
    Plot and optionally save the confusion matrix as a heatmap.

    Parameters:
    - cm (ndarray): Confusion matrix.
    - class_names (list): Names of the classes.
    - title (str): Title for the heatmap.
    - output_path (str): If provided, saves the figure to this path.
    """
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names)
    plt.title(title)
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')

    if output_path:
        plt.savefig(output_path, bbox_inches='tight')
    plt.show()
    plt.close()


# Phase 1: Train Modified Classifier Head Only

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.optim import AdamW
from torch.cuda.amp import GradScaler
from torch.optim.lr_scheduler import CosineAnnealingLR
import time

# -----------------------------------------------
# Phase 1: Partial Fine-Tuning of Pretrained Model
# -----------------------------------------------

# Freeze all layers
for param in model.model.parameters():
    param.requires_grad = False

# Unfreeze final block and classifier
for param in model.model.layer4.parameters():
    param.requires_grad = True
for param in model.model.fc.parameters():
    param.requires_grad = True

# Optimizer with different learning rates
optimizer = AdamW([
    {'params': model.model.layer4.parameters(), 'lr': 3e-5},
    {'params': model.model.fc.parameters(), 'lr': 1e-4}
], weight_decay=1e-4)

# Cosine Annealing LR Scheduler
scheduler = CosineAnnealingLR(
    optimizer,
    T_max=10,       # Maximum number of iterations (usually = num_epochs)
    eta_min=1e-6,   # Minimum learning rate
    last_epoch=-1   # Start fresh
)

# Loss Function
criterion = nn.CrossEntropyLoss()

# Begin Training
best_metrics, history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=10,
    checkpoint_path='phase1_best_model.pth',
    device=device,
    patience=7
)

# Save model state and entire model
torch.save(model.state_dict(), 'phase1_model_state_dict.pth')
torch.save(model, 'phase1_full_model.pth')

In [None]:
import os
from google.colab import drive
import matplotlib.pyplot as plt
import seaborn as sns
import torch.nn.functional as F
import time
import psutil

# Plot Confusion Matrices
plot_confusion_matrix(
    best_metrics['confusion_matrix'],
    class_names=['Cover', 'LSB', 'WOW', 'HILL'],
    title='Validation Confusion Matrix',
    output_path='/content/drive/MyDrive/Stego-Images-Dataset/phase1_cm.png'
)
plot_history_metrics(history)

# Phase 2: Add SRM + HighFreqPath


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ------------------------------------------------------
# SRM Layer: Extracts spatial residuals via fixed filters
# ------------------------------------------------------
class SRMLayer(nn.Module):
    """
    SRMLayer applies 3 fixed high-pass filters to grayscale input:
    - Laplacian (noise detection)
    - Gabor (directional inconsistencies)
    - Hybrid edge detector

    Input: (B, 1, H, W)
    Output: (B, 3, H, W)
    """
    def __init__(self):
        super().__init__()
        self.srm = nn.Conv2d(1, 3, kernel_size=5, padding=2, bias=False)

        # Define SRM filter kernels
        kernel1 = torch.tensor([
            [-1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1],
            [-1, -1, 24, -1, -1],
            [-1, -1, -1, -1, -1],
            [-1, -1, -1, -1, -1]
        ], dtype=torch.float32).unsqueeze(0).unsqueeze(0)

        kernel2 = torch.tensor([
            [ 0,  0, -1,  0,  0],
            [ 0, -1,  0,  1,  0],
            [-1,  0,  6,  0, -1],
            [ 0,  1,  0, -1,  0],
            [ 0,  0, -1,  0,  0]
        ], dtype=torch.float32).unsqueeze(0).unsqueeze(0)

        kernel3 = torch.tensor([
            [ 0,  0, -1,  0,  0],
            [ 0,  0, -1,  0,  0],
            [-1, -1, 12, -1, -1],
            [ 0,  0, -1,  0,  0],
            [ 0,  0, -1,  0,  0]
        ], dtype=torch.float32).unsqueeze(0).unsqueeze(0)

        # Normalize kernels
        kernel1 /= kernel1.sum()
        kernel2 /= kernel2.sum()
        kernel3 /= kernel3.sum()

        # Assign to weights and freeze
        self.srm.weight.data = torch.cat([kernel1, kernel2, kernel3], dim=0)
        self.srm.requires_grad_(False)

    def forward(self, x):
        return self.srm(x)

# ------------------------------------------------------
# High-Frequency Path: Trainable branch for high-pass cues
# ------------------------------------------------------
class HighFreqPath(nn.Module):
    """
    Simple 2-layer CNN branch to extract learnable high-frequency features.

    Input: (B, 1, H, W)
    Output: (B, 32, H, W)
    """
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        return torch.relu(self.conv2(x))

# ------------------------------------------------------
# Full Model: SRM + High-Frequency Path + ResNet34 backbone
# ------------------------------------------------------
class StegoResNet34(nn.Module):
    def __init__(self, pretrained_model):
        """
        - Combines SRM features, learnable high-frequency features, and a ResNet34 backbone.
        - Final classifier uses both ResNet + high-frequency features.
        """
        super().__init__()
        self.srm_layer = SRMLayer()
        self.high_freq_path = HighFreqPath()
        self.adaptive_pool = nn.AdaptiveAvgPool2d((7, 7))
        self.high_freq_reduce = nn.Conv2d(32, 2, kernel_size=1)

        # Load base model (handle LogitScaler-wrapped case)
        self.model = pretrained_model.model if hasattr(pretrained_model, 'model') else pretrained_model

        # Modify first conv layer to accept 5 channels (3 SRM + 2 reduced HF)
        self.model.conv1 = nn.Conv2d(5, 64, kernel_size=7, stride=2, padding=3, bias=False)

        # Custom classifier head
        self.fc = nn.Sequential(
            nn.Linear(512 + 32 * 7 * 7, 512),
            nn.ReLU(),
            nn.Dropout(0.8),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 4)  # For 4-class classification
        )

    def forward(self, x):
        srm_out = self.srm_layer(x)                           # (B, 3, H, W)
        high_freq_out = self.high_freq_path(x)                # (B, 32, H, W)
        reduced_high_freq = self.high_freq_reduce(high_freq_out)  # (B, 2, H, W)

        combined_input = torch.cat([srm_out, reduced_high_freq], dim=1)  # (B, 5, H, W)

        # Pass through modified ResNet
        x = self.model.conv1(combined_input)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)
        x = self.model.avgpool(x)
        resnet_features = torch.flatten(x, 1)                 # (B, 512)

        # Pool and flatten high-frequency branch
        hf_flat = self.adaptive_pool(high_freq_out).view(x.size(0), -1)  # (B, 1568)

        combined = torch.cat([resnet_features, hf_flat], dim=1)          # (B, 2080)
        return self.fc(combined)

# ------------------------------------------------------
# Focal Loss: Class imbalance-aware alternative to CE
# ------------------------------------------------------
class FocalLoss(nn.Module):
    """
    Focal Loss = (1 - pt)^gamma * CrossEntropy
    Helps focus training on harder samples.

    Args:
        alpha (Tensor): class-wise weights
        gamma (float): focusing parameter
    """
    def __init__(self, alpha=None, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none', weight=self.alpha)
        pt = torch.exp(-ce_loss)
        focal = (1 - pt) ** self.gamma * ce_loss
        return focal.mean()


In [None]:
# First recreate the Phase 1 architecture
def phase1_model():
    # Load a base ResNet34 with no pretrained weights
    model = models.resnet34(weights=None)

    # Modify the input layer to accept grayscale images (1 channel)
    model.conv1 = nn.Conv2d(
            in_channels=1,
            out_channels=64,
            kernel_size=7,
            stride=2,
            padding=3,
            bias=False
        )

    # Redefine the classifier head with dropout, batchnorm, and LeakyReLU
    model.fc = nn.Sequential(
        nn.Dropout(0.5),  # Strong regularization
        nn.Linear(model.fc.in_features, 256),
        nn.BatchNorm1d(256),
        nn.LeakyReLU(0.1, inplace=True),
        nn.Dropout(0.3),  # Additional regularization
        nn.Linear(256, num_classes, bias=False)
    )

    return LogitScaler(model)

# Load Phase 1 model
phase1_model = phase1_model().to(device)
checkpoint = torch.load('phase1_best_model.pth', map_location=device, weights_only=False)
phase1_model.load_state_dict(checkpoint['model_state_dict'])  # Load pretrained weights from Phase 1

# Convert to Phase 2 architecture
model = StegoResNet34(phase1_model).to(device)  # Wrap Phase 1 backbone in custom Phase 2 model

# Reinitialize conv1 for 5-channel input (3 from SRM + 2 from HighFreq)
nn.init.kaiming_normal_(model.model.conv1.weight, mode='fan_out', nonlinearity='relu')

# Freezing all layers initially
for name, param in model.named_parameters():
    param.requires_grad = False

# Unfreeze only the SRM layer, high-frequency path, and classifier head for Phase 2 fine-tuning
for name, param in model.model.fc.named_parameters():
    param.requires_grad = True
for name, param in model.srm_layer.named_parameters():
    param.requires_grad = True
for name, param in model.high_freq_path.named_parameters():
    param.requires_grad = True

# Optimizer with adjusted learning rates for different submodules
optimizer = optim.AdamW([
    {'params': model.srm_layer.parameters(), 'lr': 1e-4},          # SRM filters (static, but learnable in Phase 2)
    {'params': model.high_freq_path.parameters(), 'lr': 1e-4},    # High frequency path
    {'params': model.model.fc.parameters(), 'lr': 5e-5}            # Classifier head
], weight_decay=1e-3)  # Weight decay for regularization

# Initialize new layers with Xavier initialization for stability
def init_srm(m):
    if isinstance(m, nn.Conv2d):
        nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('leaky_relu', 0.1))

# Apply initializer to custom submodules
model.srm_layer.apply(init_srm)
model.high_freq_path.apply(init_srm)

# Cosine annealing scheduler for dynamic learning rate decay
scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6, last_epoch=-1)

# Define weighted Focal Loss to handle class imbalance
class_weights = torch.tensor([0.30, 0.30, 0.20, 0.20]).to(device)
criterion = FocalLoss(alpha=class_weights, gamma=2.0)

# Start training Phase 2 model
best_metrics, history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=20,
    checkpoint_path='phase2_best_model.pth',  # Save best checkpoint
    device=device,
    patience=7  # Early stopping
)

# Save final model weights and architecture
torch.save(model.state_dict(), 'full_model_state_dict.pth')
torch.save(model, 'full_model.pth')


In [None]:
# Plot Confusion Matrices
plot_confusion_matrix(
    best_metrics['confusion_matrix'],
    class_names=['Cover', 'LSB', 'WOW', 'HILL'],
    title='Validation Confusion Matrix',
    output_path='/content/drive/MyDrive/Stego-Images-Dataset/phase2_cm.png'
)
plot_history_metrics(history)

# Model Evaluation

In [None]:
model = torch.load('full_model.pth', map_location=device)
model.eval()

# Evaluate model
metrics = evaluate(model, test_loader, num_classes=4)

# results
print(f"Accuracy: {metrics['accuracy']:.4f}")
print(f"F1 Macro: {metrics['f1_macro']:.4f}")
print(f"Precision Macro: {metrics['precision_macro']:.4f}")
print(f"Recall Macro: {metrics['recall_macro']:.4f}")
print(f"AUC (OVO): {metrics['auc_ovo']:.4f}")
print(f"Cohen's Kappa: {metrics['kappa']:.4f}")
print("\nF1 Score per Class:", metrics['f1_per_class'])
print("\nConfusion Matrix:\n", metrics['confusion_matrix'])
print("\nClassification Report:\n", metrics['classification_report'])