In [None]:
'''
QNI + adversarial (claude)

'''

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import transforms
from collections import Counter
import numpy as np
import random
import os
from torchvision.datasets import ImageFolder
from matplotlib import pyplot as plt
import pennylane as qml
from pennylane.qnn import TorchLayer
from tqdm.notebook import tqdm
import torch.nn.functional as F
from sklearn.utils.class_weight import compute_class_weight

class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * ((1 - pt) ** self.gamma) * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

# Set seeds for reproducibility
def seed_all(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

seed_all(42)

# ========== DEVICE ==========
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ========== PARAMETERS ==========
n_qubits = 6
batch_size = 16
num_classes = 25
num_epochs = 50
lr = 0.0005

# ========== TRANSFORMS WITH DATA AUGMENTATION ==========
train_transform = transforms.Compose([
    transforms.Grayscale(1),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

eval_transform = transforms.Compose([
    transforms.Grayscale(1),
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# ========== DATASETS ==========
train_dataset = ImageFolder('/home/netsec1/dataset_folder/malimg_dataset/train', transform=train_transform)
val_dataset   = ImageFolder('/home/netsec1/dataset_folder/malimg_dataset/val', transform=eval_transform)
test_dataset  = ImageFolder('/home/netsec1/dataset_folder/malimg_dataset/test', transform=eval_transform)
print("**dataset loaded**")

# ========== CLASS WEIGHTS ==========
labels = [label for _, label in train_dataset.samples]
class_weights = compute_class_weight(class_weight='balanced',
                                     classes=np.unique(labels),
                                     y=labels)
class_wts = torch.tensor(class_weights, dtype=torch.float)
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float).to(device)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# ========== QUANTUM CIRCUIT ==========
dev = qml.device("default.qubit", wires=n_qubits)

@qml.qnode(dev, interface="torch")
def quantum_circuit(inputs, weights):
    for i in range(n_qubits):
        qml.RY(inputs[i], wires=i)
    
    for l in range(weights.shape[0]):
        for i in range(n_qubits):
            qml.RY(weights[l][i], wires=i)
        for i in range(n_qubits - 1):
            qml.CNOT(wires=[i, i+1])
    
    return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]

weight_shapes = {"weights": (6, n_qubits)}

# ========== CNN + QNN MODEL ==========
class FeatureReduce(nn.Module):
    def __init__(self, final_dim, dropout=0.4):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 8, 3, stride=2, padding=1),    # 128 -> 64
            nn.BatchNorm2d(8),
            nn.ReLU(),
            nn.Dropout(dropout),

            nn.Conv2d(8, 16, 3, stride=2, padding=1),   # 64 -> 32
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Dropout(dropout),

            nn.Conv2d(16, 32, 3, stride=2, padding=1),  # 32 -> 16
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout(dropout),

            nn.Conv2d(32, 64, 3, stride=2, padding=1),  # 16 -> 8
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout(dropout),

            nn.Conv2d(64, 128, 3, stride=2, padding=1),  # 8 -> 4
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))                # 4×4 -> 1×1
        )
        self.fc = nn.Linear(128, final_dim)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

class HybridQNN(nn.Module):
    def __init__(self, n_qubits, num_classes):
        super().__init__()
        self.feature_extractor = FeatureReduce(final_dim=n_qubits)
        self.q_layer = TorchLayer(quantum_circuit, weight_shapes)
        self.classifier = nn.Sequential(
            nn.Linear(n_qubits, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, num_classes)
        )

    def forward(self, x):
        x = self.feature_extractor(x)
        x = torch.tanh(x)
        q_out = torch.stack([self.q_layer(f) for f in x])
        return self.classifier(q_out)

# ========== PROPER QNI-CCP IMPLEMENTATION ==========

def compute_quantum_epsilon(model, n_cnots=30, depth=6, alpha=1.0, beta=1.0):
    """
    Compute epsilon_q based on quantum circuit complexity.
    Higher complexity (more CNOTs, deeper) -> smaller epsilon_q
    """
    epsilon_q = 1.0 / (1 + alpha * n_cnots + beta * depth)
    return epsilon_q

