# ViT and DeiT Training for Malicious Payload Detection



### 1. Imports


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import pandas as pd
import numpy as np
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import roc_auc_score
import timm
from tqdm.notebook import tqdm
import warnings
import os
import shutil
import kagglehub
from PIL import Image

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")

### 2. Configuration

We define a `Config` class to hold all hyperparameters and settings. This includes paths for our new directory structure, making it easy to manage data, models, and results.

In [None]:
class Config:
    # Project Directories
    BASE_DIR = './payload_classification_project'
    DATA_DIR = os.path.join(BASE_DIR, 'data')
    TRAIN_DATA_DIR = os.path.join(DATA_DIR, 'train')
    TEST_DATA_DIR = os.path.join(DATA_DIR, 'test')
    VAL_DATA_DIR = os.path.join(DATA_DIR, 'val')
    MODEL_DIR = os.path.join(BASE_DIR, 'models')
    RESULTS_DIR = os.path.join(BASE_DIR, 'results')

    # Training Parameters
    IMAGE_SIZE = 224
    BATCH_SIZE = 32
    EPOCHS = 10
    LEARNING_RATE = 1e-4
    NUM_FOLDS = 5
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Models to train
    MODELS_TO_TRAIN = {
        'ViT_224x224': 'vit_base_patch16_224',
        'DeiT_224x224': 'deit_base_distilled_patch16_224'
    }

print(f"Using device: {Config.DEVICE}")

### 3. Data Setup and Loading


In [None]:
def segregate_data_by_payload(original_data_path, new_base_path):
    """
    Segregates the dataset into directories based on payload type (e.g., 'Benign', 'HiddenPayloadA').
    """
    print("\nSegregating data by payload type...")
    for split in ['train', 'test', 'val']:
        stego_path = os.path.join(original_data_path, split, split, 'stego')
        clean_path = os.path.join(original_data_path, split, split, 'clean')
        new_split_path = os.path.join(new_base_path, split)

        if os.path.exists(stego_path):
            print(f"Processing stego images in {split} split...")
            for img_name in os.listdir(stego_path):
                try:
                    parts = img_name.split("_")
                    if len(parts) > 2:
                        payload_class = parts[2]
                        class_dir = os.path.join(new_split_path, payload_class)
                        os.makedirs(class_dir, exist_ok=True)
                        shutil.copy(os.path.join(stego_path, img_name), class_dir)
                    else:
                        print(f"Skipping file with unexpected name format in stego: {img_name}")
                except Exception as e:
                    print(f"Error processing file {img_name}: {e}")

        if os.path.exists(clean_path):
            print(f"Processing clean images in {split} split...")
            new_clean_path = os.path.join(new_split_path, 'Benign')
            os.makedirs(new_clean_path, exist_ok=True)
            for img_name in os.listdir(clean_path):
                if os.path.isfile(os.path.join(clean_path, img_name)):
                    shutil.copy(os.path.join(clean_path, img_name), new_clean_path)

    print("Data segregation complete.")

def setup_environment():
    """
    Sets up the necessary directory structure and downloads/organizes data.
    """
    print("--- Setting up project environment ---")
    for dir_path in [Config.BASE_DIR, Config.DATA_DIR, Config.TRAIN_DATA_DIR, Config.TEST_DATA_DIR, Config.VAL_DATA_DIR, Config.MODEL_DIR, Config.RESULTS_DIR]:
        os.makedirs(dir_path, exist_ok=True)

    expected_train_classes = ['Benign', 'HiddenPayloadA']
    data_already_segregated = all(os.path.exists(os.path.join(Config.TRAIN_DATA_DIR, cls)) for cls in expected_train_classes)

    if data_already_segregated:
        print("Dataset already downloaded and segregated.")
        return

    print("Dataset not found or not organized. Starting download and segregation...")
    try:
        download_path = kagglehub.dataset_download('marcozuppelli/stegoimagesdataset')
        segregate_data_by_payload(download_path, Config.DATA_DIR)
    except Exception as e:
        print(f"Could not download or process from Kaggle Hub. Error: {e}")
        raise

    print("\nEnvironment setup complete.")

def load_and_combine_data():
    """
    Loads the segregated training and testing data and combines them for cross-validation.
    """
    print("\n--- Loading segregated data into DataFrames ---")

    def build_df_from_directory(directory):
        image_files, labels = [], []
        if not os.path.exists(directory):
            print(f"Warning: Directory not found: {directory}. Skipping.")
            return pd.DataFrame({'image_path': [], 'label': []})
        for class_folder in os.listdir(directory):
            class_path = os.path.join(directory, class_folder)
            if os.path.isdir(class_path):
                for img_file in os.listdir(class_path):
                    image_files.append(os.path.join(class_path, img_file))
                    labels.append(class_folder)
        return pd.DataFrame({'image_path': image_files, 'label': labels})

    train_df = build_df_from_directory(Config.TRAIN_DATA_DIR)
    test_df = build_df_from_directory(Config.TEST_DATA_DIR)

    if train_df.empty or test_df.empty:
        raise ValueError("Training or testing DataFrame is empty. Data segregation may have failed.")

    combined_df = pd.concat([train_df, test_df], ignore_index=True)
    label_encoder = LabelEncoder()
    combined_df['encoded_label'] = label_encoder.fit_transform(combined_df['label'])

    print(f"Total combined data shape: {combined_df.shape}")
    print("Class distribution in combined set:\n", combined_df['label'].value_counts())

    return combined_df, label_encoder

# Execute the data preparation pipeline
setup_environment()
df, label_encoder = load_and_combine_data()
df.head()

