In [3]:
import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

/kaggle/input/llamaguard-ft-dataset/train.json
/kaggle/input/llamaguard-ft-dataset/test.json
/kaggle/input/llamaguard-ft-dataset/val.json


In [4]:
# ============================
# 0. Imports & global label map
# ============================
import json
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Any

import numpy as np
import torch
from datasets import Dataset, DatasetDict
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback,
    set_seed,
)

from peft import (
    LoraConfig,
    TaskType,
    get_peft_model,
)

2025-12-03 10:52:28.633309: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1764759148.837105      47 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1764759148.897382      47 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

In [5]:
# Login to Hugging Face (Use Kaggle Secrets if possible, otherwise interactive)

from huggingface_hub import login
try:
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    hf_token = user_secrets.get_secret("HF_TOKEN")
    if hf_token:
        login(token=hf_token)
        print("‚úÖ Logged into HuggingFace via Kaggle secrets.")
    else:
        print("‚ö†Ô∏è No HF_TOKEN in Kaggle secrets. Skipping HF login.")
except Exception as e:
    print("Kaggle secrets not available or HF login skipped:", e)

print(f"GPU available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

‚úÖ Logged into HuggingFace via Kaggle secrets.
GPU available: True
GPU: Tesla T4


In [6]:
# ---------- FULL 26-LABEL TAXONOMY (no catastrophic forgetting) ----------
LABEL2ID: Dict[str, int] = {
    # Safe / routing
    "safe_info": 0,
    "safe_with_disclaimer": 1,
    "redirect_to_professional": 2,
    "redirect_to_docs": 3,
    "ambiguous_needs_review": 4,
    "high_value_approval": 5,

    # General safety (S1‚ÄìS14 style)
    "violent_crime": 6,
    "non_violent_crime": 7,
    "sex_related_crime": 8,
    "child_sexual_exploitation": 9,
    "defamation": 10,
    "specialized_advice": 11,
    "privacy_violation": 12,
    "ip_violation": 13,
    "weapons": 14,
    "hate_speech": 15,
    "suicide_self_harm": 16,
    "sexual_content": 17,
    "elections": 18,
    "code_interpreter_abuse": 19,

    # Financial / illicit (S15‚ÄìS22)
    "illicit_market_manipulation": 20,
    "illicit_insider_trading": 21,
    "illicit_money_laundering": 22,
    "illicit_fraud": 23,
    "illicit_regulatory_evasion": 24,
    "illicit_unauthorized_advice": 25,
}

ID2LABEL: Dict[int, str] = {v: k for k, v in LABEL2ID.items()}
NUM_LABELS = len(LABEL2ID)

# Financial safety labels subset (for safety_f1 metric)
FIN_SAFETY_LABEL_IDS = [
    LABEL2ID["illicit_market_manipulation"],
    LABEL2ID["illicit_insider_trading"],
    LABEL2ID["illicit_money_laundering"],
    LABEL2ID["illicit_fraud"],
    LABEL2ID["illicit_regulatory_evasion"],
    LABEL2ID["illicit_unauthorized_advice"],
]

# -----------------
# Logging setup
# -----------------
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

In [7]:
# ============================
# 1. Config dataclasses
# ============================

@dataclass
class ModelConfig:
    base_model: str = "meta-llama/Llama-Guard-3-1B"
    num_labels: int = NUM_LABELS
    max_length: int = 512
    trust_remote_code: bool = True


@dataclass
class LoRAConfig:
    enabled: bool = True
    r: int = 16
    lora_alpha: int = 32
    lora_dropout: float = 0.05
    target_modules: List[str] = field(default_factory=lambda: [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ])
    bias: str = "none"


@dataclass
class TrainConfig:
    output_dir: str = "/kaggle/working/financial-guardrails"
    num_epochs: int = 5
    batch_size: int = 4        # adjust if OOM: try 2
    gradient_accumulation_steps: int = 4
    learning_rate: float = 1e-4  # keep conservative on a 1B model
    warmup_ratio: float = 0.1
    weight_decay: float = 0.01
    max_grad_norm: float = 1.0
    fp16: bool = False         # <-- fp32 to avoid NaN loss & scaler issues
    logging_steps: int = 50
    eval_steps: int = 200
    save_steps: int = 500
    seed: int = 42


@dataclass
class DataConfig:
    # Kaggle input dir
    base_dir: str = "/kaggle/input/llamaguard-ft-dataset"

    @property
    def train_path(self) -> str:
        return f"{self.base_dir}/train.json"

    @property
    def val_path(self) -> str:
        return f"{self.base_dir}/val.json"

    @property
    def test_path(self) -> str:
        return f"{self.base_dir}/test.json"

In [8]:
import collections, json

with open("/kaggle/input/llamaguard-ft-dataset/train.json") as f:
    train_raw = json.load(f)

print("Example keys:", train_raw[0].keys())
print("Label sample:", train_raw[0]["label"])

label_counts = collections.Counter(ex["label"] for ex in train_raw)
print("Label counts:", label_counts)
print("Labels in dataset but not in LABEL2ID:",
      set(label_counts.keys()) - set(LABEL2ID.keys()))

Example keys: dict_keys(['id', 'text', 'label', 'policy_label', 'tags', 'adversarial_flag', 'metadata', 'explanation'])
Label sample: illicit_market_manipulation
Label counts: Counter({'illicit_market_manipulation': 233, 'safe_info': 191, 'illicit_money_laundering': 179, 'safe_with_disclaimer': 175, 'high_value_approval': 163, 'illicit_regulatory_evasion': 163, 'redirect_to_docs': 136, 'redirect_to_professional': 119, 'illicit_insider_trading': 101, 'ambiguous_needs_review': 99, 'code_interpreter_abuse': 93, 'illicit_fraud': 92, 'violent_crime': 68, 'hate_speech': 48, 'specialized_advice': 37, 'defamation': 28, 'illicit_unauthorized_advice': 28, 'privacy_violation': 22, 'weapons': 21, 'child_sexual_exploitation': 18, 'sexual_content': 18, 'ip_violation': 18, 'elections': 14, 'sex_related_crime': 9, 'non_violent_crime': 8, 'suicide_self_harm': 8})
Labels in dataset but not in LABEL2ID: set()


In [14]:
# ============================
# 2. Data loading & preprocessing
# ============================

def load_dataset_from_json(path: str) -> List[Dict[str, Any]]:
    with open(path, "r") as f:
        return json.load(f)


def preprocess_function(examples: Dict[str, List[Any]], tokenizer, max_length: int) -> Dict:
    # examples["text"] and examples["label"] are lists (batched)
    tokenized = tokenizer(
        examples["text"],
        truncation=True,
        padding="max_length",
        max_length=max_length,
    )
    tokenized["labels"] = [LABEL2ID[label] for label in examples["label"]]
    return tokenized


def prepare_datasets(
    data_config: DataConfig,
    tokenizer,
    max_length: int
) -> DatasetDict:
    """
    Loads train/val/test from JSON and returns a DatasetDict.
    Each split is set to 'torch' format so DataLoader + model(...) works directly.
    """
    datasets = {}

    split_paths = {
        "train": data_config.train_path,
        "validation": data_config.val_path,
        "test": data_config.test_path,
    }

    for split_name, path in split_paths.items():
        if not Path(path).exists():
            logger.warning(f"[{split_name}] Dataset not found: {path}")
            continue

        raw_data = load_dataset_from_json(path)

        hf_ds = Dataset.from_list([
            {"text": ex["text"], "label": ex["label"]}
            for ex in raw_data
        ])

        hf_ds = hf_ds.map(
            lambda batch: preprocess_function(batch, tokenizer, max_length),
            batched=True,
            remove_columns=["text", "label"],
        )

        # üëá Make it return torch tensors (input_ids, attention_mask, labels)
        hf_ds.set_format(
            type="torch",
            columns=["input_ids", "attention_mask", "labels"],
        )

        datasets[split_name] = hf_ds
        logger.info(f"Loaded {split_name}: {len(hf_ds)} examples from {path}")

    return DatasetDict(datasets)

In [15]:
# ============================
# 2. Data loading & preprocessing
# ============================

def load_dataset_from_json(path: str) -> List[Dict[str, Any]]:
    with open(path, "r") as f:
        return json.load(f)


def preprocess_function(examples: Dict[str, List[Any]], tokenizer, max_length: int) -> Dict:
    # examples["text"] and examples["label"] are lists (batched)
    tokenized = tokenizer(
        examples["text"],
        truncation=True,
        padding="max_length",
        max_length=max_length,
    )
    tokenized["labels"] = [LABEL2ID[label] for label in examples["label"]]
    return tokenized


def prepare_datasets(
    data_config: DataConfig,
    tokenizer,
    max_length: int
) -> DatasetDict:
    """
    Loads train/val/test from JSON and returns a DatasetDict.
    Each split is set to 'torch' format so DataLoader + model(...) works directly.
    """
    datasets = {}

    split_paths = {
        "train": data_config.train_path,
        "validation": data_config.val_path,
        "test": data_config.test_path,
    }

    for split_name, path in split_paths.items():
        if not Path(path).exists():
            logger.warning(f"[{split_name}] Dataset not found: {path}")
            continue

        raw_data = load_dataset_from_json(path)

        hf_ds = Dataset.from_list([
            {"text": ex["text"], "label": ex["label"]}
            for ex in raw_data
        ])

        hf_ds = hf_ds.map(
            lambda batch: preprocess_function(batch, tokenizer, max_length),
            batched=True,
            remove_columns=["text", "label"],
        )

        # üëá Make it return torch tensors (input_ids, attention_mask, labels)
        hf_ds.set_format(
            type="torch",
            columns=["input_ids", "attention_mask", "labels"],
        )

        datasets[split_name] = hf_ds
        logger.info(f"Loaded {split_name}: {len(hf_ds)} examples from {path}")

    return DatasetDict(datasets)

In [23]:
# ============================
# 3. Model + LoRA setup
# ============================

def setup_model_and_tokenizer(
    model_config: ModelConfig,
    lora_config: LoRAConfig,
):
    """
    Load Llama-Guard-3-1B and apply LoRA.
    We keep it in fp32 for stability; LoRA still works fine.
    """

    logger.info(f"Loading tokenizer: {model_config.base_model}")
    tokenizer = AutoTokenizer.from_pretrained(
        model_config.base_model,
        trust_remote_code=model_config.trust_remote_code,
    )

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    logger.info(f"Loading base model (fp32): {model_config.base_model}")
    model = AutoModelForSequenceClassification.from_pretrained(
        model_config.base_model,
        num_labels=model_config.num_labels,
        id2label=ID2LABEL,
        label2id=LABEL2ID,
        trust_remote_code=model_config.trust_remote_code,
        torch_dtype=torch.float32,      # <-- keep fp32
        device_map="auto",              # put on GPU
    )

    model.config.pad_token_id = tokenizer.pad_token_id

    if lora_config.enabled:
        logger.info("Applying LoRA adapters...")

        lora_cfg = LoraConfig(
            task_type=TaskType.SEQ_CLS,
            r=lora_config.r,
            lora_alpha=lora_config.lora_alpha,
            lora_dropout=lora_config.lora_dropout,
            target_modules=lora_config.target_modules,
            bias=lora_config.bias,
        )

        model = get_peft_model(model, lora_cfg)
        model.print_trainable_parameters()

    # For now, do *not* enable gradient checkpointing (to keep things simpler)
    return model, tokenizer



In [24]:
# ============================
# 4. Metrics (incl. safety_f1)
# ============================

def compute_metrics(eval_pred) -> Dict[str, float]:
    predictions, labels = eval_pred
    if isinstance(predictions, tuple):
        predictions = predictions[0]

    preds = np.argmax(predictions, axis=1)

    accuracy = accuracy_score(labels, preds)
    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, preds, average=None, zero_division=0
    )

    macro_p, macro_r, macro_f1, _ = precision_recall_fscore_support(
        labels, preds, average="macro", zero_division=0
    )

    # Financial safety subset
    safety_mask = np.isin(labels, FIN_SAFETY_LABEL_IDS)
    if safety_mask.sum() > 0:
        safety_true = labels[safety_mask]
        safety_pred = preds[safety_mask]
        safety_f1 = precision_recall_fscore_support(
            safety_true, safety_pred, average="macro", zero_division=0
        )[2]
    else:
        safety_f1 = 0.0

    metrics = {
        "accuracy": float(accuracy),
        "macro_precision": float(macro_p),
        "macro_recall": float(macro_r),
        "macro_f1": float(macro_f1),
        "safety_f1": float(safety_f1),
    }

    # Per-label F1 for the 6 financial labels
    for lbl_id in FIN_SAFETY_LABEL_IDS:
        name = ID2LABEL[lbl_id]
        metrics[f"f1_{name}"] = float(f1[lbl_id]) if lbl_id < len(f1) else 0.0

    return metrics

