In [None]:
'''
Changed value of epsilon, rather taking it as a constant value, defined a formula based on circuit

'''

In [3]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import transforms
from collections import Counter
import numpy as np
import random
import os
from torchvision.datasets import ImageFolder
from matplotlib import pyplot as plt
import pennylane as qml
from pennylane.qnn import TorchLayer
from tqdm.notebook import tqdm

#for loss function 
import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * ((1 - pt) ** self.gamma) * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

# Set seeds for reproducibility
def seed_all(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

seed_all(42)

# ========== DEVICE ==========
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ========== PARAMETERS ==========
n_qubits = 6
batch_size = 16
num_classes = 25
num_epochs = 50
lr = 0.0005

# ========== TRANSFORMS WITH DATA AUGMENTATION ==========
# ✅ For training (with augmentation)
train_transform = transforms.Compose([
    transforms.Grayscale(1),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# ✅ For validation and test (no augmentation)
eval_transform = transforms.Compose([
    transforms.Grayscale(1),
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])


# ========== DATASETS ==========
train_dataset = ImageFolder('/home/netsec1/dataset_folder/malimg_dataset/train', transform=train_transform)
val_dataset   = ImageFolder('/home/netsec1/dataset_folder/malimg_dataset/val', transform=eval_transform)
test_dataset  = ImageFolder('/home/netsec1/dataset_folder/malimg_dataset/test', transform=eval_transform)
print("**dataset loaded**")
# ========== CLASS WEIGHTS ==========
from sklearn.utils.class_weight import compute_class_weight

labels = [label for _, label in train_dataset.samples]
class_weights = compute_class_weight(class_weight='balanced',
                                     classes=np.unique(labels),
                                     y=labels)
class_wts = torch.tensor(class_weights, dtype=torch.float)

class_weights_tensor = torch.tensor(class_weights, dtype=torch.float).to(device)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# ========== QUANTUM CIRCUIT ==========
dev = qml.device("default.qubit", wires=n_qubits)

@qml.qnode(dev, interface="torch")

def quantum_circuit(inputs, weights):
    for i in range(n_qubits):
        qml.RY(inputs[i], wires=i)
    
    for l in range(weights.shape[0]):
        for i in range(n_qubits):
            qml.RY(weights[l][i], wires=i)
        for i in range(n_qubits - 1):
            qml.CNOT(wires=[i, i+1])
    
    return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]

weight_shapes = {"weights": (6, n_qubits)}


# ========== CNN + QNN MODEL ==========
class FeatureReduce(nn.Module):
    def __init__(self, final_dim, dropout=0.4):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 8, 3, stride=2, padding=1),    # 128 -> 64
            nn.BatchNorm2d(8),
            nn.ReLU(),
            nn.Dropout(dropout),

            nn.Conv2d(8, 16, 3, stride=2, padding=1),   # 64 -> 32
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Dropout(dropout),

            nn.Conv2d(16, 32, 3, stride=2, padding=1),  # 32 -> 16
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout(dropout),

            nn.Conv2d(32, 64, 3, stride=2, padding=1),  # 16 -> 8
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout(dropout),

            nn.Conv2d(64, 128, 3, stride=2, padding=1),  # ⬅️ Extra block: 8 -> 4
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))                # 4×4 -> 1×1
        )
        self.fc = nn.Linear(128, final_dim)  # ⬅️ Changed from 64 to 128

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)


class HybridQNN(nn.Module):
    def __init__(self, n_qubits, num_classes):
        super().__init__()
        self.feature_extractor = FeatureReduce(final_dim=n_qubits)
        self.q_layer = TorchLayer(quantum_circuit, weight_shapes)

        # Adding 4-layer MLP after quantum layer
        self.classifier = nn.Sequential(
            nn.Linear(n_qubits, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, num_classes)
        )

    def forward(self, x):
        x = self.feature_extractor(x)
        x = torch.tanh(x)
        q_out = torch.stack([self.q_layer(f) for f in x])
        return self.classifier(q_out)

# ========== TRAINING ==========
print("Starting training")

