In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler, RandomSampler
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, f1_score
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
import wandb
from omegaconf import OmegaConf
import os
import warnings
from kaggle_secrets import UserSecretsClient
# from kaggle_secrets import UserSecretsClient


warnings.filterwarnings('ignore')

# GRU Configuration
config = OmegaConf.create({
    "wandb": {
        "project": "DL-NIDS-2--cic-ids-2017",
        "entity": "mohammad-fleity-lebanese-university",
        "tags": ["GRU", "CIC-IDS-2017", "PyTorch"],
        "notes": "Optimized GRU for network intrusion detection"
    },
    "model": {
        "hidden_size": 128,          # Hidden state size
        "num_layers": 2,             # Number of GRU layers
        "dropout": 0.4,              # Dropout rate
        "dense_units": [128, 64],    # Dense layer sizes
        "learning_rate": 0.0001,     # Learning rate
        "weight_decay": 1e-4         # L2 regularization
    },
    "training": {
        "sequence_length": 5,        # Sequence length
        "batch_size": 128,           # Batch size
        "max_epochs": 10,            # Max training epochs
        "early_stopping_patience": 7, # Early stopping patience
        "oversample": True,          # Class balancing
        "gpus": 1 if torch.cuda.is_available() else 0,
        "train_size": 0.7,           # Train split size
        "val_size": 0.15            # Validation split size
    },
    "data": {
        "raw": "cic_ids_2017.parquet",
        "num_workers": 4
    }
})

class GRUModel(pl.LightningModule):
    def __init__(self, input_size, num_classes, config):
        super().__init__()
        self.save_hyperparameters()
        
        # GRU with layer normalization
        self.gru = nn.GRU(
            input_size=input_size,
            hidden_size=config.model.hidden_size,
            num_layers=config.model.num_layers,
            batch_first=True,
            dropout=config.model.dropout if config.model.num_layers > 1 else 0
        )
        
        self.gru_ln = nn.LayerNorm(config.model.hidden_size)
        
        # Dense layers
        self.dense = nn.Sequential(
            nn.Linear(config.model.hidden_size, config.model.dense_units[0]),
            nn.LayerNorm(config.model.dense_units[0]),
            nn.ReLU(),
            nn.Dropout(config.model.dropout),
            nn.Linear(config.model.dense_units[0], config.model.dense_units[1]),
            nn.LayerNorm(config.model.dense_units[1]),
            nn.ReLU(),
            nn.Dropout(config.model.dropout)
        )
        
        self.output = nn.Linear(config.model.dense_units[1], num_classes)
        self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

    def forward(self, x):
        gru_out, _ = self.gru(x)
        gru_out = gru_out[:, -1, :]  # Last timestep
        gru_out = self.gru_ln(gru_out)
        features = self.dense(gru_out)
        return self.output(features)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        
        # Log metrics per step
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, prog_bar=True)
        
        return loss
    
    # def training_epoch_end(self, outputs):
    #     # Log epoch-level metrics
    #     avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
    #     self.log('train_epoch_loss', avg_loss, prog_bar=True)
        
    #     # Calculate epoch accuracy
    #     correct = sum([x['correct'] for x in outputs])
    #     total = sum([x['total'] for x in outputs])
    #     epoch_acc = correct / total
    #     self.log('train_acc_epoch', epoch_acc, prog_bar=True)
    def on_train_epoch_end(self, outputs):
        # Log epoch-level metrics
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        self.log('train_epoch_loss', avg_loss, prog_bar=True)
        
        # Calculate epoch accuracy
        correct = sum([x['correct'] for x in outputs])
        total = sum([x['total'] for x in outputs])
        epoch_acc = correct / total
        self.log('train_acc_epoch', epoch_acc, prog_bar=True)
        
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        
        return {'val_loss': loss, 'val_acc': acc, 'correct': (logits.argmax(dim=1) == y).sum(), 'total': len(y)}
    
    # def validation_epoch_end(self, outputs):
    #     avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
    #     self.log('val_epoch_loss', avg_loss, prog_bar=True)
        
    #     correct = sum([x['correct'] for x in outputs])
    #     total = sum([x['total'] for x in outputs])
    #     epoch_acc = correct / total
    #     self.log('val_acc_epoch', epoch_acc, prog_bar=True)
    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        self.log('val_epoch_loss', avg_loss, prog_bar=True)
    
        correct = sum([x['correct'] for x in outputs])
        total = sum([x['total'] for x in outputs])
        epoch_acc = correct / total
        self.log('val_acc_epoch', epoch_acc, prog_bar=True)

    def on_validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        self.log('val_epoch_loss', avg_loss, prog_bar=True)
        
        correct = sum([x['correct'] for x in outputs])
        total = sum([x['total'] for x in outputs])
        epoch_acc = correct / total
        self.log('val_acc_epoch', epoch_acc, prog_bar=True)



    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()
        
        self.log('test_loss', loss)
        self.log('test_acc', acc)
        
        return {'test_loss': loss, 'test_acc': acc, 'preds': logits.argmax(dim=1), 'targets': y}
    
    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), 
                              lr=self.hparams.config.model.learning_rate,
                              weight_decay=self.hparams.config.model.weight_decay)
        return optimizer