def compute_class_centroids(model, loader, device, num_classes):
    """
    Compute class centroids in the FEATURE SPACE (before quantum layer)
    """
    model.eval()
    sums = torch.zeros(num_classes, n_qubits, device=device)
    counts = torch.zeros(num_classes, device=device)
    
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            # Get features before quantum layer (after feature extractor)
            features = model.feature_extractor(x)
            features = torch.tanh(features)  # Apply tanh as in forward pass
            
            for c in range(num_classes):
                mask = (y == c)
                if mask.any():
                    sums[c] += features[mask].sum(0)
                    counts[c] += mask.sum()
    
    # Avoid division by zero
    counts[counts == 0] = 1
    centroids = sums / counts.unsqueeze(1)
    return centroids


def qni_ccp_perturbation(model, x, y, centroids, epsilon_q=0.1, target_class=None):
    """
    Proper QNI-CCP implementation as per the paper:
    x' = x + epsilon_q * [S ⊙ (μ_c' - x)]
    
    Where:
    - S is the gradient of loss w.r.t. input features
    - μ_c' is the centroid of target class c'
    - ⊙ is element-wise multiplication
    """
    model.eval()
    
    # Step 1: Get features and compute gradients w.r.t. features
    x_for_features = x.clone().detach().requires_grad_(False)  # Don't need grad for input
    
    # Forward pass to get features - need to track gradients
    features = model.feature_extractor(x_for_features)
    features = torch.tanh(features)
    features = features.detach().requires_grad_(True)  # Enable gradients for features
    
    # Continue forward pass through quantum and classifier
    q_out = torch.stack([model.q_layer(f) for f in features])
    logits = model.classifier(q_out)
    loss = F.cross_entropy(logits, y)
    
    # Backward pass to get gradients w.r.t. features
    loss.backward(retain_graph=True)
    
    # Check if gradients exist
    if features.grad is None:
        print("Warning: features.grad is None, using zero gradients")
        S = torch.zeros_like(features)
    else:
        S = features.grad.data  # Shape: [batch_size, n_qubits]
    
    # Step 2: Select target class and get its centroid
    if target_class is None:
        # Choose a random different class for each sample
        target_classes = []
        for i in range(y.size(0)):
            available_classes = [c for c in range(centroids.size(0)) if c != y[i].item()]
            if available_classes:
                target_classes.append(torch.randint(0, len(available_classes), (1,)).item())
                target_classes[-1] = available_classes[target_classes[-1]]
            else:
                target_classes.append((y[i].item() + 1) % centroids.size(0))  # fallback
        target_class = torch.tensor(target_classes, device=y.device)
    else:
        target_class = torch.full_like(y, target_class)
    
    # Get target centroids for each sample
    mu_c_prime = centroids[target_class]  # Shape: [batch_size, n_qubits]
    
    # Step 3: Compute perturbation in feature space
    current_features = features.detach()
    feature_direction = mu_c_prime - current_features  # Shape: [batch_size, n_qubits]
    
    # Element-wise multiplication with sensitivity (gradient)
    weighted_direction = S * feature_direction  # Element-wise multiplication
    
    # Now we need to map this back to input space
    # We'll use a simplified approach: perturb input in direction of input gradient
    # scaled by the magnitude of the feature-space perturbation
    
    # Get input gradients
    x_input = x.clone().detach().requires_grad_(True)
    logits_input = model(x_input)
    loss_input = F.cross_entropy(logits_input, y)
    loss_input.backward()
    
    if x_input.grad is not None:
        input_grad = x_input.grad.data
        # Scale by feature perturbation magnitude
        perturbation_magnitude = torch.norm(weighted_direction, dim=1, keepdim=True)
        # Reshape to match input dimensions
        perturbation_magnitude = perturbation_magnitude.unsqueeze(-1).unsqueeze(-1)
        input_direction = input_grad.sign() * perturbation_magnitude
        
        # Apply perturbation
        x_perturbed = x + epsilon_q * input_direction
        x_perturbed = torch.clamp(x_perturbed, -1, 1)  # Keep within normalized bounds
    else:
        print("Warning: input gradients are None, returning original input")
        x_perturbed = x
    
    return x_perturbed.detach()


