In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler, RandomSampler
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, f1_score
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
import wandb
from omegaconf import OmegaConf
import os
import torchmetrics
import warnings
from kaggle_secrets import UserSecretsClient

warnings.filterwarnings('ignore')

# GRU Configuration
config = OmegaConf.create({
    "wandb": {
        "project": "DL-NIDS-2--cic-ton-iot",
        "entity": "mohammad-fleity-lebanese-university",
        "tags": ["AutoEncoderDecoder", "cic-ton-iot", "PyTorch"],
        "notes": "Optimized AutoEncoderDecoder for network intrusion detection with limited samples"
    },
    "model": {
        "hidden_size": 128,
        "num_layers": 2,
        "dropout": 0.4,
        "dense_units": [128, 64],
        "learning_rate": 0.0001,
        "weight_decay": 1e-4
    },
    "training": {
        "sequence_length": 4,
        "batch_size": 64,
        "max_epochs": 40,            # Hard limit of  epochs
        "early_stopping_patience": 7,
        "oversample": True,
        "gpus": 1 if torch.cuda.is_available() else 0,
        "max_train_samples": 100000,  # Maximum training samples
        "max_val_samples": 20000,     # Maximum validation samples
        "max_test_samples": 15000     # Maximum test samples
    },
    "data": {
        "raw": "cic_ton_iot.parquet",
        "num_workers": 2
    }
})


