In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/nsl-kdd-augmented/smote_augmented.csv
/kaggle/input/nslkdd/KDDTest+.arff
/kaggle/input/nslkdd/KDDTest-21.arff
/kaggle/input/nslkdd/KDDTest1.jpg
/kaggle/input/nslkdd/KDDTrain+.txt
/kaggle/input/nslkdd/KDDTrain+_20Percent.txt
/kaggle/input/nslkdd/KDDTest-21.txt
/kaggle/input/nslkdd/KDDTest+.txt
/kaggle/input/nslkdd/KDDTrain+.arff
/kaggle/input/nslkdd/index.html
/kaggle/input/nslkdd/KDDTrain+_20Percent.arff
/kaggle/input/nslkdd/KDDTrain1.jpg
/kaggle/input/nslkdd/nsl-kdd/KDDTest+.arff
/kaggle/input/nslkdd/nsl-kdd/KDDTest-21.arff
/kaggle/input/nslkdd/nsl-kdd/KDDTest1.jpg
/kaggle/input/nslkdd/nsl-kdd/KDDTrain+.txt
/kaggle/input/nslkdd/nsl-kdd/KDDTrain+_20Percent.txt
/kaggle/input/nslkdd/nsl-kdd/KDDTest-21.txt
/kaggle/input/nslkdd/nsl-kdd/KDDTest+.txt
/kaggle/input/nslkdd/nsl-kdd/KDDTrain+.arff
/kaggle/input/nslkdd/nsl-kdd/index.html
/kaggle/input/nslkdd/nsl-kdd/KDDTrain+_20Percent.arff
/kaggle/input/nslkdd/nsl-kdd/KDDTrain1.jpg


In [9]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import QuantileTransformer, LabelEncoder
from sklearn.metrics import classification_report, f1_score
from sklearn.utils.class_weight import compute_class_weight
from tqdm import tqdm
import math

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

columns = [
    'duration', 'protocol_type', 'service', 'flag', 'src_bytes', 'dst_bytes', 'land',
    'wrong_fragment', 'urgent', 'hot', 'num_failed_logins', 'logged_in', 'num_compromised',
    'root_shell', 'su_attempted', 'num_root', 'num_file_creations', 'num_shells',
    'num_access_files', 'num_outbound_cmds', 'is_host_login', 'is_guest_login', 'count',
    'srv_count', 'serror_rate', 'srv_serror_rate', 'rerror_rate', 'srv_rerror_rate',
    'same_srv_rate', 'diff_srv_rate', 'srv_diff_host_rate', 'dst_host_count',
    'dst_host_srv_count', 'dst_host_same_srv_rate', 'dst_host_diff_srv_rate',
    'dst_host_same_src_port_rate', 'dst_host_srv_diff_host_rate', 'dst_host_serror_rate',
    'dst_host_srv_serror_rate', 'dst_host_rerror_rate', 'dst_host_srv_rerror_rate',
    'outcome', 'level'
]

# ===========================================
# DATA PREPROCESSING
# ===========================================
df_train = pd.read_csv("/kaggle/input/nsl-kdd-augmented/smote_augmented.csv") 
df_test = pd.read_csv("/kaggle/input/nslkdd/KDDTest+.txt", header=None)
df_test.columns = columns

train_labels = set(df_train['outcome'].unique())
df_test = df_test[df_test['outcome'].isin(train_labels)].reset_index(drop=True)

cat_cols = ['protocol_type', 'service', 'flag']
num_cols = [c for c in df_train.columns if c not in cat_cols + ['outcome', 'level']]

# Label encoding with careful test set handling
cat_dims = []
for col in cat_cols:
    le_c = LabelEncoder()
    df_train[col] = le_c.fit_transform(df_train[col].astype(str))
    train_classes = {cls: i for i, cls in enumerate(le_c.classes_)}
    df_test[col] = df_test[col].map(lambda x: train_classes.get(str(x), 0))
    cat_dims.append(len(le_c.classes_))

# Quantile transformation
qt = QuantileTransformer(output_distribution='normal', random_state=42)
X_train_num = qt.fit_transform(df_train[num_cols]).astype(np.float32)
X_test_num = qt.transform(df_test[num_cols]).astype(np.float32)

le_target = LabelEncoder()
y_train = le_target.fit_transform(df_train['outcome'])
y_test = le_target.transform(df_test['outcome'])

# Compute class weights for focal loss
class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(DEVICE)

print(f"Number of classes: {len(le_target.classes_)}")
print(f"Training samples: {len(y_train)}, Test samples: {len(y_test)}")

# ===========================================
# NOVEL ARCHITECTURE: HCAN
# Hierarchical Class-Aware Network with Multi-Scale Attention
# ===========================================

class AdaptiveFeatureGate(nn.Module):
    """Dynamic feature selection with learnable thresholds"""
    def __init__(self, dim):
        super().__init__()
        self.gate_weight = nn.Parameter(torch.ones(dim))
        self.gate_bias = nn.Parameter(torch.zeros(dim))
        self.threshold = nn.Parameter(torch.tensor(0.5))
        
    def forward(self, x):
        importance = torch.sigmoid(x * self.gate_weight + self.gate_bias)
        mask = (importance > self.threshold).float()
        # Straight-through estimator for gradients
        mask = mask - importance.detach() + importance
        return x * mask

class MultiScaleFeatureExtractor(nn.Module):
    """Extract features at multiple temporal scales"""
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.scales = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.GELU(),
                nn.Dropout(0.2)
            ) for _ in range(3)
        ])
        self.fusion = nn.Linear(hidden_dim * 3, hidden_dim)
        
    def forward(self, x):
        features = [scale(x) for scale in self.scales]
        combined = torch.cat(features, dim=-1)
        return self.fusion(combined)

class ClassPrototypeAttention(nn.Module):
    """Learn class prototypes and compute attention-weighted representations"""
    def __init__(self, dim, num_classes):
        super().__init__()
        # Learnable class prototypes
        self.prototypes = nn.Parameter(torch.randn(num_classes, dim))
        self.query = nn.Linear(dim, dim)
        self.key = nn.Linear(dim, dim)
        self.value = nn.Linear(dim, dim)
        self.scale = math.sqrt(dim)
        
    def forward(self, x):
        # x: [batch, dim]
        q = self.query(x)  # [batch, dim]
        k = self.key(self.prototypes)  # [num_classes, dim]
        v = self.value(self.prototypes)  # [num_classes, dim]
        
        # Compute attention scores
        scores = torch.matmul(q, k.t()) / self.scale  # [batch, num_classes]
        attn_weights = F.softmax(scores, dim=-1)
        
        # Weighted combination of prototypes
        attended = torch.matmul(attn_weights, v)  # [batch, dim]
        return attended, attn_weights

class ResidualBlock(nn.Module):
    """Residual block with layer normalization"""
    def __init__(self, dim, dropout=0.3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * 2),
            nn.LayerNorm(dim * 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim * 2, dim),
            nn.LayerNorm(dim),
            nn.Dropout(dropout)
        )
        
    def forward(self, x):
        return x + self.net(x)