def qni_ccp_feature_perturbation_fixed(model, x, y, centroids, epsilon_q=0.1, target_class=None):
    """
    Alternative: Direct feature-space perturbation for QNI-CCP
    This is more direct and avoids the input-space mapping issue
    """
    model.eval()
    
    # Get original features
    with torch.no_grad():
        original_features = model.feature_extractor(x)
        original_features = torch.tanh(original_features)
    
    # Create a copy that requires gradients
    perturbed_features = original_features.clone().detach().requires_grad_(True)
    
    # Forward pass through quantum and classifier
    q_out = torch.stack([model.q_layer(f) for f in perturbed_features])
    logits = model.classifier(q_out)
    loss = F.cross_entropy(logits, y)
    loss.backward()
    
    # Check if gradients exist
    if perturbed_features.grad is None:
        print("Warning: perturbed_features.grad is None, using zero gradients")
        S = torch.zeros_like(perturbed_features)
    else:
        S = perturbed_features.grad.data
    
    # Step 2: Select target class and get centroid
    if target_class is None:
        # Choose random different class for each sample
        target_classes = []
        for i in range(y.size(0)):
            available_classes = [c for c in range(centroids.size(0)) if c != y[i].item()]
            if available_classes:
                target_classes.append(torch.randint(0, len(available_classes), (1,)).item())
                target_classes[-1] = available_classes[target_classes[-1]]
            else:
                target_classes.append((y[i].item() + 1) % centroids.size(0))  # fallback
        target_class = torch.tensor(target_classes, device=y.device)
    else:
        target_class = torch.full_like(y, target_class)
    
    mu_c_prime = centroids[target_class]
    
    # Step 3: Compute QNI-CCP perturbation
    # x' = x + epsilon_q * [S ⊙ (μ_c' - x)]
    perturbation_direction = mu_c_prime - original_features
    weighted_perturbation = S * perturbation_direction  # Element-wise multiplication
    
    perturbed_features_final = original_features + epsilon_q * weighted_perturbation
    
    return perturbed_features_final.detach()

def fgsm_attack(model, images, labels, eps_fgsm=0.03, device='cuda'):
    """
    Generate FGSM adversarial examples.
    
    Args:
        model     : your neural network
        images    : clean input batch, shape [B, C, H, W]
        labels    : true labels for images, shape [B]
        eps_fgsm  : perturbation magnitude (ε)
        device    : 'cuda' or 'cpu'
    
    Returns:
        images_adv: adversarial images in the same shape
    """
    model.eval()
    # make a copy that requires gradient
    images_adv = images.clone().detach().to(device).requires_grad_(True)
    labels = labels.to(device)

    # forward + backward
    logits = model(images_adv)
    loss = F.cross_entropy(logits, labels)
    model.zero_grad()
    loss.backward()

    # take a single step in the sign‐gradient direction
    images_adv = images_adv + eps_fgsm * images_adv.grad.sign()
    # clamp to valid range
    images_adv = torch.clamp(images_adv, min=-1.0, max=1.0)
    return images_adv.detach()


def pgd_attack(model,
               images,
               labels,
               pgd_eps=0.1,
               pgd_alpha=0.01,
               pgd_iters=7,
               device='cuda'):
    """
    Generate PGD adversarial examples.
    
    Args:
        model      : your neural network
        images     : clean input batch, shape [B, C, H, W]
        labels     : true labels for images, shape [B]
        pgd_eps    : maximum total perturbation (ℓ∞ radius)
        pgd_alpha  : step size per iteration
        pgd_iters  : number of iterations
        device     : 'cuda' or 'cpu'
    
    Returns:
        images_adv : adversarial images in the same shape
    """
    model.eval()
    images_orig = images.clone().detach().to(device)
    labels = labels.to(device)

    # start from a random point in the eps‐ball
    images_adv = images_orig + torch.empty_like(images_orig).uniform_(-pgd_eps, pgd_eps)
    images_adv = torch.clamp(images_adv, -1.0, 1.0).detach()

    for _ in range(pgd_iters):
        images_adv.requires_grad_(True)
        logits = model(images_adv)
        loss = F.cross_entropy(logits, labels)
        model.zero_grad()
        loss.backward()

        # gradient step
        perturb = pgd_alpha * images_adv.grad.sign()
        images_adv = images_adv + perturb

        # project back into the ℓ∞-ball around the original images
        delta = torch.clamp(images_adv - images_orig, min=-pgd_eps, max=pgd_eps)
        images_adv = torch.clamp(images_orig + delta, -1.0, 1.0).detach()

    return images_adv


