In [None]:
# %%
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import transforms
from collections import Counter, defaultdict
import numpy as np 
import random 
import os 
from torchvision.datasets import ImageFolder
import pennylane as qml 
import matplotlib.pyplot as plt 
from tqdm.notebook import tqdm 
import torch.nn.functional as F
from sklearn.utils.class_weight import compute_class_weight

# %%
class Focal_loss(nn.Module):
    def __init__(self, alpha, gamma, reduction="mean"):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        CE = F.cross_entropy(inputs, targets, reduction="none")
        pt = torch.exp(-CE)
        focal_loss = self.alpha * ((1 - pt) ** self.gamma) * CE

        if self.reduction == "mean":
            return focal_loss.mean()
        elif self.reduction == "sum":
            return focal_loss.sum()
        return focal_loss

# %%
class QNI_CCP:
    """
    Quantum-aware Noise Injection with Class-Conditional Perturbation
    Enhanced with adaptive perturbation and better regularization
    """
    def __init__(self, n_qubits, n_layers, num_classes, epsilon_base=0.05, 
                 update_freq=100, alpha=0.9, adaptive_epsilon=True):
        self.n_qubits = n_qubits
        self.n_layers = n_layers
        self.num_classes = num_classes
        self.epsilon_base = epsilon_base
        self.update_freq = update_freq
        self.alpha = alpha
        self.adaptive_epsilon = adaptive_epsilon
        
        # Class-conditional statistics
        self.class_means = {}
        self.class_counts = defaultdict(int)
        self.update_counter = 0
        
        # Quantum complexity metrics
        self.circuit_depth = n_layers
        self.entangling_gates = (n_qubits - 1) * n_layers
        self.quantum_complexity = self._compute_quantum_complexity()
        
        # Adaptive perturbation tracking
        self.current_epsilon = epsilon_base
        self.performance_history = []
        
    def _compute_quantum_complexity(self):
        """Compute quantum-aware scaling factor based on circuit properties"""
        depth_factor = min(self.circuit_depth / 10.0, 1.0)
        entanglement_factor = min(self.entangling_gates / (self.n_qubits * 5), 1.0)
        
        complexity = 0.5 * depth_factor + 0.5 * entanglement_factor
        epsilon_q = 1.0 - complexity
        return max(epsilon_q, 0.1)
    
    def update_epsilon(self, train_acc, val_acc):
        """Adaptively adjust epsilon based on overfitting detection"""
        if len(self.performance_history) > 0:
            overfitting_gap = train_acc - val_acc
            
            if overfitting_gap > 0.15:  # Significant overfitting
                self.current_epsilon = min(self.current_epsilon * 1.2, 0.3)
            elif overfitting_gap < 0.05:  # Underfitting or good fit
                self.current_epsilon = max(self.current_epsilon * 0.9, 0.01)
        
        self.performance_history.append((train_acc, val_acc))
        if len(self.performance_history) > 10:
            self.performance_history.pop(0)
    
    def update_class_statistics(self, features, labels):
        """Update running class means using exponential moving average"""
        features = features.detach().cpu()
        labels = labels.detach().cpu()
        
        for i, label in enumerate(labels):
            label_int = label.item()
            feature_vec = features[i]
            
            if label_int not in self.class_means:
                self.class_means[label_int] = feature_vec.clone()
            else:
                self.class_means[label_int] = (self.alpha * self.class_means[label_int] + 
                                             (1 - self.alpha) * feature_vec)
            
            self.class_counts[label_int] += 1
        
        self.update_counter += 1
    
    def compute_gradient_sensitivity(self, model, features, labels, loss_fn):
        """Compute gradient sensitivity with better numerical stability"""
        device = features.device
        
        with torch.no_grad():
            extracted_features = model.featureextractor(features)
        
        extracted_features.requires_grad_(True)
        
        q_outputs = []
        for i in range(extracted_features.size(0)):
            q_out = model.q_layer(torch.tanh(extracted_features[i]))
            q_outputs.append(q_out)
        q_out_batch = torch.stack(q_outputs)
        
        outputs = model.classifier(q_out_batch)
        loss = loss_fn(outputs, labels)
        
        gradients = torch.autograd.grad(loss, extracted_features,
                                        retain_graph=True, create_graph=True)[0]
        
        # Add small epsilon for numerical stability
        sensitivity = torch.abs(gradients).detach() + 1e-8
        
        return sensitivity

    def generate_perturbation(self, features, labels, model, loss_fn, training=True):
        """Generate QNI-CCP perturbations with improved stability"""
        batch_size = features.size(0)
        device = features.device
        
        with torch.no_grad():
            extracted_features = model.featureextractor(features)
        
        if len(self.class_means) < 2:
            return torch.tanh(extracted_features)
        
        sensitivity = self.compute_gradient_sensitivity(model, features, labels, loss_fn)
        perturbed_features = []
        
        for i in range(batch_size):
            current_label = labels[i].item()
            current_feature = extracted_features[i]
            current_sensitivity = sensitivity[i]
            
            available_classes = [c for c in self.class_means
                                 if c != current_label and self.class_counts[c] > 5]
            if not available_classes:
                perturbed_features.append(torch.tanh(current_feature))
                continue
            
            target_class = random.choice(available_classes)
            target_mean = self.class_means[target_class].to(device)
            
            direction = target_mean - current_feature
            direction = direction / (torch.norm(direction) + 1e-8)  # Normalize
            
            # Use adaptive epsilon
            epsilon_q = self.quantum_complexity
            current_eps = self.current_epsilon if self.adaptive_epsilon else self.epsilon_base
            
            # Scale perturbation more conservatively
            perturbation_scale = current_eps * epsilon_q * 0.1
            scaled_perturbation = perturbation_scale * current_sensitivity * direction
            
            perturbed_feature = current_feature + scaled_perturbation
            perturbed_features.append(torch.tanh(perturbed_feature))
        
        return torch.stack(perturbed_features)
    
    def apply_qni_ccp(self, model, features, labels, loss_fn, training=True, 
                     mix_ratio=0.3, epoch=0):
        """Apply QNI-CCP with progressive mixing"""
        # Get clean features and outputs
        clean_features = model.featureextractor(features)
        clean_features = torch.tanh(clean_features)
        clean_q_outputs = []
        for i in range(clean_features.size(0)):
            q_out = model.q_layer(clean_features[i])
            clean_q_outputs.append(q_out)
        clean_q_out_batch = torch.stack(clean_q_outputs)
        clean_outputs = model.classifier(clean_q_out_batch)
        
        if training:
            self.update_class_statistics(clean_features, labels)
        
        # Progressive mixing: start with lower ratio, increase over time
        if training:
            progressive_ratio = min(mix_ratio * (1 + epoch * 0.02), 0.6)
        else:
            progressive_ratio = mix_ratio
            
        # Apply QNI-CCP with probability based on progressive ratio
        if (training and len(self.class_means) >= 2 and 
            random.random() < progressive_ratio):
            
            perturbed_features = self.generate_perturbation(features, labels, 
                                                          model, loss_fn, training)
            
            perturbed_q_outputs = []
            for i in range(perturbed_features.size(0)):
                q_out = model.q_layer(perturbed_features[i])
                perturbed_q_outputs.append(q_out)
            perturbed_q_out_batch = torch.stack(perturbed_q_outputs)
            perturbed_outputs = model.classifier(perturbed_q_out_batch)
            
            clean_loss = loss_fn(clean_outputs, labels)
            perturbed_loss = loss_fn(perturbed_outputs, labels)
            
            # Adaptive loss mixing based on performance
            clean_weight = 0.9 if epoch < 10 else 0.7
            perturbed_weight = 1.0 - clean_weight
            
            total_loss = clean_weight * clean_loss + perturbed_weight * perturbed_loss
            
            return total_loss, clean_outputs
        else:
            loss = loss_fn(clean_outputs, labels)
            return loss, clean_outputs