class HCAN(nn.Module):
    """
    Hierarchical Class-Aware Network
    
    Novel contributions:
    1. Multi-scale feature extraction at different granularities
    2. Class prototype attention mechanism for better rare class recognition
    3. Adaptive feature gating to handle noisy features
    4. Hierarchical fusion of categorical and numerical streams
    5. Deep residual pathway for gradient flow
    """
    def __init__(self, cat_dims, num_feat_dim, num_classes, emb_dim=128, hidden_dim=256):
        super().__init__()
        self.num_classes = num_classes
        
        # === Categorical Stream ===
        self.cat_embeddings = nn.ModuleList([
            nn.Embedding(d, emb_dim) for d in cat_dims
        ])
        self.cat_projection = nn.Sequential(
            nn.Linear(emb_dim * len(cat_dims), hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(0.3)
        )
        
        # === Numerical Stream with Multi-Scale Processing ===
        self.num_gate = AdaptiveFeatureGate(num_feat_dim)
        self.multi_scale = MultiScaleFeatureExtractor(num_feat_dim, hidden_dim)
        
        # === Feature Fusion ===
        self.fusion = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(0.3)
        )
        
        # === Deep Residual Pathway ===
        self.residual_blocks = nn.ModuleList([
            ResidualBlock(hidden_dim, dropout=0.3) for _ in range(3)
        ])
        
        # === Class Prototype Attention ===
        self.prototype_attn = ClassPrototypeAttention(hidden_dim, num_classes)
        
        # === Final Classification Head ===
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(0.4),
            nn.Linear(hidden_dim, num_classes)
        )
        
        self._init_weights()
        
    def _init_weights(self):
        """Custom initialization for better convergence"""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Embedding):
                nn.init.normal_(m.weight, mean=0, std=0.02)
                
    def forward(self, x_cat, x_num):
        # Categorical stream
        cat_embeds = [emb(x_cat[:, i]) for i, emb in enumerate(self.cat_embeddings)]
        cat_features = torch.cat(cat_embeds, dim=-1)
        cat_features = self.cat_projection(cat_features)
        
        # Numerical stream with adaptive gating
        num_gated = self.num_gate(x_num)
        num_features = self.multi_scale(num_gated)
        
        # Fusion
        fused = self.fusion(torch.cat([cat_features, num_features], dim=-1))
        
        # Deep residual processing
        for block in self.residual_blocks:
            fused = block(fused)
        
        # Class prototype attention
        prototype_features, attn_weights = self.prototype_attn(fused)
        
        # Combine original features with prototype-attended features
        final_features = torch.cat([fused, prototype_features], dim=-1)
        
        # Classification
        logits = self.classifier(final_features)
        
        return logits, attn_weights

# ===========================================
# ADVANCED LOSS FUNCTIONS
# ===========================================

class FocalLoss(nn.Module):
    """Focal Loss for addressing class imbalance"""
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none', weight=self.alpha)
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

class PrototypeLoss(nn.Module):
    """Auxiliary loss to ensure prototypes are well-separated"""
    def __init__(self, margin=1.0):
        super().__init__()
        self.margin = margin
        
    def forward(self, prototypes):
        # Compute pairwise distances
        dist_matrix = torch.cdist(prototypes, prototypes, p=2)
        # Mask diagonal
        mask = ~torch.eye(len(prototypes), dtype=bool, device=prototypes.device)
        distances = dist_matrix[mask]
        # Encourage minimum margin between prototypes
        loss = F.relu(self.margin - distances).mean()
        return loss

class CombinedLoss(nn.Module):
    """Combined loss with multiple objectives"""
    def __init__(self, alpha, num_classes):
        super().__init__()
        self.focal = FocalLoss(alpha=alpha, gamma=2.5)
        self.proto = PrototypeLoss(margin=1.5)
        
    def forward(self, logits, targets, prototypes):
        focal_loss = self.focal(logits, targets)
        proto_loss = self.proto(prototypes)
        return focal_loss + 0.1 * proto_loss, focal_loss, proto_loss

# ===========================================
# DATASET AND TRAINING
# ===========================================

class NSLDataset(Dataset):
    def __init__(self, c, n, y):
        self.c = torch.tensor(c, dtype=torch.long)
        self.n = torch.tensor(n, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
    
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, i):
        return self.c[i], self.n[i], self.y[i]

# Create data loaders
train_dataset = NSLDataset(df_train[cat_cols].values, X_train_num, y_train)
test_dataset = NSLDataset(df_test[cat_cols].values, X_test_num, y_test)

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False, num_workers=2)

# Initialize model
model = HCAN(
    cat_dims=cat_dims,
    num_feat_dim=X_train_num.shape[1],
    num_classes=len(le_target.classes_),
    emb_dim=128,
    hidden_dim=256
).to(DEVICE)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Optimizer with weight decay
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-3,
    weight_decay=1e-4,
    betas=(0.9, 0.999)
)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=10, T_mult=2, eta_min=1e-6
)

# Loss function
criterion = CombinedLoss(alpha=class_weights, num_classes=len(le_target.classes_))

# ===========================================
# TRAINING LOOP
# ===========================================

best_macro_f1 = 0.0
patience = 15
patience_counter = 0

for epoch in range(50):
    # Training
    model.train()
    train_loss = 0
    train_focal = 0
    train_proto = 0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/50")
    for xc, xn, y in pbar:
        xc, xn, y = xc.to(DEVICE), xn.to(DEVICE), y.to(DEVICE)
        
        optimizer.zero_grad()
        logits, _ = model(xc, xn)
        
        loss, focal_loss, proto_loss = criterion(
            logits, y, model.prototype_attn.prototypes
        )
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        train_loss += loss.item()
        train_focal += focal_loss.item()
        train_proto += proto_loss.item()
        
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'lr': f'{optimizer.param_groups[0]["lr"]:.6f}'
        })
    
    scheduler.step()
    
    avg_loss = train_loss / len(train_loader)
    avg_focal = train_focal / len(train_loader)
    avg_proto = train_proto / len(train_loader)
    
    # Validation
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for xc, xn, y in test_loader:
            xc, xn = xc.to(DEVICE), xn.to(DEVICE)
            logits, _ = model(xc, xn)
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(y.numpy())
    
    # Calculate metrics
    macro_f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)
    weighted_f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
    
    print(f"\nEpoch {epoch+1}:")
    print(f"  Train Loss: {avg_loss:.4f} (Focal: {avg_focal:.4f}, Proto: {avg_proto:.4f})")
    print(f"  Macro F1: {macro_f1:.4f} | Weighted F1: {weighted_f1:.4f}")
    
    # Early stopping
    if macro_f1 > best_macro_f1:
        best_macro_f1 = macro_f1
        patience_counter = 0
        torch.save(model.state_dict(), 'best_hcan_model.pth')
        print(f"  ✓ New best Macro F1: {best_macro_f1:.4f}")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"\nEarly stopping at epoch {epoch+1}")
            break

# ===========================================
# FINAL EVALUATION
# ===========================================

print("\n" + "="*60)
print("LOADING BEST MODEL FOR FINAL EVALUATION")
print("="*60)

model.load_state_dict(torch.load('best_hcan_model.pth'))
model.eval()

all_preds = []
all_labels = []
all_probs = []

with torch.no_grad():
    for xc, xn, y in test_loader:
        xc, xn = xc.to(DEVICE), xn.to(DEVICE)
        logits, _ = model(xc, xn)
        probs = F.softmax(logits, dim=1)
        preds = torch.argmax(logits, dim=1)
        
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(y.numpy())
        all_probs.extend(probs.cpu().numpy())

print("\nFINAL CLASSIFICATION REPORT:")
print("="*60)
print(classification_report(
    all_labels, 
    all_preds, 
    target_names=le_target.classes_, 
    zero_division=0,
    digits=4
))

