In [30]:
"""
DragoNet - End-to-End Transaction Categorisation Pipeline
---------------------------------------------------------

Features:
- Load & clean UPI transaction CSV.
- Text + numeric feature fusion (DistilBERT + engineered features).
- Train/Val/Test split with stratification on macro label.
- Dual-head model: predicts MACRO + MICRO categories.
- Full evaluation:
    - Weighted F1 for macro & micro.
    - Accuracy for macro & micro.
    - Classification reports (saved as JSON).
    - Confusion matrices (saved as CSV).
- Inference helper: predict_transaction(...) for a single row.

Notes about performance:
- This version uses stronger regularisation (higher dropout, weight decay,
  limited epochs) to *avoid overfitting / overperformance*.
- Exact Macro-F1 depends on your dataset; adjust EPOCHS / DROPOUT / LR etc
  if you still see > 0.94.
"""

import os
import random
from datetime import datetime
from typing import Tuple, Dict, Any, List

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    f1_score,
    accuracy_score,
    classification_report,
    confusion_matrix,
)
from tqdm.auto import tqdm

# ----------------------------------
# CONFIG
# ----------------------------------
CSV_PATH = "synthetic_100k_transactions.csv"   # change to your file path
TEXT_NAME_COL = "Counterparty"
TEXT_DESC_COL = "Note"
UPI_COL = "UPI_ID"            # retained but not parsed into suffix/domain
DATE_COL = "Date"
TIME_COL = "Time"
DIRECTION_COL = "Direction"
AMOUNT_COL = "Amount"
MACRO_COL = "Tag"
MICRO_COL = "SubCategory"

MODEL_NAME = "distilbert-base-multilingual-cased"
MAX_LEN = 32
BATCH_SIZE = 8

# regularisation tuned to AVOID overperformance
EPOCHS = 5
LR = 2e-5
WEIGHT_DECAY = 1e-2
WARMUP_RATIO = 0.05
DROPOUT = 0.4

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
GRAD_CLIP = 1.0
MIN_SAMPLES_PER_CLASS = 2
PATIENCE = 3
CHECKPOINT_PATH = "best_drago_net_with_reports.pt"
SEED = 42

REPORTS_DIR = "reports"
os.makedirs(REPORTS_DIR, exist_ok=True)

# ----------------------------------
# REPRO
# ----------------------------------
def set_seed(seed: int = SEED):
    import torch
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed()

# ----------------------------------
# UTIL: text cleaning
# ----------------------------------
import re
from typing import Any

def clean_text(s: Any) -> str:
    if pd.isna(s):
        return ""
    s = str(s).lower().strip()
    s = re.sub(r"\s+", " ", s)
    # mask long numeric strings (like transaction IDs)
    s = re.sub(r"\d{6,}", "<NUM>", s)
    # remove generic transaction boilerplate words to lower overfitting
    s = re.sub(r"(txn|upi|ref|id|payment|pmt|tran)", " ", s)
    s = re.sub(r"[^\w@\-/<>\s]", " ", s)
    s = re.sub(r"\s+", " ", s).strip()
    return s