# ── 2) Precompute class‐centroids in feature space ──────────────────────────
def compute_centroids(model, loader, device, num_classes):
    model.eval()
    sums = torch.zeros(num_classes, n_qubits, device=device)
    counts = torch.zeros(num_classes, device=device)
    with torch.no_grad():
        for x,y in loader:
            x,y = x.to(device), y.to(device)
            feats = model.feature_extractor(x)      # pre‐tanh features
            for c in range(num_classes):
                mask = (y==c)
                if mask.any():
                    sums[c] += feats[mask].sum(0)
                    counts[c] += mask.sum()
    return sums / counts.unsqueeze(1)

    
def compute_epsilon_q(num_cnots, depth, alpha=0.1, beta=0.1):
    return 1.0 / (1.0 + alpha * num_cnots + beta * depth)


# ── 3) QNI perturbation function ────────────────────────────────────────────
def class_conditional_noise(feats, labels, centroids, epsilon=0.1):
    """
    feats: [B, n_qubits] pre‐tanh
    labels: [B]
    centroids: [num_classes, n_qubits]
    """
    B = feats.size(0)
    noise = torch.zeros_like(feats)
    for i in range(B):
        c_true = centroids[labels[i]]
        # pick random target class ≠ true
        choices = list(range(centroids.size(0)))
        choices.remove(int(labels[i]))
        c_targ = centroids[random.choice(choices)]
        # direction and scale
        dir_vec = (c_targ - c_true)
        noise[i] = epsilon * dir_vec.sign()
    return feats + noise

# … everything above stays the same up to compute_centroids …

# ── 4) Training loop with QNI ──────────────────────────────────────────────
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HybridQNN(n_qubits, num_classes).to(device)
opt   = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=5e-3)
sched = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, 'min', patience=5)

# Initialize best validation accuracy
best_val_acc = 0.0
best_model_path = "best_QNI_model.pth"


# helper to evaluate on a loader
def evaluate(model, loader):
    model.eval()
    total, correct = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            preds = logits.argmax(1)
            correct += (preds == y).sum().item()
            total   += y.size(0)
    return correct/total

# initial centroids before training
centroids = compute_centroids(model, train_loader, device, num_classes)

for epoch in range(1, 51):
    # every 5 epochs, recompute centroids on the *current* model:
    if epoch % 5 == 0:
        centroids = compute_centroids(model, train_loader, device, num_classes)
    
    model.train()
    running_loss, running_correct, running_total = 0, 0, 0

    for x, y in tqdm(train_loader, desc=f"Epoch {epoch} [train]"):
        x, y = x.to(device), y.to(device)

        # 1) clean path
        feats       = model.feature_extractor(x)            # [B, n_qubits]
        clean_logits = model(x)                             # [B, num_classes]
        loss_clean   = F.cross_entropy(clean_logits, y)

        # 2) perturbed path
        epsilon_q = compute_epsilon_q(num_cnots=5, depth=6)  # adjusted as per circuit
        feats_pert = class_conditional_noise(feats, y, centroids, epsilon=epsilon_q)

        feats_pert_t = torch.tanh(feats_pert)
        q_out_pert   = torch.stack([model.q_layer(feats_pert_t[i]) 
                                    for i in range(feats_pert_t.size(0))])
        pert_logits  = model.classifier(q_out_pert)
        loss_pert    = F.cross_entropy(pert_logits, y)

        # 3) joint loss & backward
        loss = 0.8 * loss_clean + 0.2 * loss_pert
        opt.zero_grad()
        loss.backward()
        opt.step()

        # track
        running_loss   += loss.item() * x.size(0)
        running_correct += (clean_logits.argmax(1) == y).sum().item()
        running_total   += y.size(0)

    # step scheduler on *average* training loss
    avg_train_loss = running_loss / running_total
    sched.step(avg_train_loss)

    train_acc = running_correct / running_total
    val_acc   = evaluate(model, val_loader)
    print(f"\nEpoch {epoch:2d} — train loss: {avg_train_loss:.4f}, "
      f"train acc: {train_acc:.4f}, val acc: {val_acc:.4f}")
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': opt.state_dict(),
            'val_accuracy': val_acc
        }, best_model_path)
        print(f"✅ Best model saved at epoch {epoch} with val_acc: {val_acc:.4f}\n")
    else:
        print()