# Calculate per-class F1 scores
from sklearn.metrics import f1_score
per_class_f1 = f1_score(all_labels, all_preds, average=None, zero_division=0)
print("\nPER-CLASS F1 SCORES:")
print("="*60)
for class_name, f1 in zip(le_target.classes_, per_class_f1):
    print(f"{class_name:20s}: {f1:.4f}")

print(f"\n{'='*60}")
print(f"BEST MACRO F1 ACHIEVED: {best_macro_f1:.4f}")
print(f"{'='*60}")

Using device: cuda
Number of classes: 23
Training samples: 557934, Test samples: 18794
Model parameters: 1,604,452


Epoch 1/50: 100%|██████████| 2180/2180 [00:27<00:00, 80.21it/s, loss=0.0577, lr=0.001000]



Epoch 1:
  Train Loss: 0.1422 (Focal: 0.1422, Proto: 0.0000)
  Macro F1: 0.4212 | Weighted F1: 0.8251
  ✓ New best Macro F1: 0.4212


Epoch 2/50: 100%|██████████| 2180/2180 [00:26<00:00, 80.99it/s, loss=0.0385, lr=0.000976]



Epoch 2:
  Train Loss: 0.0513 (Focal: 0.0513, Proto: 0.0000)
  Macro F1: 0.4248 | Weighted F1: 0.8129
  ✓ New best Macro F1: 0.4248


Epoch 3/50: 100%|██████████| 2180/2180 [00:26<00:00, 81.44it/s, loss=0.0536, lr=0.000905]



Epoch 3:
  Train Loss: 0.0419 (Focal: 0.0419, Proto: 0.0000)
  Macro F1: 0.4026 | Weighted F1: 0.8094


Epoch 4/50: 100%|██████████| 2180/2180 [00:26<00:00, 81.07it/s, loss=0.1007, lr=0.000794]



Epoch 4:
  Train Loss: 0.0427 (Focal: 0.0427, Proto: 0.0000)
  Macro F1: 0.4129 | Weighted F1: 0.8064


Epoch 5/50: 100%|██████████| 2180/2180 [00:26<00:00, 81.35it/s, loss=0.0211, lr=0.000655]



Epoch 5:
  Train Loss: 0.0374 (Focal: 0.0374, Proto: 0.0000)
  Macro F1: 0.4280 | Weighted F1: 0.8167
  ✓ New best Macro F1: 0.4280


Epoch 6/50: 100%|██████████| 2180/2180 [00:26<00:00, 81.50it/s, loss=0.0257, lr=0.000501]



Epoch 6:
  Train Loss: 0.0320 (Focal: 0.0320, Proto: 0.0000)
  Macro F1: 0.4047 | Weighted F1: 0.8092


Epoch 7/50: 100%|██████████| 2180/2180 [00:26<00:00, 80.79it/s, loss=0.0243, lr=0.000346]



Epoch 7:
  Train Loss: 0.0274 (Focal: 0.0274, Proto: 0.0000)
  Macro F1: 0.4194 | Weighted F1: 0.7993


Epoch 8/50: 100%|██████████| 2180/2180 [00:26<00:00, 81.23it/s, loss=0.0187, lr=0.000207]



Epoch 8:
  Train Loss: 0.0243 (Focal: 0.0243, Proto: 0.0000)
  Macro F1: 0.4205 | Weighted F1: 0.8087


Epoch 9/50: 100%|██████████| 2180/2180 [00:26<00:00, 81.29it/s, loss=0.0055, lr=0.000096]



Epoch 9:
  Train Loss: 0.0226 (Focal: 0.0226, Proto: 0.0000)
  Macro F1: 0.4210 | Weighted F1: 0.8036


Epoch 10/50: 100%|██████████| 2180/2180 [00:26<00:00, 81.31it/s, loss=0.0472, lr=0.000025]



Epoch 10:
  Train Loss: 0.0215 (Focal: 0.0215, Proto: 0.0000)
  Macro F1: 0.4248 | Weighted F1: 0.8077


Epoch 11/50: 100%|██████████| 2180/2180 [00:26<00:00, 80.98it/s, loss=0.0564, lr=0.001000]



Epoch 11:
  Train Loss: 0.0375 (Focal: 0.0375, Proto: 0.0000)
  Macro F1: 0.4184 | Weighted F1: 0.7966


Epoch 12/50: 100%|██████████| 2180/2180 [00:26<00:00, 81.56it/s, loss=0.0372, lr=0.000994]



Epoch 12:
  Train Loss: 0.0367 (Focal: 0.0367, Proto: 0.0000)
  Macro F1: 0.4132 | Weighted F1: 0.8043


Epoch 13/50: 100%|██████████| 2180/2180 [00:26<00:00, 81.95it/s, loss=0.0400, lr=0.000976]



Epoch 13:
  Train Loss: 0.0301 (Focal: 0.0301, Proto: 0.0000)
  Macro F1: 0.4283 | Weighted F1: 0.8145
  ✓ New best Macro F1: 0.4283


Epoch 14/50: 100%|██████████| 2180/2180 [00:26<00:00, 81.67it/s, loss=0.0095, lr=0.000946]



Epoch 14:
  Train Loss: 0.0278 (Focal: 0.0278, Proto: 0.0000)
  Macro F1: 0.4220 | Weighted F1: 0.8144


Epoch 15/50: 100%|██████████| 2180/2180 [00:26<00:00, 81.71it/s, loss=0.0137, lr=0.000905]



Epoch 15:
  Train Loss: 0.0256 (Focal: 0.0256, Proto: 0.0000)
  Macro F1: 0.4129 | Weighted F1: 0.8007


Epoch 16/50: 100%|██████████| 2180/2180 [00:26<00:00, 81.75it/s, loss=0.0184, lr=0.000854]



Epoch 16:
  Train Loss: 0.0245 (Focal: 0.0245, Proto: 0.0000)
  Macro F1: 0.4151 | Weighted F1: 0.8030


Epoch 17/50: 100%|██████████| 2180/2180 [00:26<00:00, 81.41it/s, loss=0.0060, lr=0.000794]



Epoch 17:
  Train Loss: 0.0229 (Focal: 0.0229, Proto: 0.0000)
  Macro F1: 0.4165 | Weighted F1: 0.8024


Epoch 18/50: 100%|██████████| 2180/2180 [00:26<00:00, 80.90it/s, loss=0.0163, lr=0.000727]



Epoch 18:
  Train Loss: 0.0212 (Focal: 0.0212, Proto: 0.0000)
  Macro F1: 0.4177 | Weighted F1: 0.8034


Epoch 19/50: 100%|██████████| 2180/2180 [00:26<00:00, 81.15it/s, loss=0.0078, lr=0.000655]



Epoch 19:
  Train Loss: 0.0204 (Focal: 0.0204, Proto: 0.0000)
  Macro F1: 0.4236 | Weighted F1: 0.8070


Epoch 20/50: 100%|██████████| 2180/2180 [00:26<00:00, 81.26it/s, loss=0.0026, lr=0.000579]



Epoch 20:
  Train Loss: 0.0193 (Focal: 0.0193, Proto: 0.0000)
  Macro F1: 0.4325 | Weighted F1: 0.7995
  ✓ New best Macro F1: 0.4325


Epoch 21/50: 100%|██████████| 2180/2180 [00:26<00:00, 81.40it/s, loss=0.0069, lr=0.000501]



Epoch 21:
  Train Loss: 0.0181 (Focal: 0.0181, Proto: 0.0000)
  Macro F1: 0.4102 | Weighted F1: 0.7973


Epoch 22/50: 100%|██████████| 2180/2180 [00:26<00:00, 81.79it/s, loss=0.0083, lr=0.000422]



