In [None]:
# ============================================================
# Jobenn Bezuidenhout u22518500
# TRAIN AND FINE-TUNE LLMs on AMD GPU (DirectML) - Same outputs
# ============================================================

import os
import math
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch_directml
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader

from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    get_linear_schedule_with_warmup,
)
from torch.optim import AdamW
import evaluate  # kept for parity with your original (though we compute accuracy directly too)
from tqdm.auto import tqdm
# -----------------------------
# Device: AMD GPU via DirectML
# -----------------------------
dml = torch_directml.device()
print("Using device:", dml)

# ---------------------------------------------
# Load the combined preprocessed CSV from Phase 2
# ---------------------------------------------
combined_df = pd.read_csv('Jobenn_preprocessed_NCF_data.csv')

separator = ' [SEP] '

split_parts = combined_df['text'].apply(
    lambda x: x.split(separator, 1) if separator in x else [x, '']
)
split_df = pd.DataFrame(split_parts.tolist(), index=combined_df.index, columns=['pm_text', 'ugr_text'])
combined_df = pd.concat([combined_df, split_df], axis=1)
combined_df = combined_df.drop(columns=['text'])

# Map string labels to integers in the DataFrame (0 for 'Benign', 1 for 'Ransomware')
combined_df['label'] = [0 if str(l).lower() == 'benign' else 1 for l in combined_df['label']]

# Split: 80/20
train_df, test_df = train_test_split(combined_df, test_size=0.2, stratify=combined_df['label'], random_state=42)

# Handle class imbalance: Compute weights
classes = np.unique(train_df['label'])
class_weights = compute_class_weight('balanced', classes=classes, y=train_df['label'])
class_weights = torch.tensor(class_weights, dtype=torch.float)

# Convert to Hugging Face Dataset
train_dataset = Dataset.from_pandas(train_df.reset_index(drop=True))
test_dataset  = Dataset.from_pandas(test_df.reset_index(drop=True))

# --------------------------------------
# 1) FAST Preprocessing Function (same)
# --------------------------------------
def preprocess_fast(examples, tokenizer):
    return tokenizer(
        examples['pm_text'],
        examples['ugr_text'],
        truncation=True,
        padding='max_length',
        max_length=64
    )

columns_to_remove = ['protocol', 'flag', 'family', 'address', 'usd', 'seed_address',
                     'btc', 'netflow_bytes', 'ip_address', 'clusters', 'threats',
                     'port', 'time', 'prediction', 'segment', 'embeddings',
                     'dataset', 'r', 'rw', 'rx', 'rwc', 'rwx', 'rwxc', 'category',
                     'pm_text', 'ugr_text']

# --------------------------
# Tokenize (BERT + RoBERTa)
# --------------------------
bert_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
train_ds_bert = train_dataset.map(
    lambda e: preprocess_fast(e, bert_tokenizer),
    batched=True,
    remove_columns=[c for c in columns_to_remove if c in train_dataset.column_names]
)
test_ds_bert = test_dataset.map(
    lambda e: preprocess_fast(e, bert_tokenizer),
    batched=True,
    remove_columns=[c for c in columns_to_remove if c in test_dataset.column_names]
)

roberta_tokenizer = AutoTokenizer.from_pretrained('roberta-base')
train_ds_roberta = train_dataset.map(
    lambda e: preprocess_fast(e, roberta_tokenizer),
    batched=True,
    remove_columns=[c for c in columns_to_remove if c in train_dataset.column_names]
)
test_ds_roberta = test_dataset.map(
    lambda e: preprocess_fast(e, roberta_tokenizer),
    batched=True,
    remove_columns=[c for c in columns_to_remove if c in test_dataset.column_names]
)

# HF Datasets expect label->labels for torch format
def prepare_for_torch(ds, tokenizer):
    ds = ds.rename_column('label', 'labels')
    cols = ['input_ids', 'attention_mask', 'labels']
    # Include token_type_ids if the tokenizer provides them (e.g., BERT)
    if 'token_type_ids' in tokenizer.model_input_names:
        cols.append('token_type_ids')
    ds.set_format(type='torch', columns=cols)
    return ds, ('token_type_ids' in tokenizer.model_input_names)

train_ds_bert_torch, bert_uses_token_type = prepare_for_torch(train_ds_bert, bert_tokenizer)
test_ds_bert_torch, _ = prepare_for_torch(test_ds_bert, bert_tokenizer)

train_ds_roberta_torch, roberta_uses_token_type = prepare_for_torch(train_ds_roberta, roberta_tokenizer)
test_ds_roberta_torch, _ = prepare_for_torch(test_ds_roberta, roberta_tokenizer)