class NIDSDataModule(pl.LightningDataModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.batch_size = config.training.batch_size
        self.sequence_length = config.training.sequence_length
        self.num_workers = config.data.num_workers
        self.oversample = config.training.oversample

    def prepare_data(self):
        df = pd.read_parquet(os.path.join('/kaggle/input/cic-ids-2017-parquet', self.config.data.raw))
    
        # Clean data
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.dropna(inplace=True)
        df.drop_duplicates(inplace=True)
    
        # Identify non-numeric columns
        self.non_numeric_cols = ['Label', 'Timestamp', 'Flow ID', 'Src IP', 
                               'Src Port', 'Attack', 'Dst IP', 'Dst Port', 'Protocol']
        self.non_numeric_cols = [col for col in self.non_numeric_cols if col in df.columns]
    
        # Encode labels
        self.label_encoder = LabelEncoder()
        df['Label_Num'] = self.label_encoder.fit_transform(df['Label'])
        self.classes = self.label_encoder.classes_
    
        # Initialize scaler
        self.scaler = StandardScaler()
    
        # Train/Val/Test split
        train_df, test_df = train_test_split(
            df,
            test_size=1 - self.config.training.train_size,
            random_state=42,
            stratify=df['Label_Num']
        )
        val_df, test_df = train_test_split(
            test_df,
            test_size=0.5,
            random_state=42,
            stratify=test_df['Label_Num']
        )
        print(len(train_df))
        print(len(test_df))
        # max_rows = 90_000
        # if len(train_df) > max_rows:
        #     train_df = train_df.sample(n=max_rows, random_state=42)

        # Use `fit=True` only for training data
        self.X_train, self.y_train = self._prepare_features(train_df, fit=True)
        self.X_val, self.y_val = self._prepare_features(val_df, fit=False)
        self.X_test, self.y_test = self._prepare_features(test_df, fit=False)
    
    def _prepare_features(self, df, fit=False):
        X = df.drop(['Label_Num'] + self.non_numeric_cols, axis=1)
        y = df['Label_Num']
        if fit:
            X = self.scaler.fit_transform(X)
        else:
            X = self.scaler.transform(X)
        return self.create_sequences(X, y)

    def create_sequences(self, X, y):
        sequences = []
        labels = []
        for i in range(len(X) - self.sequence_length):
            sequences.append(X[i:i+self.sequence_length])
            labels.append(y.iloc[i+self.sequence_length-1])
        return np.array(sequences), np.array(labels)
    
    def setup(self, stage=None):
        self.train_dataset = TensorDataset(torch.FloatTensor(self.X_train), torch.LongTensor(self.y_train))
        self.val_dataset = TensorDataset(torch.FloatTensor(self.X_val), torch.LongTensor(self.y_val))
        self.test_dataset = TensorDataset(torch.FloatTensor(self.X_test), torch.LongTensor(self.y_test))
        
        print("Training samples:", len(self.train_dataset))
        print("Validation samples:", len(self.val_dataset))
        print("Test samples:", len(self.test_dataset))
    
    def train_dataloader(self):
        if self.oversample:
            class_counts = np.bincount(self.y_train)
            weights = 1. / class_counts[self.y_train]
            sampler = WeightedRandomSampler(weights, len(weights), replacement=True)
        else:
            sampler = RandomSampler(self.train_dataset)
            
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            sampler=sampler,
            num_workers=self.num_workers,
            persistent_workers=True,
            pin_memory=True
        )
    
    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )
    
    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )

def init_wandb():
    user_secrets = UserSecretsClient()
    # secret_value_0 = user_secrets.get_secret("mohammad_wandb_secret")
    # user_secrets = UserSecretsClient()
    wandb_api_key = user_secrets.get_secret("mohammad_wandb_secret")
    wandb.login(key=wandb_api_key)
    
    run = wandb.init(
        project=config.wandb.project,
        entity=config.wandb.entity,
        tags=config.wandb.tags,
        notes=config.wandb.notes,
        config={
            "input_size": None,
            "num_classes": None,
            "sequence_length": config.training.sequence_length,
            "train_samples": None,
            "test_samples": None,
            "model_config": dict(config.model),
            "training_config": dict(config.training)
        }
    )
    
    wandb_logger = WandbLogger(
        experiment=run,
        log_model='all'
    )
    
    return wandb_logger, run