# %%
# Enhanced model with better regularization
class FeatureReduce(nn.Module):
    def __init__(self, final_dim, dropout=0.5):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 8, 3, stride=2, padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(),
            nn.Dropout2d(dropout * 0.5),  # Spatial dropout

            nn.Conv2d(8, 16, 3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Dropout2d(dropout * 0.6),

            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout2d(dropout * 0.7),

            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout2d(dropout * 0.8),

            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.fc = nn.Sequential(
            nn.Linear(128, final_dim * 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(final_dim * 2, final_dim)
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

class HybridQnn(nn.Module):
    def __init__(self, n_qubits, num_classes, dropout=0.4):
        super().__init__()
        self.featureextractor = FeatureReduce(final_dim=n_qubits, dropout=dropout)
        self.q_layer = qml.qnn.TorchLayer(circuit, weight_shapes=weight_shapes)
        
        # Simplified classifier to reduce overfitting
        self.classifier = nn.Sequential(
            nn.Linear(n_qubits, 32),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Dropout(dropout * 0.8),
            nn.Linear(16, num_classes)
        )
    
    def forward(self, x):
        x = self.featureextractor(x)
        x = torch.tanh(x)
        # q_outputs = []
        for i in range(x.size(0)):
            q_out = self.q_layer(x[i])
            q_outputs.append(q_out)
        q_out_batch = torch.stack(q_outputs)
        # q_out_batch=self.q_layer(x)
        return self.classifier(q_out_batch)

# %%
# Enhanced training transforms with more augmentation
train_transform = transforms.Compose([
    transforms.Grayscale(1),
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomRotation(15),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
    # Add random noise
    transforms.Lambda(lambda x: x + torch.randn_like(x) * 0.02)
])

test_transform = transforms.Compose([
    transforms.Grayscale(1),
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# %%
# Enhanced training function with better monitoring
def train_with_qni_ccp(model, data_loader, loss_fn, optimizer, device, qni_ccp, epoch=0):
    model.train()
    total_loss = 0.0
    correct = 0
    
    for batch_idx, (inputs, labels) in enumerate(tqdm(data_loader, desc=f"Training Epoch {epoch+1}")):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        # Apply QNI-CCP with epoch information
        loss, outputs = qni_ccp.apply_qni_ccp(model, inputs, labels, loss_fn, 
                                             training=True, mix_ratio=0.3, epoch=epoch)
        
        loss.backward()
        
        # Gradient clipping with adaptive norm
        max_norm = 1.0 if epoch < 10 else 0.5
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)
        
        optimizer.step()
        total_loss += loss.item()
        correct += (outputs.argmax(dim=1) == labels).sum().item()
    
    return total_loss / len(data_loader), correct / len(data_loader.dataset)

# %%
# Enhanced evaluation function
def evaluate_with_qni_ccp(model, dataloader, loss_fn, device, qni_ccp, test_robustness=False):
    model.eval()
    total_loss, correct, total = 0.0, 0, 0
    robust_correct = 0
    
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Evaluating"):
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            if test_robustness:
                _, perturbed_outputs = qni_ccp.apply_qni_ccp(model, inputs, labels, 
                                                           loss_fn, training=False, 
                                                           mix_ratio=1.0)
                _, robust_predicted = torch.max(perturbed_outputs.data, 1)
                robust_correct += (robust_predicted == labels).sum().item()
    
    accuracy = correct / total
    robust_accuracy = robust_correct / total if test_robustness else None
    
    return total_loss / len(dataloader), accuracy, robust_accuracy

# %%
# Enhanced hyperparameters for better generalization
def get_enhanced_config():
    return {
        'n_qubits': 6,
        'n_layers': 2,
        'num_classes': 25,
        'batch_size': 32,  # Increased batch size
        'num_epochs': 100,
        'lr': 0.001,      # Reduced learning rate
        'weight_decay': 5e-4,  # Increased weight decay
        'dropout': 0.5,
        'early_stopping_patience': 10,
        'lr_scheduler_patience': 5,
        'lr_scheduler_factor': 0.7,
        'qni_ccp_epsilon': 0.05,  # Reduced epsilon
        'warmup_epochs': 8
    }

config = get_enhanced_config()

# %%
# Enhanced learning rate scheduler
class CosineAnnealingWarmupRestarts(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, T_0, T_mult=1, eta_max=0.1, T_up=0, gamma=1.0):
        self.T_0 = T_0
        self.T_mult = T_mult
        self.eta_max = eta_max
        self.T_up = T_up
        self.gamma = gamma
        self.T_cur = 0
        super().__init__(optimizer)

    def get_lr(self):
        if self.T_cur < self.T_up:
            return [(self.eta_max - base_lr) * self.T_cur / self.T_up + base_lr
                    for base_lr in self.base_lrs]
        else:
            return [base_lr + (self.eta_max - base_lr) * 
                    (1 + np.cos(np.pi * (self.T_cur - self.T_up) / (self.T_0 - self.T_up))) / 2
                    for base_lr in self.base_lrs]

    def step(self, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
        self.T_cur = epoch
        super().step(epoch)

# %%
# Enhanced training setup
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set seeds for reproducibility
def seed_all(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

seed_all(42)

# %%
# Initialize model and training components
model = HybridQnn(n_qubits=config['n_qubits'], 
                  num_classes=config['num_classes'], 
                  dropout=config['dropout'])
model.to(device)

# Enhanced optimizer with different learning rates for different parts
optimizer = torch.optim.AdamW([
    {'params': model.featureextractor.parameters(), 'lr': config['lr']},
    {'params': model.q_layer.parameters(), 'lr': config['lr'] * 0.5},  # Lower LR for quantum part
    {'params': model.classifier.parameters(), 'lr': config['lr']}
], weight_decay=config['weight_decay'])

# Enhanced scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=config['lr_scheduler_factor'], 
    patience=config['lr_scheduler_patience']
)

# Initialize enhanced QNI-CCP
qni_ccp = QNI_CCP(n_qubits=config['n_qubits'], 
                  n_layers=config['n_layers'], 
                  num_classes=config['num_classes'],
                  epsilon_base=config['qni_ccp_epsilon'], 
                  update_freq=100, 
                  alpha=0.9,
                  adaptive_epsilon=True)

# %%
# Enhanced training loop with better monitoring
def enhanced_training_loop(model, train_loader, val_loader, test_loader, 
                          loss_fn, optimizer, scheduler, qni_ccp, config, device):
    
    train_losses, val_losses = [], []
    train_accs, val_accs = [], []
    robust_accs = []
    
    best_val_loss = float('inf')
    epochs_without_improvement = 0
    
    print("Starting enhanced training with adaptive QNI-CCP...")
    print(f"Configuration: {config}")
    
    for epoch in range(config['num_epochs']):
        # Training
        train_loss, train_acc = train_with_qni_ccp(model, train_loader, loss_fn, 
                                                  optimizer, device, qni_ccp, epoch)
        
        # Validation
        val_loss, val_acc, robust_acc = evaluate_with_qni_ccp(model, val_loader, loss_fn, 
                                                              device, qni_ccp, 
                                                              test_robustness=True)
        
        # Update adaptive epsilon in QNI-CCP
        qni_ccp.update_epsilon(train_acc, val_acc)
        
        # Record metrics
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accs.append(train_acc)
        val_accs.append(val_acc)
        robust_accs.append(robust_acc)
        
        # Learning rate scheduling
        scheduler.step(val_loss)
        
        # Enhanced logging
        print(f"Epoch {epoch+1}/{config['num_epochs']}")
        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
        print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
        print(f"Robust Acc: {robust_acc:.4f}")
        print(f"Overfitting Gap: {(train_acc - val_acc):.4f}")
        print(f"Current QNI-CCP ε: {qni_ccp.current_epsilon:.4f}")
        print(f"LR: {optimizer.param_groups[0]['lr']:.6f}")
        print("-" * 60)
        
        # Enhanced early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_without_improvement = 0
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'qni_ccp_state': {
                    'class_means': qni_ccp.class_means,
                    'class_counts': qni_ccp.class_counts,
                    'quantum_complexity': qni_ccp.quantum_complexity,
                    'current_epsilon': qni_ccp.current_epsilon
                },
                'epoch': epoch,
                'best_val_loss': best_val_loss,
                'config': config
            }, "best_model_enhanced_qni_ccp.pth")
            print("💾 Best model saved.")
        else:
            epochs_without_improvement += 1
        
        if epochs_without_improvement >= config['early_stopping_patience']:
            print(f"⏹️ Early stopping triggered after {epoch+1} epochs.")
            break
    
    return train_losses, val_losses, train_accs, val_accs, robust_accs

# %%
print("Enhanced QNI-CCP training implementation ready!")
print("Key improvements:")
print("1. Adaptive perturbation strength based on overfitting detection")
print("2. Enhanced regularization with spatial dropout")
print("3. Progressive perturbation mixing")
print("4. Better gradient clipping and learning rate scheduling")
print("5. Improved numerical stability")
print("6. Enhanced data augmentation")
print("7. Different learning rates for different model components")
print("8. Better early stopping and model checkpointing")

# %%
# Define quantum circuit and weight shapes (needed for the model)
dev = qml.device("lightning.qubit", wires=config['n_qubits'], shots=None)

@qml.qnode(dev, interface="torch")
def circuit(inputs, weights):
    # Initial encoding
    for i in range(config['n_qubits']):
        qml.RY(inputs[i], wires=i)
    
    # Variational layers with improved entanglement
    for l in range(weights.shape[0]):
        # Parameterized rotations
        for i in range(config['n_qubits']):
            qml.RY(weights[l][i][0], wires=i)
            qml.RZ(weights[l][i][1], wires=i)
        
        # Circular entanglement for better connectivity
        for i in range(config['n_qubits']):
            qml.CNOT(wires=[i, (i + 1) % config['n_qubits']])
        
        # Data re-encoding (every other layer)
        if l % 2 == 1:
            for i in range(config['n_qubits']):
                qml.RY(inputs[i] * 0.1, wires=i)
    
    return [qml.expval(qml.PauliZ(j)) for j in range(config['n_qubits'])]

weight_shapes = {"weights": (config['n_layers'], config['n_qubits'], 2)}

# %%
# COMPLETE EXECUTION BLOCK - Run this to train the model
print("="*80)
print("🚀 STARTING ENHANCED QNI-CCP TRAINING")
print("="*80)

# Load datasets (update paths as needed)
train_dataset = ImageFolder('/home/netsec1/dataset_folder/malimg_dataset/train', transform=train_transform)
val_dataset = ImageFolder('/home/netsec1/dataset_folder/malimg_dataset/val', transform=test_transform)
test_dataset = ImageFolder('/home/netsec1/dataset_folder/malimg_dataset/test', transform=test_transform)

print(f"Dataset loaded:")
print(f"  Train: {len(train_dataset)} samples")
print(f"  Val: {len(val_dataset)} samples")
print(f"  Test: {len(test_dataset)} samples")

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], 
                         shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], 
                       shuffle=False, num_workers=8, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], 
                        shuffle=False, num_workers=8, pin_memory=True)