Epoch 22:
  Train Loss: 0.0173 (Focal: 0.0173, Proto: 0.0000)
  Macro F1: 0.4221 | Weighted F1: 0.8045


Epoch 23/50: 100%|██████████| 2180/2180 [00:26<00:00, 81.21it/s, loss=0.0086, lr=0.000346]



Epoch 23:
  Train Loss: 0.0163 (Focal: 0.0163, Proto: 0.0000)
  Macro F1: 0.4274 | Weighted F1: 0.8097


Epoch 24/50: 100%|██████████| 2180/2180 [00:26<00:00, 81.71it/s, loss=0.0077, lr=0.000274]



Epoch 24:
  Train Loss: 0.0156 (Focal: 0.0156, Proto: 0.0000)
  Macro F1: 0.4221 | Weighted F1: 0.8082


Epoch 25/50: 100%|██████████| 2180/2180 [00:26<00:00, 80.76it/s, loss=0.0026, lr=0.000207]



Epoch 25:
  Train Loss: 0.0148 (Focal: 0.0148, Proto: 0.0000)
  Macro F1: 0.4204 | Weighted F1: 0.8035


Epoch 26/50: 100%|██████████| 2180/2180 [00:26<00:00, 81.20it/s, loss=0.0062, lr=0.000147]



Epoch 26:
  Train Loss: 0.0142 (Focal: 0.0142, Proto: 0.0000)
  Macro F1: 0.4294 | Weighted F1: 0.8076


Epoch 27/50: 100%|██████████| 2180/2180 [00:26<00:00, 81.42it/s, loss=0.0077, lr=0.000096]



Epoch 27:
  Train Loss: 0.0137 (Focal: 0.0137, Proto: 0.0000)
  Macro F1: 0.4288 | Weighted F1: 0.8053


Epoch 28/50: 100%|██████████| 2180/2180 [00:26<00:00, 80.82it/s, loss=0.0317, lr=0.000055]



Epoch 28:
  Train Loss: 0.0146 (Focal: 0.0146, Proto: 0.0000)
  Macro F1: 0.4240 | Weighted F1: 0.8058


Epoch 29/50: 100%|██████████| 2180/2180 [00:26<00:00, 81.53it/s, loss=0.0260, lr=0.000025]



Epoch 29:
  Train Loss: 0.0132 (Focal: 0.0132, Proto: 0.0000)
  Macro F1: 0.4224 | Weighted F1: 0.8068


Epoch 30/50: 100%|██████████| 2180/2180 [00:26<00:00, 81.41it/s, loss=0.0050, lr=0.000007]



Epoch 30:
  Train Loss: 0.0129 (Focal: 0.0129, Proto: 0.0000)
  Macro F1: 0.4236 | Weighted F1: 0.8052


Epoch 31/50: 100%|██████████| 2180/2180 [00:26<00:00, 81.23it/s, loss=0.0206, lr=0.001000]



Epoch 31:
  Train Loss: 0.0282 (Focal: 0.0282, Proto: 0.0000)
  Macro F1: 0.4184 | Weighted F1: 0.8078


Epoch 32/50: 100%|██████████| 2180/2180 [00:26<00:00, 81.79it/s, loss=0.0089, lr=0.000998]



Epoch 32:
  Train Loss: 0.0245 (Focal: 0.0245, Proto: 0.0000)
  Macro F1: 0.4177 | Weighted F1: 0.8122


Epoch 33/50: 100%|██████████| 2180/2180 [00:26<00:00, 81.95it/s, loss=0.0084, lr=0.000994]



Epoch 33:
  Train Loss: 0.0247 (Focal: 0.0247, Proto: 0.0000)
  Macro F1: 0.4296 | Weighted F1: 0.8156


Epoch 34/50: 100%|██████████| 2180/2180 [00:26<00:00, 81.10it/s, loss=0.0143, lr=0.000986]



Epoch 34:
  Train Loss: 0.0234 (Focal: 0.0234, Proto: 0.0000)
  Macro F1: 0.4294 | Weighted F1: 0.8145


Epoch 35/50: 100%|██████████| 2180/2180 [00:26<00:00, 82.12it/s, loss=0.0233, lr=0.000976]



Epoch 35:
  Train Loss: 0.0225 (Focal: 0.0225, Proto: 0.0000)
  Macro F1: 0.4249 | Weighted F1: 0.8199

Early stopping at epoch 35

LOADING BEST MODEL FOR FINAL EVALUATION

FINAL CLASSIFICATION REPORT:
                 precision    recall  f1-score   support

           back     0.9301    1.0000    0.9638       359
buffer_overflow     0.5714    0.2000    0.2963        20
      ftp_write     0.0000    0.0000    0.0000         3
   guess_passwd     0.8000    0.0032    0.0065      1231
           imap     0.0000    0.0000    0.0000         1
        ipsweep     0.9079    0.9787    0.9420       141
           land     1.0000    1.0000    1.0000         7
     loadmodule     0.0000    0.0000    0.0000         2
       multihop     0.0000    0.0000    0.0000        18
        neptune     0.9991    0.9811    0.9900      4657
           nmap     0.7826    0.9863    0.8727        73
         normal     0.8523    0.9139    0.8820      9711
           perl     0.2500    0.5000    0.3333         

In [10]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from sklearn.preprocessing import QuantileTransformer, LabelEncoder, StandardScaler
from sklearn.metrics import classification_report, f1_score, confusion_matrix
from tqdm import tqdm
import math

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

columns = [
    'duration', 'protocol_type', 'service', 'flag', 'src_bytes', 'dst_bytes', 'land',
    'wrong_fragment', 'urgent', 'hot', 'num_failed_logins', 'logged_in', 'num_compromised',
    'root_shell', 'su_attempted', 'num_root', 'num_file_creations', 'num_shells',
    'num_access_files', 'num_outbound_cmds', 'is_host_login', 'is_guest_login', 'count',
    'srv_count', 'serror_rate', 'srv_serror_rate', 'rerror_rate', 'srv_rerror_rate',
    'same_srv_rate', 'diff_srv_rate', 'srv_diff_host_rate', 'dst_host_count',
    'dst_host_srv_count', 'dst_host_same_srv_rate', 'dst_host_diff_srv_rate',
    'dst_host_same_src_port_rate', 'dst_host_srv_diff_host_rate', 'dst_host_serror_rate',
    'dst_host_srv_serror_rate', 'dst_host_rerror_rate', 'dst_host_srv_rerror_rate',
    'outcome', 'level'
]

# ===========================================
# ENHANCED DATA PREPROCESSING
# ===========================================
df_train = pd.read_csv("/kaggle/input/nsl-kdd-augmented/smote_augmented.csv") 
df_test = pd.read_csv("/kaggle/input/nslkdd/KDDTest+.txt", header=None)
df_test.columns = columns

train_labels = set(df_train['outcome'].unique())
df_test = df_test[df_test['outcome'].isin(train_labels)].reset_index(drop=True)

cat_cols = ['protocol_type', 'service', 'flag']
num_cols = [c for c in df_train.columns if c not in cat_cols + ['outcome', 'level']]

# Label encoding
cat_dims = []
for col in cat_cols:
    le_c = LabelEncoder()
    df_train[col] = le_c.fit_transform(df_train[col].astype(str))
    train_classes = {cls: i for i, cls in enumerate(le_c.classes_)}
    df_test[col] = df_test[col].map(lambda x: train_classes.get(str(x), 0))
    cat_dims.append(len(le_c.classes_))

