In [None]:
from model import InsiderClassifier, LSTM_Encoder, CNN_Classifier

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, DistributedSampler
import torch.distributed as dist
import os
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score

In [None]:
# ==============================================================================
# üíæ DATASET SETUP (Re-using the structure defined previously)
# ==============================================================================

class InsiderThreatDataset(Dataset):
    def __init__(self, X_path, y_path):
        self.X = pd.read_pickle(X_path)
        self.y = pd.read_pickle(y_path)
        
        # Convert to Tensors: X must be Long (for nn.Embedding input), y must be Float
        self.X = torch.tensor(self.X.tolist(), dtype=torch.long)
        # Unsqueeze(1) makes labels (N, 1) for standard binary classification
        self.y = torch.tensor(self.y.values.astype(float), dtype=torch.float32).unsqueeze(1)
        
    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [None]:
# ==============================================================================
# ‚öôÔ∏è CONFIGURATION VARIABLES (Change these for your specific environment)
# ==============================================================================

# --- A. DEVICE / SINGLE-GPU CONFIG ---
# 'cpu', 'cuda', or 'cuda:0'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'


# --- B. DDP (Distributed Data Parallel) CONFIG ---
# Set to True to enable DDP for multi-GPU/multi-node training
USE_DDP = False
# DDP parameters (only relevant if USE_DDP is True)
RANK = 0         # Rank of the current process (0 to WORLD_SIZE - 1)
WORLD_SIZE = 1   # Total number of participating GPUs/processes
BACKEND = 'nccl' # Communication backend (usually 'nccl' for GPUs)
INIT_METHOD = 'env://' # How processes find each other (e.g., environment variables)

# --- C. HYPERPARAMETERS & DATA LOADING ---
BATCH_SIZE = 64
LEARNING_RATE = 1e-4 # extremely small learning rate
# Data paths (Assuming X.pkl contains padded action_id sequences, y.pkl contains labels)
X_PATH = 'X_train.pkl'
Y_PATH = 'y_train.pkl'
NUM_WORKERS = 0 # How many subprocesses to use for data loading

# ==============================================================================
# üöÄ MAIN TRAINING FUNCTION (UPDATED)
# ==============================================================================

def train_cnn_classifier(EPOCHS=10, OUTPUT_FILENAME='model.pkl', LSTM_CHECKPOINT_PATH='./kk'):
# ------------------------------------
    # 1. DDP and Device Initialization
    # ------------------------------------
    if USE_DDP:
        # Initialize the distributed process group
        dist.init_process_group(BACKEND, init_method=INIT_METHOD, rank=RANK, world_size=WORLD_SIZE)
        # Use the local rank (device ID) as the actual training device
        local_rank = int(os.environ["LOCAL_RANK"]) if "LOCAL_RANK" in os.environ else 0
        current_device = torch.device(f'cuda:{local_rank}')
        print(f"DDP: Rank {RANK} initialized on device {current_device}")
    else:
        # Single-device setup
        current_device = torch.device(DEVICE)
        print(f"Single Device: Initializing on device {current_device}")

    # ------------------------------------
    # 2. Data Loading
    # ------------------------------------
    dataset = InsiderThreatDataset(X_PATH, Y_PATH)

    if USE_DDP:
        # Use DistributedSampler for DDP
        sampler = DistributedSampler(dataset, num_replicas=WORLD_SIZE, rank=RANK, shuffle=True)
        # When using DDP, DataLoader should NOT shuffle, sampler handles it
        dataloader = DataLoader(
            dataset,
            batch_size=BATCH_SIZE,
            sampler=sampler,
            num_workers=NUM_WORKERS
        )
    else:
        # Standard DataLoader for single device
        dataloader = DataLoader(
            dataset,
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers=NUM_WORKERS
        )

    # ------------------------------------
    # 3. Model, Loss, and Optimizer
    # ------------------------------------

    # Initialize the InsiderClassifier model
    model = InsiderClassifier(
        lstm_checkpoint=LSTM_CHECKPOINT_PATH,
        device=current_device
    )

    # DDP wrapping (if applicable)
    if USE_DDP:
        model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])

    # Use standard CrossEntropyLoss (expects raw logits from CNN_Classifier)
    # Use standard CrossEntropyLoss
    class_weights = torch.tensor((1.0, 49.0)).to(DEVICE)
    criterion = nn.CrossEntropyLoss(weight=class_weights)

    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    print(f"Starting unweighted training...")

    for epoch in range(EPOCHS):
        if USE_DDP:
            dataloader.sampler.set_epoch(epoch)

        model.train()
        running_loss = 0.0

        # Lists to temporarily hold predictions/labels for the metrics calculation
        temp_preds, temp_labels = [], []

        for i, (X_batch, y_batch) in enumerate(dataloader):
            if X_batch.shape[1] != 250:
                print("actual batch shape is", X_batch.shape)
            X_batch = X_batch.long().to(current_device)
            # CrossEntropyLoss expects (N,) Long tensor for labels
            y_batch_long = y_batch.long().squeeze(1).to(current_device)

            optimizer.zero_grad()
            scores = model(X_batch)
            loss = criterion(scores, y_batch_long)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            # --- METRICS COLLECTION ---
            _, predicted_classes = torch.max(scores, 1)

            temp_preds.extend(predicted_classes.cpu().tolist())
            temp_labels.extend(y_batch_long.cpu().tolist())

            # Print progress every 100 batches
            if (i + 1) % 100 == 0:

                batch_preds = np.array(temp_preds)
                batch_labels = np.array(temp_labels)

                # --- CALCULATE ALL METRICS ---
                batch_accuracy = accuracy_score(batch_labels, batch_preds)
                batch_f1 = f1_score(batch_labels, batch_preds, average='binary', zero_division=0)

                # New: Calculate Recall and Precision for the Malicious (positive/Class 1) class
                batch_recall = recall_score(batch_labels, batch_preds, average='binary', zero_division=0)
                batch_precision = precision_score(batch_labels, batch_preds, average='binary', zero_division=0)

                # Print the combined report
                avg_loss = running_loss / 100
                print(f"[Epoch {epoch+1}, Batch {i+1}] "
                      f"Loss: {avg_loss:.4f} | "
                      f"Acc: {batch_accuracy:.4f} | "
                      f"F1: {batch_f1:.4f} | "
                      f"Prec: {batch_precision:.4f} | "
                      f"Recall: {batch_recall:.4f}")

                # Reset counters for the next 100 batches
                running_loss = 0.0
                temp_preds, temp_labels = [], []
    # ------------------------------------
    # 5. Final Model Saving (Updated Logic)
    # ------------------------------------

    # In DDP, ensure only Rank 0 saves the model
    if not USE_DDP or RANK == 0:

        # Get the actual model state, unwrapping DDP if necessary
        model_to_save = model.module if USE_DDP else model

        # Save the final model state to 'model.pkl'
        torch.save(model_to_save.state_dict(), OUTPUT_FILENAME)
        print(f"\nTraining Complete. Final model parameters saved to: {OUTPUT_FILENAME}")

    if USE_DDP:
        dist.destroy_process_group()