**dataset loaded**
Starting training


Epoch 1 [train]:   0%|          | 0/467 [00:00<?, ?it/s]


Epoch  1 — train loss: 2.3053, train acc: 0.2945, val acc: 0.4009
✅ Best model saved at epoch 1 with val_acc: 0.4009



Epoch 2 [train]:   0%|          | 0/467 [00:00<?, ?it/s]


Epoch  2 — train loss: 1.7316, train acc: 0.4056, val acc: 0.4420
✅ Best model saved at epoch 2 with val_acc: 0.4420



Epoch 3 [train]:   0%|          | 0/467 [00:00<?, ?it/s]


Epoch  3 — train loss: 1.5457, train acc: 0.4444, val acc: 0.4680
✅ Best model saved at epoch 3 with val_acc: 0.4680



Epoch 4 [train]:   0%|          | 0/467 [00:00<?, ?it/s]


Epoch  4 — train loss: 1.3824, train acc: 0.4951, val acc: 0.4377



Epoch 5 [train]:   0%|          | 0/467 [00:00<?, ?it/s]


Epoch  5 — train loss: 1.0802, train acc: 0.6333, val acc: 0.5298
✅ Best model saved at epoch 5 with val_acc: 0.5298



Epoch 6 [train]:   0%|          | 0/467 [00:00<?, ?it/s]


Epoch  6 — train loss: 0.9637, train acc: 0.6651, val acc: 0.5255



Epoch 7 [train]:   0%|          | 0/467 [00:00<?, ?it/s]


Epoch  7 — train loss: 0.9144, train acc: 0.6888, val acc: 0.5385
✅ Best model saved at epoch 7 with val_acc: 0.5385



Epoch 8 [train]:   0%|          | 0/467 [00:00<?, ?it/s]


Epoch  8 — train loss: 0.8614, train acc: 0.7142, val acc: 0.7703
✅ Best model saved at epoch 8 with val_acc: 0.7703



Epoch 9 [train]:   0%|          | 0/467 [00:00<?, ?it/s]


Epoch  9 — train loss: 0.8254, train acc: 0.7313, val acc: 0.3196



Epoch 10 [train]:   0%|          | 0/467 [00:00<?, ?it/s]


Epoch 10 — train loss: 0.7400, train acc: 0.7529, val acc: 0.7844
✅ Best model saved at epoch 10 with val_acc: 0.7844



Epoch 11 [train]:   0%|          | 0/467 [00:00<?, ?it/s]


Epoch 11 — train loss: 0.7224, train acc: 0.7540, val acc: 0.7941
✅ Best model saved at epoch 11 with val_acc: 0.7941



Epoch 12 [train]:   0%|          | 0/467 [00:00<?, ?it/s]


Epoch 12 — train loss: 0.6812, train acc: 0.7682, val acc: 0.8212
✅ Best model saved at epoch 12 with val_acc: 0.8212



Epoch 13 [train]:   0%|          | 0/467 [00:00<?, ?it/s]


Epoch 13 — train loss: 0.6408, train acc: 0.7839, val acc: 0.7963



Epoch 14 [train]:   0%|          | 0/467 [00:00<?, ?it/s]


Epoch 14 — train loss: 0.6172, train acc: 0.7953, val acc: 0.8310
✅ Best model saved at epoch 14 with val_acc: 0.8310



Epoch 15 [train]:   0%|          | 0/467 [00:00<?, ?it/s]


Epoch 15 — train loss: 0.6025, train acc: 0.7973, val acc: 0.7898



Epoch 16 [train]:   0%|          | 0/467 [00:00<?, ?it/s]


Epoch 16 — train loss: 0.5848, train acc: 0.8048, val acc: 0.8299



Epoch 17 [train]:   0%|          | 0/467 [00:00<?, ?it/s]

KeyboardInterrupt: 