# ----------------------------------
# LOAD & PREPROCESS
# ----------------------------------
def load_and_preprocess(csv_path: str) -> pd.DataFrame:
    df = pd.read_csv(csv_path)

    # drop clearly unneeded columns if present
    for c in ["Account", "UPI_Ref_No"]:
        if c in df.columns:
            df = df.drop(columns=[c])

    # text columns
    for col in [TEXT_NAME_COL, TEXT_DESC_COL, UPI_COL]:
        if col in df.columns:
            df[col] = df[col].fillna("").astype(str).map(clean_text)
        else:
            df[col] = ""

    # direction – keep simple, don't over-encode
    if DIRECTION_COL in df.columns:
        df[DIRECTION_COL] = (
            df[DIRECTION_COL]
            .fillna("unknown")
            .astype(str)
            .str.lower()
            .str.strip()
        )
    else:
        df[DIRECTION_COL] = "unknown"

    # date/time parsing
    if DATE_COL in df.columns:
        df[DATE_COL] = pd.to_datetime(df[DATE_COL], errors="coerce", dayfirst=True)
    else:
        df[DATE_COL] = pd.NaT

    if TIME_COL in df.columns:
        def combine_dt(r):
            d = r[DATE_COL]
            t = r[TIME_COL]
            if pd.isna(d):
                return pd.NaT
            if pd.isna(t) or str(t).strip() == "":
                return d
            tt = str(t).strip()
            try:
                if ":" in tt:
                    parts = tt.split(":")
                    h = int(parts[0])
                    m = int(parts[1]) if len(parts) > 1 else 0
                    s = int(parts[2]) if len(parts) > 2 else 0
                    return datetime(d.year, d.month, d.day, h, m, s)
                else:
                    h = int(tt)
                    return datetime(d.year, d.month, d.day, h)
            except Exception:
                return d
        df["dt"] = df.apply(combine_dt, axis=1)
    else:
        df["dt"] = df[DATE_COL]

    # drop rows without valid datetime
    df = df[~df["dt"].isna()].reset_index(drop=True)

    # amount
    if AMOUNT_COL in df.columns:
        df[AMOUNT_COL] = pd.to_numeric(df[AMOUNT_COL], errors="coerce")
        df = df[~df[AMOUNT_COL].isna()].reset_index(drop=True)
    else:
        df[AMOUNT_COL] = 0.0

    # labels
    for col in [MACRO_COL, MICRO_COL]:
        if col in df.columns:
            df[col] = df[col].fillna("unknown").astype(str).str.strip()
        else:
            df[col] = "unknown"

    return df

# ----------------------------------
# FEATURE ENGINEERING
# ----------------------------------
def engineer_features(df: pd.DataFrame) -> pd.DataFrame:
    # temporal features
    df["hour"] = df["dt"].dt.hour.fillna(0).astype(int)
    df["hour_sin"] = np.sin(2 * np.pi * df["hour"] / 24.0)
    df["hour_cos"] = np.cos(2 * np.pi * df["hour"] / 24.0)

    df["dow"] = df["dt"].dt.dayofweek
    df["dow_sin"] = np.sin(2 * np.pi * df["dow"] / 7.0)
    df["dow_cos"] = np.cos(2 * np.pi * df["dow"] / 7.0)

    df["month"] = df["dt"].dt.month
    df["month_sin"] = np.sin(2 * np.pi * (df["month"] - 1) / 12.0)
    df["month_cos"] = np.cos(2 * np.pi * (df["month"] - 1) / 12.0)

    df["is_month_start"] = df["dt"].dt.is_month_start.astype(int)
    df["is_month_end"] = df["dt"].dt.is_month_end.astype(int)
    df["is_weekend"] = df["dow"].isin([5, 6]).astype(int)

    # merchant_key = Counterparty only (NO UPI parsing, as requested)
    df["merchant_key"] = df[TEXT_NAME_COL].astype(str)
    merchant_group = df.groupby("merchant_key")[AMOUNT_COL]
    merchant_mean = merchant_group.mean().to_dict()
    merchant_count = merchant_group.count().to_dict()

    df["merchant_mean_amt"] = df["merchant_key"].map(merchant_mean).fillna(
        df[AMOUNT_COL].mean()
    )
    df["merchant_count"] = df["merchant_key"].map(merchant_count).fillna(0)

    # days_since_prev_merchant
    df = df.sort_values("dt").reset_index(drop=True)
    df["days_since_prev_merchant"] = -1.0
    last_dt_by_merchant: Dict[str, datetime] = {}
    for i, row in df.iterrows():
        k = row["merchant_key"]
        cur_dt = row["dt"]
        if k in last_dt_by_merchant:
            delta = (cur_dt - last_dt_by_merchant[k]).days
            df.at[i, "days_since_prev_merchant"] = float(delta)
        else:
            df.at[i, "days_since_prev_merchant"] = -1.0
        last_dt_by_merchant[k] = cur_dt

    # recurring heuristic
    median_gap = (
        df[df["days_since_prev_merchant"] >= 0]
        .groupby("merchant_key")["days_since_prev_merchant"]
        .median()
        .to_dict()
    )
    df["median_gap"] = df["merchant_key"].map(median_gap).fillna(9999)
    df["is_recurring"] = (
        (df["merchant_count"] >= 3) & (df["median_gap"] < 40)
    ).astype(int)

    # amount transforms
    df["amount_log"] = np.log1p(df[AMOUNT_COL].abs())
    amt_mean = df["amount_log"].mean()
    amt_std = df["amount_log"].std() if df["amount_log"].std() > 0 else 1.0
    df["amount_log_z"] = (df["amount_log"] - amt_mean) / amt_std
    df["amount_div_merchant_mean"] = df[AMOUNT_COL] / (df["merchant_mean_amt"] + 1e-6)

    # amount bucket (quantiles)
    try:
        df["amount_bucket"] = pd.qcut(
            df["amount_log"], q=6, labels=False, duplicates="drop"
        ).astype(int)
    except Exception:
        df["amount_bucket"] = 0

    # missing flags
    df["missing_note"] = (df[TEXT_DESC_COL] == "").astype(int)
    df["missing_counterparty"] = (df[TEXT_NAME_COL] == "").astype(int)
    df["missing_upi"] = (df[UPI_COL] == "").astype(int)

    return df

