# 🧪 Unified FLAN‑T5 Fraud Demo

This short notebook loads your fine-tuned FLAN‑T5 model (classification + reasoning) and runs a few sample predictions.

Model directory defaults to `models/unified-flan-t5-small`. Change it below if needed.

In [1]:
# Imports and device selection
import re
from typing import Dict, Tuple

import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

device = 'cuda' if torch.cuda.is_available() else ('mps' if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() else 'cpu')
print('Using device:', device)

Using device: cuda


In [2]:
# Configuration
MODEL_DIR = '/kaggle/input/flan-t5/transformers/default/1'  # change if you saved elsewhere

DEFAULT_LABELS = [
    'job_scam', 'legitimate', 'phishing', 'popup_scam', 'refund_scam',
    'reward_scam', 'sms_spam', 'ssn_scam', 'tech_support_scam'
]

INSTRUCTION_PREFIX = (
    'Classify the message into one of these categories and explain briefly:\n'
    'Categories: job_scam, legitimate, phishing, popup_scam, refund_scam, reward_scam, sms_spam, ssn_scam, tech_support_scam\n'
    'Message: '
)

In [3]:
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_DIR)
if device == 'cuda':
    model.to('cuda')
elif device == 'mps':
    model.to('mps')
else:
    model.to('cpu')
print('Loaded model from:', MODEL_DIR)

2025-09-19 13:49:32.987210: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1758289773.212729      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1758289773.276150      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Loaded model from: /kaggle/input/flan-t5/transformers/default/1


In [4]:
# Parsing utilities and generator
LABEL_PATTERN = re.compile(r'label\s*:\s*([a-zA-Z0-9_\- ]+)')
REASON_PATTERN = re.compile(r'reason\s*:\s*(.+)$')

def parse_output(text: str) -> Tuple[str, str]:
    t = ' '.join(text.strip().split())
    label, reason = None, None
    if '|' in t:
        parts = [p.strip() for p in t.split('|')]
        for p in parts:
            m1 = LABEL_PATTERN.search(p)
            m2 = REASON_PATTERN.search(p)
            if m1: label = m1.group(1).strip().lower()
            if m2: reason = m2.group(1).strip()
    else:
        m1 = LABEL_PATTERN.search(t)
        if m1: label = m1.group(1).strip().lower()
        m2 = REASON_PATTERN.search(t)
        if m2: reason = m2.group(1).strip()
    return (label or 'unknown'), (reason or '')

def generate(text: str, max_new_tokens: int = 64) -> Dict:
    inp = INSTRUCTION_PREFIX + text
    inputs = tokenizer([inp], return_tensors='pt', truncation=True, padding=True).to(model.device)
    model.eval()
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            return_dict_in_generate=True,
            output_scores=True,
        )
    seq = outputs.sequences[0]
    decoded = tokenizer.decode(seq, skip_special_tokens=True)
    avg_token_conf = None
    if outputs.scores:
        probs = [torch.softmax(s, dim=-1).max().item() for s in outputs.scores]
        avg_token_conf = sum(probs) / len(probs)
    label, reason = parse_output(decoded)
    if label not in DEFAULT_LABELS:
        lower_map = {l.lower(): l for l in DEFAULT_LABELS}
        label = lower_map.get(label.lower(), label)
    return {
        'raw_output': decoded,
        'predicted_label': label,
        'reason': reason,
        'avg_generation_confidence': avg_token_conf,
    }

In [5]:
# Quick samples
samples = [
    'We detected unusual activity on your account. Please login at http://phish.example to verify.',
    'Your interview is scheduled for tomorrow. Please see the attached calendar invite.',
    'Congratulations! You have won a $1000 gift card. Claim now.',
]
for i, s in enumerate(samples, 1):
    res = generate(s)
    print(f'Sample {i}')
    print('-'*60)
    print('Text  :', s)
    print('Label :', res['predicted_label'])
    print('Reason:', res['reason'])
    conf = res['avg_generation_confidence']
    if conf is not None:
        print(f'Confidence (token avg): {conf:.3f}')
    # print('Raw:', res['raw_output'])  # uncomment for debugging

Sample 1
------------------------------------------------------------
Text  : We detected unusual activity on your account. Please login at http://phish.example to verify.
Label : sms_spam
Reason: Unwanted promotional or bulk messaging characteristics.
Confidence (token avg): 1.000
Sample 2
------------------------------------------------------------
Text  : Your interview is scheduled for tomorrow. Please see the attached calendar invite.
Label : legitimate
Reason: Message content appears normal and lacks scam indicators.
Confidence (token avg): 1.000
Sample 3
------------------------------------------------------------
Text  : Congratulations! You have won a $1000 gift card. Claim now.
Label : sms_spam
Reason: Unwanted promotional or bulk messaging characteristics.
Confidence (token avg): 1.000


In [6]:
# Convenience helper
def predict(text: str):
    return generate(text)

# Try your own text
predict('Hello, your SSN has been compromised. Call us immediately to avoid charges.')

{'raw_output': 'label: sms_spam | reason: Unwanted promotional or bulk messaging characteristics.',
 'predicted_label': 'sms_spam',
 'reason': 'Unwanted promotional or bulk messaging characteristics.',
 'avg_generation_confidence': 0.9998764043504541}