# Modified training function that uses the feature-space perturbation
def train_with_feature_perturbation_and_adv(model,
                                            train_loader,
                                            val_loader,
                                            centroids,
                                            epsilon_q,
                                            epsilon_fgsm=0.03,
                                            eps_pgd=0.1,
                                            alpha_pgd=0.01,
                                            iters_pgd=7,
                                            num_epochs=50):
    """
    Training that combines:
      • Clean loss
      • QNI‑CCP feature‑space perturbation loss
      • FGSM adversarial loss
      • PGD adversarial loss
      • Optional centroid regularization
    """
    opt   = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=5e-3)
    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, 'min', patience=5)
    best_val_acc = 0.0

    for epoch in range(1, num_epochs + 1):
        # Recompute centroids every 5 epochs
        if epoch % 5 == 0:
            centroids = compute_class_centroids(model, train_loader, device, num_classes)

        model.train()
        running_loss, running_corr, running_total = 0, 0, 0

        for xb, yb in tqdm(train_loader, desc=f"Epoch {epoch}"):
            xb, yb = xb.to(device), yb.to(device)

            # 1) Clean loss
            logits_clean = model(xb)
            loss_clean   = F.cross_entropy(logits_clean, yb)

            # 2) QNI‑CCP perturbation in feature space
            perturbed_features = qni_ccp_feature_perturbation_fixed(
                model, xb, yb, centroids, epsilon_q=epsilon_q
            )
            q_out = torch.stack([model.q_layer(f) for f in perturbed_features])
            logits_qni = model.classifier(q_out)
            loss_qni   = F.cross_entropy(logits_qni, yb)

            # 3) FGSM adversarial loss
            xb_fgsm   = fgsm_attack(model, xb, yb, eps_fgsm=epsilon_fgsm)

            logits_fgsm = model(xb_fgsm)
            loss_fgsm  = F.cross_entropy(logits_fgsm, yb)

            # 4) PGD adversarial loss
            xb_pgd = pgd_attack(
            model, xb, yb,
            pgd_eps=eps_pgd,
            pgd_alpha=alpha_pgd,
            pgd_iters=iters_pgd
        )

            logits_pgd  = model(xb_pgd)
            loss_pgd   = F.cross_entropy(logits_pgd, yb)

            # 5) Centroid regularization (optional)
            current_features = torch.tanh(model.feature_extractor(xb))
            centroid_reg     = ((current_features - centroids[yb])**2).mean()

            # Combine all losses with weights
            loss = (
                0.5 * loss_clean   +  # clean
                0.15 * loss_qni    +  # QNI‑CCP
                0.1  * loss_fgsm   +  # FGSM
                0.15 * loss_pgd    +  # PGD
                0.1  * centroid_reg   # centroid reg
            )

            opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            opt.step()

            running_loss += loss.item() * xb.size(0)
            running_corr += (logits_clean.argmax(1) == yb).sum().item()
            running_total += xb.size(0)

        train_loss = running_loss / running_total
        train_acc  = running_corr / running_total
        sched.step(train_loss)

        # Validation
        val_acc = evaluate(model, val_loader)
        print(f"Epoch {epoch:2d} | Train L: {train_loss:.4f} | "
              f"Train A: {train_acc:.4f} | Val A: {val_acc:.4f}")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), "best_qni_ccp_adv_model.pth")
            print("💾 Saved best model.")

    return best_val_acc

# ========== TRAINING FUNCTIONS ==========
def evaluate(model, loader):
    model.eval()
    total, correct = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            preds = logits.argmax(1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total

# ========== TRAINING WITH PROPER QNI-CCP ==========
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HybridQNN(n_qubits, num_classes).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=5e-3)
sched = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, 'min', patience=5)

# Initialize best validation accuracy
best_val_acc = 0.0

# Compute quantum epsilon based on circuit complexity
epsilon_q = compute_quantum_epsilon(model, n_cnots=30, depth=6, alpha=1.0, beta=1.0)
print(f"Quantum epsilon: {epsilon_q:.4f}")

# Compute initial centroids
centroids = compute_class_centroids(model, train_loader, device, num_classes)
print("Initial centroids computed")