In [25]:
# ============================
# 5. Training function
# ============================

def train(
    model_config: ModelConfig,
    lora_config: LoRAConfig,
    train_config: TrainConfig,
    data_config: DataConfig,
    use_wandb: bool = False,
):
    """Main training function."""

    set_seed(train_config.seed)

    # 1. Model + tokenizer
    model, tokenizer = setup_model_and_tokenizer(model_config, lora_config)

    # 2. Datasets
    datasets = prepare_datasets(data_config, tokenizer, model_config.max_length)
    assert "train" in datasets and "validation" in datasets, "Train/validation splits missing!"

    # 3. HF eval/save step multiple fix
    eval_steps = train_config.eval_steps
    save_steps = train_config.save_steps
    if save_steps % eval_steps != 0:
        logger.warning(
            f"save_steps={save_steps} is not a multiple of eval_steps={eval_steps}. "
            f"Overriding save_steps to {eval_steps} to satisfy HF constraint."
        )
        save_steps = eval_steps

    training_args = TrainingArguments(
        output_dir=train_config.output_dir,
        num_train_epochs=train_config.num_epochs,
        per_device_train_batch_size=train_config.batch_size,
        per_device_eval_batch_size=train_config.batch_size * 2,
        gradient_accumulation_steps=train_config.gradient_accumulation_steps,
        learning_rate=train_config.learning_rate,
        warmup_ratio=train_config.warmup_ratio,
        weight_decay=train_config.weight_decay,
        max_grad_norm=train_config.max_grad_norm,

        fp16=train_config.fp16,     # False => pure fp32, no GradScaler weirdness
        logging_steps=train_config.logging_steps,

        eval_strategy="steps",
        eval_steps=eval_steps,
        save_strategy="steps",
        save_steps=save_steps,

        load_best_model_at_end=True,
        metric_for_best_model="safety_f1",
        greater_is_better=True,

        report_to="none",
        seed=train_config.seed,
        remove_unused_columns=False,

        # Tell Trainer that our label field is 'labels'
        label_names=["labels"],
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=datasets["train"],
        eval_dataset=datasets["validation"],
        compute_metrics=compute_metrics,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
    )

    # Optional sanity check BEFORE training
    logger.info("Running one sanity-check forward pass...")
    from torch.utils.data import DataLoader
    train_dl = DataLoader(datasets["train"], batch_size=2, shuffle=True)
    batch = next(iter(train_dl))
    batch = {k: v.to(model.device) for k, v in batch.items()}
    with torch.no_grad():
        out = model(**batch)
    logger.info(f"[Sanity] loss={float(out.loss):.4f}, logits_shape={tuple(out.logits.shape)}")

    # 4. Train
    logger.info("Starting training...")
    train_result = trainer.train()

    logger.info("Saving final model...")
    trainer.save_model(f"{train_config.output_dir}/final")
    tokenizer.save_pretrained(f"{train_config.output_dir}/final")

    with open(f"{train_config.output_dir}/train_metrics.json", "w") as f:
        json.dump(train_result.metrics, f, indent=2)

    # 5. Test evaluation
    if "test" in datasets:
        logger.info("Evaluating on test set...")
        test_results = trainer.evaluate(datasets["test"])
        with open(f"{train_config.output_dir}/test_metrics.json", "w") as f:
            json.dump(test_results, f, indent=2)
        logger.info(f"Test Results: {test_results}")

    return trainer, model, tokenizer

