# 🧠 Unified Classification + Reasoning (FLAN‑T5) — Kaggle Training Notebook

> This notebook fine-tunes a FLAN‑T5 model to output both a fraud label and a short explanation in a single generation, using your compiled dataset.

It mirrors `training/unified_t5_fraud.py` but is optimized for Kaggle. Make sure your CSV is provided as a Kaggle dataset input (e.g., `/kaggle/input/fraud-data/final_fraud_detection_dataset.csv`).

Tip: Early stopping is enabled by a loss threshold. Adjust `EARLY_STOP_LOSS` (and `MIN_EPOCHS`) in the config cell to stop training once `eval_loss` is small enough.

In [None]:
# 1) Detect Kaggle Environment and Set Paths
import os, glob, json, sys
from pathlib import Path

IS_KAGGLE = Path('/kaggle').exists()
INPUT_DIR = Path('/kaggle/input') if IS_KAGGLE else Path('..')
WORK_DIR = Path('/kaggle/working') if IS_KAGGLE else Path('.')
WORK_DIR.mkdir(parents=True, exist_ok=True)

print(f"Running on Kaggle: {IS_KAGGLE}")
print(f"INPUT_DIR: {INPUT_DIR}")
print(f"WORK_DIR: {WORK_DIR}")

if IS_KAGGLE:
    print("\nAvailable datasets under /kaggle/input:")
    for p in sorted(INPUT_DIR.glob('*')):
        print(' -', p)

In [None]:
# 2) Import and (Optionally) Install Libraries
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, log_loss
from pathlib import Path

import torch
import transformers as tf
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Trainer,
    TrainingArguments,
    set_seed,
)
import inspect

print('PyTorch:', torch.__version__)
print('Transformers:', tf.__version__)
device = 'cuda' if torch.cuda.is_available() else ('mps' if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() else 'cpu')
print('Device:', device)

In [None]:
# 3) Configuration (Seeds, Paths, Target, Problem Type)
SEED = 42
set_seed(SEED)

# Point to your fraud dataset CSV under /kaggle/input
# Adjust this path based on your dataset name
CSV_PATH = INPUT_DIR / 'fraud-data' / 'final_fraud_detection_dataset.csv'
if not CSV_PATH.exists():
    # Fallback guess: sometimes users rename the dataset folder
    matches = list(INPUT_DIR.glob('**/final_fraud_detection_dataset.csv'))
    if matches:
        CSV_PATH = matches[0]
print('CSV_PATH:', CSV_PATH)

TEXT_COL = 'text'
LABEL_COL = 'detailed_category'

DEFAULT_LABELS = [
    'job_scam',
    'legitimate',
    'phishing',
    'popup_scam',
    'refund_scam',
    'reward_scam',
    'sms_spam',
    'ssn_scam',
    'tech_support_scam'
]

MODEL_NAME = 'google/flan-t5-small'
OUTPUT_DIR = WORK_DIR / 'unified-flan-t5-small'
MAX_SOURCE_LENGTH = 256
MAX_TARGET_LENGTH = 64
TRAIN_SIZE = 0.9
NUM_EPOCHS = 3
LR = 3e-4
BATCH_TRAIN = 8
BATCH_EVAL = 8
GRAD_ACCUM = 1

# Early stopping when eval_loss is small enough
EARLY_STOP_LOSS = 0.025   # stop training once eval_loss <= this threshold
MIN_EPOCHS = 1           # do not stop before this many epochs

INSTRUCTION_PREFIX = (
    "Classify the message into one of these categories and explain briefly:\n"
    "Categories: job_scam, legitimate, phishing, popup_scam, refund_scam, reward_scam, sms_spam, ssn_scam, tech_support_scam\n"
    "Message: "
)

In [None]:
# 4) Load Data from /kaggle/input
assert CSV_PATH.exists(), f"CSV not found at {CSV_PATH}. Attach your dataset to the notebook."
df = pd.read_csv(CSV_PATH)
print('Loaded:', df.shape)
print(df.head(3))