In [None]:
def train_lstm_encoder(EPOCHS=2, LSTM_CHECKPOINT_PATH='./kk'):
    # ------------------------------------
    # 1. Device and Data Loading Setup
    # ------------------------------------
    current_device = torch.device(DEVICE)
    print(f"Starting LSTM Encoder training on device {current_device}...")

    # For the Encoder training, we don't need the y_path (malicious labels)
    # but the InsiderThreatDataset loads it, so we'll just ignore it in the loop.
    dataset = InsiderThreatDataset(X_PATH, Y_PATH)
    dataloader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS
    )

    # ------------------------------------
    # 2. Model, Loss, and Optimizer
    # ------------------------------------

    # Initialize the Encoder model and move to device
    # Note: LSTM_Encoder is NOT wrapped in the InsiderClassifier here.
    model = LSTM_Encoder().to(current_device)
    model.train() # Set to training mode for dropout and decoder output

    # Loss function for reconstruction (input is categorical, output is LogSoftmax)
    # We use NLLLoss combined with F.one_hot() to handle the categorical reconstruction.
    criterion = nn.NLLLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    # ------------------------------------
    # 3. Training Loop
    # ------------------------------------
    for epoch in range(EPOCHS):
        running_loss = 0.0

        for i, (X_batch, _) in enumerate(dataloader):
            # X_batch is the padded sequence (B, S). Target must be Long tensor.
            X_batch = X_batch.to(current_device)

            optimizer.zero_grad()

            # Forward pass: the encoder returns the reconstructed sequence (LogSoftmax output)
            reconstructed_X = model(X_batch)

            # --- Calculate Loss (Categorical Reconstruction) ---
            # 1. reconstructed_X shape: (B, S, V) [Logits]
            # 2. X_batch shape: (B, S) [Target IDs]
            # NLLLoss expects (N*V, C) logits and (N*V,) target IDs.

            # Reshape logits to (B*S, V) and target to (B*S)
            loss = criterion(
                reconstructed_X.permute(0, 2, 1), # NLLLoss expects (B, V, S) input
                X_batch.long()
            )

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        avg_loss = running_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{EPOCHS} | Avg Reconstruction Loss: {avg_loss:.4f}", end='\t')
        print()

    # ------------------------------------
    # 4. Save the Checkpoint
    # ------------------------------------
    # Save the final state dictionary to the required path
    torch.save(model.state_dict(), f'{LSTM_CHECKPOINT_PATH}')
    print(f"\nLSTM Encoder training complete. Checkpoint saved to: {LSTM_CHECKPOINT_PATH}")

In [None]:
# Start Training

train_lstm_encoder(EPOCHS=2, LSTM_CHECKPOINT_PATH = 'kk (1)')

In [None]:
train_cnn_classifier(EPOCHS=10, OUTPUT_FILENAME = 'model.pkl', LSTM_CHECKPOINT_PATH = 'kk (1)')