# CRITICAL: Use StandardScaler instead of QuantileTransformer for better rare class separation
scaler = StandardScaler()
X_train_num = scaler.fit_transform(df_train[num_cols]).astype(np.float32)
X_test_num = scaler.transform(df_test[num_cols]).astype(np.float32)

le_target = LabelEncoder()
y_train = le_target.fit_transform(df_train['outcome'])
y_test = le_target.transform(df_test['outcome'])

# Compute effective number of samples per class for CB loss
class_counts = np.bincount(y_train)
print("\nClass distribution in training:")
for i, (name, count) in enumerate(zip(le_target.classes_, class_counts)):
    print(f"{name:20s}: {count:8d} samples")

# ===========================================
# NOVEL ARCHITECTURE: MC-DPAN
# Multi-Expert Class-Discriminative Prototype Alignment Network
# ===========================================

class ExpertMixture(nn.Module):
    """Multiple expert networks, each specializing in different class groups"""
    def __init__(self, input_dim, hidden_dim, num_experts=4):
        super().__init__()
        self.num_experts = num_experts
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.GELU(),
                nn.Dropout(0.2),
                nn.Linear(hidden_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.GELU()
            ) for _ in range(num_experts)
        ])
        # Gating network to route inputs to experts
        self.gate = nn.Sequential(
            nn.Linear(input_dim, num_experts),
            nn.Softmax(dim=-1)
        )
        
    def forward(self, x):
        # Compute gating weights
        gate_weights = self.gate(x)  # [batch, num_experts]
        
        # Get outputs from all experts
        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)  # [batch, num_experts, hidden]
        
        # Weighted combination
        output = torch.einsum('be,beh->bh', gate_weights, expert_outputs)
        return output, gate_weights

class PrototypeLayer(nn.Module):
    """Learnable prototypes with distance-based similarity"""
    def __init__(self, num_classes, feature_dim):
        super().__init__()
        self.prototypes = nn.Parameter(torch.randn(num_classes, feature_dim))
        nn.init.orthogonal_(self.prototypes)
        self.scale = nn.Parameter(torch.tensor(10.0))
        
    def forward(self, x):
        # Normalize features and prototypes for cosine similarity
        x_norm = F.normalize(x, p=2, dim=1)
        p_norm = F.normalize(self.prototypes, p=2, dim=1)
        
        # Cosine similarity (higher is more similar)
        similarities = torch.matmul(x_norm, p_norm.t()) * self.scale
        return similarities

class ContrastiveHead(nn.Module):
    """Supervised contrastive learning head"""
    def __init__(self, input_dim, projection_dim=128):
        super().__init__()
        self.projector = nn.Sequential(
            nn.Linear(input_dim, input_dim),
            nn.ReLU(),
            nn.Linear(input_dim, projection_dim)
        )
        
    def forward(self, x):
        return F.normalize(self.projector(x), p=2, dim=1)

class MCDPAN(nn.Module):
    """
    Multi-Expert Class-Discriminative Prototype Alignment Network
    
    Novel Contributions:
    1. Mixture of Experts for capturing diverse attack patterns
    2. Orthogonal prototype initialization for better class separation
    3. Dual-head design: prototype matching + contrastive learning
    4. Multi-scale categorical embedding fusion
    5. Attention-based feature recalibration
    """
    def __init__(self, cat_dims, num_feat_dim, num_classes, emb_dim=64, hidden_dim=512):
        super().__init__()
        self.num_classes = num_classes
        
        # === Categorical Embeddings with Different Scales ===
        self.embeddings = nn.ModuleList([
            nn.Embedding(d, emb_dim) for d in cat_dims
        ])
        cat_total_dim = emb_dim * len(cat_dims)
        
        # === Categorical Processing ===
        self.cat_processor = nn.Sequential(
            nn.Linear(cat_total_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(0.3)
        )
        
        # === Numerical Processing with Mixture of Experts ===
        self.num_experts = ExpertMixture(num_feat_dim, hidden_dim, num_experts=5)
        
        # === Feature Fusion ===
        self.fusion = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(0.3)
        )
        
        # === Self-Attention for Feature Recalibration ===
        self.self_attn = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=8,
            dropout=0.2,
            batch_first=True
        )
        self.attn_norm = nn.LayerNorm(hidden_dim)
        
        # === Deep Feature Extractor ===
        self.deep_features = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(0.3)
        )
        
        # === Prototype Layer ===
        self.prototype_layer = PrototypeLayer(num_classes, hidden_dim)
        
        # === Contrastive Learning Head ===
        self.contrastive_head = ContrastiveHead(hidden_dim, projection_dim=256)
        
        # === Auxiliary Classifier for Regularization ===
        self.aux_classifier = nn.Linear(hidden_dim, num_classes)
        
    def forward(self, x_cat, x_num, return_features=False):
        # Categorical embedding
        cat_embeds = torch.cat([emb(x_cat[:, i]) for i, emb in enumerate(self.embeddings)], dim=-1)
        cat_features = self.cat_processor(cat_embeds)
        
        # Numerical processing with experts
        num_features, expert_weights = self.num_experts(x_num)
        
        # Fusion
        fused = self.fusion(torch.cat([cat_features, num_features], dim=-1))
        
        # Self-attention (treating features as sequence of length 1)
        fused_unsqueezed = fused.unsqueeze(1)
        attn_out, _ = self.self_attn(fused_unsqueezed, fused_unsqueezed, fused_unsqueezed)
        fused = self.attn_norm(fused + attn_out.squeeze(1))
        
        # Deep features
        features = self.deep_features(fused)
        
        # Prototype-based logits
        proto_logits = self.prototype_layer(features)
        
        # Auxiliary logits
        aux_logits = self.aux_classifier(features)
        
        if return_features:
            # Contrastive projections
            projections = self.contrastive_head(features)
            return proto_logits, aux_logits, features, projections
        
        return proto_logits, aux_logits

# ===========================================
# ADVANCED LOSS FUNCTIONS
# ===========================================

class ClassBalancedFocalLoss(nn.Module):
    """Focal loss with class-balanced weighting using effective number of samples"""
    def __init__(self, class_counts, beta=0.9999, gamma=2.0):
        super().__init__()
        effective_num = 1.0 - np.power(beta, class_counts)
        weights = (1.0 - beta) / np.array(effective_num)
        weights = weights / weights.sum() * len(weights)
        self.weights = torch.tensor(weights, dtype=torch.float32)
        self.gamma = gamma
        
    def forward(self, logits, targets):
        self.weights = self.weights.to(logits.device)
        ce_loss = F.cross_entropy(logits, targets, weight=self.weights, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss
        return focal_loss.mean()

class SupConLoss(nn.Module):
    """Supervised Contrastive Loss"""
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
        
    def forward(self, features, labels):
        device = features.device
        batch_size = features.shape[0]
        
        # Normalize features
        features = F.normalize(features, dim=1)
        
        # Compute similarity matrix
        similarity_matrix = torch.matmul(features, features.T) / self.temperature
        
        # Create mask for positive pairs (same class)
        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels, labels.T).float().to(device)
        
        # Mask out self-similarity
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask
        
        # Compute log_prob
        exp_logits = torch.exp(similarity_matrix) * logits_mask
        log_prob = similarity_matrix - torch.log(exp_logits.sum(1, keepdim=True) + 1e-12)
        
        # Compute mean of log-likelihood over positive
        mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + 1e-12)
        
        loss = -mean_log_prob_pos.mean()
        return loss

