In [None]:
from torch.optim import Adam
from sklearn.utils.class_weight import compute_class_weight
from pathlib import Path


import math
from torch.optim import AdamW

import numpy as np
import pandas as pd
import torch
from torch.utils.data import TensorDataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ReduceLROnPlateau

from torchmetrics.classification import BinaryAccuracy, BinaryAveragePrecision, BinaryF1Score, BinaryPrecision, BinaryRecall, BinaryHammingDistance

# Import model
from Transformer_Archs.BiLTSM import BiLSTM

# Clear cuda cache
torch.cuda.empty_cache()


In [None]:
# Option for 5% or 10% attack data
option = '15' 

# Get data loaders
train_dataset = torch.load('./Preprocessed_data/train_dataset_{}.pt'.format(option), weights_only=False)
train_config = torch.load('./Preprocessed_data/train_config_{}.pt'.format(option), weights_only=False)
train_loader = DataLoader(train_dataset, **train_config)

val_dataset = torch.load('./Preprocessed_data/val_dataset_{}.pt'.format(option), weights_only=False)
val_config = torch.load('./Preprocessed_data/val_config_{}.pt'.format(option), weights_only=False)
val_loader = DataLoader(val_dataset, **val_config)

test_dataset = torch.load('./Preprocessed_data/test_dataset_{}.pt'.format(option), weights_only=False)
test_config = torch.load('./Preprocessed_data/test_config_{}.pt'.format(option), weights_only=False)
test_loader = DataLoader(test_dataset, **test_config)

# Set feautures and target size
num_features = 86
out_features = 10
seq_len = 12

In [None]:
# derive dims from a real batch
xb, yb = next(iter(train_loader))
seq_len      = xb.shape[1]   #
in_channels  = xb.shape[2]   # 
out_features   = yb.shape[1]   # 
out_channels  = 32                   
kernel_size   = 3                  
hidden_size   = 64                  
lstm_layers   = 2
output_size   = out_features
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:
model = BiLSTM(input_size=num_features, hidden_size=64, num_layers=2, num_labels=out_features).to(device)


In [None]:
def prepare_batch(x, y):
    """
    Ensures x -> (batch, in_channels=num_features, seq_len) for Conv1d,
    and y -> int64 class indices for CrossEntropyLoss.
    """
    # x may be (B, S, F) or (B, F, S)
    if x.dim() != 3:
        raise ValueError(f"Expected 3D input (B, S, F) or (B, F, S), got {x.shape}")

    B, A, Bdim = x.shape
    if A == seq_len and Bdim == num_features:
        # (B, S, F) -> transpose to (B, F, S)
        x = x.permute(0, 2, 1)
    elif A == num_features and Bdim == seq_len:
        # already (B, F, S)
        pass
    else:
        # Try to infer; if not, error out clearly
        raise ValueError(f"Input shape {x.shape} doesn't match either (B,{seq_len},{num_features}) or (B,{num_features},{seq_len}).")

    # Targets: want class indices [0..out_features-1]
    if y.dtype != torch.long:
        # If one-hot or floats, convert to indices
        if y.dim() > 1 and y.size(-1) == out_features:
            y = y.argmax(dim=-1)
        else:
            y = y.long()

    return x.to(device), y.to(device)


In [None]:
# Training loop (simplified)
criterion=nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr= 1e-4) #0.001, 5e-4, 1e-3

# Define the learning rate scheduler (for example, exponential decay)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

class EarlyStopping:
    def __init__(self, patience=5, min_delta=0):
        """
        Args:
            patience (int): How many epochs to wait after last time validation loss improved.
            min_delta (float): Minimum change in the monitored quantity to qualify as an improvement.
        """
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = None
        self.counter = 0
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

In [None]:
import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler
from torchmetrics.classification import (
    BinaryAccuracy, BinaryAveragePrecision, BinaryF1Score, BinaryPrecision,
    BinaryRecall, BinaryHammingDistance
)

device = next(model.parameters()).device
scaler = GradScaler(enabled=torch.cuda.is_available())
THR = 0.5 

def ensure_channels_first(x, model):
    """If x is (B,T,C) and conv expects (B,C,T), transpose once."""
    if hasattr(model, "conv1"):
        Cexp = model.conv1.in_channels
        if x.ndim == 3 and x.shape[1] != Cexp and x.shape[2] == Cexp:
            x = x.transpose(1, 2).contiguous()
    return x

@torch.no_grad()
def compute_pos_weight(train_loader, device):
    # per-class pos_weight for BCEWithLogitsLoss
    _, y0 = next(iter(train_loader))
    num_labels = y0.shape[1]
    pos = torch.zeros(num_labels)
    total = 0
    for _, yb in train_loader:
        pos += yb.sum(dim=0)
        total += yb.size(0)
    neg = total - pos
    eps = 1e-6
    return ((neg + eps) / (pos + eps)).to(device), num_labels

pos_weight, NUM_LABELS = compute_pos_weight(train_loader, device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)