def main():
    wandb_logger, run = init_wandb()
    
    data_module = NIDSDataModule(config)
    data_module.prepare_data()
    data_module.setup()
    
    sample_x, _ = next(iter(data_module.train_dataloader()))
    input_size = sample_x.shape[2]
    num_classes = len(data_module.classes)
    
    run.config.update({
        "input_size": input_size,
        "num_classes": num_classes,
        "train_samples": len(data_module.train_dataset),
        "test_samples": len(data_module.test_dataset)
    })
    
    model = GRUModel(input_size, num_classes, config)
    
    early_stopping = EarlyStopping(
        monitor='val_loss',
        patience=config.training.early_stopping_patience,
        mode='min'
    )
    
    checkpoint_callback = ModelCheckpoint(
        monitor='val_acc',
        mode='max',
        save_top_k=1,
        dirpath='checkpoints',
        filename='best_model'
    )
 
    trainer = pl.Trainer(
        precision=16,
        logger=wandb_logger,
        max_epochs=config.training.max_epochs,
        callbacks=[early_stopping, checkpoint_callback],
        deterministic=True,
        gradient_clip_val=1.0,
        enable_progress_bar=True,
        log_every_n_steps=1000
    )
    
    trainer.fit(model, datamodule=data_module)
    
    test_results = trainer.test(model, datamodule=data_module)
    
    # Collect all predictions and targets
    test_loader = data_module.test_dataloader()
    all_preds = []
    all_targets = []
    
    model.eval()
    with torch.no_grad():
        for batch in test_loader:
            x, y = batch
            y_hat = model(x)
            preds = torch.argmax(y_hat, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(y.cpu().numpy())
    
    # Calculate metrics
    test_acc = accuracy_score(all_targets, all_preds)
    test_f1 = f1_score(all_targets, all_preds, average='weighted')
    
    # Log final metrics
    wandb.log({
        'test_acc': test_acc,
        'test_f1': test_f1,
        'test_loss': test_results[0]['test_loss']
    })
    
    # Enhanced multiclass confusion matrix
    class_names = data_module.classes.tolist()
    conf_mat = confusion_matrix(all_targets, all_preds)
    
    # Create a custom confusion matrix plot
    data = []
    for i in range(len(class_names)):
        for j in range(len(class_names)):
            data.append([class_names[i], class_names[j], conf_mat[i, j]])
    
    fields = {
        "Actual": "Actual",
        "Predicted": "Predicted",
        "n": "Count"
    }
    
    wandb.log({
        "multiclass_confusion_matrix": wandb.plot_table(
            "wandb/confusion_matrix/v1",
            wandb.Table(columns=["Actual", "Predicted", "Count"], data=data),
            fields,
            {"title": "Multiclass Confusion Matrix"}
        )
    })
    
    # Classification Report
    report = classification_report(
        all_targets, all_preds, 
        target_names=class_names,
        output_dict=True
    )
    
    report_table = wandb.Table(columns=["Class", "Precision", "Recall", "F1-Score", "Support"])
    for class_name in class_names:
        report_table.add_data(
            class_name,
            report[class_name]["precision"],
            report[class_name]["recall"],
            report[class_name]["f1-score"],
            report[class_name]["support"]
        )
    
    report_table.add_data(
        "Weighted Avg",
        report["weighted avg"]["precision"],
        report["weighted avg"]["recall"],
        report["weighted avg"]["f1-score"],
        report["weighted avg"]["support"]
    )
    
    wandb.log({"classification_report": report_table})
    
    wandb.finish()

if __name__ == "__main__":
    main()



In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler, RandomSampler
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, f1_score
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
import wandb
from omegaconf import OmegaConf
import os
import warnings
from kaggle_secrets import UserSecretsClient

warnings.filterwarnings('ignore')

# GRU Configuration
config = OmegaConf.create({
    "wandb": {
        "project": "https://wandb.ai/mohamDL-NIDS-2--cic-ids-2017",
        "entity": "mohammad-fleity-lebanese-university",
        "tags": ["GRU", "CIC-IDS-2017", "PyTorch"],
        "notes": "Optimized GRU for network intrusion detection with limited samples"
    },
    "model": {
        "hidden_size": 128,
        "num_layers": 2,
        "dropout": 0.4,
        "dense_units": [128, 64],
        "learning_rate": 0.0001,
        "weight_decay": 1e-4
    },
    "training": {
        "sequence_length": 5,
        "batch_size": 128,
        "max_epochs": 10,            # Hard limit of 10 epochs
        "early_stopping_patience": 7,
        "oversample": True,
        "gpus": 1 if torch.cuda.is_available() else 0,
        "max_train_samples": 1000000,  # Maximum training samples
        "max_val_samples": 200000,     # Maximum validation samples
        "max_test_samples": 200000     # Maximum test samples
    },
    "data": {
        "raw": "cic_ids_2017.parquet",
        "num_workers": 4
    }
})

# class GRUModel(pl.LightningModule):
#     def __init__(self, input_size, num_classes, config):
#         super().__init__()
#         self.save_hyperparameters()
#         self.outputs=[]
#         self.gru = nn.GRU(
#             input_size=input_size,
#             hidden_size=config.model.hidden_size,
#             num_layers=config.model.num_layers,
#             batch_first=True,
#             dropout=config.model.dropout if config.model.num_layers > 1 else 0
#         )
        
#         self.gru_ln = nn.LayerNorm(config.model.hidden_size)
        
#         self.dense = nn.Sequential(
#             nn.Linear(config.model.hidden_size, config.model.dense_units[0]),
#             nn.LayerNorm(config.model.dense_units[0]),
#             nn.ReLU(),
#             nn.Dropout(config.model.dropout),
#             nn.Linear(config.model.dense_units[0], config.model.dense_units[1]),
#             nn.LayerNorm(config.model.dense_units[1]),
#             nn.ReLU(),
#             nn.Dropout(config.model.dropout)
#         )
        
#         self.output = nn.Linear(config.model.dense_units[1], num_classes)
#         self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

#     def forward(self, x):
#         gru_out, _ = self.gru(x)
#         gru_out = gru_out[:, -1, :]
#         gru_out = self.gru_ln(gru_out)
#         features = self.dense(gru_out)
#         return self.output(features)
    
#     def training_step(self, batch, batch_idx):
#         x, y = batch
#         logits = self(x)
#         loss = self.criterion(logits, y)
#         acc = (logits.argmax(dim=1) == y).float().mean()
        
#         self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
#         self.log('train_acc', acc, on_step=True, on_epoch=True, prog_bar=True)
        
#         return {'loss': loss, 'correct': (logits.argmax(dim=1) == y).sum(), 'total': len(y)}
    
#     # def training_epoch_end(self, outputs):
#     #     avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
#     #     self.log('train_epoch_loss', avg_loss, prog_bar=True)
        
#     #     correct = sum([x['correct'] for x in outputs])
#     #     total = sum([x['total'] for x in outputs])
#     #     epoch_acc = correct / total
#     #     self.log('train_acc_epoch', epoch_acc, prog_bar=True)
#     def on_train_epoch_end(self):
#         # Access saved outputs from training_step
#         if not hasattr(self, 'train_step_outputs'):
#             return
            
#         avg_loss = torch.stack([x['loss'] for x in self.train_step_outputs]).mean()
#         correct = sum([x['correct'] for x in self.train_step_outputs])
#         total = sum([x['total'] for x in self.train_step_outputs])
#         epoch_acc = correct / total
        
#         self.log('train_epoch_loss', avg_loss, prog_bar=True)
#         self.log('train_acc_epoch', epoch_acc, prog_bar=True)
#         self.train_step_outputs.clear()  # free memory

#     def training_step(self, batch, batch_idx):
#         x, y = batch
#         logits = self(x)
#         loss = self.criterion(logits, y)
#         acc = (logits.argmax(dim=1) == y).float().mean()
        
#         self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
#         self.log('train_acc', acc, on_step=True, on_epoch=True, prog_bar=True)
        
#         output = {'loss': loss, 'correct': (logits.argmax(dim=1) == y).sum(), 'total': len(y)}
        
#         if not hasattr(self, 'train_step_outputs'):
#             self.train_step_outputs = []
#         self.train_step_outputs.append(output)
        
#         return output

#     def validation_step(self, batch, batch_idx):
#         x, y = batch
#         logits = self(x)
#         loss = self.criterion(logits, y)
#         acc = (logits.argmax(dim=1) == y).float().mean()
        
#         self.log('val_loss', loss, prog_bar=True)
#         self.log('val_acc', acc, prog_bar=True)
        
#         return {'val_loss': loss, 'correct': (logits.argmax(dim=1) == y).sum(), 'total': len(y)}
    
#     def validation_epoch_end(self, outputs):
#         avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
#         self.log('val_epoch_loss', avg_loss, prog_bar=True)
        
#         correct = sum([x['correct'] for x in outputs])
#         total = sum([x['total'] for x in outputs])
#         epoch_acc = correct / total
#         self.log('val_acc_epoch', epoch_acc, prog_bar=True)
    
#     def test_step(self, batch, batch_idx):
#         x, y = batch
#         logits = self(x)
#         loss = self.criterion(logits, y)
#         acc = (logits.argmax(dim=1) == y).float().mean()
        
#         self.log('test_loss', loss)
#         self.log('test_acc', acc)
        
#         return {'test_loss': loss, 'preds': logits.argmax(dim=1), 'targets': y}
    
#     def configure_optimizers(self):
#         optimizer = optim.AdamW(self.parameters(), 
#                               lr=self.hparams.config.model.learning_rate,
#                               weight_decay=self.hparams.config.model.weight_decay)
#         return optimizer

class GRUModel(pl.LightningModule):
    def __init__(self, input_size, num_classes, config):
        super().__init__()
        self.save_hyperparameters()

        # Temporary storage for step outputs
        self.train_outputs = []
        self.val_outputs = []

        self.gru = nn.GRU(
            input_size=input_size,
            hidden_size=config.model.hidden_size,
            num_layers=config.model.num_layers,
            batch_first=True,
            dropout=config.model.dropout if config.model.num_layers > 1 else 0
        )
        
        self.gru_ln = nn.LayerNorm(config.model.hidden_size)

        self.dense = nn.Sequential(
            nn.Linear(config.model.hidden_size, config.model.dense_units[0]),
            nn.LayerNorm(config.model.dense_units[0]),
            nn.ReLU(),
            nn.Dropout(config.model.dropout),
            nn.Linear(config.model.dense_units[0], config.model.dense_units[1]),
            nn.LayerNorm(config.model.dense_units[1]),
            nn.ReLU(),
            nn.Dropout(config.model.dropout)
        )
        
        self.output = nn.Linear(config.model.dense_units[1], num_classes)
        self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

    def forward(self, x):
        gru_out, _ = self.gru(x)
        gru_out = gru_out[:, -1, :]
        gru_out = self.gru_ln(gru_out)
        features = self.dense(gru_out)
        return self.output(features)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()

        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, prog_bar=True)

        # Save for epoch-end
        self.train_outputs.append({
            'loss': loss.detach(),
            'correct': (logits.argmax(dim=1) == y).sum().detach(),
            'total': len(y)
        })

        return loss

    def on_train_epoch_end(self):
        if not self.train_outputs:
            return

        avg_loss = torch.stack([x['loss'] for x in self.train_outputs]).mean()
        correct = sum([x['correct'] for x in self.train_outputs])
        total = sum([x['total'] for x in self.train_outputs])
        epoch_acc = correct / total

        self.log('train_epoch_loss', avg_loss, prog_bar=True)
        self.log('train_acc_epoch', epoch_acc*100, prog_bar=True)
        self.train_outputs.clear()

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()

        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)

        # Save for epoch-end
        self.val_outputs.append({
            'val_loss': loss.detach(),
            'correct': (logits.argmax(dim=1) == y).sum().detach(),
            'total': len(y)
        })

        return loss

    def on_validation_epoch_end(self):
        if not self.val_outputs:
            return

        avg_loss = torch.stack([x['val_loss'] for x in self.val_outputs]).mean()
        correct = sum([x['correct'] for x in self.val_outputs])
        total = sum([x['total'] for x in self.val_outputs])
        epoch_acc = (correct / total)*100

        self.log('val_loss', avg_loss, prog_bar=True)
        self.log('val_acc', epoch_acc, prog_bar=True)
        self.val_outputs.clear()

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()

        self.log('test_loss', loss)
        self.log('test_acc', acc*100)

        return {'test_loss': loss, 'preds': logits.argmax(dim=1), 'targets': y}

    def configure_optimizers(self):
        return optim.AdamW(
            self.parameters(),
            lr=self.hparams.config.model.learning_rate,
            weight_decay=self.hparams.config.model.weight_decay
        )