# Training loop
def qni_ccp_perturbation(model, x, y, centroids, epsilon_q=0.1, target_class=None):
    """
    Proper QNI-CCP implementation as per the paper:
    x' = x + epsilon_q * [S ⊙ (μ_c' - x)]
    
    Where:
    - S is the gradient of loss w.r.t. input features
    - μ_c' is the centroid of target class c'
    - ⊙ is element-wise multiplication
    """
    model.eval()
    
    # Step 1: Get features and compute gradients w.r.t. features
    x_for_features = x.clone().detach().requires_grad_(False)  # Don't need grad for input
    
    # Forward pass to get features - need to track gradients
    features = model.feature_extractor(x_for_features)
    features = torch.tanh(features)
    features = features.detach().requires_grad_(True)  # Enable gradients for features
    
    # Continue forward pass through quantum and classifier
    q_out = torch.stack([model.q_layer(f) for f in features])
    logits = model.classifier(q_out)
    loss = F.cross_entropy(logits, y)
    
    # Backward pass to get gradients w.r.t. features
    loss.backward(retain_graph=True)
    
    # Check if gradients exist
    if features.grad is None:
        print("Warning: features.grad is None, using zero gradients")
        S = torch.zeros_like(features)
    else:
        S = features.grad.data  # Shape: [batch_size, n_qubits]
    
    # Step 2: Select target class and get its centroid
    if target_class is None:
        # Choose a random different class for each sample
        target_classes = []
        for i in range(y.size(0)):
            available_classes = [c for c in range(centroids.size(0)) if c != y[i].item()]
            if available_classes:
                target_classes.append(torch.randint(0, len(available_classes), (1,)).item())
                target_classes[-1] = available_classes[target_classes[-1]]
            else:
                target_classes.append((y[i].item() + 1) % centroids.size(0))  # fallback
        target_class = torch.tensor(target_classes, device=y.device)
    else:
        target_class = torch.full_like(y, target_class)
    
    # Get target centroids for each sample
    mu_c_prime = centroids[target_class]  # Shape: [batch_size, n_qubits]
    
    # Step 3: Compute perturbation in feature space
    current_features = features.detach()
    feature_direction = mu_c_prime - current_features  # Shape: [batch_size, n_qubits]
    
    # Element-wise multiplication with sensitivity (gradient)
    weighted_direction = S * feature_direction  # Element-wise multiplication
    
    # Now we need to map this back to input space
    # We'll use a simplified approach: perturb input in direction of input gradient
    # scaled by the magnitude of the feature-space perturbation
    
    # Get input gradients
    x_input = x.clone().detach().requires_grad_(True)
    logits_input = model(x_input)
    loss_input = F.cross_entropy(logits_input, y)
    loss_input.backward()
    
    if x_input.grad is not None:
        input_grad = x_input.grad.data
        # Scale by feature perturbation magnitude
        perturbation_magnitude = torch.norm(weighted_direction, dim=1, keepdim=True)
        # Reshape to match input dimensions
        perturbation_magnitude = perturbation_magnitude.unsqueeze(-1).unsqueeze(-1)
        input_direction = input_grad.sign() * perturbation_magnitude
        
        # Apply perturbation
        x_perturbed = x + epsilon_q * input_direction
        x_perturbed = torch.clamp(x_perturbed, -1, 1)  # Keep within normalized bounds
    else:
        print("Warning: input gradients are None, returning original input")
        x_perturbed = x
    
    return x_perturbed.detach()


def qni_ccp_feature_perturbation_fixed(model, x, y, centroids, epsilon_q=0.1, target_class=None):
    """
    Alternative: Direct feature-space perturbation for QNI-CCP
    This is more direct and avoids the input-space mapping issue
    """
    model.eval()
    
    # Get original features
    with torch.no_grad():
        original_features = model.feature_extractor(x)
        original_features = torch.tanh(original_features)
    
    # Create a copy that requires gradients
    perturbed_features = original_features.clone().detach().requires_grad_(True)
    
    # Forward pass through quantum and classifier
    q_out = torch.stack([model.q_layer(f) for f in perturbed_features])
    logits = model.classifier(q_out)
    loss = F.cross_entropy(logits, y)
    loss.backward()
    
    # Check if gradients exist
    if perturbed_features.grad is None:
        print("Warning: perturbed_features.grad is None, using zero gradients")
        S = torch.zeros_like(perturbed_features)
    else:
        S = perturbed_features.grad.data
    
    # Step 2: Select target class and get centroid
    if target_class is None:
        # Choose random different class for each sample
        target_classes = []
        for i in range(y.size(0)):
            available_classes = [c for c in range(centroids.size(0)) if c != y[i].item()]
            if available_classes:
                target_classes.append(torch.randint(0, len(available_classes), (1,)).item())
                target_classes[-1] = available_classes[target_classes[-1]]
            else:
                target_classes.append((y[i].item() + 1) % centroids.size(0))  # fallback
        target_class = torch.tensor(target_classes, device=y.device)
    else:
        target_class = torch.full_like(y, target_class)
    
    mu_c_prime = centroids[target_class]
    
    # Step 3: Compute QNI-CCP perturbation
    # x' = x + epsilon_q * [S ⊙ (μ_c' - x)]
    perturbation_direction = mu_c_prime - original_features
    weighted_perturbation = S * perturbation_direction  # Element-wise multiplication
    
    perturbed_features_final = original_features + epsilon_q * weighted_perturbation
    
    return perturbed_features_final.detach()