In [None]:
def _epoch_pass(loader, train_mode: bool):
    model.train(train_mode)

    total_loss, n = 0.0, 0
    logits_all, targets_all = [], []

    for x, y in loader:
        x = x.to(device); y = y.to(device).float()
        x = ensure_channels_first(x, model)

        if train_mode:
            optimizer.zero_grad(set_to_none=True)

        with autocast(enabled=torch.cuda.is_available()):
            logits = model(x)                          # (B, 10)
            loss = criterion(logits, y)

        if train_mode:
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()

        bs = x.size(0)
        total_loss += loss.item() * bs
        n += bs

        logits_all.append(logits.detach())
        targets_all.append(y.detach())

    # --- epoch metrics (Binary, micro over all labels) ---
    logits_all = torch.cat(logits_all, dim=0).cpu()
    targets_all = torch.cat(targets_all, dim=0).cpu().int()
    probs_all = torch.sigmoid(logits_all)

    # flatten to treat as one big binary task (“micro”)
    p_flat = probs_all.reshape(-1)
    t_flat = targets_all.reshape(-1)

    acc     = BinaryAccuracy(threshold=THR)(p_flat, t_flat).item()
    ap      = BinaryAveragePrecision()(p_flat, t_flat).item()
    f1      = BinaryF1Score(threshold=THR)(p_flat, t_flat).item()
    prec    = BinaryPrecision(threshold=THR)(p_flat, t_flat).item()
    rec     = BinaryRecall(threshold=THR)(p_flat, t_flat).item()
    hamming = BinaryHammingDistance(threshold=THR)(p_flat, t_flat).item()

    avg_loss = total_loss / max(1, n)

    return {
        "loss": avg_loss, "acc": acc, "ap": ap, "f1": f1,
        "precision": prec, "recall": rec, "hamming": hamming
    }

def train_epoch():
    return _epoch_pass(train_loader, train_mode=True)

@torch.no_grad()
def validate_epoch():
    return _epoch_pass(val_loader, train_mode=False)


In [None]:
# EarlyStopping like yours
early_stopping = EarlyStopping(patience=10, min_delta=1e-4)

train_acc_list, val_acc_list = [], []
train_loss_list, val_loss_list = [], []

EPOCHS = 150
for epoch in range(1, EPOCHS + 1):
    tr = train_epoch()
    va = validate_epoch()

    train_acc_list.append(tr["acc"]); val_acc_list.append(va["acc"])
    train_loss_list.append(tr["loss"]); val_loss_list.append(va["loss"])

    # step scheduler on validation loss (recommended)
    scheduler.step(va["loss"])
    early_stopping(va["loss"])

    # ---- EXACT FORMAT ----
    print(f"Train Epoch: {epoch} - Training Loss: {tr['loss']:.5f} Training accuracy: {tr['acc']*100:.3f}%")
    print(f"Valid  Epoch: {epoch} - Validation Loss: {va['loss']:.5f} Validation accuracy: {va['acc']*100:.3f}%")

    # (optional) print the rest of the metrics each epoch
    # print(f"  AP: {va['ap']:.6f}  F1: {va['f1']:.6f}  P: {va['precision']:.6f}  R: {va['recall']:.6f}  Hamming: {va['hamming']:.6f}")

    if early_stopping.early_stop:
        print("Early stopping triggered. Stopping training.")
        break


In [None]:
@torch.no_grad()
def test_epoch():
    # reuse the same micro-binary stats on the test loader
    total_loss, n = 0.0, 0
    logits_all, targets_all = [], []
    for x, y in test_loader:
        x = x.to(device); y = y.to(device).float()
        x = ensure_channels_first(x, model)
        with autocast(enabled=torch.cuda.is_available()):
            logits = model(x)
            loss = criterion(logits, y)
        total_loss += loss.item() * x.size(0); n += x.size(0)
        logits_all.append(logits.detach()); targets_all.append(y.detach())

    logits_all = torch.cat(logits_all, dim=0).cpu()
    targets_all = torch.cat(targets_all, dim=0).cpu().int()
    probs_all = torch.sigmoid(logits_all)

    p_flat = probs_all.reshape(-1)
    t_flat = targets_all.reshape(-1)

    acc     = BinaryAccuracy(threshold=THR)(p_flat, t_flat).item()
    ap      = BinaryAveragePrecision()(p_flat, t_flat).item()
    f1      = BinaryF1Score(threshold=THR)(p_flat, t_flat).item()
    prec    = BinaryPrecision(threshold=THR)(p_flat, t_flat).item()
    rec     = BinaryRecall(threshold=THR)(p_flat, t_flat).item()
    hamming = BinaryHammingDistance(threshold=THR)(p_flat, t_flat).item()

    avg_loss = total_loss / max(1, n)
    return avg_loss, acc, ap, f1, prec, rec, hamming

# ---- Run after training (optionally load best weights first) ----
# model.load_state_dict(torch.load(best_path, map_location=device))

test_loss, test_acc, ap, f1, prec, rec, ham = test_epoch()
print(f"Test set: Loss: {test_loss:.6f}, Accuracy: {test_acc*100:.3f}%")
print(f"Average Precision: {ap:.6f}")
print(f"F1 Score: {f1:.6f}")
print(f"Precision: {prec:.6f}")
print(f"Recall: {rec:.6f}")
print(f"Hamming Distance: {ham:.6f}")