# ----------------------------------
# LABEL ENCODING & FILTER RARE CLASSES
# ----------------------------------
def encode_and_filter(
    df: pd.DataFrame, min_samples: int = MIN_SAMPLES_PER_CLASS
) -> Tuple[pd.DataFrame, LabelEncoder, LabelEncoder]:
    macro_le = LabelEncoder()
    micro_le = LabelEncoder()

    df["macro_id"] = macro_le.fit_transform(df[MACRO_COL])
    df["micro_id"] = micro_le.fit_transform(df[MICRO_COL])

    macro_counts = df["macro_id"].value_counts()
    micro_counts = df["micro_id"].value_counts()

    mask_keep = df["macro_id"].isin(
        macro_counts[macro_counts >= min_samples].index
    ) & df["micro_id"].isin(micro_counts[micro_counts >= min_samples].index)

    df = df[mask_keep].reset_index(drop=True)

    # recompute encoders for contiguous indices
    macro_le = LabelEncoder()
    df["macro_id"] = macro_le.fit_transform(df[MACRO_COL])

    micro_le = LabelEncoder()
    df["micro_id"] = micro_le.fit_transform(df[MICRO_COL])

    return df, macro_le, micro_le

# ----------------------------------
# DATASET & DATALOADERS
# ----------------------------------
class TransactionDataset(Dataset):
    def __init__(
        self, df: pd.DataFrame, tokenizer: AutoTokenizer, max_len: int = MAX_LEN
    ):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_len = max_len

        # numerical features
        self.num_cols = [
            "amount_log_z",
            "amount_div_merchant_mean",
            "merchant_count",
            "days_since_prev_merchant",
            "is_recurring",
            "hour_sin",
            "hour_cos",
            "dow_sin",
            "dow_cos",
            "month_sin",
            "month_cos",
            "is_month_start",
            "is_month_end",
            "is_weekend",
            "missing_note",
            "missing_counterparty",
            "missing_upi",
        ]
        for c in self.num_cols:
            if c not in self.df.columns:
                self.df[c] = 0.0

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        text_parts = []
        if row.get(TEXT_NAME_COL, ""):
            text_parts.append(row[TEXT_NAME_COL])
        if row.get(TEXT_DESC_COL, ""):
            text_parts.append(row[TEXT_DESC_COL])
        if row.get(UPI_COL, ""):
            text_parts.append(row[UPI_COL])

        text = " [SEP] ".join(text_parts).strip()

        encoded = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.max_len,
            return_tensors="pt",
        )
        input_ids = encoded["input_ids"].squeeze(0)
        attention_mask = encoded["attention_mask"].squeeze(0)

        nums = torch.tensor(
            [float(row[c]) for c in self.num_cols], dtype=torch.float32
        )

        sample = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "numerical": nums,
            "macro_id": torch.tensor(int(row["macro_id"]), dtype=torch.long),
            "micro_id": torch.tensor(int(row["micro_id"]), dtype=torch.long),
        }
        return sample