# Replace the original training loop with this:
print("Training with QNI-CCP (Feature-space perturbation)...")
# before training
epsilon_q = compute_quantum_epsilon(model, n_cnots=30, depth=6)
centroids = compute_class_centroids(model, train_loader, device, num_classes)

# now train
best_val = train_with_feature_perturbation_and_adv(
    model, train_loader, val_loader,
    centroids, epsilon_q,
    epsilon_fgsm=0.03,
    eps_pgd=0.1, alpha_pgd=0.01, iters_pgd=7,
    num_epochs=50
)
print("Best val acc:", best_val)


**dataset loaded**
Quantum epsilon: 0.0270
Initial centroids computed
Training with QNI-CCP (Feature-space perturbation)...


Epoch 1:   0%|          | 0/467 [00:00<?, ?it/s]

Epoch  1 | Train L: 2.1570 | Train A: 0.2920 | Val A: 0.3619
💾 Saved best model.


Epoch 2:   0%|          | 0/467 [00:00<?, ?it/s]

Epoch  2 | Train L: 1.5573 | Train A: 0.4309 | Val A: 0.4496
💾 Saved best model.


Epoch 3:   0%|          | 0/467 [00:00<?, ?it/s]

Epoch  3 | Train L: 1.1747 | Train A: 0.5916 | Val A: 0.4875
💾 Saved best model.


Epoch 4:   0%|          | 0/467 [00:00<?, ?it/s]

Epoch  4 | Train L: 0.9429 | Train A: 0.6927 | Val A: 0.6511
💾 Saved best model.


Epoch 5:   0%|          | 0/467 [00:00<?, ?it/s]

Epoch  5 | Train L: 0.7959 | Train A: 0.7355 | Val A: 0.7703
💾 Saved best model.


Epoch 6:   0%|          | 0/467 [00:00<?, ?it/s]

Epoch  6 | Train L: 0.7036 | Train A: 0.7647 | Val A: 0.7866
💾 Saved best model.


Epoch 7:   0%|          | 0/467 [00:00<?, ?it/s]

Epoch  7 | Train L: 0.6330 | Train A: 0.7926 | Val A: 0.8028
💾 Saved best model.


Epoch 8:   0%|          | 0/467 [00:00<?, ?it/s]

Epoch  8 | Train L: 0.5429 | Train A: 0.8320 | Val A: 0.8559
💾 Saved best model.


Epoch 9:   0%|          | 0/467 [00:00<?, ?it/s]

Epoch  9 | Train L: 0.4764 | Train A: 0.8653 | Val A: 0.8830
💾 Saved best model.


Epoch 10:   0%|          | 0/467 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [6]:
# Resume training from saved model
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
import numpy as np
from tqdm.notebook import tqdm
import os

# Make sure all your model classes and functions are defined before this
# (FocalLoss, HybridQNN, FeatureReduce, etc. - all from your original code)

# ========== RESUME TRAINING SETUP ==========
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Recreate the model architecture (same as original)
n_qubits = 6
num_classes = 25
model = HybridQNN(n_qubits, num_classes).to(device)

# Load the saved model weights
model_path = "qni_ccp_adv_model_epoch_10.pth"
if os.path.exists(model_path):
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])  # ✅ fixed here
    print(f"✅ Model loaded from {model_path}")
else:
    print(f"❌ Model file {model_path} not found!")
    exit()


# Recreate data loaders (same as original)
batch_size = 16