In [26]:
from datasets import load_dataset
import json

with open("/kaggle/input/llamaguard-ft-dataset/train.json") as f:
    tmp = json.load(f)
print(tmp[0].keys())
# should include: 'text', 'label', 'policy_label', 'tags', 'adversarial_flag', 'metadata', 'explanation'

# After prepare_datasets:
# 1. Build model + tokenizer (reuses your setup function)
model_dbg, tokenizer_dbg = setup_model_and_tokenizer(model_cfg, lora_cfg)

# 2. Prepare datasets
datasets_dbg = prepare_datasets(data_cfg, tokenizer_dbg, model_cfg.max_length)
print(datasets_dbg["train"].column_names)
# -> should show: ['input_ids', 'attention_mask', 'labels']

# should include: 'input_ids', 'attention_mask', 'labels'


dict_keys(['id', 'text', 'label', 'policy_label', 'tags', 'adversarial_flag', 'metadata', 'explanation'])


tokenizer_config.json:   0%|          | 0.00/53.2k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.00G [00:00<?, ?B/s]

Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at meta-llama/Llama-Guard-3-1B and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


trainable params: 11,325,440 || all params: 1,247,193,088 || trainable%: 0.9081


Map:   0%|          | 0/2089 [00:00<?, ? examples/s]

Map:   0%|          | 0/231 [00:00<?, ? examples/s]