def create_dataloaders(
    df: pd.DataFrame,
    tokenizer: AutoTokenizer,
    batch_size: int = BATCH_SIZE,
    val_ratio: float = 0.1,
    test_ratio: float = 0.1,
    seed: int = SEED,
):
    # stratify macro so splits are balanced
    stratify_col = df["macro_id"] if df["macro_id"].nunique() > 1 else None

    train_val_df, test_df = train_test_split(
        df,
        test_size=test_ratio,
        random_state=seed,
        stratify=stratify_col,
    )

    stratify_tv = (
        train_val_df["macro_id"] if train_val_df["macro_id"].nunique() > 1 else None
    )
    val_ratio_adj = val_ratio / (1.0 - test_ratio)
    train_df, val_df = train_test_split(
        train_val_df,
        test_size=val_ratio_adj,
        random_state=seed,
        stratify=stratify_tv,
    )

    train_ds = TransactionDataset(train_df, tokenizer)
    val_ds = TransactionDataset(val_df, tokenizer)
    test_ds = TransactionDataset(test_df, tokenizer)

    train_loader = DataLoader(
        train_ds, batch_size=batch_size, shuffle=True, drop_last=False
    )
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader, train_df, val_df, test_df

# ----------------------------------
# MODEL
# ----------------------------------
class DragoNetModel(nn.Module):
    def __init__(
        self,
        transformer_name: str,
        num_numerical: int,
        num_macro: int,
        num_micro: int,
        hidden_dim: int = 128,
        dropout: float = DROPOUT,
    ):
        super().__init__()
        self.transformer = AutoModel.from_pretrained(transformer_name)
        hidden_size = self.transformer.config.hidden_size

        self.numerical_proj = nn.Sequential(
            nn.Linear(num_numerical, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
        )

        self.fc = nn.Sequential(
            nn.Linear(hidden_size + hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
        )

        self.macro_head = nn.Linear(hidden_dim, num_macro)
        self.micro_head = nn.Linear(hidden_dim, num_micro)

    def forward(self, input_ids, attention_mask, numerical):
        outputs = self.transformer(
            input_ids=input_ids, attention_mask=attention_mask, return_dict=True
        )
        last_hidden = outputs.last_hidden_state  # (B, T, H)

        mask = attention_mask.unsqueeze(-1).float()
        summed = (last_hidden * mask).sum(1)
        denom = mask.sum(1).clamp(min=1e-9)
        pooled = summed / denom

        num_proj = self.numerical_proj(numerical)
        fused = torch.cat([pooled, num_proj], dim=1)
        hidden = self.fc(fused)

        macro_logits = self.macro_head(hidden)
        micro_logits = self.micro_head(hidden)

        return macro_logits, micro_logits

# ----------------------------------
# TRAIN / EVAL
# ----------------------------------
def train_epoch(
    model,
    dataloader,
    optimizer,
    scheduler,
    criterion,
    device,
):
    model.train()
    total_loss = 0.0

    for batch in tqdm(dataloader, desc="train", leave=False):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        numerical = batch["numerical"].to(device)
        macro_id = batch["macro_id"].to(device)
        micro_id = batch["micro_id"].to(device)

        optimizer.zero_grad()

        macro_logits, micro_logits = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            numerical=numerical,
        )

        loss_macro = criterion(macro_logits, macro_id)
        loss_micro = criterion(micro_logits, micro_id)
        loss = loss_macro + loss_micro

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
        optimizer.step()
        if scheduler is not None:
            scheduler.step()

        total_loss += loss.item() * input_ids.size(0)

    avg_loss = total_loss / max(1, len(dataloader.dataset))
    return avg_loss


def eval_model(
    model,
    dataloader,
    device,
    return_preds: bool = False,
):
    model.eval()
    preds_macro = []
    preds_micro = []
    trues_macro = []
    trues_micro = []
    total_loss = 0.0
    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="eval", leave=False):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            numerical = batch["numerical"].to(device)
            macro_id = batch["macro_id"].to(device)
            micro_id = batch["micro_id"].to(device)

            macro_logits, micro_logits = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                numerical=numerical,
            )

            loss_macro = criterion(macro_logits, macro_id)
            loss_micro = criterion(micro_logits, micro_id)
            loss = loss_macro + loss_micro
            total_loss += loss.item() * input_ids.size(0)

            preds_macro.extend(torch.argmax(macro_logits, dim=1).cpu().tolist())
            preds_micro.extend(torch.argmax(micro_logits, dim=1).cpu().tolist())
            trues_macro.extend(macro_id.cpu().tolist())
            trues_micro.extend(micro_id.cpu().tolist())

    avg_loss = total_loss / max(1, len(dataloader.dataset))

    macro_f1 = (
        f1_score(trues_macro, preds_macro, average="weighted")
        if len(set(trues_macro)) > 1
        else 0.0
    )
    micro_f1 = (
        f1_score(trues_micro, preds_micro, average="weighted")
        if len(set(trues_micro)) > 1
        else 0.0
    )
    macro_acc = (
        accuracy_score(trues_macro, preds_macro) if len(trues_macro) > 0 else 0.0
    )
    micro_acc = (
        accuracy_score(trues_micro, preds_micro) if len(trues_micro) > 0 else 0.0
    )

    stats = {
        "loss": avg_loss,
        "macro_f1": macro_f1,
        "micro_f1": micro_f1,
        "macro_acc": macro_acc,
        "micro_acc": micro_acc,
    }

    if return_preds:
        return stats, (trues_macro, preds_macro, trues_micro, preds_micro)
    else:
        return stats