# Initialize loss function with class weights
labels = [label for _, label in train_dataset.samples]
class_weights = compute_class_weight(class_weight="balanced", classes=np.unique(labels), y=labels)
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
loss_fn = Focal_loss(alpha=1, gamma=2)

print(f"Loss function initialized with {len(class_weights)} classes")

# %%
# Run the enhanced training
print("Starting enhanced training loop...")
train_losses, val_losses, train_accs, val_accs, robust_accs = enhanced_training_loop(
    model, train_loader, val_loader, test_loader, 
    loss_fn, optimizer, scheduler, qni_ccp, config, device
)

# %%
# Training completion and visualization
print("\n" + "="*80)
print("📈 TRAINING COMPLETED - GENERATING VISUALIZATIONS")
print("="*80)

# Plot comprehensive training results
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# Loss curves
axes[0, 0].plot(train_losses, label='Train Loss', linewidth=2)
axes[0, 0].plot(val_losses, label='Val Loss', linewidth=2)
axes[0, 0].set_title('Loss Curves', fontsize=14, fontweight='bold')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Accuracy curves
axes[0, 1].plot(train_accs, label='Train Acc', linewidth=2)
axes[0, 1].plot(val_accs, label='Val Acc', linewidth=2)
axes[0, 1].set_title('Accuracy Curves', fontsize=14, fontweight='bold')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Robustness analysis
axes[0, 2].plot(val_accs, label='Standard Acc', linewidth=2)
axes[0, 2].plot(robust_accs, label='Robust Acc', linewidth=2)
axes[0, 2].set_title('Robustness Analysis', fontsize=14, fontweight='bold')
axes[0, 2].set_xlabel('Epoch')
axes[0, 2].set_ylabel('Accuracy')
axes[0, 2].legend()
axes[0, 2].grid(True, alpha=0.3)