Map:   0%|          | 0/267 [00:00<?, ? examples/s]

['input_ids', 'attention_mask', 'labels']


In [27]:
# 1) Confirm labels & mapping
import collections, json

with open("/kaggle/input/llamaguard-ft-dataset/train.json") as f:
    train_raw = json.load(f)

label_counts = collections.Counter(ex["label"] for ex in train_raw)
print("Num train examples:", len(train_raw))
print("Num unique labels:", len(label_counts))
print("Labels present:", label_counts)
print("Labels in dataset but not in LABEL2ID:",
      set(label_counts.keys()) - set(LABEL2ID.keys()))

Num train examples: 2089
Num unique labels: 26
Labels present: Counter({'illicit_market_manipulation': 233, 'safe_info': 191, 'illicit_money_laundering': 179, 'safe_with_disclaimer': 175, 'high_value_approval': 163, 'illicit_regulatory_evasion': 163, 'redirect_to_docs': 136, 'redirect_to_professional': 119, 'illicit_insider_trading': 101, 'ambiguous_needs_review': 99, 'code_interpreter_abuse': 93, 'illicit_fraud': 92, 'violent_crime': 68, 'hate_speech': 48, 'specialized_advice': 37, 'defamation': 28, 'illicit_unauthorized_advice': 28, 'privacy_violation': 22, 'weapons': 21, 'child_sexual_exploitation': 18, 'sexual_content': 18, 'ip_violation': 18, 'elections': 14, 'sex_related_crime': 9, 'non_violent_crime': 8, 'suicide_self_harm': 8})
Labels in dataset but not in LABEL2ID: set()