class CombinedLoss(nn.Module):
    """Multi-objective loss combining focal, contrastive, and consistency"""
    def __init__(self, class_counts):
        super().__init__()
        self.focal_loss = ClassBalancedFocalLoss(class_counts, beta=0.9999, gamma=3.0)
        self.supcon_loss = SupConLoss(temperature=0.07)
        
    def forward(self, proto_logits, aux_logits, targets, projections=None):
        # Main prototype-based loss
        loss_proto = self.focal_loss(proto_logits, targets)
        
        # Auxiliary loss for regularization
        loss_aux = self.focal_loss(aux_logits, targets)
        
        # Consistency loss between two heads
        loss_consistency = F.kl_div(
            F.log_softmax(proto_logits, dim=1),
            F.softmax(aux_logits, dim=1),
            reduction='batchmean'
        )
        
        total_loss = loss_proto + 0.5 * loss_aux + 0.3 * loss_consistency
        
        # Contrastive loss if projections provided
        if projections is not None:
            loss_contrastive = self.supcon_loss(projections, targets)
            total_loss = total_loss + 0.5 * loss_contrastive
            return total_loss, loss_proto, loss_aux, loss_consistency, loss_contrastive
        
        return total_loss, loss_proto, loss_aux, loss_consistency, torch.tensor(0.0)

# ===========================================
# DATASET WITH AGGRESSIVE RESAMPLING
# ===========================================

class NSLDataset(Dataset):
    def __init__(self, c, n, y):
        self.c = torch.tensor(c, dtype=torch.long)
        self.n = torch.tensor(n, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
    
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, i):
        return self.c[i], self.n[i], self.y[i]

# Create datasets
train_dataset = NSLDataset(df_train[cat_cols].values, X_train_num, y_train)
test_dataset = NSLDataset(df_test[cat_cols].values, X_test_num, y_test)

# CRITICAL: Aggressive resampling for rare classes
# Use square root of inverse frequency for more balanced sampling
class_sample_counts = np.bincount(y_train)
weights = 1.0 / np.sqrt(class_sample_counts[y_train])
sampler = WeightedRandomSampler(weights, len(weights), replacement=True)

train_loader = DataLoader(train_dataset, batch_size=128, sampler=sampler, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False, num_workers=2)

# ===========================================
# MODEL INITIALIZATION
# ===========================================

model = MCDPAN(
    cat_dims=cat_dims,
    num_feat_dim=X_train_num.shape[1],
    num_classes=len(le_target.classes_),
    emb_dim=64,
    hidden_dim=512
).to(DEVICE)

print(f"\nModel parameters: {sum(p.numel() for p in model.parameters()):,}")

# Optimizer with discriminative learning rates
param_groups = [
    {'params': model.prototype_layer.parameters(), 'lr': 2e-3},  # Higher LR for prototypes
    {'params': [p for n, p in model.named_parameters() if 'prototype' not in n], 'lr': 1e-3}
]
optimizer = torch.optim.AdamW(param_groups, weight_decay=1e-4)

# Warm-up + Cosine scheduler
def lr_lambda(epoch):
    if epoch < 5:
        return (epoch + 1) / 5
    else:
        return 0.5 * (1 + math.cos(math.pi * (epoch - 5) / 45))

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

# Loss function
criterion = CombinedLoss(class_sample_counts)

# ===========================================
# TRAINING LOOP WITH CONTRASTIVE LEARNING
# ===========================================

best_macro_f1 = 0.0
patience = 20
patience_counter = 0

print("\n" + "="*60)
print("TRAINING STARTED")
print("="*60)

for epoch in range(60):
    # Training
    model.train()
    train_losses = {'total': 0, 'proto': 0, 'aux': 0, 'consistency': 0, 'contrastive': 0}
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/60")
    for batch_idx, (xc, xn, y) in enumerate(pbar):
        xc, xn, y = xc.to(DEVICE), xn.to(DEVICE), y.to(DEVICE)
        
        optimizer.zero_grad()
        
        # Forward with contrastive learning every other batch
        if batch_idx % 2 == 0:
            proto_logits, aux_logits, features, projections = model(xc, xn, return_features=True)
            loss, loss_p, loss_a, loss_c, loss_con = criterion(proto_logits, aux_logits, y, projections)
        else:
            proto_logits, aux_logits = model(xc, xn, return_features=False)
            loss, loss_p, loss_a, loss_c, loss_con = criterion(proto_logits, aux_logits, y)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        train_losses['total'] += loss.item()
        train_losses['proto'] += loss_p.item()
        train_losses['aux'] += loss_a.item()
        train_losses['consistency'] += loss_c.item()
        train_losses['contrastive'] += loss_con.item() if isinstance(loss_con, torch.Tensor) else 0
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    scheduler.step()
    
    # Average losses
    for k in train_losses:
        train_losses[k] /= len(train_loader)
    
    # Validation
    model.eval()
    all_preds = []
    all_labels = []
    all_proto_preds = []
    
    with torch.no_grad():
        for xc, xn, y in test_loader:
            xc, xn = xc.to(DEVICE), xn.to(DEVICE)
            proto_logits, aux_logits = model(xc, xn, return_features=False)
            
            # Ensemble: average both heads
            ensemble_logits = (proto_logits + aux_logits) / 2
            preds = torch.argmax(ensemble_logits, dim=1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(y.numpy())
    
    # Metrics
    macro_f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)
    weighted_f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
    per_class_f1 = f1_score(all_labels, all_preds, average=None, zero_division=0)
    
    # Count classes with F1 > 0.3
    good_classes = np.sum(per_class_f1 > 0.3)
    
    print(f"\nEpoch {epoch+1}:")
    print(f"  Loss - Total: {train_losses['total']:.4f}, Proto: {train_losses['proto']:.4f}, "
          f"Aux: {train_losses['aux']:.4f}, Consist: {train_losses['consistency']:.4f}, "
          f"Contr: {train_losses['contrastive']:.4f}")
    print(f"  Macro F1: {macro_f1:.4f} | Weighted F1: {weighted_f1:.4f}")
    print(f"  Classes with F1 > 0.3: {good_classes}/{len(le_target.classes_)}")
    
    # Early stopping based on macro F1
    if macro_f1 > best_macro_f1:
        best_macro_f1 = macro_f1
        patience_counter = 0
        torch.save({
            'model': model.state_dict(),
            'epoch': epoch,
            'macro_f1': macro_f1
        }, 'best_mcdpan_model.pth')
        print(f"  ✓ New best Macro F1: {best_macro_f1:.4f}")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"\nEarly stopping at epoch {epoch+1}")
            break

# ===========================================
# FINAL EVALUATION
# ===========================================

print("\n" + "="*70)
print("LOADING BEST MODEL FOR FINAL EVALUATION")
print("="*70)

checkpoint = torch.load('best_mcdpan_model.pth')
model.load_state_dict(checkpoint['model'])
model.eval()

all_preds = []
all_labels = []