# Overfitting gap
overfitting_gap = [train_accs[i] - val_accs[i] for i in range(len(train_accs))]
axes[1, 0].plot(overfitting_gap, label='Overfitting Gap', linewidth=2, color='red')
axes[1, 0].axhline(y=0.1, color='orange', linestyle='--', alpha=0.7, label='Warning Threshold')
axes[1, 0].axhline(y=0.15, color='red', linestyle='--', alpha=0.7, label='Critical Threshold')
axes[1, 0].set_title('Overfitting Monitoring', fontsize=14, fontweight='bold')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Train Acc - Val Acc')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Learning rate evolution (if available)
if hasattr(scheduler, 'get_last_lr'):
    lr_history = [scheduler.get_last_lr()[0] for _ in range(len(train_losses))]
    axes[1, 1].plot(lr_history, label='Learning Rate', linewidth=2, color='green')
    axes[1, 1].set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Learning Rate')
    axes[1, 1].set_yscale('log')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
else:
    axes[1, 1].text(0.5, 0.5, 'LR History\nNot Available', 
                   horizontalalignment='center', verticalalignment='center',
                   transform=axes[1, 1].transAxes, fontsize=12)
    axes[1, 1].set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')

# QNI-CCP epsilon evolution
if hasattr(qni_ccp, 'performance_history') and qni_ccp.performance_history:
    epsilon_history = [qni_ccp.current_epsilon] * len(train_losses)  # Simplified
    axes[1, 2].plot(epsilon_history, label='QNI-CCP ε', linewidth=2, color='purple')
    axes[1, 2].set_title('Adaptive Perturbation Strength', fontsize=14, fontweight='bold')
    axes[1, 2].set_xlabel('Epoch')
    axes[1, 2].set_ylabel('Epsilon')
    axes[1, 2].legend()
    axes[1, 2].grid(True, alpha=0.3)