In [28]:
# 2) Check HF dataset columns
model, tokenizer = setup_model_and_tokenizer(model_cfg, lora_cfg)
datasets = prepare_datasets(data_cfg, tokenizer, model_cfg.max_length)
print(datasets["train"].column_names)

Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at meta-llama/Llama-Guard-3-1B and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


trainable params: 11,325,440 || all params: 1,247,193,088 || trainable%: 0.9081


Map:   0%|          | 0/2089 [00:00<?, ? examples/s]

Map:   0%|          | 0/231 [00:00<?, ? examples/s]

Map:   0%|          | 0/267 [00:00<?, ? examples/s]

['input_ids', 'attention_mask', 'labels']


In [29]:
# Quick numeric sanity: run 1 batch and compute loss manually
from torch.utils.data import DataLoader

datasets = prepare_datasets(data_cfg, tokenizer, model_cfg.max_length)
train_dl = DataLoader(datasets["train"], batch_size=2)

batch = next(iter(train_dl))
batch = {k: v.to(model.device) for k, v in batch.items()}

with torch.no_grad():
    out = model(**batch)
    print("Initial loss:", out.loss.item())

Map:   0%|          | 0/2089 [00:00<?, ? examples/s]

Map:   0%|          | 0/231 [00:00<?, ? examples/s]

