# Malware Detection using ResNet - Full Training Notebook

This notebook contains the complete pipeline for training a ResNet model to detect malware from executable files using image-based representations.

**Image Channels:**
- Channel 0: Sparse bigram frequency image
- Channel 1: DCT-transformed bigram image  
- Channel 2: Byteplot image

## 1. Setup and Imports

In [None]:
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import (
    confusion_matrix, 
    roc_curve, 
    auc, 
    accuracy_score,
    precision_score,
    recall_score,
    f1_score
)
from tqdm import tqdm
import time
import random
from typing import Dict, Tuple, Optional
from scipy.fft import dctn
from scipy.ndimage import zoom
import math
from torchvision import models

sys.path.insert(0, os.path.dirname(os.path.abspath(os.getcwd())))

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Configuration

In [None]:
DATA_DIR = '../data'
TRAIN_SPLIT = 0.7
VAL_SPLIT = 0.2
TEST_SPLIT = 0.1
MAX_SAMPLES = None  # Set to integer to limit samples, None for all data

IMAGE_SIZE = 256
RESNET_VARIANT = 'resnet50'  # 'resnet18' or 'resnet50'
PRETRAINED = True
FREEZE_BACKBONE = False

BATCH_SIZE = 32
NUM_EPOCHS = 25
LEARNING_RATE = 0.001
PATIENCE = 10
USE_AMP = True  # Automatic Mixed Precision
GRADIENT_ACCUMULATION_STEPS = 1

NUM_WORKERS = 4
PREFETCH_FACTOR = 2
PERSISTENT_WORKERS = True

CHECKPOINT_DIR = '../checkpoints'
RESULTS_DIR = '../results'

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

## 3. Image Generation Functions

Functions to convert binary executables into image representations:
- **Bigram Frequency Image**: 256x256 image from byte pair frequencies
- **DCT Image**: DCT transform of the bigram image
- **Byteplot Image**: Raw bytes resized to 256x256

In [None]:
def read_binary_file(file_path: str) -> bytes:
    with open(file_path, 'rb') as f:
        return f.read()


def extract_bigrams(byte_data: bytes) -> np.ndarray:
    bigram_freq = np.zeros(65536, dtype=np.float64)
    
    for i in range(len(byte_data) - 1):
        bigram = (byte_data[i] << 8) | byte_data[i + 1]
        bigram_freq[bigram] += 1
    
    return bigram_freq


def create_bigram_image(bigram_freq: np.ndarray, zero_out_0000: bool = True) -> np.ndarray:
    freq = bigram_freq.copy()
    
    if zero_out_0000:
        freq[0] = 0
    
    total = np.sum(freq)
    if total > 0:
        freq = freq / total
                  
    bigram_image = freq.reshape(256, 256)
    
    return bigram_image


def apply_2d_dct(image: np.ndarray) -> np.ndarray:
    dct_image = dctn(image, type=2, norm='ortho')
    
    dct_image = np.abs(dct_image)
    max_val = np.max(dct_image)
    if max_val > 0:
        dct_image = dct_image / max_val
    
    return dct_image


def resize_image(image: np.ndarray, target_size: Tuple[int, int]) -> np.ndarray:
    h, w = image.shape
    target_h, target_w = target_size
    
    zoom_factors = (target_h / h, target_w / w)
    resized = zoom(image, zoom_factors, order=1)
    
    return resized


def create_byteplot_from_bytes(byte_data: bytes, target_size: Tuple[int, int] = (256, 256)) -> np.ndarray:
    byte_array = np.frombuffer(byte_data, dtype=np.uint8)
    
    total_bytes = len(byte_array)
    side_length = int(math.sqrt(total_bytes))
    
    truncated_length = side_length * side_length
    byte_array = byte_array[:truncated_length]
    
    byteplot = byte_array.reshape(side_length, side_length)
    
    byteplot_resized = resize_image(byteplot, target_size)
    
    byteplot_resized = byteplot_resized.astype(np.float32) / 255.0
    
    return byteplot_resized