# Filter to known labels only
df = df[df[LABEL_COL].isin(DEFAULT_LABELS)].copy()
print('After label filtering:', df.shape)

# 5) Basic Data Checks
assert TEXT_COL in df.columns and LABEL_COL in df.columns, 'Missing required columns'
print('Columns:', list(df.columns))
print('Label distribution:')
print(df[LABEL_COL].value_counts())

# Train/Validation split (stratified)
train_df, val_df = train_test_split(df, test_size=1-TRAIN_SIZE, random_state=SEED, stratify=df[LABEL_COL])
print('Train:', train_df.shape, 'Val:', val_df.shape)

In [None]:
# 6) Build Targets: label + short reason
def default_reason_for_label(label: str) -> str:
    templates = {
        'legitimate': "Message content appears normal and lacks scam indicators.",
        'phishing': "Contains credential-stealing cues (links/requests for login or personal info).",
        'tech_support_scam': "Mentions fake support, urgent fixes, or remote access.",
        'reward_scam': "Promises winnings/prizes with urgency or fees.",
        'job_scam': "Unrealistic job offers, upfront payments, or unsolicited interviews.",
        'sms_spam': "Unwanted promotional or bulk messaging characteristics.",
        'popup_scam': "Fake security alerts/popups demanding immediate action.",
        'refund_scam': "Unexpected refund/chargeback claims with links or callbacks.",
        'ssn_scam': "Requests social security info or threats about your SSN.",
    }
    return templates.get(label, "Heuristic cues indicate this category.")

def build_target(label: str, text: str, reason: str = None) -> str:
    if not reason or not isinstance(reason, str) or len(reason.strip()) == 0:
        reason = default_reason_for_label(label)
    return f"label: {label} | reason: {reason}"

def make_examples(df, text_col: str, label_col: str):
    examples = []
    for _, row in df.iterrows():
        text = str(row[text_col])
        label = str(row[label_col])
        examples.append({
            'source': f"{INSTRUCTION_PREFIX}{text}",
            'target': build_target(label, text),
        })
    return examples

train_examples = make_examples(train_df, TEXT_COL, LABEL_COL)
val_examples = make_examples(val_df, TEXT_COL, LABEL_COL)
print('Examples:', len(train_examples), len(val_examples))

In [None]:
# 7) Tokenizer and Simple Dataset Wrappers
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
model.to(device)

class SimpleMapDataset:
    def __init__(self, examples):
        self.examples = examples
    def __len__(self):
        return len(self.examples)
    def __getitem__(self, idx):
        e = self.examples[idx]
        model_inputs = tokenizer(
            e['source'],
            max_length=MAX_SOURCE_LENGTH,
            truncation=True,
            padding='max_length',
)
        with tokenizer.as_target_tokenizer():
            labels = tokenizer(
                e['target'],
                max_length=MAX_TARGET_LENGTH,
                truncation=True,
                padding='max_length',
)
        model_inputs['labels'] = labels['input_ids']
        return model_inputs

train_ds = SimpleMapDataset(train_examples)
val_ds = SimpleMapDataset(val_examples)

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

In [None]:
# 8) TrainingArguments, Trainer, and Metrics
from typing import Any


def extract_label_from_text(text: str) -> str:
    t = " ".join(str(text).strip().lower().split())
    if 'label:' in t:
        after = t.split('label:', 1)[1].strip()
        if '|' in after:
            after = after.split('|', 1)[0].strip()
        return " ".join(after.split()[:5])
    return 'unknown'


def compute_metrics(eval_pred: Any):
    # Support both (preds, labels) tuple and EvalPrediction object
    preds = getattr(eval_pred, 'predictions', None)
    labels_ids = getattr(eval_pred, 'label_ids', None)
    if preds is None or labels_ids is None:
        preds, labels_ids = eval_pred
    pred_texts = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels_ids = [[(lid if lid != -100 else tokenizer.pad_token_id) for lid in seq] for seq in labels_ids]
    gold_texts = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    pred_labels = [extract_label_from_text(t) for t in pred_texts]
    gold_labels = [extract_label_from_text(t) for t in gold_texts]
    canon = set(DEFAULT_LABELS)

    def canonize(s: str) -> str:
        s = s.strip()
        if s in canon:
            return s
        for c in canon:
            if s in c or c in s:
                return c
        return s

    pred_labels = [canonize(x) for x in pred_labels]
    gold_labels = [canonize(x) for x in gold_labels]
    acc = accuracy_score(gold_labels, pred_labels)
    return {'label_accuracy': acc}