Map:   0%|          | 0/267 [00:00<?, ? examples/s]

Initial loss: 2.844724416732788


In [30]:


# ============================
# 6. Start button
# ============================

model_cfg = ModelConfig(
    base_model="meta-llama/Llama-Guard-3-1B",
    num_labels=NUM_LABELS,
    max_length=512,
)

lora_cfg = LoRAConfig(
    enabled=True,
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
)

train_cfg = TrainConfig(
    output_dir="/kaggle/working/financial-guardrails",
    num_epochs=5,
    batch_size=4,
)

data_cfg = DataConfig(
    base_dir="/kaggle/input/llamaguard-ft-dataset"
)

trainer, model, tokenizer = train(
    model_config=model_cfg,
    lora_config=lora_cfg,
    train_config=train_cfg,
    data_config=data_cfg,
)

Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at meta-llama/Llama-Guard-3-1B and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


trainable params: 11,325,440 || all params: 1,247,193,088 || trainable%: 0.9081


Map:   0%|          | 0/2089 [00:00<?, ? examples/s]

Map:   0%|          | 0/231 [00:00<?, ? examples/s]

Map:   0%|          | 0/267 [00:00<?, ? examples/s]



Step,Training Loss,Validation Loss,Accuracy,Macro Precision,Macro Recall,Macro F1,Safety F1,F1 Illicit Market Manipulation,F1 Illicit Insider Trading,F1 Illicit Money Laundering,F1 Illicit Fraud,F1 Illicit Regulatory Evasion,F1 Illicit Unauthorized Advice
200,0.1455,0.011064,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0
400,0.0159,0.01668,0.995671,0.998264,0.997024,0.99757,1.0,1.0,1.0,1.0,1.0,0.0,0.0
600,0.0066,0.019443,0.995671,0.998106,0.997024,0.997488,1.0,1.0,1.0,1.0,1.0,0.0,0.0


In [31]:
import json, os

base = "/kaggle/working/financial-guardrails"

with open(os.path.join(base, "train_metrics.json")) as f:
    train_metrics = json.load(f)

with open(os.path.join(base, "test_metrics.json")) as f:
    test_metrics = json.load(f)

print("Train metrics:", train_metrics)
print("\nTest metrics:", test_metrics)

Train metrics: {'train_runtime': 7450.5524, 'train_samples_per_second': 1.402, 'train_steps_per_second': 0.088, 'total_flos': 3.1590486638592e+16, 'train_loss': 0.3119972374494621, 'epoch': 5.0}

Test metrics: {'eval_loss': 0.022920673713088036, 'eval_accuracy': 0.9925093632958801, 'eval_macro_precision': 0.9960653157584104, 'eval_macro_recall': 0.9967892976588628, 'eval_macro_f1': 0.9963773286754652, 'eval_safety_f1': 1.0, 'eval_f1_illicit_market_manipulation': 1.0, 'eval_f1_illicit_insider_trading': 1.0, 'eval_f1_illicit_money_laundering': 1.0, 'eval_f1_illicit_fraud': 1.0, 'eval_f1_illicit_regulatory_evasion': 1.0, 'eval_f1_illicit_unauthorized_advice': 1.0, 'eval_runtime': 86.7525, 'eval_samples_per_second': 3.078, 'eval_steps_per_second': 0.392, 'epoch': 5.0}


In [35]:
# ============================
# 7. Manual interactive sanity check
# ============================
import time
import torch
import torch.nn.functional as F

# Make sure we're in eval mode
model.eval()

# Figure out which device the model is on
device = next(model.parameters()).device
print(f"\nModel is on device: {device}")
print("Type a query to test the guardrail classifier. Type 'q' or 'quit' to exit.\n")