def create_three_channel_image(file_path: str) -> np.ndarray:
    byte_data = read_binary_file(file_path)
    
    bigram_freq = extract_bigrams(byte_data)
    
    # Channel 0: Sparse bigram frequency image
    # - Zero out 0000 bigram, normalize, reshape to 256x256
    sparse_bigram = create_bigram_image(bigram_freq, zero_out_0000=True)
    
    # Channel 1: Bigram-DCT image (the main feature from the paper)
    # - Apply full-frame DCT to de-sparsify and create distinctive patterns
    dct_bigram = apply_2d_dct(sparse_bigram)
    
    # Channel 2: Byteplot image
    # - Raw bytes as 2D visualization, resized to 256x256
    byteplot = create_byteplot_from_bytes(byte_data, target_size=(256, 256))
    
    # Stack into 3 channels: [sparse_bigram, bigram_dct, byteplot]
    three_channel = np.stack([sparse_bigram, dct_bigram, byteplot], axis=0)
    
    return three_channel.astype(np.float32)


print("Image generation functions defined!")
print("\nPipeline summary (per paper):")
print("  1. Binary → Overlapping bigrams (e.g., 0a1bc48a → 0a1b, 1bc4, c48a)")
print("  2. Bigram frequency count → 256x256 sparse image (zero out 0000, normalize)")
print("  3. DCT transform → De-sparsified 'bigram-dct' image with textured patterns")
print("  4. Byteplot → Raw bytes as 256x256 image")

## 4. Dataset Class

In [None]:
class MalwareImageDataset(Dataset):
    """Dataset for malware detection using image representations of executables."""
    
    def __init__(
        self, 
        data_dir: str, 
        max_samples: Optional[int] = None
    ):
        self.data_dir = data_dir
        self.samples = []  # List of (file_path, label)
        self._load_samples(max_samples)
    
    def _load_samples(self, max_samples: Optional[int]):
        """Load sample file paths and labels."""
        malware_dir = os.path.join(self.data_dir, 'malware')
        benign_dir = os.path.join(self.data_dir, 'benign')
        
        # Load malware samples (label = 1)
        if os.path.exists(malware_dir):
            for filename in os.listdir(malware_dir):
                file_path = os.path.join(malware_dir, filename)
                if os.path.isfile(file_path):
                    self.samples.append((file_path, 1))
        
        # Load benign samples (label = 0)
        if os.path.exists(benign_dir):
            for filename in os.listdir(benign_dir):
                file_path = os.path.join(benign_dir, filename)
                if os.path.isfile(file_path):
                    self.samples.append((file_path, 0))
        
        # Shuffle samples
        random.shuffle(self.samples)
        
        # Limit samples if specified
        if max_samples is not None:
            self.samples = self.samples[:max_samples]
        
        # Print statistics
        malware_count = sum(1 for _, label in self.samples if label == 1)
        benign_count = len(self.samples) - malware_count
        print(f"Loaded {len(self.samples)} samples:")
        print(f"  Malware: {malware_count}")
        print(f"  Benign:  {benign_count}")
    
    def __len__(self) -> int:
        return len(self.samples)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        file_path, label = self.samples[idx]
        
        try:
            # Create 3-channel image
            image = create_three_channel_image(file_path)
            image_tensor = torch.from_numpy(image).float()
            return image_tensor, label
        
        except Exception as e:
            print(f"Error processing {file_path}: {e}")
            # Return zero tensor on error
            image_tensor = torch.zeros((3, 256, 256), dtype=torch.float32)
            return image_tensor, label


def create_data_loaders(
    data_dir: str,
    batch_size: int = 32,
    train_split: float = 0.7,
    val_split: float = 0.2,
    test_split: float = 0.1,
    max_samples: Optional[int] = None,
    num_workers: int = 0,
    prefetch_factor: Optional[int] = None,
    persistent_workers: bool = False
) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """Create train, validation, and test data loaders."""
    
    full_dataset = MalwareImageDataset(data_dir, max_samples=max_samples)
    
    total_size = len(full_dataset)
    train_size = int(train_split * total_size)
    val_size = int(val_split * total_size)
    test_size = total_size - train_size - val_size
    
    print(f"\nDataset splits:")
    print(f"  Train: {train_size} ({train_split*100:.0f}%)")
    print(f"  Val:   {val_size} ({val_split*100:.0f}%)")
    print(f"  Test:  {test_size} ({test_split*100:.0f}%)")
    
    # Split dataset
    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
        full_dataset,
        [train_size, val_size, test_size],
        generator=torch.Generator().manual_seed(42)
    )
    
    # DataLoader configuration
    loader_kwargs = {
        'batch_size': batch_size,
        'num_workers': num_workers,
        'pin_memory': True,
    }
    
    if num_workers > 0:
        loader_kwargs['persistent_workers'] = persistent_workers
        if prefetch_factor is not None:
            loader_kwargs['prefetch_factor'] = prefetch_factor
        if torch.cuda.is_available():
            loader_kwargs['pin_memory_device'] = 'cuda'
    
    # Create loaders
    train_loader = DataLoader(train_dataset, shuffle=True, **loader_kwargs)
    val_loader = DataLoader(val_dataset, shuffle=False, **loader_kwargs)
    test_loader = DataLoader(test_dataset, shuffle=False, **loader_kwargs)
    
    return train_loader, val_loader, test_loader