class NIDSDataModule(pl.LightningDataModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.batch_size = config.training.batch_size
        self.sequence_length = config.training.sequence_length
        self.num_workers = config.data.num_workers
        self.oversample = config.training.oversample
        self.max_train_samples = config.training.max_train_samples
        self.max_val_samples = config.training.max_val_samples
        self.max_test_samples = config.training.max_test_samples

    def prepare_data(self):
        df = pd.read_parquet(os.path.join('/kaggle/input/cic-ids-2017-parquet', self.config.data.raw))
    
        # Clean data
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.dropna(inplace=True)
        df.drop_duplicates(inplace=True)
    
        # Identify non-numeric columns
        self.non_numeric_cols = ['Label', 'Timestamp', 'Flow ID', 'Src IP', 
                               'Src Port', 'Attack', 'Dst IP', 'Dst Port', 'Protocol']
        self.non_numeric_cols = [col for col in self.non_numeric_cols if col in df.columns]
    
        # Encode labels
        self.label_encoder = LabelEncoder()
        df['Label_Num'] = self.label_encoder.fit_transform(df['Label'])
        self.classes = self.label_encoder.classes_
    
        # Initialize scaler
        self.scaler = StandardScaler()
    
        # Stratified split with sample limits
        train_df, test_df = train_test_split(
            df,
            test_size=0.3,  # 70% train, 30% test+val
            random_state=42,
            stratify=df['Label_Num']
        )
        
        # Further split test into val and test
        val_df, test_df = train_test_split(
            test_df,
            test_size=0.5,  # 15% val, 15% test
            random_state=42,
            stratify=test_df['Label_Num']
        )
    
        # Process each split with sample limits
        self.X_train, self.y_train = self._prepare_features(train_df, fit=True)
        self.X_val, self.y_val = self._prepare_features(val_df, fit=False)
        self.X_test, self.y_test = self._prepare_features(test_df, fit=False)
        
        # Apply sample limits
        self._limit_samples()

    def _limit_samples(self):
        """Limit samples according to configuration"""
        # Training data
        if len(self.X_train) > self.max_train_samples:
            indices = np.random.choice(len(self.X_train), self.max_train_samples, replace=False)
            self.X_train = self.X_train[indices]
            self.y_train = self.y_train[indices]
        
        # Validation data
        if len(self.X_val) > self.max_val_samples:
            indices = np.random.choice(len(self.X_val), self.max_val_samples, replace=False)
            self.X_val = self.X_val[indices]
            self.y_val = self.y_val[indices]
        
        # Test data
        if len(self.X_test) > self.max_test_samples:
            indices = np.random.choice(len(self.X_test), self.max_test_samples, replace=False)
            self.X_test = self.X_test[indices]
            self.y_test = self.y_test[indices]
    
    def _prepare_features(self, df, fit=False):
        X = df.drop(['Label_Num'] + self.non_numeric_cols, axis=1)
        y = df['Label_Num']
        if fit:
            X = self.scaler.fit_transform(X)
        else:
            X = self.scaler.transform(X)
        return self.create_sequences(X, y)

    def create_sequences(self, X, y):
        sequences = []
        labels = []
        for i in range(len(X) - self.sequence_length):
            sequences.append(X[i:i+self.sequence_length])
            labels.append(y.iloc[i+self.sequence_length-1])
        return np.array(sequences), np.array(labels)
    
    def setup(self, stage=None):
        self.train_dataset = TensorDataset(torch.FloatTensor(self.X_train), torch.LongTensor(self.y_train))
        self.val_dataset = TensorDataset(torch.FloatTensor(self.X_val), torch.LongTensor(self.y_val))
        self.test_dataset = TensorDataset(torch.FloatTensor(self.X_test), torch.LongTensor(self.y_test))
        
        print(f"Training samples: {len(self.train_dataset)} (limited to {self.max_train_samples})")
        print(f"Validation samples: {len(self.val_dataset)} (limited to {self.max_val_samples})")
        print(f"Test samples: {len(self.test_dataset)} (limited to {self.max_test_samples})")
    
    def train_dataloader(self):
        if self.oversample:
            class_counts = np.bincount(self.y_train)
            weights = 1. / class_counts[self.y_train]
            sampler = WeightedRandomSampler(weights, len(weights), replacement=True)
        else:
            sampler = RandomSampler(self.train_dataset)
            
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            sampler=sampler,
            num_workers=self.num_workers,
            persistent_workers=True,
            pin_memory=True
        )
    
    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )
    
    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )

def init_wandb():
    user_secrets = UserSecretsClient()
    wandb_api_key = user_secrets.get_secret("mohammad_wandb_secret")
    wandb.login(key=wandb_api_key)
    
    run = wandb.init(
        project=config.wandb.project,
        entity=config.wandb.entity,
        tags=config.wandb.tags,
        notes=config.wandb.notes,
        config={
            "input_size": None,
            "num_classes": None,
            "sequence_length": config.training.sequence_length,
            "train_samples": config.training.max_train_samples,
            "val_samples": config.training.max_val_samples,
            "test_samples": config.training.max_test_samples,
            "model_config": dict(config.model),
            "training_config": dict(config.training)
        }
    )
    
    wandb_logger = WandbLogger(
        experiment=run,
        log_model='all'
    )
    
    return wandb_logger, run

def main():
    wandb_logger, run = init_wandb()
    
    data_module = NIDSDataModule(config)
    data_module.prepare_data()
    data_module.setup()
    
    sample_x, _ = next(iter(data_module.train_dataloader()))
    input_size = sample_x.shape[2]
    num_classes = len(data_module.classes)
    
    run.config.update({
        "input_size": input_size,
        "num_classes": num_classes
    })
    
    model = GRUModel(input_size, num_classes, config)
    
    early_stopping = EarlyStopping(
        monitor='val_loss',
        patience=config.training.early_stopping_patience,
        mode='min'
    )
    
    checkpoint_callback = ModelCheckpoint(
        monitor='val_acc',
        mode='max',
        save_top_k=1,
        dirpath='checkpoints',
        filename='best_model'
    )
 
    trainer = pl.Trainer(
        precision=16,
        logger=wandb_logger,
        max_epochs=config.training.max_epochs,
        callbacks=[early_stopping, checkpoint_callback],
        deterministic=True,
        gradient_clip_val=1.0,
        enable_progress_bar=True,
        log_every_n_steps=1000
    )
    
    trainer.fit(model, datamodule=data_module)
    
    test_results = trainer.test(model, datamodule=data_module)
    
    # Collect all predictions and targets
    test_loader = data_module.test_dataloader()
    all_preds = []
    all_targets = []
    
    model.eval()
    with torch.no_grad():
        for batch in test_loader:
            x, y = batch
            y_hat = model(x)
            preds = torch.argmax(y_hat, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(y.cpu().numpy())
    
    # Calculate metrics
    test_acc = accuracy_score(all_targets, all_preds)
    test_f1 = f1_score(all_targets, all_preds, average='weighted')
    
    # Log final metrics
    wandb.log({
        'test_acc': test_acc,
        'test_f1': test_f1,
        'test_loss': test_results[0]['test_loss']
    })
    
    # Enhanced multiclass confusion matrix
    class_names = data_module.classes.tolist()
    conf_mat = confusion_matrix(all_targets, all_preds)
    
    # Create a custom confusion matrix plot
    data = []
    for i in range(len(class_names)):
        for j in range(len(class_names)):
            data.append([class_names[i], class_names[j], conf_mat[i, j]])
    
    fields = {
        "Actual": "Actual",
        "Predicted": "Predicted",
        "n": "Count"
    }
    
    wandb.log({
        "multiclass_confusion_matrix": wandb.plot_table(
            "wandb/confusion_matrix/v1",
            wandb.Table(columns=["Actual", "Predicted", "Count"], data=data),
            fields,
            {"title": "Multiclass Confusion Matrix"}
        )
    })
    
    # Classification Report
    report = classification_report(
        all_targets, all_preds, 
        target_names=class_names,
        output_dict=True
    )
    
    report_table = wandb.Table(columns=["Class", "Precision", "Recall", "F1-Score", "Support"])
    for class_name in class_names:
        report_table.add_data(
            class_name,
            report[class_name]["precision"],
            report[class_name]["recall"],
            report[class_name]["f1-score"],
            report[class_name]["support"]
        )
    
    report_table.add_data(
        "Weighted Avg",
        report["weighted avg"]["precision"],
        report["weighted avg"]["recall"],
        report["weighted avg"]["f1-score"],
        report["weighted avg"]["support"]
    )
    
    wandb.log({"classification_report": report_table})
    
    wandb.finish()

if __name__ == "__main__":
    main()


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mmohammad-fleity[0m ([33mmohammad-fleity-lebanese-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


<h2>SECOND DATASET</h2>

In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler, RandomSampler
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, f1_score
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
import wandb
from omegaconf import OmegaConf
import os
import warnings
from kaggle_secrets import UserSecretsClient

warnings.filterwarnings('ignore')

# GRU Configuration (same architecture)
config = OmegaConf.create({
    "wandb": {
        "project": "DL-NIDS-2--cic-ton-iot",
        "entity": "mohammad-fleity-lebanese-university",
        "tags": ["GRU", "CIC-TON-IOT", "PyTorch"],
        "notes": "GRU for network intrusion detection with memory optimizations"
    },
    "model": {
        "hidden_size": 128,
        "num_layers": 2,
        "dropout": 0.4,
        "dense_units": [128, 64],
        "learning_rate": 0.0001,
        "weight_decay": 1e-4
    },
    "training": {
        "sequence_length": 3,       # Reduced from 5
        "batch_size": 64,           # Reduced from 128
        "max_epochs": 40,
        "early_stopping_patience": 7,
        "oversample": True,
        "gpus": 1 if torch.cuda.is_available() else 0,
        "max_train_samples": 100000,  # Reduced from 100000
        "max_val_samples": 20000,     # Reduced from 20000
        "max_test_samples": 10000      # Reduced from 10000
    },
    "data": {
        "raw": "cic_ton_iot.parquet",
        "num_workers": 2            # Reduced from 4
    }
})

class GRUModel(pl.LightningModule):
    def __init__(self, input_size, num_classes, config):
        super().__init__()
        self.save_hyperparameters()
        
        # Same architecture as original
        self.gru = nn.GRU(
            input_size=input_size,
            hidden_size=config.model.hidden_size,
            num_layers=config.model.num_layers,
            batch_first=True,
            dropout=config.model.dropout if config.model.num_layers > 1 else 0
        )
        
        self.gru_ln = nn.LayerNorm(config.model.hidden_size)
        
        self.dense = nn.Sequential(
            nn.Linear(config.model.hidden_size, config.model.dense_units[0]),
            nn.LayerNorm(config.model.dense_units[0]),
            nn.ReLU(),
            nn.Dropout(config.model.dropout),
            nn.Linear(config.model.dense_units[0], config.model.dense_units[1]),
            nn.LayerNorm(config.model.dense_units[1]),
            nn.ReLU(),
            nn.Dropout(config.model.dropout)
        )
        
        self.output = nn.Linear(config.model.dense_units[1], num_classes)
        self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

    def forward(self, x):
        gru_out, _ = self.gru(x)
        gru_out = gru_out[:, -1, :]
        gru_out = self.gru_ln(gru_out)
        features = self.dense(gru_out)
        return self.output(features)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()

        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_acc', acc*100, on_step=True, on_epoch=True, prog_bar=True)

        return {'loss': loss, 'correct': (logits.argmax(dim=1) == y).sum(), 'total': len(y)}

    def on_train_epoch_end(self):
        # This will be handled by Lightning automatically
        pass

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()

        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc*100, prog_bar=True)

        return {'val_loss': loss, 'correct': (logits.argmax(dim=1) == y).sum(), 'total': len(y)}

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        acc = (logits.argmax(dim=1) == y).float().mean()

        self.log('test_loss', loss)
        self.log('test_acc', acc*100)

        return {'test_loss': loss, 'preds': logits.argmax(dim=1), 'targets': y}

    def configure_optimizers(self):
        return optim.AdamW(
            self.parameters(),
            lr=self.hparams.config.model.learning_rate,
            weight_decay=self.hparams.config.model.weight_decay
        )

class NIDSDataModule(pl.LightningDataModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.batch_size = config.training.batch_size
        self.sequence_length = config.training.sequence_length
        self.num_workers = config.data.num_workers
        self.oversample = config.training.oversample
        self.max_train_samples = config.training.max_train_samples
        self.max_val_samples = config.training.max_val_samples
        self.max_test_samples = config.training.max_test_samples

    def prepare_data(self):
        # Load data
        df = pd.read_parquet(os.path.join('/kaggle/input/cic-ton-iot-parquet', self.config.data.raw))
    
        # Clean data
        df.replace([np.inf, -np.inf], np.nan, inplace=True)
        df.dropna(inplace=True)
        df.drop_duplicates(inplace=True)
        
        # Downcast to save memory
        for col in df.select_dtypes(include=['float64']).columns:
            df[col] = pd.to_numeric(df[col], downcast='float')
        for col in df.select_dtypes(include=['int64']).columns:
            df[col] = pd.to_numeric(df[col], downcast='integer')
    
        # Identify non-numeric columns
        self.non_numeric_cols = ['Label', 'Timestamp', 'Flow ID', 'Src IP', 
                               'Src Port', 'Attack', 'Dst IP', 'Dst Port', 'Protocol']
        self.non_numeric_cols = [col for col in self.non_numeric_cols if col in df.columns]
    
        # Encode labels
        self.label_encoder = LabelEncoder()
        df['Label_Num'] = self.label_encoder.fit_transform(df['Label'])
        self.classes = self.label_encoder.classes_
    
        # Initialize scaler
        self.scaler = StandardScaler()
    
        # Stratified split with sample limits
        train_df, test_df = train_test_split(
            df,
            test_size=0.3,
            random_state=42,
            stratify=df['Label_Num']
        )
        
        # Further split test into val and test
        val_df, test_df = train_test_split(
            test_df,
            test_size=0.4,
            random_state=42,
            stratify=test_df['Label_Num']
        )
    
        # Process each split with sample limits
        self.X_train, self.y_train = self._prepare_features(train_df, fit=True)
        self.X_val, self.y_val = self._prepare_features(val_df, fit=False)
        self.X_test, self.y_test = self._prepare_features(test_df, fit=False)
        
        # Apply sample limits
        self._limit_samples()
        
        # Clean up
        del df, train_df, test_df, val_df

    def _limit_samples(self):
        """Limit samples according to configuration"""
        # Training data
        if len(self.X_train) > self.max_train_samples:
            indices = np.random.choice(len(self.X_train), self.max_train_samples, replace=False)
            self.X_train = self.X_train[indices]
            self.y_train = self.y_train[indices]
        
        # Validation data
        if len(self.X_val) > self.max_val_samples:
            indices = np.random.choice(len(self.X_val), self.max_val_samples, replace=False)
            self.X_val = self.X_val[indices]
            self.y_val = self.y_val[indices]
        
        # Test data
        if len(self.X_test) > self.max_test_samples:
            indices = np.random.choice(len(self.X_test), self.max_test_samples, replace=False)
            self.X_test = self.X_test[indices]
            self.y_test = self.y_test[indices]
            
        print(f"Sample counts - Train: {len(self.X_train)}, Val: {len(self.X_val)}, Test: {len(self.X_test)}")
    
    def _prepare_features(self, df, fit=False):
        X = df.drop(['Label_Num'] + self.non_numeric_cols, axis=1)
        y = df['Label_Num']
        if fit:
            X = self.scaler.fit_transform(X)
        else:
            X = self.scaler.transform(X)
        return self.create_sequences(X, y)

    def create_sequences(self, X, y):
        sequences = []
        labels = []
        for i in range(len(X) - self.sequence_length):
            sequences.append(X[i:i+self.sequence_length])
            labels.append(y.iloc[i+self.sequence_length-1])
        return np.array(sequences, dtype=np.float32), np.array(labels)  # Use float32
    
    def setup(self, stage=None):
        self.train_dataset = TensorDataset(
            torch.FloatTensor(self.X_train), 
            torch.LongTensor(self.y_train)
        )
        self.val_dataset = TensorDataset(
            torch.FloatTensor(self.X_val), 
            torch.LongTensor(self.y_val)
        )
        self.test_dataset = TensorDataset(
            torch.FloatTensor(self.X_test), 
            torch.LongTensor(self.y_test)
        )
        
        # Clear memory
        del self.X_train, self.y_train, self.X_val, self.y_val, self.X_test, self.y_test
    
    def train_dataloader(self):
        if self.oversample:
            # Get class counts from the dataset directly
            _, y = self.train_dataset[:]
            class_counts = torch.bincount(y)
            weights = 1. / class_counts[y]
            sampler = WeightedRandomSampler(weights, len(weights), replacement=True)
        else:
            sampler = RandomSampler(self.train_dataset)
            
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            sampler=sampler,
            num_workers=self.num_workers,
            persistent_workers=False,  # Disabled to save memory
            pin_memory=True,
            drop_last=True
        )
    
    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=True
        )
    
    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=32,  # Reduced test batch size
            shuffle=False,
            num_workers=0,   # No multiprocessing for test
            pin_memory=False,
            drop_last=False
        )

def init_wandb():
    user_secrets = UserSecretsClient()
    wandb_api_key = user_secrets.get_secret("mohammad_wandb_secret")
    wandb.login(key=wandb_api_key)
    
    run = wandb.init(
        project=config.wandb.project,
        entity=config.wandb.entity,
        tags=config.wandb.tags,
        notes=config.wandb.notes,
        config={
            "input_size": None,
            "num_classes": None,
            "sequence_length": config.training.sequence_length,
            "train_samples": config.training.max_train_samples,
            "val_samples": config.training.max_val_samples,
            "test_samples": config.training.max_test_samples,
            "model_config": dict(config.model),
            "training_config": dict(config.training)
        }
    )
    
    wandb_logger = WandbLogger(
        experiment=run,
        log_model=False  # Disabled to save space
    )
    
    return wandb_logger, run

def main():
    wandb_logger, run = init_wandb()
    
    data_module = NIDSDataModule(config)
    data_module.prepare_data()
    data_module.setup()
    
    sample_x, _ = next(iter(data_module.train_dataloader()))
    input_size = sample_x.shape[2]
    num_classes = len(data_module.classes)
    
    run.config.update({
        "input_size": input_size,
        "num_classes": num_classes
    })
    
    model = GRUModel(input_size, num_classes, config)
    
    early_stopping = EarlyStopping(
        monitor='val_loss',
        patience=config.training.early_stopping_patience,
        mode='min'
    )
    
    checkpoint_callback = ModelCheckpoint(
        monitor='val_acc',
        mode='max',
        save_top_k=1,
        dirpath='checkpoints',
        filename='best_model'
    )
 
    trainer = pl.Trainer(
        precision=16,
        logger=wandb_logger,
        max_epochs=config.training.max_epochs,
        callbacks=[early_stopping, checkpoint_callback],
        deterministic=True,
        gradient_clip_val=1.0,
        enable_progress_bar=True,
        log_every_n_steps=100,
        accumulate_grad_batches=2  # Gradient accumulation
    )
    
    trainer.fit(model, datamodule=data_module)
    
    # Test with memory optimizations
    test_results = trainer.test(model, datamodule=data_module, verbose=False)
    
    # Collect predictions in batches
    test_loader = data_module.test_dataloader()
    all_preds = []
    all_targets = []
    
    model.eval()
    with torch.no_grad():
        for batch in test_loader:
            x, y = batch
            y_hat = model(x)
            preds = torch.argmax(y_hat, dim=1)
            all_preds.append(preds.cpu())
            all_targets.append(y.cpu())
    
    all_preds = torch.cat(all_preds).numpy()
    all_targets = torch.cat(all_targets).numpy()
    
    # Calculate metrics
    test_acc = accuracy_score(all_targets, all_preds)
    test_f1 = f1_score(all_targets, all_preds, average='weighted')
    
    # Log final metrics
    wandb.log({
        'test_acc': test_acc,
        'test_f1': test_f1,
        'test_loss': test_results[0]['test_loss']
    })
    
    # Create confusion matrix data
    class_names = data_module.classes.tolist()
    conf_mat = confusion_matrix(all_targets, all_preds)
    
    # Log as table to save memory
    conf_mat_table = wandb.Table(columns=["Actual", "Predicted", "Count"])
    for i in range(len(class_names)):
        for j in range(len(class_names)):
            conf_mat_table.add_data(class_names[i], class_names[j], conf_mat[i, j])
    
    wandb.log({"confusion_matrix": conf_mat_table})
    
    # Classification Report
    report = classification_report(
        all_targets, all_preds, 
        target_names=class_names,
        output_dict=True
    )
    
    report_table = wandb.Table(columns=["Class", "Precision", "Recall", "F1-Score", "Support"])
    for class_name in class_names:
        report_table.add_data(
            class_name,
            report[class_name]["precision"],
            report[class_name]["recall"],
            report[class_name]["f1-score"],
            report[class_name]["support"]
        )
    
    wandb.log({"classification_report": report_table})
    wandb.finish()

if __name__ == "__main__":
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    
    main()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mmohammad-fleity[0m ([33mmohammad-fleity-lebanese-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Sample counts - Train: 100000, Val: 20000, Test: 10000
Sample counts - Train: 100000, Val: 20000, Test: 10000


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]