def classify_guardrail(text: str):
    """Run a single query through the fine-tuned Llama-Guard classifier."""
    enc = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=model_cfg.max_length,
    )
    enc = {k: v.to(device) for k, v in enc.items()}

    with torch.no_grad():
        outputs = model(**enc)
        logits = outputs.logits
        probs = F.softmax(logits, dim=-1)[0]  # (num_labels,)

    pred_id = int(probs.argmax().item())
    pred_label = ID2LABEL[pred_id]
    confidence = float(probs[pred_id].item())

    return pred_label, confidence, probs.cpu().tolist()

# Optional: helper to pretty-print top-k labels
def print_top_k(probs_list, k: int = 5):
    probs_tensor = torch.tensor(probs_list)
    top_vals, top_idx = torch.topk(probs_tensor, k)
    print("\nTop classes:")
    for val, idx in zip(top_vals, top_idx):
        lbl = ID2LABEL[int(idx)]
        print(f"  {lbl:30s} : {float(val):.4f}")

# Interactive loop
while True:
    q = input("\nEnter a query (or 'q' to quit): ").strip()
    if q.lower() in {"q", "quit", "exit"}:
        print("Exiting manual check loop.")
        break
    if not q:
        print("Empty query, try again.")
        continue

    t0 = time.time()
    label, conf, probs_list = classify_guardrail(q)
    dt = (time.time() - t0) * 1000.0

    print(f"\nPredicted label : {label}")
    print(f"Confidence      : {conf:.3f}")
    print(f"Latency         : {dt:.1f} ms")

    # Show top-5 labels for debugging / intuition
    print_top_k(probs_list, k=5)


Model is on device: cuda:0
Type a query to test the guardrail classifier. Type 'q' or 'quit' to exit.




Enter a query (or 'q' to quit):  insider trading



Predicted label : illicit_insider_trading
Confidence      : 0.442
Latency         : 44.7 ms

Top classes:
  illicit_insider_trading        : 0.4416
  redirect_to_docs               : 0.1791
  weapons                        : 0.0622
  safe_info                      : 0.0530
  child_sexual_exploitation      : 0.0522



Enter a query (or 'q' to quit):  should i invest in a nearby tea shop 



Predicted label : redirect_to_docs
Confidence      : 0.256
Latency         : 54.9 ms

Top classes:
  redirect_to_docs               : 0.2563
  weapons                        : 0.1891
  ambiguous_needs_review         : 0.1140
  sexual_content                 : 0.1139
  illicit_insider_trading        : 0.0720



Enter a query (or 'q' to quit):  jdsfjdkgdsk



Predicted label : safe_info
Confidence      : 0.239
Latency         : 54.6 ms

Top classes:
  safe_info                      : 0.2387
  privacy_violation              : 0.1292
  specialized_advice             : 0.1266
  ip_violation                   : 0.0683
  hate_speech                    : 0.0682



Enter a query (or 'q' to quit):  aryan jain



Predicted label : safe_info
Confidence      : 0.684
Latency         : 49.8 ms

Top classes:
  safe_info                      : 0.6839
  suicide_self_harm              : 0.1829
  weapons                        : 0.0725
  illicit_market_manipulation    : 0.0154
  ambiguous_needs_review         : 0.0080



Enter a query (or 'q' to quit):  aryan



Predicted label : safe_info
Confidence      : 0.717
Latency         : 46.4 ms

Top classes:
  safe_info                      : 0.7165
  weapons                        : 0.0613
  suicide_self_harm              : 0.0391
  sexual_content                 : 0.0321
  violent_crime                  : 0.0285



Enter a query (or 'q' to quit):  aryan fuck you



Predicted label : ambiguous_needs_review
Confidence      : 0.309
Latency         : 49.3 ms

Top classes:
  ambiguous_needs_review         : 0.3086
  weapons                        : 0.1318
  safe_info                      : 0.1181
  hate_speech                    : 0.0953
  sexual_content                 : 0.0895