with torch.no_grad():
    for xc, xn, y in test_loader:
        xc, xn = xc.to(DEVICE), xn.to(DEVICE)
        proto_logits, aux_logits = model(xc, xn, return_features=False)
        ensemble_logits = (proto_logits + aux_logits) / 2
        preds = torch.argmax(ensemble_logits, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(y.numpy())

print("\n" + "="*70)
print("FINAL CLASSIFICATION REPORT")
print("="*70 + "\n")
print(classification_report(all_labels, all_preds, target_names=le_target.classes_, 
                          zero_division=0, digits=4))

per_class_f1 = f1_score(all_labels, all_preds, average=None, zero_division=0)
print("\n" + "="*70)
print("PER-CLASS F1 SCORES")
print("="*70)
for class_name, f1 in zip(le_target.classes_, per_class_f1):
    marker = "✓" if f1 > 0.5 else ("⚠" if f1 > 0.3 else "✗")
    print(f"{marker} {class_name:20s}: {f1:.4f}")

print(f"\n{'='*70}")
print(f"BEST MACRO F1 ACHIEVED: {best_macro_f1:.4f}")
print(f"Classes with F1 > 0.5: {np.sum(per_class_f1 > 0.5)}/{len(per_class_f1)}")
print(f"Classes with F1 > 0.3: {np.sum(per_class_f1 > 0.3)}/{len(per_class_f1)}")
print(f"{'='*70}")

Using device: cuda

Class distribution in training:
back                :    24258 samples
buffer_overflow     :    24258 samples
ftp_write           :    24258 samples
guess_passwd        :    24258 samples
imap                :    24258 samples
ipsweep             :    24258 samples
land                :    24258 samples
loadmodule          :    24258 samples
multihop            :    24258 samples
neptune             :    24258 samples
nmap                :    24258 samples
normal              :    24258 samples
perl                :    24258 samples
phf                 :    24258 samples
pod                 :    24258 samples
portsweep           :    24258 samples
rootkit             :    24258 samples
satan               :    24258 samples
smurf               :    24258 samples
spy                 :    24258 samples
teardrop            :    24258 samples
warezclient         :    24258 samples
warezmaster         :    24258 samples

Model parameters: 4,051,163

TRAINING STARTED


Epoch 1/60: 100%|██████████| 4359/4359 [01:07<00:00, 64.73it/s, loss=0.8337]



Epoch 1:
  Loss - Total: 0.5964, Proto: 0.0684, Aux: 0.0723, Consist: 0.0097, Contr: 0.9778
  Macro F1: 0.4134 | Weighted F1: 0.8061
  Classes with F1 > 0.3: 11/23
  ✓ New best Macro F1: 0.4134


Epoch 2/60: 100%|██████████| 4359/4359 [01:07<00:00, 65.00it/s, loss=0.9549]



Epoch 2:
  Loss - Total: 0.5300, Proto: 0.0409, Aux: 0.0407, Consist: 0.0022, Contr: 0.9361
  Macro F1: 0.4064 | Weighted F1: 0.8055
  Classes with F1 > 0.3: 11/23


Epoch 3/60: 100%|██████████| 4359/4359 [01:06<00:00, 65.26it/s, loss=0.8056]



Epoch 3:
  Loss - Total: 0.5143, Proto: 0.0350, Aux: 0.0351, Consist: 0.0022, Contr: 0.9222
  Macro F1: 0.4206 | Weighted F1: 0.8011
  Classes with F1 > 0.3: 12/23
  ✓ New best Macro F1: 0.4206


Epoch 4/60: 100%|██████████| 4359/4359 [01:07<00:00, 64.76it/s, loss=0.8453]



Epoch 4:
  Loss - Total: 0.5043, Proto: 0.0312, Aux: 0.0317, Consist: 0.0027, Contr: 0.9129
  Macro F1: 0.3990 | Weighted F1: 0.7994
  Classes with F1 > 0.3: 12/23


Epoch 5/60: 100%|██████████| 4359/4359 [01:07<00:00, 64.72it/s, loss=0.7841]



Epoch 5:
  Loss - Total: 0.4962, Proto: 0.0283, Aux: 0.0290, Consist: 0.0030, Contr: 0.9050
  Macro F1: 0.4498 | Weighted F1: 0.8038
  Classes with F1 > 0.3: 12/23
  ✓ New best Macro F1: 0.4498


Epoch 6/60: 100%|██████████| 4359/4359 [01:07<00:00, 64.81it/s, loss=0.7530]



Epoch 6:
  Loss - Total: 0.4840, Proto: 0.0237, Aux: 0.0243, Consist: 0.0027, Contr: 0.8947
  Macro F1: 0.4252 | Weighted F1: 0.8093
  Classes with F1 > 0.3: 12/23


Epoch 7/60: 100%|██████████| 4359/4359 [01:07<00:00, 64.87it/s, loss=0.8368]



Epoch 7:
  Loss - Total: 0.4763, Proto: 0.0213, Aux: 0.0219, Consist: 0.0026, Contr: 0.8866
  Macro F1: 0.4142 | Weighted F1: 0.8011
  Classes with F1 > 0.3: 12/23


Epoch 8/60: 100%|██████████| 4359/4359 [01:07<00:00, 64.98it/s, loss=0.8307]



Epoch 8:
  Loss - Total: 0.4681, Proto: 0.0183, Aux: 0.0189, Consist: 0.0024, Contr: 0.8792
  Macro F1: 0.4276 | Weighted F1: 0.8166
  Classes with F1 > 0.3: 12/23


Epoch 9/60: 100%|██████████| 4359/4359 [01:06<00:00, 65.10it/s, loss=0.7779]



Epoch 9:
  Loss - Total: 0.4639, Proto: 0.0171, Aux: 0.0176, Consist: 0.0023, Contr: 0.8747
  Macro F1: 0.4340 | Weighted F1: 0.8137
  Classes with F1 > 0.3: 12/23


Epoch 10/60: 100%|██████████| 4359/4359 [01:07<00:00, 64.91it/s, loss=0.7577]



Epoch 10:
  Loss - Total: 0.4598, Proto: 0.0156, Aux: 0.0162, Consist: 0.0023, Contr: 0.8709
  Macro F1: 0.4202 | Weighted F1: 0.8224
  Classes with F1 > 0.3: 13/23


Epoch 11/60: 100%|██████████| 4359/4359 [01:06<00:00, 65.07it/s, loss=0.8084]



Epoch 11:
  Loss - Total: 0.4562, Proto: 0.0145, Aux: 0.0150, Consist: 0.0022, Contr: 0.8669
  Macro F1: 0.4352 | Weighted F1: 0.8148
  Classes with F1 > 0.3: 12/23


Epoch 12/60: 100%|██████████| 4359/4359 [01:06<00:00, 65.18it/s, loss=0.7097]



Epoch 12:
  Loss - Total: 0.4533, Proto: 0.0134, Aux: 0.0138, Consist: 0.0021, Contr: 0.8647
  Macro F1: 0.3929 | Weighted F1: 0.7942
  Classes with F1 > 0.3: 12/23


Epoch 13/60: 100%|██████████| 4359/4359 [01:07<00:00, 64.98it/s, loss=0.8041]



Epoch 13:
  Loss - Total: 0.4506, Proto: 0.0127, Aux: 0.0131, Consist: 0.0021, Contr: 0.8614
  Macro F1: 0.4591 | Weighted F1: 0.8217
  Classes with F1 > 0.3: 12/23
  ✓ New best Macro F1: 0.4591


Epoch 14/60: 100%|██████████| 4359/4359 [01:07<00:00, 64.60it/s, loss=0.7132]



Epoch 14:
  Loss - Total: 0.4474, Proto: 0.0117, Aux: 0.0120, Consist: 0.0020, Contr: 0.8582
  Macro F1: 0.4172 | Weighted F1: 0.8186
  Classes with F1 > 0.3: 12/23


Epoch 15/60: 100%|██████████| 4359/4359 [01:06<00:00, 65.28it/s, loss=0.7743]



Epoch 15:
  Loss - Total: 0.4448, Proto: 0.0109, Aux: 0.0112, Consist: 0.0019, Contr: 0.8555
  Macro F1: 0.4214 | Weighted F1: 0.8133
  Classes with F1 > 0.3: 11/23


Epoch 16/60: 100%|██████████| 4359/4359 [01:06<00:00, 65.07it/s, loss=0.8091]



Epoch 16:
  Loss - Total: 0.4430, Proto: 0.0104, Aux: 0.0107, Consist: 0.0018, Contr: 0.8533
  Macro F1: 0.4438 | Weighted F1: 0.8136
  Classes with F1 > 0.3: 12/23


Epoch 17/60: 100%|██████████| 4359/4359 [01:06<00:00, 65.24it/s, loss=0.8368]



Epoch 17:
  Loss - Total: 0.4420, Proto: 0.0099, Aux: 0.0102, Consist: 0.0018, Contr: 0.8531
  Macro F1: 0.4449 | Weighted F1: 0.8171
  Classes with F1 > 0.3: 11/23


Epoch 18/60: 100%|██████████| 4359/4359 [01:06<00:00, 65.14it/s, loss=0.7687]



Epoch 18:
  Loss - Total: 0.4388, Proto: 0.0090, Aux: 0.0094, Consist: 0.0017, Contr: 0.8492
  Macro F1: 0.3919 | Weighted F1: 0.8092
  Classes with F1 > 0.3: 12/23


Epoch 19/60: 100%|██████████| 4359/4359 [01:07<00:00, 64.89it/s, loss=0.7218]



Epoch 19:
  Loss - Total: 0.4382, Proto: 0.0088, Aux: 0.0091, Consist: 0.0016, Contr: 0.8486
  Macro F1: 0.4314 | Weighted F1: 0.8166
  Classes with F1 > 0.3: 11/23


Epoch 20/60: 100%|██████████| 4359/4359 [01:07<00:00, 64.90it/s, loss=0.8767]



Epoch 20:
  Loss - Total: 0.4362, Proto: 0.0082, Aux: 0.0085, Consist: 0.0016, Contr: 0.8464
  Macro F1: 0.4206 | Weighted F1: 0.8063
  Classes with F1 > 0.3: 12/23


Epoch 21/60: 100%|██████████| 4359/4359 [01:06<00:00, 65.42it/s, loss=0.8597]



Epoch 21:
  Loss - Total: 0.4349, Proto: 0.0079, Aux: 0.0081, Consist: 0.0015, Contr: 0.8450
  Macro F1: 0.4194 | Weighted F1: 0.8104
  Classes with F1 > 0.3: 12/23


Epoch 22/60: 100%|██████████| 4359/4359 [01:06<00:00, 65.35it/s, loss=0.7385]



Epoch 22:
  Loss - Total: 0.4336, Proto: 0.0075, Aux: 0.0077, Consist: 0.0014, Contr: 0.8436
  Macro F1: 0.4217 | Weighted F1: 0.8209
  Classes with F1 > 0.3: 11/23


Epoch 23/60: 100%|██████████| 4359/4359 [01:06<00:00, 65.45it/s, loss=0.8657]



Epoch 23:
  Loss - Total: 0.4317, Proto: 0.0070, Aux: 0.0073, Consist: 0.0014, Contr: 0.8412
  Macro F1: 0.4194 | Weighted F1: 0.8159
  Classes with F1 > 0.3: 11/23


Epoch 24/60: 100%|██████████| 4359/4359 [01:06<00:00, 65.29it/s, loss=0.7939]



Epoch 24:
  Loss - Total: 0.4307, Proto: 0.0066, Aux: 0.0068, Consist: 0.0013, Contr: 0.8406
  Macro F1: 0.4195 | Weighted F1: 0.8173
  Classes with F1 > 0.3: 11/23


Epoch 25/60: 100%|██████████| 4359/4359 [01:06<00:00, 65.59it/s, loss=0.7235]



Epoch 25:
  Loss - Total: 0.4293, Proto: 0.0062, Aux: 0.0063, Consist: 0.0013, Contr: 0.8391
  Macro F1: 0.4178 | Weighted F1: 0.8053
  Classes with F1 > 0.3: 11/23


Epoch 26/60: 100%|██████████| 4359/4359 [01:06<00:00, 65.56it/s, loss=0.8943]



Epoch 26:
  Loss - Total: 0.4280, Proto: 0.0059, Aux: 0.0061, Consist: 0.0012, Contr: 0.8375
  Macro F1: 0.4185 | Weighted F1: 0.8068
  Classes with F1 > 0.3: 11/23


Epoch 27/60: 100%|██████████| 4359/4359 [01:06<00:00, 65.32it/s, loss=0.7256]



Epoch 27:
  Loss - Total: 0.4269, Proto: 0.0058, Aux: 0.0059, Consist: 0.0011, Contr: 0.8357
  Macro F1: 0.4136 | Weighted F1: 0.8055
  Classes with F1 > 0.3: 11/23


Epoch 28/60: 100%|██████████| 4359/4359 [01:06<00:00, 65.36it/s, loss=0.7142]



Epoch 28:
  Loss - Total: 0.4259, Proto: 0.0053, Aux: 0.0055, Consist: 0.0011, Contr: 0.8350
  Macro F1: 0.4174 | Weighted F1: 0.8145
  Classes with F1 > 0.3: 11/23


Epoch 29/60: 100%|██████████| 4359/4359 [01:06<00:00, 65.76it/s, loss=0.7897]



Epoch 29:
  Loss - Total: 0.4244, Proto: 0.0050, Aux: 0.0052, Consist: 0.0010, Contr: 0.8330
  Macro F1: 0.4225 | Weighted F1: 0.8205
  Classes with F1 > 0.3: 11/23


Epoch 30/60: 100%|██████████| 4359/4359 [01:06<00:00, 65.77it/s, loss=0.7251]



Epoch 30:
  Loss - Total: 0.4237, Proto: 0.0048, Aux: 0.0049, Consist: 0.0009, Contr: 0.8324
  Macro F1: 0.4221 | Weighted F1: 0.8177
  Classes with F1 > 0.3: 11/23


Epoch 31/60: 100%|██████████| 4359/4359 [01:06<00:00, 65.82it/s, loss=0.7664]



Epoch 31:
  Loss - Total: 0.4229, Proto: 0.0047, Aux: 0.0048, Consist: 0.0009, Contr: 0.8310
  Macro F1: 0.4187 | Weighted F1: 0.8073
  Classes with F1 > 0.3: 11/23


Epoch 32/60: 100%|██████████| 4359/4359 [01:06<00:00, 65.52it/s, loss=0.7124]



Epoch 32:
  Loss - Total: 0.4215, Proto: 0.0042, Aux: 0.0043, Consist: 0.0008, Contr: 0.8298
  Macro F1: 0.4401 | Weighted F1: 0.8201
  Classes with F1 > 0.3: 11/23


Epoch 33/60: 100%|██████████| 4359/4359 [01:06<00:00, 65.51it/s, loss=0.7916]



Epoch 33:
  Loss - Total: 0.4209, Proto: 0.0041, Aux: 0.0042, Consist: 0.0008, Contr: 0.8289
  Macro F1: 0.4389 | Weighted F1: 0.8022
  Classes with F1 > 0.3: 11/23

Early stopping at epoch 33

LOADING BEST MODEL FOR FINAL EVALUATION

FINAL CLASSIFICATION REPORT



ValueError: Number of classes, 22, does not match size of target_names, 23. Try specifying the labels parameter