else:
    axes[1, 2].text(0.5, 0.5, f'Final ε: {qni_ccp.current_epsilon:.4f}', 
                   horizontalalignment='center', verticalalignment='center',
                   transform=axes[1, 2].transAxes, fontsize=12)
    axes[1, 2].set_title('Adaptive Perturbation Strength', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()

# %%
# Load best model and final evaluation
print("\n" + "="*80)
print("🏆 FINAL EVALUATION WITH BEST MODEL")
print("="*80)

# Load best model
if os.path.exists("best_model_enhanced_qni_ccp.pth"):
    checkpoint = torch.load("best_model_enhanced_qni_ccp.pth", weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # Restore QNI-CCP state
    qni_ccp.class_means = checkpoint['qni_ccp_state']['class_means']
    qni_ccp.class_counts = checkpoint['qni_ccp_state']['class_counts']
    qni_ccp.current_epsilon = checkpoint['qni_ccp_state']['current_epsilon']
    
    print(f"✅ Best model loaded from epoch {checkpoint['epoch']+1}")
    print(f"✅ Best validation loss: {checkpoint['best_val_loss']:.4f}")
    print(f"✅ Final QNI-CCP ε: {qni_ccp.current_epsilon:.4f}")
else:
    print("⚠️  No saved model found, using current model state")

# Final test evaluation
test_loss, test_acc, test_robust_acc = evaluate_with_qni_ccp(
    model, test_loader, loss_fn, device, qni_ccp, test_robustness=True
)

print(f"\n🎯 FINAL TEST RESULTS:")
print(f"   Test Loss: {test_loss:.4f}")
print(f"   Test Accuracy: {test_acc:.4f}")
print(f"   Test Robust Accuracy: {test_robust_acc:.4f}")
print(f"   Robustness Drop: {(test_acc - test_robust_acc):.4f}")

# %%
# Comprehensive robustness analysis
print("\n" + "="*80)
print("🛡️  COMPREHENSIVE ROBUSTNESS ANALYSIS")
print("="*80)

original_epsilon = qni_ccp.current_epsilon
epsilons = [0.01, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3]
robustness_results = []

print("Testing robustness across different perturbation strengths...")
for eps in epsilons:
    qni_ccp.current_epsilon = eps
    _, _, robust_acc = evaluate_with_qni_ccp(model, test_loader, loss_fn, 
                                           device, qni_ccp, test_robustness=True)
    robustness_results.append(robust_acc)
    print(f"   ε = {eps:.2f}: Robust Accuracy = {robust_acc:.4f}")

# Restore original epsilon
qni_ccp.current_epsilon = original_epsilon

# Plot robustness analysis
plt.figure(figsize=(12, 8))
plt.subplot(2, 2, 1)
plt.plot(epsilons, robustness_results, 'b-o', linewidth=3, markersize=8)
plt.axhline(y=test_acc, color='red', linestyle='--', alpha=0.7, label=f'Clean Accuracy ({test_acc:.3f})')
plt.xlabel('Perturbation Strength (ε)')
plt.ylabel('Robust Accuracy')
plt.title('QNI-CCP Robustness vs Perturbation Strength')
plt.legend()
plt.grid(True, alpha=0.3)

# Robustness drop analysis
robustness_drop = [test_acc - acc for acc in robustness_results]
plt.subplot(2, 2, 2)
plt.plot(epsilons, robustness_drop, 'r-s', linewidth=3, markersize=8)
plt.xlabel('Perturbation Strength (ε)')
plt.ylabel('Accuracy Drop')
plt.title('Robustness Degradation Analysis')
plt.grid(True, alpha=0.3)

# Training summary
plt.subplot(2, 2, 3)
final_epochs = len(train_losses)
x_epochs = range(1, final_epochs + 1)
plt.plot(x_epochs, train_accs, label='Train', linewidth=2)
plt.plot(x_epochs, val_accs, label='Validation', linewidth=2)
plt.plot(x_epochs, robust_accs, label='Robust', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training Progress Summary')
plt.legend()
plt.grid(True, alpha=0.3)

# Final metrics comparison
plt.subplot(2, 2, 4)
metrics = ['Train Acc', 'Val Acc', 'Test Acc', 'Robust Acc']
values = [train_accs[-1], val_accs[-1], test_acc, test_robust_acc]
colors = ['blue', 'green', 'orange', 'red']
bars = plt.bar(metrics, values, color=colors, alpha=0.7)
plt.ylabel('Accuracy')
plt.title('Final Performance Metrics')
plt.ylim(0, 1)

# Add value labels on bars
for bar, value in zip(bars, values):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
             f'{value:.3f}', ha='center', va='bottom', fontweight='bold')

plt.tight_layout()
plt.show()

# %%
# Summary report
print("\n" + "="*80)
print("📋 TRAINING SUMMARY REPORT")
print("="*80)

print(f"Configuration Used:")
for key, value in config.items():
    print(f"  {key}: {value}")

print(f"\nTraining Results:")
print(f"  Total epochs completed: {len(train_losses)}")
print(f"  Final train accuracy: {train_accs[-1]:.4f}")
print(f"  Final validation accuracy: {val_accs[-1]:.4f}")
print(f"  Final overfitting gap: {(train_accs[-1] - val_accs[-1]):.4f}")

print(f"\nTest Results:")
print(f"  Test accuracy: {test_acc:.4f}")
print(f"  Test robust accuracy: {test_robust_acc:.4f}")
print(f"  Robustness preservation: {(test_robust_acc/test_acc*100):.1f}%")

print(f"\nQNI-CCP Effectiveness:")
print(f"  Classes tracked: {len(qni_ccp.class_means)}")
print(f"  Final epsilon: {qni_ccp.current_epsilon:.4f}")
print(f"  Quantum complexity factor: {qni_ccp.quantum_complexity:.4f}")

print(f"\nModel Performance:")
best_val_idx = np.argmin(val_losses)
print(f"  Best validation loss: {val_losses[best_val_idx]:.4f} (epoch {best_val_idx+1})")
print(f"  Best validation accuracy: {val_accs[best_val_idx]:.4f}")

print("\n" + "="*80)
print("🎉 ENHANCED QNI-CCP TRAINING COMPLETED SUCCESSFULLY!")
print("="*80)

In [None]:
# %%
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import transforms
from collections import Counter, defaultdict
import numpy as np 
import random 
import os 
from torchvision.datasets import ImageFolder
import pennylane as qml 
import matplotlib.pyplot as plt 
from tqdm.notebook import tqdm 
import torch.nn.functional as F
from sklearn.utils.class_weight import compute_class_weight

# %%
class Focal_loss(nn.Module):
    def __init__(self, alpha, gamma, reduction="mean"):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        CE = F.cross_entropy(inputs, targets, reduction="none")
        pt = torch.exp(-CE)
        focal_loss = self.alpha * ((1 - pt) ** self.gamma) * CE

        if self.reduction == "mean":
            return focal_loss.mean()
        elif self.reduction == "sum":
            return focal_loss.sum()
        return focal_loss

# %%
class QNI_CCP:
    """
    Quantum-aware Noise Injection with Class-Conditional Perturbation
    """
    def __init__(self, n_qubits, n_layers, num_classes, epsilon_base=0.1, 
                 update_freq=100, alpha=0.9):
        self.n_qubits = n_qubits
        self.n_layers = n_layers
        self.num_classes = num_classes
        self.epsilon_base = epsilon_base
        self.update_freq = update_freq
        self.alpha = alpha
        
        # Class-conditional statistics
        self.class_means = {}
        self.class_counts = defaultdict(int)
        self.update_counter = 0
        
        # Quantum complexity metrics
        self.circuit_depth = n_layers
        self.entangling_gates = (n_qubits - 1) * n_layers  # CNOT gates
        self.quantum_complexity = self._compute_quantum_complexity()
        
    def _compute_quantum_complexity(self):
        """Compute quantum-aware scaling factor based on circuit properties"""
        # Normalize complexity metrics
        depth_factor = min(self.circuit_depth / 10.0, 1.0)
        entanglement_factor = min(self.entangling_gates / (self.n_qubits * 5), 1.0)
        
        # Higher complexity -> smaller perturbations
        complexity = 0.5 * depth_factor + 0.5 * entanglement_factor
        epsilon_q = 1.0 - complexity  # Inverse relationship
        return max(epsilon_q, 0.1)  # Minimum threshold
    
    def update_class_statistics(self, features, labels):
        """Update running class means using exponential moving average"""
        features = features.detach().cpu()
        labels = labels.detach().cpu()
        
        for i, label in enumerate(labels):
            label_int = label.item()
            feature_vec = features[i]
            
            if label_int not in self.class_means:
                self.class_means[label_int] = feature_vec.clone()
            else:
                # Exponential moving average
                self.class_means[label_int] = (self.alpha * self.class_means[label_int] + 
                                             (1 - self.alpha) * feature_vec)
            
            self.class_counts[label_int] += 1
        
        self.update_counter += 1
    
    def compute_gradient_sensitivity(self, model, features, labels, loss_fn):
        """
        Compute gradient sensitivity for feature-level perturbations.
        Sensitivity is defined as the absolute gradient of the loss w.r.t. extracted features.
        """
        device = features.device
    
        # Extract features (no grad needed here)
        with torch.no_grad():
            extracted_features = model.featureextractor(features)
    
        # Enable gradient tracking for extracted features
        extracted_features.requires_grad_(True)
    
        # Forward pass through quantum layer and classifier
        q_outputs = []
        for i in range(extracted_features.size(0)):
            q_out = model.q_layer(torch.tanh(extracted_features[i]))
            q_outputs.append(q_out)
        q_out_batch = torch.stack(q_outputs)
    
        # Classification output
        outputs = model.classifier(q_out_batch)
    
        # Compute loss
        loss = loss_fn(outputs, labels)
    
        # Compute gradients w.r.t. extracted features
        gradients = torch.autograd.grad(loss, extracted_features,
                                        retain_graph=True, create_graph=True)[0]
    
        # Sensitivity = absolute value of gradient
        sensitivity = torch.abs(gradients).detach()
    
        return sensitivity

    def generate_perturbation(self, features, labels, model, loss_fn, training=True):
        """
        Generate QNI-CCP perturbations on pre-tanh features using class-conditional shifts.
        """
        batch_size = features.size(0)
        device = features.device
    
        # Get raw features (pre-tanh)
        with torch.no_grad():
            extracted_features = model.featureextractor(features)
    
        # Skip if not enough class statistics
        if len(self.class_means) < 2:
            return torch.tanh(extracted_features)
    
        # Compute gradient sensitivity (still using pre-tanh features)
        sensitivity = self.compute_gradient_sensitivity(model, features, labels, loss_fn)
    
        perturbed_features = []
    
        for i in range(batch_size):
            current_label = labels[i].item()
            current_feature = extracted_features[i]
            current_sensitivity = sensitivity[i]
    
            # Skip if insufficient samples for this class
            available_classes = [c for c in self.class_means
                                 if c != current_label and self.class_counts[c] > 10]
            if not available_classes:
                perturbed_features.append(torch.tanh(current_feature))
                continue
    
            # Pick a target class randomly
            target_class = random.choice(available_classes)
            target_mean = self.class_means[target_class].to(device)
    
            # Direction and scaling
            direction = target_mean - current_feature
            epsilon_q = self.quantum_complexity
            scaled_perturbation = self.epsilon_base * epsilon_q * current_sensitivity * direction
    
            perturbed_feature = current_feature + scaled_perturbation
    
            # Push through tanh only at the end (constrain to quantum input space)
            perturbed_features.append(torch.tanh(perturbed_feature))
    
        return torch.stack(perturbed_features)
    
    def apply_qni_ccp(self, model, features, labels, loss_fn, training=True, 
                     mix_ratio=0.5):
        """Apply QNI-CCP with clean/perturbed mixing"""
        # Always get clean features and outputs first
        clean_features = model.featureextractor(features)
        clean_features = torch.tanh(clean_features)
        clean_q_outputs = []
        for i in range(clean_features.size(0)):
            q_out = model.q_layer(clean_features[i])
            clean_q_outputs.append(q_out)
        clean_q_out_batch = torch.stack(clean_q_outputs)
        clean_outputs = model.classifier(clean_q_out_batch)
        
        # Update class statistics during training (always do this)
        if training:
            self.update_class_statistics(clean_features, labels)
        
        # Apply QNI-CCP if we have enough class statistics and training
        if training and len(self.class_means) >= 2 and random.random() < mix_ratio:
            # Generate perturbed features
            perturbed_features = self.generate_perturbation(features, labels, 
                                                          model, loss_fn, training)
            
            # Process through quantum layer and classifier
            perturbed_q_outputs = []
            for i in range(perturbed_features.size(0)):
                q_out = model.q_layer(perturbed_features[i])
                perturbed_q_outputs.append(q_out)
            perturbed_q_out_batch = torch.stack(perturbed_q_outputs)
            perturbed_outputs = model.classifier(perturbed_q_out_batch)
            
            # Mixed loss
            clean_loss = loss_fn(clean_outputs, labels)
            perturbed_loss = loss_fn(perturbed_outputs, labels)
            total_loss = 0.8 * clean_loss + 0.2 * perturbed_loss
            
            return total_loss, clean_outputs
        else:
            # Standard forward pass
            loss = loss_fn(clean_outputs, labels)
            return loss, clean_outputs

# %%
# Set seeds for reproducibility
def seed_all(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

seed_all(42)

# %%
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# %%
n_qubits = 6
n_layers = 2
num_classes = 25
batch_size = 16
num_epochs = 50
lr = 0.0005

# %%
train_transform = transforms.Compose([
    transforms.Grayscale(1),
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomRotation(10),
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

test_transform = transforms.Compose([
    transforms.Grayscale(1),
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

# %%
# Update paths as needed
train_dataset = ImageFolder('/home/netsec1/dataset_folder/malimg_dataset/train', transform=train_transform)
test_dataset = ImageFolder('/home/netsec1/dataset_folder/malimg_dataset/test', transform=test_transform)
val_dataset = ImageFolder('/home/netsec1/dataset_folder/malimg_dataset/val', transform=test_transform)

# %%
labels = [label for _, label in train_dataset.samples]
class_weights = compute_class_weight(class_weight="balanced", classes=np.unique(labels), y=labels)
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
loss_fn = Focal_loss(alpha=1, gamma=2)

# %%
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8,pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=8,pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=8,pin_memory=True)

# %%
dev = qml.device("default.qubit", wires=n_qubits,shots=None)

# %%
# @qml.qnode(dev, interface="torch")
# def circuit(inputs, weights):
#     # Encoding
#     for i in range(n_qubits):
#         qml.RY(inputs[i], wires=i)
    
#     # Variational layers
#     for l in range(weights.shape[0]):
#         for i in range(n_qubits):
#             qml.RY(weights[l][i], wires=i)
        
#         for i in range(n_qubits - 1):
#             qml.CNOT(wires=[i, i + 1])
    
#     return [qml.expval(qml.PauliZ(j)) for j in range(n_qubits)]

@qml.qnode(dev, interface="torch")
def circuit(inputs, weights):
    # Initial encoding
    for i in range(n_qubits):
        qml.RY(inputs[i], wires=i)
    
    # Variational layers with improved entanglement
    for l in range(weights.shape[0]):
        # Parameterized rotations
        for i in range(n_qubits):
            qml.RY(weights[l][i][0], wires=i)
            qml.RZ(weights[l][i][1], wires=i)
        
        # Circular entanglement for better connectivity
        for i in range(n_qubits):
            qml.CNOT(wires=[i, (i + 1) % n_qubits])
        
        # Data re-encoding (every other layer)
        if l % 2 == 1:
            for i in range(n_qubits):
                qml.RY(inputs[i] * 0.1, wires=i)
    
    return [qml.expval(qml.PauliZ(j)) for j in range(n_qubits)]

# %%
weight_shapes = {"weights": (n_layers, n_qubits,2)}

# %%
class FeatureReduce(nn.Module):
    def __init__(self, final_dim, dropout=0.4):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 8, 3, stride=2, padding=1),    # 128 -> 64
            nn.BatchNorm2d(8),
            nn.ReLU(),
            nn.Dropout(dropout),

            nn.Conv2d(8, 16, 3, stride=2, padding=1),   # 64 -> 32
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Dropout(dropout),

            nn.Conv2d(16, 32, 3, stride=2, padding=1),  # 32 -> 16
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout(dropout),

            nn.Conv2d(32, 64, 3, stride=2, padding=1),  # 16 -> 8
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout(dropout),

            nn.Conv2d(64, 128, 3, stride=2, padding=1),  # 8 -> 4
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))                # 4×4 -> 1×1
        )
        self.fc = nn.Linear(128, final_dim)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

# %%
class HybridQnn(nn.Module):
    def __init__(self, n_qubits, num_classes):
        super().__init__()
        self.featureextractor = FeatureReduce(final_dim=n_qubits)
        self.q_layer = qml.qnn.TorchLayer(circuit, weight_shapes=weight_shapes)
        # 4-layer MLP after quantum layer
        self.classifier = nn.Sequential(
            nn.Linear(n_qubits, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, num_classes)
        )
    
    def forward(self, x):
        x = self.featureextractor(x)
        x = torch.tanh(x)
        # Process each sample individually through quantum layer
        q_outputs = []
        for i in range(x.size(0)):
            q_out = self.q_layer(x[i])
            q_outputs.append(q_out)
        q_out_batch = torch.stack(q_outputs)
        return self.classifier(q_out_batch)

# %%
model = HybridQnn(n_qubits=n_qubits, num_classes=num_classes)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)

# Initialize QNI-CCP
qni_ccp = QNI_CCP(n_qubits=n_qubits, n_layers=n_layers, num_classes=num_classes,
                  epsilon_base=0.1, update_freq=100, alpha=0.9)

# %%
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
#     optimizer, mode="min", factor=0.5, patience=5
# )

# Add warmup scheduler:
def get_lr_scheduler(optimizer, warmup_epochs=5):
    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            return (epoch + 1) / warmup_epochs
        return 1.0
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

# Use it:
scheduler = get_lr_scheduler(optimizer, warmup_epochs=5)

# %%
def train_with_qni_ccp(model, data_loader, loss_fn, optimizer, device, qni_ccp):
    model.train()
    total_loss = 0.0
    correct = 0
    
    for inputs, labels in tqdm(data_loader, desc="Training"):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        # Apply QNI-CCP during training
        loss, outputs = qni_ccp.apply_qni_ccp(model, inputs, labels, loss_fn, 
                                             training=True, mix_ratio=0.5)
        
        loss.backward()
        
        # Clip gradients for stability
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        total_loss += loss.item()
        correct += (outputs.argmax(dim=1) == labels).sum().item()

        mix_ratio = min(0.2 + 0.02 * epoch, 0.5)

    
    return total_loss / len(data_loader), correct / len(data_loader.dataset)

# %%
def evaluate_with_qni_ccp(model, dataloader, loss_fn, device, qni_ccp, test_robustness=False):
    model.eval()
    total_loss, correct, total = 0.0, 0, 0
    robust_correct = 0  # For robustness testing
    
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Evaluating"):
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Standard evaluation
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # Robustness testing with QNI-CCP perturbations
            if test_robustness:
                # Apply QNI-CCP perturbations
                _, perturbed_outputs = qni_ccp.apply_qni_ccp(model, inputs, labels, 
                                                           loss_fn, training=False, 
                                                           mix_ratio=1.0)
                _, robust_predicted = torch.max(perturbed_outputs.data, 1)
                robust_correct += (robust_predicted == labels).sum().item()
    
    accuracy = correct / total
    robust_accuracy = robust_correct / total if test_robustness else None
    
    return total_loss / len(dataloader), accuracy, robust_accuracy

# %%
# Training loop with QNI-CCP
train_losses, val_losses = [], []
train_accs, val_accs = [], []
robust_accs = []

# Early Stopping variables
early_stopping_patience = 6
best_val_loss = float('inf')
epochs_without_improvement = 0

print("Starting training with QNI-CCP...")
print(f"Quantum complexity factor: {qni_ccp.quantum_complexity:.3f}")
print(f"Dataset info:")
print(f"  Train samples: {len(train_dataset)}")
print(f"  Validation samples: {len(val_dataset)}")
print(f"  Test samples: {len(test_dataset)}")
print(f"  Number of classes: {num_classes}")

# %%
for epoch in range(num_epochs):
    # Training with QNI-CCP
    train_loss, train_acc = train_with_qni_ccp(model, train_loader, loss_fn, 
                                              optimizer, device, qni_ccp)
    
    # Validation (standard and robustness testing)
    val_loss, val_acc, robust_acc = evaluate_with_qni_ccp(model, val_loader, loss_fn, 
                                                          device, qni_ccp, 
                                                          test_robustness=True)
    
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accs.append(train_acc)
    val_accs.append(val_acc)
    robust_accs.append(robust_acc)
    
    scheduler.step(val_loss)
    
    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
    print(f"Robust Acc: {robust_acc:.4f}")
    print(f"Class means tracked: {len(qni_ccp.class_means)}")
    print(f"QNI-CCP active: {'Yes' if len(qni_ccp.class_means) >= 2 else 'No'}")
    print("-" * 50)
    
    # Visualization
    if (epoch + 1) % 5 == 0:
        plt.figure(figsize=(12, 4))
        
        plt.subplot(1, 3, 1)
        plt.plot(train_losses, label='Train Loss')
        plt.plot(val_losses, label='Val Loss')
        plt.title('Training Progress')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        
        plt.subplot(1, 3, 2)
        plt.plot(train_accs, label='Train Acc')
        plt.plot(val_accs, label='Val Acc')
        plt.title('Accuracy Progress')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.legend()
        
        plt.subplot(1, 3, 3)
        plt.plot(val_accs, label='Standard Acc')
        plt.plot(robust_accs, label='Robust Acc')
        plt.title('Robustness Analysis')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.legend()
        
        plt.tight_layout()
        plt.show()
    
    # Early Stopping Logic
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_without_improvement = 0
        torch.save({
            'model_state_dict': model.state_dict(),
            'qni_ccp_state': {
                'class_means': qni_ccp.class_means,
                'class_counts': qni_ccp.class_counts,
                'quantum_complexity': qni_ccp.quantum_complexity
            }
        }, "best_model_qni_ccp.pth")
        print("💾 Best model with QNI-CCP saved.")
    else:
        epochs_without_improvement += 1
        print(f"🕒 No improvement for {epochs_without_improvement} epoch(s).")
    
    if epochs_without_improvement >= early_stopping_patience:
        print(f"⏹️ Early stopping triggered after {epoch+1} epochs.")
        break

# %%
# Final evaluation on test set
print("\n" + "="*50)
print("FINAL EVALUATION ON TEST SET")
print("="*50)

# Load best model
checkpoint = torch.load("best_model_qni_ccp.pth", weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])