Enter a query (or 'q' to quit):  bitch



Predicted label : safe_info
Confidence      : 0.436
Latency         : 46.5 ms

Top classes:
  safe_info                      : 0.4362
  suicide_self_harm              : 0.2092
  sexual_content                 : 0.1710
  weapons                        : 0.0477
  ambiguous_needs_review         : 0.0419



Enter a query (or 'q' to quit):  shut up



Predicted label : sexual_content
Confidence      : 0.213
Latency         : 44.3 ms

Top classes:
  sexual_content                 : 0.2130
  weapons                        : 0.1625
  safe_info                      : 0.1315
  specialized_advice             : 0.0960
  ambiguous_needs_review         : 0.0765



Enter a query (or 'q' to quit):  son of a bitch



Predicted label : suicide_self_harm
Confidence      : 0.280
Latency         : 50.8 ms

Top classes:
  suicide_self_harm              : 0.2798
  safe_info                      : 0.1855
  ambiguous_needs_review         : 0.1712
  hate_speech                    : 0.1012
  weapons                        : 0.0492



Enter a query (or 'q' to quit):  finance mafia



Predicted label : safe_info
Confidence      : 0.410
Latency         : 47.0 ms

Top classes:
  safe_info                      : 0.4095
  suicide_self_harm              : 0.1214
  ip_violation                   : 0.0853
  illicit_money_laundering       : 0.0805
  sexual_content                 : 0.0650



Enter a query (or 'q' to quit):  how to start a mafia



Predicted label : redirect_to_professional
Confidence      : 0.198
Latency         : 53.6 ms

Top classes:
  redirect_to_professional       : 0.1979
  suicide_self_harm              : 0.1927
  hate_speech                    : 0.1251
  safe_info                      : 0.1003
  sexual_content                 : 0.0771



Enter a query (or 'q' to quit):  mafia



Predicted label : safe_info
Confidence      : 0.266
Latency         : 46.3 ms

Top classes:
  safe_info                      : 0.2660
  weapons                        : 0.2103
  ip_violation                   : 0.1526
  sexual_content                 : 0.1392
  ambiguous_needs_review         : 0.0516



Enter a query (or 'q' to quit):  mafia start



Predicted label : sexual_content
Confidence      : 0.204
Latency         : 44.3 ms

Top classes:
  sexual_content                 : 0.2037
  suicide_self_harm              : 0.1651
  safe_info                      : 0.1472
  ambiguous_needs_review         : 0.0967
  hate_speech                    : 0.0812



Enter a query (or 'q' to quit):  bomb



Predicted label : sexual_content
Confidence      : 0.574
Latency         : 40.0 ms

Top classes:
  sexual_content                 : 0.5745
  weapons                        : 0.1581
  privacy_violation              : 0.0572
  ambiguous_needs_review         : 0.0541
  suicide_self_harm              : 0.0428



Enter a query (or 'q' to quit):  aryan bomb



Predicted label : weapons
Confidence      : 0.371
Latency         : 44.2 ms

Top classes:
  weapons                        : 0.3712
  ambiguous_needs_review         : 0.1332
  sexual_content                 : 0.1003
  privacy_violation              : 0.0917
  safe_info                      : 0.0832



Enter a query (or 'q' to quit):  nigga



Predicted label : safe_info
Confidence      : 0.307
Latency         : 44.4 ms

Top classes:
  safe_info                      : 0.3074
  ambiguous_needs_review         : 0.2831
  suicide_self_harm              : 0.1112
  sexual_content                 : 0.0972
  weapons                        : 0.0554



Enter a query (or 'q' to quit):  q


Exiting manual check loop.


In [40]:
# Use a pipeline as a high-level helper
from transformers import pipeline

pipe = pipeline("text-generation", model="meta-llama/Llama-Guard-3-1B")
messages = [
    {"role": "user", "content": "what do a aryan eat for jain"},
]
pipe(messages)

Device set to use cuda:0
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


[{'generated_text': [{'role': 'user',
    'content': 'what do a aryan eat for jain'},
   {'role': 'assistant', 'content': '\n\nsafe'}]}]