train_transform = transforms.Compose([
    transforms.Grayscale(1),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

eval_transform = transforms.Compose([
    transforms.Grayscale(1),
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = ImageFolder('/home/netsec1/dataset_folder/malimg_dataset/train', transform=train_transform)
val_dataset = ImageFolder('/home/netsec1/dataset_folder/malimg_dataset/val', transform=eval_transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# ========== MODIFIED TRAINING FUNCTION FOR RESUME ==========
def resume_training_with_feature_perturbation_and_adv(model,
                                                     train_loader,
                                                     val_loader,
                                                     centroids,
                                                     epsilon_q,
                                                     epsilon_fgsm=0.03,
                                                     eps_pgd=0.1,
                                                     alpha_pgd=0.01,
                                                     iters_pgd=7,
                                                     num_epochs=15,
                                                     start_epoch=11):
    """
    Resume training with model saving after each epoch
    """
    # Setup optimizer and scheduler
    opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=5e-3)  # Slightly lower LR for resume
    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, 'min', patience=5)
    
    # Get initial validation accuracy
    initial_val_acc = evaluate(model, val_loader)
    best_val_acc = initial_val_acc
    print(f"Initial validation accuracy: {initial_val_acc:.4f}")

    for epoch in range(start_epoch, start_epoch + num_epochs):
        # Recompute centroids every 5 epochs
        if epoch % 5 == 0:
            centroids = compute_class_centroids(model, train_loader, device, num_classes)
            print(f"Recomputed centroids at epoch {epoch}")

        model.train()
        running_loss, running_corr, running_total = 0, 0, 0

        for xb, yb in tqdm(train_loader, desc=f"Epoch {epoch}"):
            xb, yb = xb.to(device), yb.to(device)

            # 1) Clean loss
            logits_clean = model(xb)
            loss_clean = F.cross_entropy(logits_clean, yb)

            # 2) QNI-CCP perturbation in feature space
            perturbed_features = qni_ccp_feature_perturbation_fixed(
                model, xb, yb, centroids, epsilon_q=epsilon_q
            )
            q_out = torch.stack([model.q_layer(f) for f in perturbed_features])
            logits_qni = model.classifier(q_out)
            loss_qni = F.cross_entropy(logits_qni, yb)

            # 3) FGSM adversarial loss
            xb_fgsm = fgsm_attack(model, xb, yb, eps_fgsm=epsilon_fgsm)
            logits_fgsm = model(xb_fgsm)
            loss_fgsm = F.cross_entropy(logits_fgsm, yb)

            # 4) PGD adversarial loss
            xb_pgd = pgd_attack(
                model, xb, yb,
                pgd_eps=eps_pgd,
                pgd_alpha=alpha_pgd,
                pgd_iters=iters_pgd
            )
            logits_pgd = model(xb_pgd)
            loss_pgd = F.cross_entropy(logits_pgd, yb)

            # 5) Centroid regularization
            current_features = torch.tanh(model.feature_extractor(xb))
            centroid_reg = ((current_features - centroids[yb])**2).mean()

            # Combine all losses with weights
            loss = (
                0.5 * loss_clean +    # clean
                0.15 * loss_qni +     # QNI-CCP
                0.1 * loss_fgsm +     # FGSM
                0.15 * loss_pgd +     # PGD
                0.1 * centroid_reg    # centroid reg
            )

            opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            opt.step()

            running_loss += loss.item() * xb.size(0)
            running_corr += (logits_clean.argmax(1) == yb).sum().item()
            running_total += xb.size(0)

        train_loss = running_loss / running_total
        train_acc = running_corr / running_total
        sched.step(train_loss)

        # Validation
        val_acc = evaluate(model, val_loader)
        print(f"Epoch {epoch:2d} | Train L: {train_loss:.4f} | "
              f"Train A: {train_acc:.4f} | Val A: {val_acc:.4f}")

        # Save model after each epoch
        epoch_model_path = f"qni_ccp_adv_model_epoch_{epoch}.pth"
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': opt.state_dict(),
            'scheduler_state_dict': sched.state_dict(),
            'train_loss': train_loss,
            'train_acc': train_acc,
            'val_acc': val_acc,
            'best_val_acc': best_val_acc
        }, epoch_model_path)
        print(f"💾 Saved model for epoch {epoch}")

        # Update best model if validation accuracy improved
       
        best_val_acc = val_acc
        torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': opt.state_dict(),
                'scheduler_state_dict': sched.state_dict(),
                'train_loss': train_loss,
                'train_acc': train_acc,
                'val_acc': val_acc,
                'best_val_acc': best_val_acc
            }, "best_qni_ccp_adv_model_resumed.pth")
        print(f"🏆 New best model saved with val_acc: {val_acc:.4f}")

    return best_val_acc