# Restore QNI-CCP state
qni_ccp.class_means = checkpoint['qni_ccp_state']['class_means']
qni_ccp.class_counts = checkpoint['qni_ccp_state']['class_counts']

# Test evaluation
test_loss, test_acc, test_robust_acc = evaluate_with_qni_ccp(model, test_loader, loss_fn, 
                                                           device, qni_ccp, 
                                                           test_robustness=True)

print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.4f}")
print(f"Test Robust Accuracy: {test_robust_acc:.4f}")
print(f"Robustness Drop: {(test_acc - test_robust_acc):.4f}")

# %%
# Robustness analysis with different perturbation strengths
print("\n" + "="*50)
print("ROBUSTNESS ANALYSIS WITH DIFFERENT PERTURBATION STRENGTHS")
print("="*50)

epsilons = [0.05, 0.1, 0.15, 0.2, 0.25]
robustness_results = []

for eps in epsilons:
    qni_ccp.epsilon_base = eps
    _, _, robust_acc = evaluate_with_qni_ccp(model, test_loader, loss_fn, 
                                           device, qni_ccp, test_robustness=True)
    robustness_results.append(robust_acc)
    print(f"Epsilon: {eps:.2f} | Robust Accuracy: {robust_acc:.4f}")

# Plot robustness vs perturbation strength
plt.figure(figsize=(10, 6))
plt.plot(epsilons, robustness_results, 'b-o', linewidth=2, markersize=8)
plt.xlabel('Perturbation Strength (ε)')
plt.ylabel('Robust Accuracy')
plt.title('QNI-CCP Robustness Analysis')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"\nQNI-CCP Integration Complete!")
print(f"Final Test Accuracy: {test_acc:.4f}")
print(f"Final Robust Accuracy: {test_robust_acc:.4f}")