# ----------------------------------
# REPORTS: classification report + confusion matrices
# ----------------------------------
def save_classification_report_and_confusion(
    trues: List[int],
    preds: List[int],
    classes: List[str],
    prefix: str,
):
    """
    Save:
    - classification report (JSON)
    - confusion matrix (CSV)
    """
    # classification report
    report = classification_report(
        trues,
        preds,
        labels=list(range(len(classes))),
        target_names=classes,
        output_dict=True,
        zero_division=0,
    )

    import json

    report_path = os.path.join(REPORTS_DIR, f"{prefix}_classification_report.json")
    with open(report_path, "w") as f:
        json.dump(report, f, indent=2)

    # confusion matrix
    cm = confusion_matrix(
        trues,
        preds,
        labels=list(range(len(classes))),
    )
    cm_df = pd.DataFrame(cm, index=classes, columns=classes)
    cm_path = os.path.join(REPORTS_DIR, f"{prefix}_confusion_matrix.csv")
    cm_df.to_csv(cm_path)

    print(f"[REPORT] classification report saved to {report_path}")
    print(f"[REPORT] confusion matrix saved to {cm_path}")

# ----------------------------------
# SAVE / LOAD CHECKPOINT
# ----------------------------------
def save_checkpoint(model, tokenizer, macro_le, micro_le, path: str):
    state = {
        "model_state": model.state_dict(),
        "tokenizer": tokenizer.name_or_path
        if hasattr(tokenizer, "name_or_path")
        else None,
        "macro_classes": list(macro_le.classes_),
        "micro_classes": list(micro_le.classes_),
    }
    torch.save(state, path)


def load_checkpoint(path: str, device: str = DEVICE):
    state = torch.load(path, map_location=device)
    return state

# ----------------------------------
# INFERENCE: single transaction helper
# ----------------------------------
def predict_transaction(
    model,
    tokenizer,
    macro_le,
    micro_le,
    counterparty: str,
    note: str,
    upi_id: str,
    amount: float,
    dt: str,
):
    """
    Build features for a single transaction and return predicted macro/micro labels.
    Uses the same feature engineering as the training pipeline.
    """
    # text
    counterparty_clean = clean_text(counterparty)
    note_clean = clean_text(note)
    upi_clean = clean_text(upi_id)
    text = " [SEP] ".join(
        [x for x in [counterparty_clean, note_clean, upi_clean] if x]
    ).strip()

    # minimal single-row feature frame
    temp = pd.DataFrame(
        [
            {
                TEXT_NAME_COL: counterparty_clean,
                TEXT_DESC_COL: note_clean,
                UPI_COL: upi_clean,
                AMOUNT_COL: amount,
                "dt": pd.to_datetime(dt, errors="coerce"),
            }
        ]
    )
    temp = engineer_features(temp)

    num_cols = [
        "amount_log_z",
        "amount_div_merchant_mean",
        "merchant_count",
        "days_since_prev_merchant",
        "is_recurring",
        "hour_sin",
        "hour_cos",
        "dow_sin",
        "dow_cos",
        "month_sin",
        "month_cos",
        "is_month_start",
        "is_month_end",
        "is_weekend",
        "missing_note",
        "missing_counterparty",
        "missing_upi",
    ]
    numerical = [float(temp.iloc[0][c]) for c in num_cols]

    encoded = tokenizer(
        text,
        padding="max_length",
        truncation=True,
        max_length=MAX_LEN,
        return_tensors="pt",
    )
    input_ids = encoded["input_ids"].to(DEVICE)
    attention_mask = encoded["attention_mask"].to(DEVICE)
    numerical_tensor = (
        torch.tensor(numerical, dtype=torch.float32).unsqueeze(0).to(DEVICE)
    )

    model.eval()
    with torch.no_grad():
        macro_logits, micro_logits = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            numerical=numerical_tensor,
        )
        macro_pred = int(torch.argmax(macro_logits, dim=1).cpu().item())
        micro_pred = int(torch.argmax(micro_logits, dim=1).cpu().item())

    macro_label = macro_le.inverse_transform([macro_pred])[0]
    micro_label = micro_le.inverse_transform([micro_pred])[0]

    return macro_label, micro_label