class AutoEncoderModel(pl.LightningModule):
    def __init__(self, input_size, num_classes, config):
        super().__init__()
        self.save_hyperparameters({'config': config})
        self.config = config
        
        self.encoder = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.ReLU(),
            nn.Dropout(config.model.dropout),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(config.model.dropout),
            nn.Linear(64, 32)
        )

        self.decoder = nn.Sequential(
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Dropout(config.model.dropout),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Dropout(config.model.dropout),
            nn.Linear(128, input_size)  # match input size for reconstruction
        )

        self.classifier = nn.Sequential(
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Dropout(config.model.dropout),
            nn.Linear(64, num_classes)
        )

        self.recon_loss = nn.MSELoss()
        self.class_loss = nn.CrossEntropyLoss()

        # Metrics
        self.train_acc = torchmetrics.classification.Accuracy(task='multiclass', num_classes=num_classes)
        self.val_acc = torchmetrics.classification.Accuracy(task='multiclass', num_classes=num_classes)
        self.test_acc = torchmetrics.classification.Accuracy(task='multiclass', num_classes=num_classes)

    def forward(self, x):
        batch_size, seq_len, features = x.shape
        x_flat = x.view(batch_size, -1)
        z = self.encoder(x_flat)

        x_hat = self.decoder(z).view(batch_size, seq_len, features)
        logits = self.classifier(z)
        return x_hat, logits

    def training_step(self, batch, batch_idx):
        x, y = batch
        x_hat, logits = self(x)
        loss_recon = self.recon_loss(x_hat, x)
        loss_cls = self.class_loss(logits, y)
        loss = loss_recon + loss_cls

        preds = torch.argmax(logits, dim=1)
        self.train_acc.update(preds, y)

        self.log("train_loss_epoch", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("train_acc_epoch", self.train_acc.compute() * 100, on_step=False, on_epoch=True, prog_bar=True)
        # self.log("train_acc_epoch", self.train_acc.compute()*100, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        x_hat, logits = self(x)
        loss_recon = self.recon_loss(x_hat, x)
        loss_cls = self.class_loss(logits, y)
        loss = loss_recon + loss_cls

        preds = torch.argmax(logits, dim=1)
        self.val_acc.update(preds, y)

        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        # self.log("val_acc", self.val_acc*100, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val_acc", self.val_acc.compute() * 100, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    # def on_validation_epoch_end(self):
    #     self.val_acc.reset()
    # def on_training_epoch_end(self):
    #     self.train_acc.reset()
    # def on_testing_epoch_end(self):
    #     self.test_acc.reset()

    def test_step(self, batch, batch_idx):
        x, y = batch
        x_hat, logits = self(x)
        loss_recon = self.recon_loss(x_hat, x)
        loss_cls = self.class_loss(logits, y)
        loss = loss_recon + loss_cls

        preds = torch.argmax(logits, dim=1)
        self.test_acc.update(preds, y)

        self.log("test_loss", loss, on_step=False, on_epoch=True)
        # print("loss: ",loss ,".")
        # self.log("test_acc", self.test_acc, on_step=False, on_epoch=True)
        self.log("test_acc", self.test_acc.compute() * 100, on_step=False, on_epoch=True, prog_bar=True)
        return {"loss": loss}

    def configure_optimizers(self):
        return optim.AdamW(
            self.parameters(),
            lr=self.hparams.config.model.learning_rate,
            weight_decay=self.hparams.config.model.weight_decay
        )

# class AutoEncoderModel(pl.LightningModule):
#     # model = AutoEncoderModel(total_input_size, num_classes, config)
    
#     def __init__(self, input_size, num_classes ,config):
#         super().__init__()
#         self.save_hyperparameters({'config': config})  # Save config to hparams
        
#         self.encoder = nn.Sequential(
#             nn.Linear(input_size, 128),
#             nn.ReLU(),
#             nn.Dropout(config.model.dropout),
#             nn.Linear(128, 64),
#             nn.ReLU(),
#             # nn.Dropout(config.model.dropout),
#             # nn.Linear(128, 64),
#             # nn.ReLU(),
#             nn.Dropout(config.model.dropout),
#             nn.Linear(64, 32)
#         )

#         self.decoder = nn.Sequential(
#             nn.Linear(32, 64),
#             nn.ReLU(),
#             nn.Dropout(config.model.dropout),
#             nn.Linear(64, 128),
#             nn.ReLU(),
#             # nn.Dropout(config.model.dropout),
#             # nn.Linear(128, 128),
#             # nn.ReLU(),
#             nn.Dropout(config.model.dropout),
#             nn.Linear(128, input_size)
#         )
#         self.classifier = nn.Sequential(
#             nn.Linear(32, 64),
#             nn.ReLU(),
#             nn.Dropout(config.model.dropout),
#             nn.Linear(64, num_classes)
#         )


#         self.criterion = nn.MSELoss()

    
#     # def forward(self, x):
#     #     batch_size, seq_len, features = x.shape
#     #     x_flat = x.view(batch_size, -1)  # Flatten input
    
#     #     z = self.encoder(x_flat)
#     #     x_hat = self.decoder(z)
#     #     x_hat = x_hat.view(batch_size, seq_len, features)  # Reshape to original dimensions
        
#     #     return x_hat

#     def forward(self, x):
#         batch_size, seq_len, features = x.shape
#         x_flat = x.view(batch_size, -1)
        
#         z = self.encoder(x_flat)
#         x_hat = self.decoder(z).view(batch_size, seq_len, features)
#         logits = self.classifier(z)
    
#         return x_hat, logits

#     # def training_step(self, batch, batch_idx):
#     #     x, _ = batch  # Only x is used
#     #     x_hat = self(x)
#     #     loss = self.criterion(x_hat, x)
#     #     self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
#     #     return loss
#     def training_step(self, batch, batch_idx):
#         x, y = batch
#         x_hat, logits = self(x)
#         loss_recon = self.criterion(x_hat, x)
#         loss_cls = nn.CrossEntropyLoss()(logits, y)
#         loss = loss_recon + loss_cls
#         self.log("train_loss", loss)
#         return loss


#     def validation_step(self, batch, batch_idx):
#         x, _ = batch
#         x_hat = self(x)
#         loss = self.criterion(x_hat, x)
#         self.log("val_loss", loss, prog_bar=True)
#         return loss

#     def test_step(self, batch, batch_idx):
#         x, y = batch
#         x_hat = self(x)
#         loss = self.criterion(x_hat, x)
#         self.log("test_loss", loss)
#         return {"test_loss": loss, "reconstruction_error": torch.mean((x_hat - x)**2, dim=1), "true_label": y}

#     def configure_optimizers(self):
#         return optim.AdamW(
#             self.parameters(),
#             lr=self.hparams.config.model.learning_rate,
#             weight_decay=self.hparams.config.model.weight_decay
#         )

class NIDSDataModule(pl.LightningDataModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.batch_size = config.training.batch_size
        self.sequence_length = config.training.sequence_length
        self.num_workers = config.data.num_workers
        self.oversample = config.training.oversample
        self.max_train_samples = config.training.max_train_samples
        self.max_val_samples = config.training.max_val_samples
        self.max_test_samples = config.training.max_test_samples

    def prepare_data(self):
        # /kaggle/input/cic-ton-iot-parquet
        df = pd.read_parquet(os.path.join('/kaggle/input/cic-ton-iot-parquet', self.config.data.raw))
    
        # Clean data
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.dropna(inplace=True)
        df.drop_duplicates(inplace=True)
    
        # Identify non-numeric columns
        self.non_numeric_cols = ['Label', 'Timestamp', 'Flow ID', 'Src IP', 
                               'Src Port', 'Attack', 'Dst IP', 'Dst Port', 'Protocol']
        self.non_numeric_cols = [col for col in self.non_numeric_cols if col in df.columns]
    
        # Encode labels
        self.label_encoder = LabelEncoder()
        df['Label_Num'] = self.label_encoder.fit_transform(df['Label'])
        self.classes = self.label_encoder.classes_
    
        # Initialize scaler
        self.scaler = StandardScaler()
    
        # Stratified split with sample limits
        train_df, test_df = train_test_split(
            df,
            test_size=0.3,  # 70% train, 30% test+val
            random_state=42,
            stratify=df['Label_Num']
        )
        
        # Further split test into val and test
        val_df, test_df = train_test_split(
            test_df,
            test_size=0.5,  # 15% val, 15% test
            random_state=42,
            stratify=test_df['Label_Num']
        )
    
        # Process each split with sample limits
        self.X_train, self.y_train = self._prepare_features(train_df, fit=True)
        self.X_val, self.y_val = self._prepare_features(val_df, fit=False)
        self.X_test, self.y_test = self._prepare_features(test_df, fit=False)
        
        # Apply sample limits
        self._limit_samples()

    def _limit_samples(self):
        """Limit samples according to configuration"""
        # Training data
        if len(self.X_train) > self.max_train_samples:
            indices = np.random.choice(len(self.X_train), self.max_train_samples, replace=False)
            self.X_train = self.X_train[indices]
            self.y_train = self.y_train[indices]
        
        # Validation data
        if len(self.X_val) > self.max_val_samples:
            indices = np.random.choice(len(self.X_val), self.max_val_samples, replace=False)
            self.X_val = self.X_val[indices]
            self.y_val = self.y_val[indices]
        
        # Test data
        if len(self.X_test) > self.max_test_samples:
            indices = np.random.choice(len(self.X_test), self.max_test_samples, replace=False)
            self.X_test = self.X_test[indices]
            self.y_test = self.y_test[indices]
    
    def _prepare_features(self, df, fit=False):
        X = df.drop(['Label_Num'] + self.non_numeric_cols, axis=1)
        y = df['Label_Num']
        if fit:
            X = self.scaler.fit_transform(X)
        else:
            X = self.scaler.transform(X)
        return self.create_sequences(X, y)

    def create_sequences(self, X, y):
        sequences = []
        labels = []
        for i in range(len(X) - self.sequence_length):
            sequences.append(X[i:i+self.sequence_length])
            labels.append(y.iloc[i+self.sequence_length-1])
        return np.array(sequences), np.array(labels)
    
    def setup(self, stage=None):
        self.train_dataset = TensorDataset(torch.FloatTensor(self.X_train), torch.LongTensor(self.y_train))
        self.val_dataset = TensorDataset(torch.FloatTensor(self.X_val), torch.LongTensor(self.y_val))
        self.test_dataset = TensorDataset(torch.FloatTensor(self.X_test), torch.LongTensor(self.y_test))
        
        print(f"Training samples: {len(self.train_dataset)} (limited to {self.max_train_samples})")
        print(f"Validation samples: {len(self.val_dataset)} (limited to {self.max_val_samples})")
        print(f"Test samples: {len(self.test_dataset)} (limited to {self.max_test_samples})")
    
    def train_dataloader(self):
        if self.oversample:
            class_counts = np.bincount(self.y_train)
            weights = 1. / class_counts[self.y_train]
            sampler = WeightedRandomSampler(weights, len(weights), replacement=True)
        else:
            sampler = RandomSampler(self.train_dataset)
            
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            sampler=sampler,
            num_workers=self.num_workers,
            persistent_workers=True,
            pin_memory=True
        )
    
    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )
    
    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )

def init_wandb():
    user_secrets = UserSecretsClient()
    wandb_api_key = user_secrets.get_secret("mohammad_wandb_secret")
    wandb.login(key=wandb_api_key)
    
    run = wandb.init(
        project=config.wandb.project,
        entity=config.wandb.entity,
        tags=config.wandb.tags,
        notes=config.wandb.notes,
        config={
            "input_size": None,
            "num_classes": None,
            "sequence_length": config.training.sequence_length,
            "train_samples": config.training.max_train_samples,
            "val_samples": config.training.max_val_samples,
            "test_samples": config.training.max_test_samples,
            "model_config": dict(config.model),
            "training_config": dict(config.training)
        }
    )
    
    wandb_logger = WandbLogger(
        experiment=run,
        log_model='all'
    )
    
    return wandb_logger, run

def main():
    wandb_logger, run = init_wandb()
    
    data_module = NIDSDataModule(config)
    data_module.prepare_data()
    data_module.setup()
    
    sample_x, _ = next(iter(data_module.train_dataloader()))
    input_size_per_timestep = sample_x.shape[2]  # Features per timestep
    total_input_size = input_size_per_timestep * config.training.sequence_length
    num_classes = len(data_module.classes)
    
    run.config.update({
        "input_size_per_timestep": input_size_per_timestep,
        "total_input_size": total_input_size,
        "num_classes": num_classes
    })
    
    model = AutoEncoderModel(total_input_size, num_classes, config)
    
    early_stopping = EarlyStopping(
        monitor='val_loss',
        patience=config.training.early_stopping_patience,
        mode='min'
    )
    
    checkpoint_callback = ModelCheckpoint(
        monitor='val_loss',
        mode='min',
        save_top_k=1,
        dirpath='checkpoints',
        filename='best_model'
    )
 
    trainer = pl.Trainer(
        precision=16,
        logger=wandb_logger,
        max_epochs=config.training.max_epochs,
        callbacks=[early_stopping, checkpoint_callback],
        deterministic=True,
        gradient_clip_val=1.0,
        enable_progress_bar=True,
        log_every_n_steps=1000
    )
    
    trainer.fit(model, datamodule=data_module)
    
    test_results = trainer.test(model, datamodule=data_module)
    
    # Collect all predictions and targets
    test_loader = data_module.test_dataloader()
    all_preds = []
    all_targets = []
    
    model.eval()
    with torch.no_grad():
        for batch in test_loader:
            x, y = batch
            y_hat = model(x)
            preds = torch.argmax(y_hat, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(y.cpu().numpy())
    
    # Calculate metrics
    test_acc = accuracy_score(all_targets, all_preds)
    test_f1 = f1_score(all_targets, all_preds, average='weighted')
    
    # Log final metrics
    wandb.log({
        'test_acc': test_acc,
        'test_f1': test_f1,
        'test_loss': test_results[0]['test_loss']
    })
    
    # Enhanced multiclass confusion matrix
    class_names = data_module.classes.tolist()
    conf_mat = confusion_matrix(all_targets, all_preds)
    
    # Create a custom confusion matrix plot
    data = []
    for i in range(len(class_names)):
        for j in range(len(class_names)):
            data.append([class_names[i], class_names[j], conf_mat[i, j]])
    
    fields = {
        "Actual": "Actual",
        "Predicted": "Predicted",
        "n": "Count"
    }
    
    wandb.log({
        "multiclass_confusion_matrix": wandb.plot_table(
            "wandb/confusion_matrix/v1",
            wandb.Table(columns=["Actual", "Predicted", "Count"], data=data),
            fields,
            {"title": "Multiclass Confusion Matrix"}
        )
    })
    
    # Classification Report
    report = classification_report(
        all_targets, all_preds, 
        target_names=class_names,
        output_dict=True
    )
    
    report_table = wandb.Table(columns=["Class", "Precision", "Recall", "F1-Score", "Support"])
    for class_name in class_names:
        report_table.add_data(
            class_name,
            report[class_name]["precision"],
            report[class_name]["recall"],
            report[class_name]["f1-score"],
            report[class_name]["support"]
        )
    
    report_table.add_data(
        "Weighted Avg",
        report["weighted avg"]["precision"],
        report["weighted avg"]["recall"],
        report["weighted avg"]["f1-score"],
        report["weighted avg"]["support"]
    )
    
    wandb.log({"classification_report": report_table})
    
    wandb.finish()

if __name__ == "__main__":
    main()




Training samples: 100000 (limited to 100000)
Validation samples: 20000 (limited to 20000)
Test samples: 15000 (limited to 15000)
Training samples: 100000 (limited to 100000)
Validation samples: 20000 (limited to 20000)
Test samples: 15000 (limited to 15000)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

<h2>First Dataset</h2>

In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler, RandomSampler
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, f1_score
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
import wandb
from omegaconf import OmegaConf
import os
import torchmetrics
import warnings
from kaggle_secrets import UserSecretsClient

warnings.filterwarnings('ignore')

# GRU Configuration
# https://wandb.ai/mohammad-fleity-lebanese-university/DL-NIDS-2--cic-ids-2017
config = OmegaConf.create({
    "wandb": {
        "project": "DL-NIDS-2--cic-ids-2017",
        "entity": "mohammad-fleity-lebanese-university",
        "tags": ["AutoEncoderDecoder", "cic-ids-2017", "PyTorch"],
        "notes": "Optimized AutoEncoderDecoder for network intrusion detection with limited samples"
    },
    "model": {
        "hidden_size": 128,
        "num_layers": 2,
        "dropout": 0.4,
        "dense_units": [128, 64],
        "learning_rate": 0.0001,
        "weight_decay": 1e-4
    },
    "training": {
        "sequence_length": 4,
        "batch_size": 128,
        "max_epochs": 60,            # Hard limit of  epochs
        "early_stopping_patience": 7,
        "oversample": True,
        "gpus": 1 if torch.cuda.is_available() else 0,
        "max_train_samples": 200000,  # Maximum training samples
        "max_val_samples": 30000,     # Maximum validation samples
        "max_test_samples": 30000     # Maximum test samples
    },
    "data": {
        # /kaggle/input/cic-ids-2017-parquet/cic_ids_2017.parquet
        "raw": "cic_ids_2017.parquet",
        "num_workers": 4
    }
})


class AutoEncoderModel(pl.LightningModule):
    def __init__(self, input_size, num_classes, config):
        super().__init__()
        self.save_hyperparameters({'config': config})
        self.config = config
        
        self.encoder = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.ReLU(),
            nn.Dropout(config.model.dropout),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(config.model.dropout),
            nn.Linear(64, 32)
        )

        self.decoder = nn.Sequential(
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Dropout(config.model.dropout),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Dropout(config.model.dropout),
            nn.Linear(128, input_size)  # match input size for reconstruction
        )

        self.classifier = nn.Sequential(
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Dropout(config.model.dropout),
            nn.Linear(64, num_classes)
        )

        self.recon_loss = nn.MSELoss()
        self.class_loss = nn.CrossEntropyLoss()

        # Metrics
        self.train_acc = torchmetrics.classification.Accuracy(task='multiclass', num_classes=num_classes)
        self.val_acc = torchmetrics.classification.Accuracy(task='multiclass', num_classes=num_classes)
        self.test_acc = torchmetrics.classification.Accuracy(task='multiclass', num_classes=num_classes)

    def forward(self, x):
        batch_size, seq_len, features = x.shape
        x_flat = x.view(batch_size, -1)
        z = self.encoder(x_flat)

        x_hat = self.decoder(z).view(batch_size, seq_len, features)
        logits = self.classifier(z)
        return x_hat, logits

    def training_step(self, batch, batch_idx):
        x, y = batch
        x_hat, logits = self(x)
        loss_recon = self.recon_loss(x_hat, x)
        loss_cls = self.class_loss(logits, y)
        loss = loss_recon + loss_cls

        preds = torch.argmax(logits, dim=1)
        self.train_acc.update(preds, y)

        self.log("train_loss_epoch", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("train_acc_epoch", self.train_acc.compute() * 100, on_step=False, on_epoch=True, prog_bar=True)
        # self.log("train_acc_epoch", self.train_acc.compute()*100, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        x_hat, logits = self(x)
        loss_recon = self.recon_loss(x_hat, x)
        loss_cls = self.class_loss(logits, y)
        loss = loss_recon + loss_cls

        preds = torch.argmax(logits, dim=1)
        self.val_acc.update(preds, y)

        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        # self.log("val_acc", self.val_acc*100, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val_acc", self.val_acc.compute() * 100, on_step=False, on_epoch=True, prog_bar=True)
        return loss

    # def on_validation_epoch_end(self):
    #     self.val_acc.reset()
    # def on_training_epoch_end(self):
    #     self.train_acc.reset()
    # def on_testing_epoch_end(self):
    #     self.test_acc.reset()

    def test_step(self, batch, batch_idx):
        x, y = batch
        x_hat, logits = self(x)
        loss_recon = self.recon_loss(x_hat, x)
        # loss_cls = self.class_loss(logits, y)
        loss = loss_recon 
        # loss = loss_recon + loss_cls

        preds = torch.argmax(logits, dim=1)
        self.test_acc.update(preds, y)

        self.log("test_loss", loss, on_step=False, on_epoch=True)
        # print("loss: ",loss ,".")
        # self.log("test_acc", self.test_acc, on_step=False, on_epoch=True)
        self.log("test_acc", self.test_acc.compute() * 100, on_step=False, on_epoch=True, prog_bar=True)
        return {"loss": loss}

    def configure_optimizers(self):
        return optim.AdamW(
            self.parameters(),
            lr=self.hparams.config.model.learning_rate,
            weight_decay=self.hparams.config.model.weight_decay
        )


class NIDSDataModule(pl.LightningDataModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.batch_size = config.training.batch_size
        self.sequence_length = config.training.sequence_length
        self.num_workers = config.data.num_workers
        self.oversample = config.training.oversample
        self.max_train_samples = config.training.max_train_samples
        self.max_val_samples = config.training.max_val_samples
        self.max_test_samples = config.training.max_test_samples

    def prepare_data(self):
        # /kaggle/input/cic-ton-iot-parquet
        # cic_ids_2017
        # /kaggle/input/cic-ids-2017-parquet
        df = pd.read_parquet(os.path.join('/kaggle/input/cic-ids-2017-parquet', self.config.data.raw))
    
        # Clean data
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.dropna(inplace=True)
        df.drop_duplicates(inplace=True)
    
        # Identify non-numeric columns
        self.non_numeric_cols = ['Label', 'Timestamp', 'Flow ID', 'Src IP', 
                               'Src Port', 'Attack', 'Dst IP', 'Dst Port', 'Protocol']
        self.non_numeric_cols = [col for col in self.non_numeric_cols if col in df.columns]
    
        # Encode labels
        self.label_encoder = LabelEncoder()
        df['Label_Num'] = self.label_encoder.fit_transform(df['Label'])
        self.classes = self.label_encoder.classes_
    
        # Initialize scaler
        self.scaler = StandardScaler()
    
        # Stratified split with sample limits
        train_df, test_df = train_test_split(
            df,
            test_size=0.3,  # 70% train, 30% test+val
            random_state=42,
            stratify=df['Label_Num']
        )
        
        # Further split test into val and test
        val_df, test_df = train_test_split(
            test_df,
            test_size=0.5,  # 15% val, 15% test
            random_state=42,
            stratify=test_df['Label_Num']
        )
    
        # Process each split with sample limits
        self.X_train, self.y_train = self._prepare_features(train_df, fit=True)
        self.X_val, self.y_val = self._prepare_features(val_df, fit=False)
        self.X_test, self.y_test = self._prepare_features(test_df, fit=False)
        
        # Apply sample limits
        self._limit_samples()

    def _limit_samples(self):
        """Limit samples according to configuration"""
        # Training data
        if len(self.X_train) > self.max_train_samples:
            indices = np.random.choice(len(self.X_train), self.max_train_samples, replace=False)
            self.X_train = self.X_train[indices]
            self.y_train = self.y_train[indices]
        
        # Validation data
        if len(self.X_val) > self.max_val_samples:
            indices = np.random.choice(len(self.X_val), self.max_val_samples, replace=False)
            self.X_val = self.X_val[indices]
            self.y_val = self.y_val[indices]
        
        # Test data
        if len(self.X_test) > self.max_test_samples:
            indices = np.random.choice(len(self.X_test), self.max_test_samples, replace=False)
            self.X_test = self.X_test[indices]
            self.y_test = self.y_test[indices]
    
    def _prepare_features(self, df, fit=False):
        X = df.drop(['Label_Num'] + self.non_numeric_cols, axis=1)
        y = df['Label_Num']
        if fit:
            X = self.scaler.fit_transform(X)
        else:
            X = self.scaler.transform(X)
        return self.create_sequences(X, y)

    def create_sequences(self, X, y):
        sequences = []
        labels = []
        for i in range(len(X) - self.sequence_length):
            sequences.append(X[i:i+self.sequence_length])
            labels.append(y.iloc[i+self.sequence_length-1])
        return np.array(sequences), np.array(labels)
    
    def setup(self, stage=None):
        self.train_dataset = TensorDataset(torch.FloatTensor(self.X_train), torch.LongTensor(self.y_train))
        self.val_dataset = TensorDataset(torch.FloatTensor(self.X_val), torch.LongTensor(self.y_val))
        self.test_dataset = TensorDataset(torch.FloatTensor(self.X_test), torch.LongTensor(self.y_test))
        
        print(f"Training samples: {len(self.train_dataset)} (limited to {self.max_train_samples})")
        print(f"Validation samples: {len(self.val_dataset)} (limited to {self.max_val_samples})")
        print(f"Test samples: {len(self.test_dataset)} (limited to {self.max_test_samples})")
    
    def train_dataloader(self):
        if self.oversample:
            class_counts = np.bincount(self.y_train)
            weights = 1. / class_counts[self.y_train]
            sampler = WeightedRandomSampler(weights, len(weights), replacement=True)
        else:
            sampler = RandomSampler(self.train_dataset)
            
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            sampler=sampler,
            num_workers=self.num_workers,
            persistent_workers=True,
            pin_memory=True
        )
    
    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )
    
    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )

def init_wandb():
    user_secrets = UserSecretsClient()
    wandb_api_key = user_secrets.get_secret("mohammad_wandb_secret")
    wandb.login(key=wandb_api_key)
    
    run = wandb.init(
        project=config.wandb.project,
        entity=config.wandb.entity,
        tags=config.wandb.tags,
        notes=config.wandb.notes,
        config={
            "input_size": None,
            "num_classes": None,
            "sequence_length": config.training.sequence_length,
            "train_samples": config.training.max_train_samples,
            "val_samples": config.training.max_val_samples,
            "test_samples": config.training.max_test_samples,
            "model_config": dict(config.model),
            "training_config": dict(config.training)
        }
    )
    
    wandb_logger = WandbLogger(
        experiment=run,
        log_model='all'
    )
    
    return wandb_logger, run

def main():
    wandb_logger, run = init_wandb()
    
    data_module = NIDSDataModule(config)
    data_module.prepare_data()
    data_module.setup()
    
    sample_x, _ = next(iter(data_module.train_dataloader()))
    input_size_per_timestep = sample_x.shape[2]  # Features per timestep
    total_input_size = input_size_per_timestep * config.training.sequence_length
    num_classes = len(data_module.classes)
    
    run.config.update({
        "input_size_per_timestep": input_size_per_timestep,
        "total_input_size": total_input_size,
        "num_classes": num_classes
    })
    
    model = AutoEncoderModel(total_input_size, num_classes, config)
    
    early_stopping = EarlyStopping(
        monitor='val_loss',
        patience=config.training.early_stopping_patience,
        mode='min'
    )
    
    checkpoint_callback = ModelCheckpoint(
        monitor='val_loss',
        mode='min',
        save_top_k=1,
        dirpath='checkpoints',
        filename='best_model'
    )
 
    trainer = pl.Trainer(
        precision=16,
        logger=wandb_logger,
        max_epochs=config.training.max_epochs,
        callbacks=[early_stopping, checkpoint_callback],
        deterministic=True,
        gradient_clip_val=1.0,
        enable_progress_bar=True,
        log_every_n_steps=1000
    )
    
    trainer.fit(model, datamodule=data_module)
    
    test_results = trainer.test(model, datamodule=data_module)
    
    # Collect all predictions and targets
    test_loader = data_module.test_dataloader()
    all_preds = []
    all_targets = []
    
    model.eval()
    with torch.no_grad():
        for batch in test_loader:
            x, y = batch
            y_hat = model(x)
            preds = torch.argmax(y_hat, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(y.cpu().numpy())
    
    # Calculate metrics
    test_acc = accuracy_score(all_targets, all_preds)
    test_f1 = f1_score(all_targets, all_preds, average='weighted')
    
    # Log final metrics
    wandb.log({
        'test_acc': test_acc,
        'test_f1': test_f1,
        'test_loss': test_results[0]['test_loss']
    })
    
    # Enhanced multiclass confusion matrix
    class_names = data_module.classes.tolist()
    conf_mat = confusion_matrix(all_targets, all_preds)
    
    # Create a custom confusion matrix plot
    data = []
    for i in range(len(class_names)):
        for j in range(len(class_names)):
            data.append([class_names[i], class_names[j], conf_mat[i, j]])
    
    fields = {
        "Actual": "Actual",
        "Predicted": "Predicted",
        "n": "Count"
    }
    
    wandb.log({
        "multiclass_confusion_matrix": wandb.plot_table(
            "wandb/confusion_matrix/v1",
            wandb.Table(columns=["Actual", "Predicted", "Count"], data=data),
            fields,
            {"title": "Multiclass Confusion Matrix"}
        )
    })
    
    # Classification Report
    report = classification_report(
        all_targets, all_preds, 
        target_names=class_names,
        output_dict=True
    )
    
    report_table = wandb.Table(columns=["Class", "Precision", "Recall", "F1-Score", "Support"])
    for class_name in class_names:
        report_table.add_data(
            class_name,
            report[class_name]["precision"],
            report[class_name]["recall"],
            report[class_name]["f1-score"],
            report[class_name]["support"]
        )
    
    report_table.add_data(
        "Weighted Avg",
        report["weighted avg"]["precision"],
        report["weighted avg"]["recall"],
        report["weighted avg"]["f1-score"],
        report["weighted avg"]["support"]
    )
    
    wandb.log({"classification_report": report_table})
    
    wandb.finish()

if __name__ == "__main__":
    main()


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mmohammad-fleity[0m ([33mmohammad-fleity-lebanese-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Training samples: 200000 (limited to 200000)
Validation samples: 30000 (limited to 30000)
Test samples: 30000 (limited to 30000)
Training samples: 200000 (limited to 200000)
Validation samples: 30000 (limited to 30000)
Test samples: 30000 (limited to 30000)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]