# Custom early-stopping callback based on absolute eval_loss threshold
from transformers import TrainerCallback, TrainerState, TrainerControl


class LossThresholdEarlyStop(TrainerCallback):
    def __init__(self, threshold: float, min_epochs: int = 0):
        self.threshold = float(threshold)
        self.min_epochs = int(min_epochs)

    def on_evaluate(self, args, state: TrainerState, control: TrainerControl, **kwargs):
        metrics = kwargs.get('metrics') or {}
        eval_loss = metrics.get('eval_loss')
        # Only consider stopping after completing at least min_epochs
        current_epoch = state.epoch or 0
        if eval_loss is not None and current_epoch >= self.min_epochs:
            if eval_loss <= self.threshold:
                control.should_training_stop = True
                print(f"Early stopping triggered: eval_loss={eval_loss:.4f} <= threshold={self.threshold} (epoch={current_epoch})")
        return control


# Build TrainingArguments robustly against Transformers version changes
args_dict = {
    'output_dir': str(OUTPUT_DIR),
    'learning_rate': LR,
    'per_device_train_batch_size': BATCH_TRAIN,
    'per_device_eval_batch_size': BATCH_EVAL,
    'num_train_epochs': NUM_EPOCHS,
    'weight_decay': 0.01,
    'warmup_ratio': 0.05,
    'logging_steps': 50,
    'gradient_accumulation_steps': GRAD_ACCUM,
    'report_to': ['none'],
    'fp16': torch.cuda.is_available(),
    'bf16': False,
}

sig = inspect.signature(TrainingArguments.__init__)
if 'evaluation_strategy' in sig.parameters:
    args_dict['evaluation_strategy'] = 'epoch'  # needed so on_evaluate fires regularly
if 'save_strategy' in sig.parameters:
    args_dict['save_strategy'] = 'epoch'
if 'predict_with_generate' in sig.parameters:
    args_dict['predict_with_generate'] = True
if 'generation_max_length' in sig.parameters:
    args_dict['generation_max_length'] = MAX_TARGET_LENGTH

training_args = TrainingArguments(**args_dict)

# Instantiate callback with configured threshold
callbacks = [LossThresholdEarlyStop(threshold=EARLY_STOP_LOSS, min_epochs=MIN_EPOCHS)]

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=callbacks,
)

In [None]:
# 9) Train
trainer.train()
trainer.save_model(str(OUTPUT_DIR))
tokenizer.save_pretrained(str(OUTPUT_DIR))
print('Model saved to:', OUTPUT_DIR)

# Note: Training may stop early if eval_loss <= EARLY_STOP_LOSS after MIN_EPOCHS epochs.

In [None]:
# 10) Quick Sanity Generation
model.eval()
sample_text = val_df.iloc[0][TEXT_COL] if len(val_df) else train_df.iloc[0][TEXT_COL]
inputs = tokenizer([f"{INSTRUCTION_PREFIX}{sample_text}"], return_tensors='pt', truncation=True, padding=True).to(model.device)
with torch.no_grad():
    gen = model.generate(**inputs, max_new_tokens=64, do_sample=False)
print('Input snippet:', sample_text[:120])
print('Output:', tokenizer.decode(gen[0], skip_special_tokens=True))

In [None]:
# 11) Zip Artifacts (Optional)
import shutil
ZIP_PATH = WORK_DIR / 'unified-flan-t5-small.zip'
if ZIP_PATH.exists():
    ZIP_PATH.unlink()
shutil.make_archive(str(ZIP_PATH.with_suffix('')), 'zip', str(OUTPUT_DIR))
print('Zipped model to:', ZIP_PATH)