# ----------------------------------
# MAIN PIPELINE
# ----------------------------------
def main(csv_path: str = CSV_PATH):
    print("Loading & preprocessing...")
    df_raw = load_and_preprocess(csv_path)
    print(f"Raw rows: {len(df_raw)}")

    print("Engineering features...")
    df = engineer_features(df_raw)

    print("Encoding labels and filtering rare classes...")
    df, macro_le, micro_le = encode_and_filter(df, MIN_SAMPLES_PER_CLASS)
    print(f"Rows after filtering: {len(df)}")
    if len(df) == 0:
        raise RuntimeError(
            "No rows left after filtering. Lower MIN_SAMPLES_PER_CLASS or check labels."
        )

    print("Preparing tokenizer and dataloaders...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    (
        train_loader,
        val_loader,
        test_loader,
        train_df,
        val_df,
        test_df,
    ) = create_dataloaders(df, tokenizer, batch_size=BATCH_SIZE)

    num_numerical = len(TransactionDataset(df, tokenizer).num_cols)
    num_macro = len(macro_le.classes_)
    num_micro = len(micro_le.classes_)

    print(
        f"Classes - macro: {num_macro}, micro: {num_micro}; numerical features: {num_numerical}"
    )

    print("Building model...")
    model = DragoNetModel(
        MODEL_NAME,
        num_numerical=num_numerical,
        num_macro=num_macro,
        num_micro=num_micro,
        hidden_dim=128,
        dropout=DROPOUT,
    )
    model.to(DEVICE)

    optimizer = torch.optim.AdamW(
        model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY
    )
    total_steps = len(train_loader) * EPOCHS
    warmup_steps = int(total_steps * WARMUP_RATIO)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=max(1, total_steps),
    )
    criterion = nn.CrossEntropyLoss()

    best_val = -1.0
    best_epoch = -1
    patience_counter = 0

    for epoch in range(1, EPOCHS + 1):
        print(f"\nEpoch {epoch}/{EPOCHS}")
        train_loss = train_epoch(
            model,
            train_loader,
            optimizer,
            scheduler,
            criterion,
            DEVICE,
        )
        val_stats = eval_model(model, val_loader, DEVICE)
        val_metric = val_stats["macro_f1"] + val_stats["micro_f1"]
        print(
            f"Train Loss: {train_loss:.4f} | "
            f"Val Loss: {val_stats['loss']:.4f} | "
            f"Val Macro-F1: {val_stats['macro_f1']:.4f} | "
            f"Val Micro-F1: {val_stats['micro_f1']:.4f}"
        )

        if val_metric > best_val:
            best_val = val_metric
            best_epoch = epoch
            patience_counter = 0
            print("New best model — saving checkpoint.")
            save_checkpoint(model, tokenizer, macro_le, micro_le, CHECKPOINT_PATH)
        else:
            patience_counter += 1
            print(f"No improvement. Patience {patience_counter}/{PATIENCE}")
            if patience_counter >= PATIENCE:
                print("Early stopping triggered.")
                break

    print(
        f"Training complete. Best epoch: {best_epoch} | Best combined F1: {best_val:.4f}"
    )

    # Load best checkpoint & evaluate on TEST set with full reports
    if os.path.exists(CHECKPOINT_PATH):
        print("\nEvaluating best checkpoint on TEST set...")
        state = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
        try:
            model.load_state_dict(state["model_state"])
        except Exception:
            print(
                "Warning: checkpoint state could not be loaded into model. Using current model weights."
            )

        test_stats, (trues_macro, preds_macro, trues_micro, preds_micro) = eval_model(
            model, test_loader, DEVICE, return_preds=True
        )
        print(
            f"Test Loss: {test_stats['loss']:.4f} | "
            f"Test Macro-F1: {test_stats['macro_f1']:.4f} | "
            f"Test Micro-F1: {test_stats['micro_f1']:.4f}"
        )

        # save classification report + confusion matrices
        macro_classes = list(macro_le.classes_)
        micro_classes = list(micro_le.classes_)

        save_classification_report_and_confusion(
            trues_macro,
            preds_macro,
            macro_classes,
            prefix="macro",
        )
        save_classification_report_and_confusion(
            trues_micro,
            preds_micro,
            micro_classes,
            prefix="micro",
        )

    else:
        print("No checkpoint found; skipping test evaluation & reports.")

    return model, tokenizer, macro_le, micro_le, df, test_loader