print("Dataset classes defined!")

## 5. Model Definition

In [None]:
class ResNetMalwareDetector(nn.Module):
    """ResNet-based malware detector for binary classification."""
    
    def __init__(
        self, 
        model_name: str = 'resnet18', 
        num_classes: int = 2, 
        pretrained: bool = True, 
        freeze_backbone: bool = False
    ):
        super(ResNetMalwareDetector, self).__init__()
        
        # Load backbone
        if model_name == 'resnet18':
            if pretrained:
                self.backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
            else:
                self.backbone = models.resnet18(weights=None)
            num_features = 512
        elif model_name == 'resnet50':
            if pretrained:
                self.backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
            else:
                self.backbone = models.resnet50(weights=None)
            num_features = 2048
        else:
            raise ValueError(f"Unknown model: {model_name}. Choose 'resnet18' or 'resnet50'")
        
        # Optionally freeze backbone
        if freeze_backbone:
            for param in self.backbone.parameters():
                param.requires_grad = False
        
        # Replace final layer for binary classification
        self.backbone.fc = nn.Linear(num_features, num_classes)
    
    def forward(self, x):
        return self.backbone(x)


def count_parameters(model):
    """Count trainable parameters in the model."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


# Test model creation
print("Testing model creation...")
test_model = ResNetMalwareDetector(RESNET_VARIANT, num_classes=2, pretrained=PRETRAINED)
print(f"Model: {RESNET_VARIANT}")
print(f"Trainable parameters: {count_parameters(test_model):,}")

# Test forward pass
x = torch.randn(2, 3, 256, 256)
y = test_model(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {y.shape}")
del test_model, x, y

## 6. Metrics Tracking

In [None]:
class MetricsTracker:
    """Track predictions and compute metrics."""
    
    def __init__(self):
        self.reset()
    
    def reset(self):
        self.y_true = []
        self.y_pred = []
        self.y_scores = []
    
    def update(self, labels, predictions, scores):
        self.y_true.extend(labels.cpu().numpy().tolist())
        self.y_pred.extend(predictions.cpu().numpy().tolist())
        # Use softmax probabilities for the positive class (Malware)
        probs = torch.softmax(scores, dim=1)[:, 1]
        self.y_scores.extend(probs.cpu().numpy().tolist())
    
    def compute_metrics(self) -> Dict[str, float]:
        y_true = np.array(self.y_true)
        y_pred = np.array(self.y_pred)
        y_scores = np.array(self.y_scores)
        
        metrics = {
            'accuracy': accuracy_score(y_true, y_pred),
            'precision': precision_score(y_true, y_pred, zero_division=0),
            'recall': recall_score(y_true, y_pred, zero_division=0),
            'f1': f1_score(y_true, y_pred, zero_division=0),
        }
        
        # ROC AUC
        if len(np.unique(y_true)) > 1:
            fpr, tpr, _ = roc_curve(y_true, y_scores)
            metrics['auc'] = auc(fpr, tpr)
            metrics['fpr'] = fpr
            metrics['tpr'] = tpr
        else:
            metrics['auc'] = 0.0
        
        # Confusion matrix
        cm = confusion_matrix(y_true, y_pred)
        metrics['confusion_matrix'] = cm
        
        return metrics


print("MetricsTracker defined!")

## 7. Training Functions

In [None]:
def train_epoch(
    model: nn.Module,
    train_loader: DataLoader,
    criterion: nn.Module,
    optimizer: optim.Optimizer,
    device: torch.device,
    scaler: Optional[torch.amp.GradScaler] = None,
    gradient_accumulation_steps: int = 1
) -> Tuple[float, float]:
    """Train for one epoch."""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(train_loader, desc="Training", leave=False)
    for batch_idx, (images, labels) in enumerate(pbar):
        images = images.to(device, non_blocking=True) 
        labels = labels.to(device, non_blocking=True).long() 
        
        # Forward pass with optional AMP
        with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=(scaler is not None)):
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss = loss / gradient_accumulation_steps
        
        # Backward pass
        if scaler is not None:
            scaler.scale(loss).backward()
            if (batch_idx + 1) % gradient_accumulation_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)
        else:
            loss.backward()
            if (batch_idx + 1) % gradient_accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)
        
        # Track metrics
        running_loss += loss.item() * images.size(0) * gradient_accumulation_steps
        predictions = torch.argmax(outputs, dim=1)
        correct += (predictions == labels).sum().item()
        total += labels.size(0)
        
        pbar.set_postfix({'loss': loss.item() * gradient_accumulation_steps})
    
    avg_loss = running_loss / total
    accuracy = correct / total
    
    return avg_loss, accuracy


def evaluate(
    model: nn.Module,
    data_loader: DataLoader,
    criterion: nn.Module,
    device: torch.device
) -> Tuple[float, Dict[str, float]]:
    """Evaluate model on a dataset."""
    model.eval()
    running_loss = 0.0
    metrics_tracker = MetricsTracker()
    
    with torch.no_grad():
        pbar = tqdm(data_loader, desc="Evaluating", leave=False)
        for images, labels in pbar:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True).long()
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * images.size(0)
            predictions = torch.argmax(outputs, dim=1)
            metrics_tracker.update(labels, predictions, outputs)
    
    avg_loss = running_loss / len(data_loader.dataset)
    metrics = metrics_tracker.compute_metrics()
    
    return avg_loss, metrics


print("Training functions defined!")

## 8. Main Training Loop

In [None]:
def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    num_epochs: int = 50,
    learning_rate: float = 0.001,
    device: torch.device = None,
    save_path: Optional[str] = None,
    patience: int = 10,
    use_amp: bool = True,
    gradient_accumulation_steps: int = 1
) -> Dict:
    """Full training loop with early stopping."""
    
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    model = model.to(device)
    print(f"Training on device: {device}")
    
    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # Optional: AMP scaler
    scaler = torch.amp.GradScaler('cuda') if use_amp and torch.cuda.is_available() else None
    
    # History tracking
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': [],
        'val_auc': []
    }
    
    best_val_loss = float('inf')
    epochs_no_improve = 0
    best_model_state = None
    
    print(f"\nStarting training for {num_epochs} epochs...")
    print(f"Early stopping patience: {patience}")
    print(f"AMP enabled: {scaler is not None}")
    print("-" * 60)
    
    for epoch in range(num_epochs):
        start_time = time.time()
        
        # Training
        train_loss, train_acc = train_epoch(
            model, train_loader, criterion, optimizer, device,
            scaler=scaler, gradient_accumulation_steps=gradient_accumulation_steps
        )
        
        # Validation
        val_loss, val_metrics = evaluate(model, val_loader, criterion, device)
        val_acc = val_metrics['accuracy']
        val_auc = val_metrics.get('auc', 0.0)
        
        # Update history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        history['val_auc'].append(val_auc)
        
        epoch_time = time.time() - start_time
        
        # Print progress
        print(f"Epoch [{epoch+1}/{num_epochs}] ({epoch_time:.1f}s)")
        print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
        print(f"  Val Loss:   {val_loss:.4f}, Val Acc:   {val_acc:.4f}, Val AUC: {val_auc:.4f}")
        
        # Early stopping check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
            best_model_state = model.state_dict().copy()
            print(f"  -> New best validation loss!")
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"\nEarly stopping triggered after {epoch+1} epochs")
                break
        
        print()
    
    # Load best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print("Loaded best model from training")
    
    # Save checkpoint
    if save_path:
        torch.save({
            'model_state_dict': model.state_dict(),
            'history': history
        }, save_path)
        print(f"Model saved to {save_path}")
    
    return history


print("Main training function defined!")

## 9. Testing and Evaluation

In [None]:
def test_model(
    model: nn.Module,
    test_loader: DataLoader,
    device: torch.device = None
) -> Dict:
    """Evaluate model on test set and print detailed results."""
    
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    
    print("\n" + "="*60)
    print("EVALUATING ON TEST SET")
    print("="*60)
    
    test_loss, metrics = evaluate(model, test_loader, criterion, device)
    
    print("\nTest Results:")
    print(f"  Test Loss:      {test_loss:.4f}")
    print(f"  Accuracy:       {metrics['accuracy']:.4f} ({metrics['accuracy']*100:.2f}%)")
    print(f"  Precision:      {metrics['precision']:.4f}")
    print(f"  Recall:         {metrics['recall']:.4f}")
    print(f"  F1-Score:       {metrics['f1']:.4f}")
    print(f"  AUC:            {metrics.get('auc', 0):.4f}")
    
    # Print confusion matrix
    cm = metrics['confusion_matrix']
    print("\nConfusion Matrix:")
    print(f"               Predicted")
    print(f"             Benign  Malware")
    print(f"  Actual Benign   {cm[0,0]:5d}    {cm[0,1]:5d}")
    print(f"         Malware  {cm[1,0]:5d}    {cm[1,1]:5d}")
    
    metrics['test_loss'] = test_loss
    return metrics


print("Test function defined!")

## 10. Visualization Functions

In [None]:
def plot_training_history(history: Dict, save_path: Optional[str] = None):
    """Plot training metrics over epochs."""
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    epochs = range(1, len(history['train_loss']) + 1)
    
    # Loss plot
    axes[0].plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
    axes[0].plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training and Validation Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Accuracy plot
    axes[1].plot(epochs, history['train_acc'], 'b-', label='Train Acc', linewidth=2)
    axes[1].plot(epochs, history['val_acc'], 'r-', label='Val Acc', linewidth=2)
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy')
    axes[1].set_title('Training and Validation Accuracy')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    # AUC plot
    axes[2].plot(epochs, history['val_auc'], 'g-', label='Val AUC', linewidth=2)
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('AUC')
    axes[2].set_title('Validation AUC')
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Training history plot saved to {save_path}")
    
    plt.show()


def plot_roc_curve(metrics: Dict, save_path: Optional[str] = None):
    """Plot ROC curve."""
    if 'fpr' not in metrics or 'tpr' not in metrics:
        print("ROC curve data not available")
        return
    
    plt.figure(figsize=(8, 6))
    plt.plot(metrics['fpr'], metrics['tpr'], 'b-', linewidth=2, 
             label=f'ROC curve (AUC = {metrics["auc"]:.4f})')
    plt.plot([0, 1], [0, 1], 'r--', linewidth=1, label='Random classifier')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc='lower right')
    plt.grid(True, alpha=0.3)
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"ROC curve saved to {save_path}")
    
    plt.show()


def plot_confusion_matrix(cm: np.ndarray, save_path: Optional[str] = None):
    """Plot confusion matrix."""
    plt.figure(figsize=(8, 6))
    plt.imshow(cm, interpolation='nearest', cmap='Blues')
    plt.title('Confusion Matrix', fontsize=14)
    plt.colorbar()
    
    classes = ['Benign', 'Malware']
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, fontsize=12)
    plt.yticks(tick_marks, classes, fontsize=12)
    
    # Add text annotations
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, format(cm[i, j], 'd'),
                    ha="center", va="center",
                    color="white" if cm[i, j] > thresh else "black",
                    fontsize=20)
    
    plt.ylabel('True Label', fontsize=12)
    plt.xlabel('Predicted Label', fontsize=12)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Confusion matrix saved to {save_path}")
    
    plt.show()


def visualize_sample_images(data_loader: DataLoader, num_samples: int = 4):
    """Visualize sample images from the dataset."""
    images, labels = next(iter(data_loader))
    
    fig, axes = plt.subplots(num_samples, 3, figsize=(12, 3*num_samples))
    channel_names = ['Bigram Frequency', 'DCT Transform', 'Byteplot']
    
    for i in range(min(num_samples, len(images))):
        label = "Malware" if labels[i] == 1 else "Benign"
        
        for j in range(3):
            axes[i, j].imshow(images[i, j].numpy(), cmap='viridis')
            if i == 0:
                axes[i, j].set_title(channel_names[j])
            axes[i, j].axis('off')
        
        axes[i, 0].set_ylabel(label, fontsize=12)
    
    plt.tight_layout()
    plt.show()


print("Visualization functions defined!")

## 11. Load Data

In [None]:
# Create data loaders
train_loader, val_loader, test_loader = create_data_loaders(
    data_dir=DATA_DIR,
    batch_size=BATCH_SIZE,
    train_split=TRAIN_SPLIT,
    val_split=VAL_SPLIT,
    test_split=TEST_SPLIT,
    max_samples=MAX_SAMPLES,
    num_workers=NUM_WORKERS,
    prefetch_factor=PREFETCH_FACTOR,
    persistent_workers=PERSISTENT_WORKERS
)

print(f"\nBatch size: {BATCH_SIZE}")
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

## 12. Visualize Sample Data

In [None]:
# Visualize sample images from the training set
visualize_sample_images(train_loader, num_samples=4)

## 13. Create and Train Model

In [None]:
# Create model
model = ResNetMalwareDetector(
    model_name=RESNET_VARIANT,
    num_classes=2,
    pretrained=PRETRAINED,
    freeze_backbone=FREEZE_BACKBONE
)

print(f"Model: {RESNET_VARIANT}")
print(f"Pretrained: {PRETRAINED}")
print(f"Freeze backbone: {FREEZE_BACKBONE}")
print(f"Trainable parameters: {count_parameters(model):,}")

In [None]:
# Train model
save_path = os.path.join(CHECKPOINT_DIR, 'resnet_best.pth')

history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    device=DEVICE,
    save_path=save_path,
    patience=PATIENCE,
    use_amp=USE_AMP,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS
)

## 14. Plot Training History

In [None]:
# Plot training history
plot_training_history(
    history, 
    save_path=os.path.join(RESULTS_DIR, 'training_history.png')
)

## 15. Evaluate on Test Set

In [None]:
# Evaluate on test set
test_metrics = test_model(model, test_loader, device=DEVICE)

## 16. Plot ROC Curve

In [None]:
# Plot ROC curve
plot_roc_curve(
    test_metrics, 
    save_path=os.path.join(RESULTS_DIR, 'roc_curve.png')
)

## 17. Plot Confusion Matrix

In [None]:
# Plot confusion matrix
plot_confusion_matrix(
    test_metrics['confusion_matrix'], 
    save_path=os.path.join(RESULTS_DIR, 'confusion_matrix.png')
)

## 18. Load Saved Model (Optional)

Use this section to load a previously trained model for inference.

In [None]:
# Load a saved model checkpoint
def load_model(checkpoint_path: str, model_name: str = 'resnet50'):
    """Load a saved model from checkpoint."""
    model = ResNetMalwareDetector(
        model_name=model_name,
        num_classes=2,
        pretrained=False
    )
    
    checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(DEVICE)
    model.eval()
    
    print(f"Model loaded from {checkpoint_path}")
    return model, checkpoint.get('history', None)


# Uncomment to load a saved model:
# loaded_model, saved_history = load_model(
#     os.path.join(CHECKPOINT_DIR, 'resnet_best.pth'),
#     model_name=RESNET_VARIANT
# )

## 19. Single File Prediction

In [None]:
def predict_file(model: nn.Module, file_path: str, device: torch.device = None) -> Dict:
    """
    Predict whether a single file is malware or benign.
    
    Returns:
        Dict with prediction, confidence scores, and label
    """
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    model = model.to(device)
    model.eval()
    
    # Generate image from file
    image = create_three_channel_image(file_path)
    image_tensor = torch.from_numpy(image).float().unsqueeze(0).to(device)
    
    # Predict
    with torch.no_grad():
        output = model(image_tensor)
        probabilities = torch.softmax(output, dim=1)[0]
        prediction = torch.argmax(output, dim=1).item()
    
    result = {
        'file': file_path,
        'prediction': 'Malware' if prediction == 1 else 'Benign',
        'prediction_class': prediction,
        'confidence_benign': probabilities[0].item(),
        'confidence_malware': probabilities[1].item(),
    }
    
    return result


# Example usage:
# result = predict_file(model, '../data/benign/sample.exe', device=DEVICE)
# print(f"File: {result['file']}")
# print(f"Prediction: {result['prediction']}")
# print(f"Confidence - Benign: {result['confidence_benign']:.4f}")
# print(f"Confidence - Malware: {result['confidence_malware']:.4f}")

print("Single file prediction function defined!")