# ========== RESUME TRAINING EXECUTION ==========
print("Resuming training...")

# Compute quantum epsilon and centroids
epsilon_q = compute_quantum_epsilon(model, n_cnots=30, depth=6, alpha=1.0, beta=1.0)
centroids = compute_class_centroids(model, train_loader, device, num_classes)

print(f"Quantum epsilon: {epsilon_q:.4f}")
print("Centroids computed")

# Resume training for 15 more epochs
best_val_resumed = resume_training_with_feature_perturbation_and_adv(
    model, train_loader, val_loader,
    centroids, epsilon_q,
    epsilon_fgsm=0.03,
    eps_pgd=0.1, alpha_pgd=0.01, iters_pgd=7,
    num_epochs=15,
    start_epoch=11  # You can adjust this if you know the exact epoch number
)

print(f"✅ Training resumed and completed!")
print(f"Final best validation accuracy: {best_val_resumed:.4f}")

# Optional: Clean up individual epoch files if you only want to keep the best model
import glob
epoch_files = glob.glob("qni_ccp_adv_model_epoch_*.pth")
print(f"Created {len(epoch_files)} epoch checkpoint files")
print("Files created:", epoch_files)

✅ Model loaded from qni_ccp_adv_model_epoch_10.pth
Resuming training...
Quantum epsilon: 0.0270
Centroids computed
Initial validation accuracy: 0.9339


Epoch 11:   0%|          | 0/467 [00:00<?, ?it/s]

Epoch 11 | Train L: 0.2296 | Train A: 0.9449 | Val A: 0.7692
💾 Saved model for epoch 11
🏆 New best model saved with val_acc: 0.7692


Epoch 12:   0%|          | 0/467 [00:00<?, ?it/s]

Epoch 12 | Train L: 0.2243 | Train A: 0.9434 | Val A: 0.9415
💾 Saved model for epoch 12
🏆 New best model saved with val_acc: 0.9415


Epoch 13:   0%|          | 0/467 [00:00<?, ?it/s]

Epoch 13 | Train L: 0.2097 | Train A: 0.9460 | Val A: 0.9469
💾 Saved model for epoch 13
🏆 New best model saved with val_acc: 0.9469


Epoch 14:   0%|          | 0/467 [00:00<?, ?it/s]

Epoch 14 | Train L: 0.2021 | Train A: 0.9513 | Val A: 0.9437
💾 Saved model for epoch 14
🏆 New best model saved with val_acc: 0.9437
Recomputed centroids at epoch 15


Epoch 15:   0%|          | 0/467 [00:00<?, ?it/s]

Epoch 15 | Train L: 0.2090 | Train A: 0.9512 | Val A: 0.7844
💾 Saved model for epoch 15
🏆 New best model saved with val_acc: 0.7844


Epoch 16:   0%|          | 0/467 [00:00<?, ?it/s]

Epoch 16 | Train L: 0.1935 | Train A: 0.9529 | Val A: 0.9393
💾 Saved model for epoch 16
🏆 New best model saved with val_acc: 0.9393


Epoch 17:   0%|          | 0/467 [00:00<?, ?it/s]

Epoch 17 | Train L: 0.1779 | Train A: 0.9548 | Val A: 0.9512
💾 Saved model for epoch 17
🏆 New best model saved with val_acc: 0.9512


Epoch 18:   0%|          | 0/467 [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Epoch 21 | Train L: 0.1655 | Train A: 0.9598 | Val A: 0.9599
💾 Saved model for epoch 21
🏆 New best model saved with val_acc: 0.9599


Epoch 22:   0%|          | 0/467 [00:00<?, ?it/s]

Epoch 22 | Train L: 0.1626 | Train A: 0.9610 | Val A: 0.9393
💾 Saved model for epoch 22
🏆 New best model saved with val_acc: 0.9393


Epoch 23:   0%|          | 0/467 [00:00<?, ?it/s]

Epoch 23 | Train L: 0.1584 | Train A: 0.9603 | Val A: 0.9577
💾 Saved model for epoch 23
🏆 New best model saved with val_acc: 0.9577


Epoch 24:   0%|          | 0/467 [00:00<?, ?it/s]

KeyboardInterrupt: 