# DataLoaders — we simulate effective batch_size=16 with grad accumulation
MICRO_BATCH = 4      # micro-batch per step (fits VRAM)
ACCUM_STEPS = 4      # 4 * 4 = 16 "per_device_train_batch_size" effective
BATCH_EVAL  = 16     # eval batch

train_loader_bert = DataLoader(train_ds_bert_torch, batch_size=MICRO_BATCH, shuffle=True, num_workers=0)
test_loader_bert  = DataLoader(test_ds_bert_torch,  batch_size=BATCH_EVAL,  shuffle=False, num_workers=0)

train_loader_roberta = DataLoader(train_ds_roberta_torch, batch_size=MICRO_BATCH, shuffle=True, num_workers=0)
test_loader_roberta  = DataLoader(test_ds_roberta_torch,  batch_size=BATCH_EVAL,  shuffle=False, num_workers=0)

# -------------------------------------------------------------
# 2) "Trainer-like" wrapper so .state.log_history still exists
# -------------------------------------------------------------
class _State:
    def __init__(self):
        self.log_history = []  # will hold dicts with 'loss' and 'eval_loss' like HF

class TrainerLike:
    def __init__(self):
        self.state = _State()

# -------------------------------------------------
# 3) Fine-tune function (DirectML, class weights)
#    - Mirrors your metrics/plots/attention output
# -------------------------------------------------
def fine_tune_model_directml(model_name, train_loader, test_loader, class_weights,
                             uses_token_type_ids: bool, tokenizer):
    trainer_like = TrainerLike()

    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2, dtype=torch.float32 )
    
    model.gradient_checkpointing_enable()
    
    model.to(dml)
    model.train()

    num_epochs = 3
    lr = 5e-5
    optimizer = AdamW(model.parameters(), lr=lr)

    total_steps = num_epochs * math.ceil(len(train_loader))
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=max(1, int(0.06 * total_steps)),
        num_training_steps=total_steps
    )

    criterion = nn.CrossEntropyLoss(weight=class_weights.to(dml))

    global_step = 0
    train_losses_per_step = []
    eval_losses_per_epoch = []

    for epoch in range(num_epochs):
        running = 0.0
        optimizer.zero_grad()

        # ---- TRAIN ----
        train_bar = tqdm(train_loader, 
                        desc=f'Epoch {epoch+1}/{num_epochs} [Train]', 
                        leave=True)
        
        for step, batch in enumerate(train_bar, start=1):
            # Build inputs dict dynamically (skip token_type_ids if not present)
            inputs = {
                "input_ids": batch["input_ids"].to(dml),
                "attention_mask": batch["attention_mask"].to(dml),
            }
            if uses_token_type_ids and "token_type_ids" in batch:
                inputs["token_type_ids"] = batch["token_type_ids"].to(dml)
            labels = batch["labels"].to(dml)

            outputs = model(**inputs)
            loss = criterion(outputs.logits, labels)
            loss.backward()

            if step % ACCUM_STEPS == 0:
                nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                # Clear GPU cache periodically
                if hasattr(torch, 'dml'):
                    torch.dml.empty_cache()

            running += loss.item()
            global_step += 1

            # Update progress bar
            train_bar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'step': f'{step}/{len(train_loader)}'
            })

            # record a training loss log like HF (per step)
            trainer_like.state.log_history.append({
                'loss': float(loss.item()),
                'step': int(global_step),
                'epoch': float(epoch + (step / len(train_loader)))
            })
            train_losses_per_step.append(float(loss.item()))

        train_bar.close()

        # ---- EVAL ----
        model.eval()
        val_running = 0.0
        all_preds, all_labels = [], []

        eval_bar = tqdm(test_loader, 
                       desc=f'Epoch {epoch+1}/{num_epochs} [Eval]', 
                       leave=False)

        with torch.no_grad():
            for batch in eval_bar:
                inputs = {
                    "input_ids": batch["input_ids"].to(dml),
                    "attention_mask": batch["attention_mask"].to(dml),
                }
                if uses_token_type_ids and "token_type_ids" in batch:
                    inputs["token_type_ids"] = batch["token_type_ids"].to(dml)
                labels = batch["labels"].to(dml)

                out = model(**inputs)
                loss = criterion(out.logits, labels)
                val_running += loss.item()

                preds = out.logits.argmax(dim=1)
                all_preds.append(preds.cpu().numpy())
                all_labels.append(labels.cpu().numpy())
                # Update eval progress bar
                eval_bar.set_postfix({
                    'eval_loss': f'{loss.item():.4f}'
                })
        eval_bar.close()
        all_preds = np.concatenate(all_preds)
        all_labels = np.concatenate(all_labels)
        eval_loss = val_running / max(1, len(test_loader))
        eval_acc = accuracy_score(all_labels, all_preds)

        eval_losses_per_epoch.append(eval_loss)
        trainer_like.state.log_history.append({
            'eval_loss': float(eval_loss),
            'eval_accuracy': float(eval_acc),
            'epoch': float(epoch + 1)
        })
        print(f"[{model_name}] Epoch {epoch+1}/3  train_loss(avg/step): {running/len(train_loader):.4f} | eval_loss: {eval_loss:.4f} | eval_acc: {eval_acc:.4f}")

        model.train()

    # ---- Plot loss (same style as your code expects) ----
    train_loss = [log['loss'] for log in trainer_like.state.log_history if 'loss' in log]
    eval_loss  = [log['eval_loss'] for log in trainer_like.state.log_history if 'eval_loss' in log]

    plt.figure(figsize=(10,5))
    plt.plot(train_loss, label='Train Loss')
    plt.plot(range(1, len(eval_loss)+1), eval_loss, label='Eval Loss')
    plt.xlabel('Steps/Epochs')
    plt.ylabel('Loss')
    plt.title(f'Loss Curve for {model_name}')
    plt.legend()
    plt.show()

    # ---- Attention heatmap on a sample (like your original) ----
    with torch.no_grad():
        sample_text_1 = test_df['pm_text'].iloc[0]
        sample_text_2 = test_df['ugr_text'].iloc[0]

        sample_input = tokenizer(
            sample_text_1,
            sample_text_2,
            return_tensors='pt',
            truncation=True,
            padding='max_length',
            max_length=128
        )
        sample_input = {k: v.to(dml) for k, v in sample_input.items()}

        outputs = model(**sample_input, output_attentions=True)
        attn = outputs.attentions[-1][0][0].detach().cpu().numpy()

        plt.figure(figsize=(12, 10))
        sns.heatmap(attn, cmap='viridis')
        plt.title(f'Attention Weights for {model_name} (Sample)')
        plt.show()

    return trainer_like, model, tokenizer