### 4. PyTorch Dataset

The `PayloadDataset` class is updated to work with our new DataFrame. Instead of receiving pre-processed numpy arrays, it now takes the DataFrame and, in its `__getitem__` method, loads each image from the specified file path using PIL. It also ensures the image is converted to 'RGB' format to match the model's expected 3-channel input.

In [None]:
class PayloadDataset(Dataset):
    """
    Custom PyTorch dataset for payload images loaded from a DataFrame.
    """
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # Get image path and label
        image_path = self.df.iloc[idx]['image_path']
        label = self.df.iloc[idx]['encoded_label']

        # Load image
        image = Image.open(image_path).convert('RGB')

        # Apply transformations if any
        if self.transform:
            image = self.transform(image)
        
        return image, torch.tensor(label, dtype=torch.long)

### 5. Model Creation

This helper function uses the `timm` library to easily create pretrained ViT or DeiT models. We specify the model name and the number of output classes, and `timm` handles the rest.

In [None]:
def create_model(model_name, num_classes, pretrained=True):
    """
    Creates a ViT or DeiT model from the timm library with a custom classifier.
    """
    model = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes)
    return model

### 6. Training and Evaluation Functions

These are the core functions for the training loop:
-   `train_one_epoch`: Iterates through the training data, performs a forward pass, calculates the loss, and updates the model weights.
-   `evaluate`: Iterates through the validation data in `no_grad` mode, collects model predictions, and returns them for metric calculation.

In [None]:
def train_one_epoch(model, dataloader, criterion, optimizer, device):
    """
    Trains the model for one epoch.
    """
    model.train()
    total_loss = 0
    progress_bar = tqdm(dataloader, desc="Training", leave=False)
    for images, labels in progress_bar:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())
        
    return total_loss / len(dataloader)

def evaluate(model, dataloader, device):
    """
    Evaluates the model and returns predictions and true labels.
    """
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Evaluating", leave=False):
            images = images.to(device)
            outputs = model(images)
            
            # Apply softmax to get probabilities
            probabilities = torch.softmax(outputs, dim=1)
            
            all_preds.extend(probabilities.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
    return np.array(all_preds), np.array(all_labels)

### 7. Main Execution: Training and Cross-Validation

This is the main training block.

1.  **Define Transforms**: Set up image transformations. Since we are not using data augmentation, this just includes resizing, converting to a tensor, and normalizing.
2.  **Model Loop**: Iterate through each model specified in the `Config`.
3.  **Cross-Validation**: Use `StratifiedKFold` to split the DataFrame. The indices from the split are used to create training and validation `PayloadDataset` instances for each fold.
4.  **Training Loop & Evaluation**: For each fold, the model is trained for a set number of epochs, evaluating the ROC-AUC score on the validation set after each epoch.
5.  **Report Results**: After all folds are complete, the average and standard deviation of the ROC-AUC scores are calculated and printed.

In [None]:
# Define transformations (no data augmentation)
data_transform = transforms.Compose([
    transforms.Resize((Config.IMAGE_SIZE, Config.IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

num_classes = len(label_encoder.classes_)

# Loop through each model architecture
for model_key, model_name in Config.MODELS_TO_TRAIN.items():
    print(f"\n===== Starting Training for {model_key} =====")
    fold_results = []
    
    # Cross-Validation Loop
    skf = StratifiedKFold(n_splits=Config.NUM_FOLDS, shuffle=True, random_state=42)
    # Use the encoded labels for stratified splitting
    for fold, (train_idx, val_idx) in enumerate(skf.split(df, df['encoded_label'])):
        print(f"\n--- Fold {fold+1}/{Config.NUM_FOLDS} ---")

        # Split data for the current fold
        train_df = df.iloc[train_idx]
        val_df = df.iloc[val_idx]

        # Create datasets and dataloaders
        train_dataset = PayloadDataset(train_df, transform=data_transform)
        val_dataset = PayloadDataset(val_df, transform=data_transform)

        train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE, shuffle=False)

        # Initialize model, criterion, and optimizer
        model = create_model(model_name, num_classes).to(Config.DEVICE)
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.AdamW(model.parameters(), lr=Config.LEARNING_RATE)
        
        # Training loop for epochs
        best_auc = 0
        for epoch in range(Config.EPOCHS):
            train_loss = train_one_epoch(model, train_loader, criterion, optimizer, Config.DEVICE)
            val_preds, val_labels = evaluate(model, val_loader, Config.DEVICE)
            
            # Calculate metrics (Multi-class ROC-AUC One-vs-Rest)
            multi_auc = roc_auc_score(val_labels, val_preds, multi_class='ovr', average='macro')
            
            print(f"Epoch {epoch+1}/{Config.EPOCHS} | Train Loss: {train_loss:.4f} | Val ROC-AUC: {multi_auc:.4f}")

            if multi_auc > best_auc:
                best_auc = multi_auc
                # You could save the best model here if needed
                # model_path = os.path.join(Config.MODEL_DIR, f'{model_key}_fold_{fold+1}_best.pth')
                # torch.save(model.state_dict(), model_path)

        fold_results.append(best_auc)
        print(f"Best ROC-AUC for Fold {fold+1}: {best_auc:.4f}")
    
    # Print average results for the model
    avg_auc = np.mean(fold_results)
    std_auc = np.std(fold_results)
    print("\n" + "="*30)
    print(f"Results for {model_key}:")
    print(f"Average ROC-AUC over {Config.NUM_FOLDS} folds: {avg_auc:.4f} (+/- {std_auc:.4f})")
    print("="*30 + "\n")

print("All training complete.")