if __name__ == "__main__":
    # Run full pipeline
    model, tokenizer, macro_le, micro_le, df, test_loader = main(CSV_PATH)

    # Example single prediction
    example_counterparty = "Amazon"
    example_note = "Order payment successful"
    example_upi = "amazon@apl"
    example_amount = 1299.0
    example_dt = "2024-02-01 14:30:00"

    macro_label, micro_label = predict_transaction(
        model,
        tokenizer,
        macro_le,
        micro_le,
        counterparty=example_counterparty,
        note=example_note,
        upi_id=example_upi,
        amount=example_amount,
        dt=example_dt,
    )
    print("Example prediction -> Macro:", macro_label, "| Micro:", micro_label)


Loading & preprocessing...


  df[DATE_COL] = pd.to_datetime(df[DATE_COL], errors="coerce", dayfirst=True)


Raw rows: 100000
Engineering features...
Encoding labels and filtering rare classes...
Rows after filtering: 100000
Preparing tokenizer and dataloaders...
Classes - macro: 19, micro: 88; numerical features: 17
Building model...

Epoch 1/5


                                                           

Train Loss: 4.7774 | Val Loss: 0.5265 | Val Macro-F1: 0.9384 | Val Micro-F1: 0.9201
New best model — saving checkpoint.

Epoch 2/5


                                                           

Train Loss: 0.7991 | Val Loss: 0.2558 | Val Macro-F1: 0.9601 | Val Micro-F1: 0.9634
New best model — saving checkpoint.

Epoch 3/5


                                                           

Train Loss: 0.3876 | Val Loss: 0.1620 | Val Macro-F1: 0.9662 | Val Micro-F1: 0.9813
New best model — saving checkpoint.

Epoch 4/5


                                                            

Train Loss: 0.2549 | Val Loss: 0.1378 | Val Macro-F1: 0.9665 | Val Micro-F1: 0.9867
New best model — saving checkpoint.

Epoch 5/5


                                                            

Train Loss: 0.2173 | Val Loss: 0.1412 | Val Macro-F1: 0.9672 | Val Micro-F1: 0.9873
New best model — saving checkpoint.
Training complete. Best epoch: 5 | Best combined F1: 1.9545

Evaluating best checkpoint on TEST set...


  state = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
                                                          

Test Loss: 0.1269 | Test Macro-F1: 0.9683 | Test Micro-F1: 0.9889
[REPORT] classification report saved to reports/macro_classification_report.json
[REPORT] confusion matrix saved to reports/macro_confusion_matrix.csv
[REPORT] classification report saved to reports/micro_classification_report.json
[REPORT] confusion matrix saved to reports/micro_confusion_matrix.csv
Example prediction -> Macro: Shopping | Micro: E-Commerce Marketplaces




In [15]:
example_counterparty = "Nursing Home"
example_note = "Order payment successful"
example_upi = ""
example_amount = 1299.0
example_dt = "2024-02-01 14:30:00"

macro_label, micro_label = predict_transaction(
    model, tokenizer, macro_le, micro_le,
    counterparty=example_counterparty,
    note=example_note,
    upi_id=example_upi,
    amount=example_amount,
    dt=example_dt
)

In [16]:
macro_label

'Medical'