# -------------------------------------
# 4) Train BERT and RoBERTa (as before)
# -------------------------------------
bert_trainer, bert_model, bert_tokenizer = fine_tune_model_directml(
    'bert-base-uncased',
    train_loader_bert,
    test_loader_bert,
    class_weights,
    uses_token_type_ids=bert_uses_token_type,
    tokenizer=bert_tokenizer
)

roberta_trainer, roberta_model, roberta_tokenizer = fine_tune_model_directml(
    'roberta-base',
    train_loader_roberta,
    test_loader_roberta,
    class_weights,
    uses_token_type_ids=roberta_uses_token_type,
    tokenizer=roberta_tokenizer
)

# (Optional) DeBERTa — uncomment if needed and VRAM allows
# de_tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-base')
# train_ds_deberta = train_dataset.map(lambda e: preprocess_fast(e, de_tokenizer), batched=True,
#                                      remove_columns=[c for c in columns_to_remove if c in train_dataset.column_names])
# test_ds_deberta = test_dataset.map(lambda e: preprocess_fast(e, de_tokenizer), batched=True,
#                                     remove_columns=[c for c in columns_to_remove if c in test_dataset.column_names])
# train_ds_deberta = train_ds_deberta.rename_column('label','labels')
# test_ds_deberta  = test_ds_deberta.rename_column('label','labels')
# cols_de = ['input_ids','attention_mask','labels']  # DeBERTa does not need token_type_ids for pairs
# train_ds_deberta.set_format(type='torch', columns=cols_de)
# test_ds_deberta.set_format(type='torch', columns=cols_de)
# train_loader_de = DataLoader(train_ds_deberta, batch_size=MICRO_BATCH, shuffle=True, num_workers=0)
# test_loader_de  = DataLoader(test_ds_deberta,  batch_size=BATCH_EVAL,  shuffle=False, num_workers=0)
# deberta_trainer, deberta_model, deberta_tokenizer = fine_tune_model_directml(
#     'microsoft/deberta-base', train_loader_de, test_loader_de, class_weights,
#     uses_token_type_ids=False, tokenizer=de_tokenizer
# )

# ---------------------------------------------
# Training evidence: Print sample logs (same UX)
# ---------------------------------------------
print("BERT Training Logs (first 5):", bert_trainer.state.log_history[:5])
