# 🔁 Enhanced Unified Joint Model (BART) — High-Quality Contextual Reasoning

This notebook trains a BART model with **contextual, feature-based reasoning** instead of template responses.

**Key Improvements:**
- Analyzes message features (keywords, urgency, entities, patterns)
- Generates detailed, context-aware explanations
- Higher MAX_TARGET_LENGTH for richer reasoning
- Improved instruction prompting
- Better generation parameters (beam search, length penalty)

Model: `facebook/bart-base` with joint classification + contextual generation.
Optimized for Kaggle; will auto-detect `/kaggle/input` and save artifacts to `/kaggle/working`.

In [None]:
# 1) Detect Kaggle Environment and Set Paths
import os, json, sys
from pathlib import Path

IS_KAGGLE = Path('/kaggle').exists()
INPUT_DIR = Path('/kaggle/input') if IS_KAGGLE else Path('..')
WORK_DIR = Path('/kaggle/working') if IS_KAGGLE else Path('.')
WORK_DIR.mkdir(parents=True, exist_ok=True)

print(f'Running on Kaggle: {IS_KAGGLE}')
print(f'INPUT_DIR: {INPUT_DIR}')
print(f'WORK_DIR: {WORK_DIR}')

if IS_KAGGLE:
    print('\nAvailable datasets under /kaggle/input:')
    for p in sorted(INPUT_DIR.glob('*')):
        print(' -', p)

In [None]:
# 1.5) Ensure compatible libraries
import subprocess

def pip_install(pkgs):
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q"] + pkgs)

need_install = False
try:
    import transformers as _tf_ver
    import huggingface_hub as _hf_ver
    from packaging import version as _v
    if _v.parse(_tf_ver.__version__) < _v.parse("4.43.0") or _v.parse(_hf_ver.__version__) < _v.parse("0.23.4"):
        need_install = True
except Exception:
    need_install = True

if need_install:
    print("Installing compatible versions of transformers/huggingface_hub...")
    pip_install([
        "transformers==4.44.2",
        "huggingface_hub==0.24.6",
        "tokenizers==0.19.1",
        "safetensors",
        "sentencepiece"
    ])
    import importlib
    importlib.invalidate_caches()
    print("Installed.")

In [None]:
# 2) Imports and Device
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

import torch
from torch import nn
import transformers as tf
from transformers import (
    AutoTokenizer,
    AutoConfig,
    BartForConditionalGeneration,
    TrainingArguments,
    Trainer,
    DataCollatorForSeq2Seq,
    set_seed
)

print('PyTorch:', torch.__version__)
print('Transformers:', tf.__version__)
device = 'cuda' if torch.cuda.is_available() else ('mps' if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available() else 'cpu')
print('Device:', device)

In [None]:
# 3) Enhanced Config — Better reasoning generation
SEED = 42
set_seed(SEED)

# Dataset path
CSV_PATH = INPUT_DIR / 'fraud-data' / 'final_fraud_detection_dataset.csv'
if not CSV_PATH.exists():
    matches = list(INPUT_DIR.glob('**/final_fraud_detection_dataset.csv'))
    if matches:
        CSV_PATH = matches[0]
print('CSV_PATH:', CSV_PATH)

TEXT_COL = 'text'
LABEL_COL = 'detailed_category'

DEFAULT_LABELS = [
    'job_scam',
    'legitimate',
    'phishing',
    'popup_scam',
    'refund_scam',
    'reward_scam',
    'sms_spam',
    'ssn_scam',
    'tech_support_scam'
]
LABEL2ID = {l: i for i, l in enumerate(DEFAULT_LABELS)}
ID2LABEL = {i: l for l, i in LABEL2ID.items()}

BASE_MODEL = 'facebook/bart-base'
OUTPUT_DIR = WORK_DIR / 'unified-bart-joint-enhanced'

# Enhanced parameters for better reasoning
MAX_SOURCE_LENGTH = 256
MAX_TARGET_LENGTH = 128  # Increased for richer reasoning
TRAIN_SIZE = 0.9
NUM_EPOCHS = 4  # More epochs for better learning
LR = 2e-5  # Slightly lower for stability
BATCH_TRAIN = 8
BATCH_EVAL = 8
GRAD_ACCUM = 2  # Effective batch size = 16

# Joint-loss weights - prioritize generation quality
CLS_LOSS_WEIGHT = 0.8    # Classification still important
GEN_LOSS_WEIGHT = 1.2    # Emphasize quality reasoning

# Early stop threshold
EARLY_STOP_LOSS = 0.02
MIN_EPOCHS = 2

# Enhanced instruction prompt
INSTRUCTION_PREFIX = (
    'Analyze this message and classify it into one of these categories: '
    'job_scam, legitimate, phishing, popup_scam, refund_scam, reward_scam, sms_spam, ssn_scam, tech_support_scam. '
    'Then explain your reasoning by identifying specific suspicious elements, patterns, or indicators in the message.\n\n'
    'Message: '
)

print(f'\nModel: {BASE_MODEL}')
print(f'Output: {OUTPUT_DIR}')
print(f'Max target length: {MAX_TARGET_LENGTH}')
print(f'Loss weights - CLS: {CLS_LOSS_WEIGHT}, GEN: {GEN_LOSS_WEIGHT}')

In [None]:
# 4) Load and Split Data
assert CSV_PATH.exists(), f'CSV not found at {CSV_PATH}. Attach your dataset to the notebook.'
df = pd.read_csv(CSV_PATH)
print('Loaded:', df.shape)

# Keep only known labels
df = df[df[LABEL_COL].isin(DEFAULT_LABELS)].copy()
df = df[[TEXT_COL, LABEL_COL]]
print('After filtering:', df.shape)
print('Label distribution:')
print(df[LABEL_COL].value_counts())

train_df, val_df = train_test_split(df, test_size=1-TRAIN_SIZE, random_state=SEED, stratify=df[LABEL_COL])
print('\nTrain:', train_df.shape, 'Val:', val_df.shape)

In [None]:
# 5) Import Reasoning Generator
import re
from typing import Dict, List

# Inline reasoning generator (same as reasoning_generator.py)
class ReasoningGenerator:
    """Generates contextual reasoning for fraud detection."""
    
    FRAUD_INDICATORS = {
        'phishing': {
            'keywords': ['verify', 'account', 'suspended', 'confirm', 'click here', 'update', 
                        'security', 'alert', 'unauthorized', 'immediately', 'login'],
            'patterns': [r'click\s+(here|link)', r'verify\s+your\s+(account|identity|information)',
                        r'account\s+(suspended|locked|compromised)', r'urgent\s+action'],
            'entities': ['bank', 'paypal', 'amazon', 'netflix', 'apple', 'microsoft'],
            'urgency': ['immediately', 'urgent', 'now', 'expire', 'within 24', 'act now']
        },
        'job_scam': {
            'keywords': ['work from home', 'easy money', 'no experience', 'apply now', 
                        'guaranteed income', 'flexible hours', 'earn', 'opportunity'],
            'patterns': [r'\$\d+\s*(?:per|\/)\s*(?:day|week|month|hour)', 
                        r'work\s+from\s+home', r'no\s+experience\s+(?:required|needed)',
                        r'guaranteed\s+(?:income|salary|pay)'],
            'entities': ['remote', 'telecommute', 'online'],
            'urgency': ['apply now', 'limited positions', 'act fast', 'immediate hire']
        },
        'reward_scam': {
            'keywords': ['congratulations', 'winner', 'prize', 'reward', 'gift card', 
                        'won', 'claim', 'free', 'selected'],
            'patterns': [r'(?:won|winner|prize|reward|gift\s+card).*\$\d+',
                        r'congratulations.*(?:won|winner|selected)',
                        r'claim\s+your\s+(?:prize|reward|gift)'],
            'entities': ['walmart', 'target', 'amazon', 'visa', 'mastercard'],
            'urgency': ['claim now', 'expires', 'limited time', 'act now']
        },
        'refund_scam': {
            'keywords': ['refund', 'overpaid', 'owed', 'payment', 'credit', 'return'],
            'patterns': [r'(?:refund|owed).*\$\d+', r'overpaid'],
            'entities': ['irs', 'tax', 'government'],
            'urgency': ['claim now', 'expires']
        },
        'tech_support_scam': {
            'keywords': ['virus', 'infected', 'malware', 'computer', 'tech support', 'call'],
            'patterns': [r'(?:virus|malware)\s+detected', r'computer\s+(?:infected|compromised)'],
            'entities': ['microsoft', 'windows', 'apple'],
            'urgency': ['call now', 'immediately']
        },
        'popup_scam': {
            'keywords': ['click', 'download', 'install', 'update required'],
            'patterns': [r'click\s+(?:here|now)', r'download\s+now'],
            'entities': [],
            'urgency': ['now', 'immediately']
        },
        'ssn_scam': {
            'keywords': ['social security', 'ssn', 'suspended', 'verify'],
            'patterns': [r'social\s+security\s+(?:number|suspended)'],
            'entities': ['social security', 'ssa'],
            'urgency': ['immediately', 'urgent']
        },
        'sms_spam': {
            'keywords': ['text', 'reply', 'stop', 'offer'],
            'patterns': [r'text\s+\w+\s+to\s+\d+'],
            'entities': [],
            'urgency': ['now', 'limited time']
        },
        'legitimate': {
            'keywords': ['confirmation', 'receipt', 'order', 'tracking'],
            'patterns': [r'tracking\s+(?:number|#):\s*\w+'],
            'entities': [],
            'urgency': []
        }
    }
    
    def __init__(self):
        self.compiled_patterns = {}
        for category, indicators in self.FRAUD_INDICATORS.items():
            self.compiled_patterns[category] = [
                re.compile(pattern, re.IGNORECASE) 
                for pattern in indicators['patterns']
            ]
    
    def analyze_message(self, text: str, category: str) -> Dict:
        """Analyze message and extract fraud indicators."""
        text_lower = text.lower()
        if category not in self.FRAUD_INDICATORS:
            category = 'legitimate'
        
        indicators = self.FRAUD_INDICATORS[category]
        found_keywords = [kw for kw in indicators['keywords'] if kw in text_lower]
        found_patterns = []
        for pattern in self.compiled_patterns[category]:
            matches = pattern.findall(text)
            found_patterns.extend(matches)
        found_urgency = [u for u in indicators['urgency'] if u in text_lower]
        found_entities = [e for e in indicators['entities'] if e in text_lower]
        
        money_pattern = re.compile(r'\$\s*\d+(?:,\d{3})*(?:\.\d{2})?')
        money_amounts = money_pattern.findall(text)
        phone_pattern = re.compile(r'1?\s*[-.\s]?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}')
        phone_numbers = phone_pattern.findall(text)
        url_pattern = re.compile(r'https?://\S+|www\.\S+|click\s+here', re.IGNORECASE)
        urls = url_pattern.findall(text)
        
        return {
            'keywords': found_keywords[:5],
            'patterns': found_patterns[:3],
            'urgency': found_urgency[:3],
            'entities': found_entities[:3],
            'money': money_amounts[:2],
            'phones': phone_numbers[:2],
            'urls': bool(urls)
        }
    
    def generate_reason(self, text: str, category: str) -> str:
        """Generate contextual reasoning."""
        analysis = self.analyze_message(text, category)
        reasons = []
        
        if category == 'phishing':
            reasons.append("This appears to be a phishing attempt because")
            if analysis['entities']:
                reasons.append(f"it impersonates {', '.join(analysis['entities'])}")
            if analysis['urgency']:
                reasons.append(f"uses urgent language like '{analysis['urgency'][0]}'")
            if analysis['keywords']:
                suspicious = [kw for kw in analysis['keywords'] if kw in ['verify', 'suspended', 'unauthorized']]
                if suspicious:
                    reasons.append(f"requests to {', '.join(suspicious[:2])}")
            if analysis['urls']:
                reasons.append("includes suspicious links")
            reasons.append("to steal personal information.")
        
        elif category == 'job_scam':
            reasons.append("This is likely a job scam as")
            if analysis['money']:
                reasons.append(f"it promises unrealistic income ({analysis['money'][0]})")
            if 'no experience' in ' '.join(analysis['keywords']):
                reasons.append("requires no qualifications")
            if analysis['urgency']:
                reasons.append(f"creates urgency ('{analysis['urgency'][0]}')")
            reasons.append("which are common job scam tactics.")
        
        elif category == 'reward_scam':
            reasons.append("This is a reward scam because")
            if analysis['money']:
                reasons.append(f"it claims you've won {analysis['money'][0]}")
            if any(kw in analysis['keywords'] for kw in ['winner', 'selected']):
                reasons.append("falsely claims you've been selected")
            if analysis['urgency']:
                reasons.append(f"pressures quick action")
            reasons.append("to collect personal information.")
        
        elif category == 'refund_scam':
            reasons.append("This appears to be a refund scam as")
            if analysis['money']:
                reasons.append(f"it claims you're owed {analysis['money'][0]}")
            if analysis['entities']:
                reasons.append(f"impersonates {', '.join(analysis['entities'])}")
            reasons.append("to trick you into sharing sensitive data.")
        
        elif category == 'tech_support_scam':
            reasons.append("This is a tech support scam because")
            if any(kw in analysis['keywords'] for kw in ['virus', 'infected', 'malware']):
                reasons.append("falsely claims device infection")
            if analysis['phones']:
                reasons.append(f"provides a number to call")
            if analysis['entities']:
                reasons.append(f"impersonates {', '.join(analysis['entities'])}")
            reasons.append("to sell fake tech support.")
        
        elif category == 'popup_scam':
            reasons.append("This is a popup scam as")
            if any(kw in analysis['keywords'] for kw in ['click', 'download']):
                reasons.append("prompts immediate action")
            if analysis['urgency']:
                reasons.append("uses urgent calls-to-action")
            reasons.append("to install malware.")
        
        elif category == 'ssn_scam':
            reasons.append("This is an SSN scam because")
            if 'social security' in text.lower():
                reasons.append("involves SSN threats")
            if any(kw in analysis['keywords'] for kw in ['suspended', 'verify']):
                reasons.append("falsely claims SSN issues")
            if analysis['urgency']:
                reasons.append("creates fear with urgency")
            reasons.append("which is a common SSN fraud tactic.")
        
        elif category == 'sms_spam':
            reasons.append("This is SMS spam as")
            if any(kw in analysis['keywords'] for kw in ['text', 'reply']):
                reasons.append("uses text marketing tactics")
            if analysis['urls']:
                reasons.append("includes promotional links")
            reasons.append("for unsolicited marketing.")
        
        elif category == 'legitimate':
            reasons.append("This appears legitimate because")
            if any(kw in analysis['keywords'] for kw in ['confirmation', 'tracking', 'receipt']):
                reasons.append("contains transactional language")
            if not analysis['urgency']:
                reasons.append("lacks urgent pressure")
            reasons.append("without suspicious fraud indicators.")
        
        # Fallback
        if len(reasons) <= 1:
            category_name = category.replace('_', ' ')
            reasons = [
                f"This message exhibits characteristics of {category_name}",
                "based on its language patterns and content structure."
            ]
        
        # Join with proper grammar
        if len(reasons) == 2:
            return f"{reasons[0]} {reasons[1]}"
        elif len(reasons) > 2:
            middle = ', '.join(reasons[1:-1])
            return f"{reasons[0]} {middle}, and {reasons[-1]}"
        return reasons[0]

# Initialize generator
reasoning_gen = ReasoningGenerator()
print('✓ Reasoning generator initialized')

In [None]:
# 6) Generate Enhanced Training Data
print('Generating contextual reasoning for training data...')

# Build input texts with enhanced instruction
train_X = [f'{INSTRUCTION_PREFIX}{txt}' for txt in train_df[TEXT_COL].tolist()]
val_X = [f'{INSTRUCTION_PREFIX}{txt}' for txt in val_df[TEXT_COL].tolist()]

# Extract labels
train_y = [LABEL2ID[lbl] for lbl in train_df[LABEL_COL].tolist()]
val_y = [LABEL2ID[lbl] for lbl in val_df[LABEL_COL].tolist()]

# Generate contextual reasoning
print('Generating training reasoning...')
train_reason = [
    reasoning_gen.generate_reason(txt, cat) 
    for txt, cat in zip(train_df[TEXT_COL].tolist(), train_df[LABEL_COL].tolist())
]

print('Generating validation reasoning...')
val_reason = [
    reasoning_gen.generate_reason(txt, cat)
    for txt, cat in zip(val_df[TEXT_COL].tolist(), val_df[LABEL_COL].tolist())
]

print(f'\n✓ Prepared {len(train_X)} training samples and {len(val_X)} validation samples')
print('\nSample input:')
print(train_X[0][:200])
print('\nSample reasoning:')
print(train_reason[0])
print('\nAnother sample:')
print(f'Category: {train_df[LABEL_COL].iloc[5]}')
print(f'Reasoning: {train_reason[5]}')

In [None]:
# 7) Tokenizer and Dataset
os.environ['TRANSFORMERS_OFFLINE'] = '0'

try:
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True, local_files_only=False)
except Exception as e:
    print(f'First attempt failed: {e}')
    print('Trying with trust_remote_code=True...')
    try:
        tokenizer = AutoTokenizer.from_pretrained(
            BASE_MODEL, 
            use_fast=True, 
            trust_remote_code=True,
            force_download=False
        )
    except Exception as e2:
        print(f'Second attempt failed: {e2}')
        from huggingface_hub import snapshot_download
        model_path = snapshot_download(repo_id=BASE_MODEL, 
                                       ignore_patterns=["*.msgpack", "*.h5", "*.bin"],
                                       cache_dir=str(WORK_DIR / 'hf_cache'))
        tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, local_files_only=True)

print('✓ Tokenizer loaded')
print(f'Vocab size: {tokenizer.vocab_size}')

class JointDataset:
    def __init__(self, texts, cls_labels, reasons):
        self.texts = texts
        self.cls_labels = cls_labels
        self.reasons = reasons
    def __len__(self):
        return len(self.texts)
    def __getitem__(self, idx):
        src = self.texts[idx]
        tgt = self.reasons[idx]
        enc = tokenizer(
            src,
            max_length=MAX_SOURCE_LENGTH,
            truncation=True,
            padding='max_length',
        )
        with tokenizer.as_target_tokenizer():
            dec = tokenizer(
                tgt,
                max_length=MAX_TARGET_LENGTH,
                truncation=True,
                padding='max_length',
            )
        enc['labels'] = dec['input_ids']
        enc['cls_labels'] = int(self.cls_labels[idx])
        return enc

train_ds = JointDataset(train_X, train_y, train_reason)
val_ds = JointDataset(val_X, val_y, val_reason)
print(f'✓ Dataset created: {len(train_ds)} train, {len(val_ds)} val samples')

In [None]:
# 8) Joint Model Wrapper
from transformers.modeling_outputs import Seq2SeqLMOutput

class BartForJointClassificationAndGeneration(BartForConditionalGeneration):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = int(getattr(config, 'num_labels', 2))
        self.dropout = nn.Dropout(getattr(config, 'classifier_dropout', 0.1))
        self.classifier = nn.Linear(config.d_model, self.num_labels)
        self.cls_loss_fn = nn.CrossEntropyLoss()
        self.cls_loss_weight = float(getattr(config, 'cls_loss_weight', 1.0))
        self.gen_loss_weight = float(getattr(config, 'gen_loss_weight', 1.0))
        self.config.num_labels = self.num_labels
        self.config.cls_loss_weight = self.cls_loss_weight
        self.config.gen_loss_weight = self.gen_loss_weight

    def pooled_encoder(self, encoder_hidden_states, attention_mask):
        mask = attention_mask.unsqueeze(-1).float()
        summed = torch.sum(encoder_hidden_states * mask, dim=1)
        counts = torch.clamp(mask.sum(dim=1), min=1e-9)
        return summed / counts

    def forward(self, input_ids=None, attention_mask=None, labels=None, cls_labels=None, **kwargs):
        kwargs.pop('return_dict', None)
        outputs = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            output_hidden_states=False,
            return_dict=True,
            **kwargs
        )
        enc_out = self.model.encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        pooled = self.pooled_encoder(enc_out.last_hidden_state, attention_mask)
        cls_logits = self.classifier(self.dropout(pooled))
        cls_loss = None
        if cls_labels is not None:
            cls_loss = self.cls_loss_fn(cls_logits, cls_labels)
        
        total_loss = None
        if outputs.loss is not None and cls_loss is not None:
            total_loss = self.gen_loss_weight * outputs.loss + self.cls_loss_weight * cls_loss
        elif outputs.loss is not None:
            total_loss = outputs.loss
        elif cls_loss is not None:
            total_loss = cls_loss
        
        return Seq2SeqLMOutput(
            loss=total_loss,
            logits=outputs.logits,
            past_key_values=outputs.past_key_values,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            cross_attentions=outputs.cross_attentions,
            encoder_last_hidden_state=enc_out.last_hidden_state,
            encoder_hidden_states=None,
            encoder_attentions=None
        ), cls_logits

cfg = AutoConfig.from_pretrained(BASE_MODEL)
cfg.num_labels = len(DEFAULT_LABELS)
cfg.cls_loss_weight = CLS_LOSS_WEIGHT
cfg.gen_loss_weight = GEN_LOSS_WEIGHT
joint_model = BartForJointClassificationAndGeneration.from_pretrained(BASE_MODEL, config=cfg)
joint_model.to(device)
print('✓ Joint model initialized')

In [None]:
# 9) Training Setup
import inspect
from transformers import TrainerCallback, TrainerState, TrainerControl

class JointCollator:
    def __init__(self, base_collator):
        self.base_collator = base_collator
    def __call__(self, features):
        cls = torch.tensor([f.pop('cls_labels') for f in features], dtype=torch.long)
        batch = self.base_collator(features)
        batch['cls_labels'] = cls
        return batch

class LossThresholdEarlyStop(TrainerCallback):
    def __init__(self, threshold: float, min_epochs: int = 0):
        self.threshold = float(threshold)
        self.min_epochs = int(min_epochs)
    def on_evaluate(self, args, state: TrainerState, control: TrainerControl, **kwargs):
        metrics = kwargs.get('metrics') or {}
        eval_loss = metrics.get('eval_loss')
        current_epoch = state.epoch or 0
        if eval_loss is not None and current_epoch >= self.min_epochs and eval_loss <= self.threshold:
            control.should_training_stop = True
            print(f'Early stopping: eval_loss={eval_loss:.4f} <= {self.threshold} at epoch {current_epoch}')
        return control

args_dict = {
    'output_dir': str(OUTPUT_DIR),
    'learning_rate': LR,
    'per_device_train_batch_size': BATCH_TRAIN,
    'per_device_eval_batch_size': BATCH_EVAL,
    'num_train_epochs': NUM_EPOCHS,
    'weight_decay': 0.01,
    'warmup_ratio': 0.1,
    'logging_steps': 50,
    'gradient_accumulation_steps': GRAD_ACCUM,
    'report_to': ['none'],
    'fp16': torch.cuda.is_available(),
    'bf16': False,
}

sig = inspect.signature(TrainingArguments.__init__)
if 'evaluation_strategy' in sig.parameters:
    args_dict['evaluation_strategy'] = 'epoch'
if 'save_strategy' in sig.parameters:
    args_dict['save_strategy'] = 'epoch'
if 'predict_with_generate' in sig.parameters:
    args_dict['predict_with_generate'] = True
if 'generation_max_length' in sig.parameters:
    args_dict['generation_max_length'] = MAX_TARGET_LENGTH

training_args = TrainingArguments(**args_dict)
callbacks = [LossThresholdEarlyStop(EARLY_STOP_LOSS, MIN_EPOCHS)]

_base_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=joint_model)
data_collator = JointCollator(_base_collator)

class JointTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        (outputs, cls_logits) = model(**inputs)
        loss = outputs.loss
        return (loss, {'lm': outputs, 'cls_logits': cls_logits}) if return_outputs else loss

trainer = JointTrainer(
    model=joint_model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=tokenizer,
    data_collator=data_collator,
    callbacks=callbacks,
)

print('✓ Trainer configured')

In [None]:
# 10) Train and Save
print('Starting training with enhanced contextual reasoning...')
trainer.train()
trainer.save_model(str(OUTPUT_DIR))
tokenizer.save_pretrained(str(OUTPUT_DIR))
print(f'\n✓ Model saved to: {OUTPUT_DIR}')

In [None]:
# 11) Enhanced Inference Test
joint_model.eval()
sample_text = val_df.iloc[0][TEXT_COL] if len(val_df) else train_df.iloc[0][TEXT_COL]
true_label = val_df.iloc[0][LABEL_COL] if len(val_df) else train_df.iloc[0][LABEL_COL]
src = f'{INSTRUCTION_PREFIX}{sample_text}'

inputs = tokenizer([src], return_tensors='pt', truncation=True, padding=True, max_length=MAX_SOURCE_LENGTH).to(device)

with torch.no_grad():
    # Classification
    enc_out = joint_model.model.encoder(
        input_ids=inputs['input_ids'],
        attention_mask=inputs['attention_mask'],
        return_dict=True
    )
    pooled = joint_model.pooled_encoder(enc_out.last_hidden_state, inputs['attention_mask'])
    cls_logits = joint_model.classifier(pooled)  # No dropout during inference
    pred_id = int(cls_logits.argmax(-1).cpu().item())
    pred_label = ID2LABEL[pred_id]
    
    # Enhanced Generation with better parameters
    gen_ids = joint_model.generate(
        input_ids=inputs['input_ids'],
        attention_mask=inputs['attention_mask'],
        max_new_tokens=MAX_TARGET_LENGTH,
        num_beams=5,  # Higher beam search
        length_penalty=1.0,  # Balance length
        no_repeat_ngram_size=3,  # Avoid repetition
        early_stopping=True,
        do_sample=False
    )
    reason = tokenizer.decode(gen_ids[0], skip_special_tokens=True)

print('\n' + '='*80)
print('ENHANCED INFERENCE TEST')
print('='*80)
print(f'\nInput snippet: {sample_text[:200]}...')
print(f'\nTrue label: {true_label}')
print(f'Predicted label: {pred_label}')
print(f'\nContextual Reasoning:')
print(reason)
print('='*80)

In [None]:
# 12) Test Multiple Examples
print('\n' + '='*80)
print('TESTING MULTIPLE EXAMPLES')
print('='*80 + '\n')

test_samples = val_df.sample(n=min(5, len(val_df)), random_state=42)

for idx, (_, row) in enumerate(test_samples.iterrows(), 1):
    text = row[TEXT_COL]
    true_label = row[LABEL_COL]
    src = f'{INSTRUCTION_PREFIX}{text}'
    inputs = tokenizer([src], return_tensors='pt', truncation=True, padding=True, max_length=MAX_SOURCE_LENGTH).to(device)
    
    with torch.no_grad():
        enc_out = joint_model.model.encoder(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            return_dict=True
        )
        pooled = joint_model.pooled_encoder(enc_out.last_hidden_state, inputs['attention_mask'])
        cls_logits = joint_model.classifier(pooled)
        pred_id = int(cls_logits.argmax(-1).cpu().item())
        pred_label = ID2LABEL[pred_id]
        
        gen_ids = joint_model.generate(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            max_new_tokens=MAX_TARGET_LENGTH,
            num_beams=5,
            length_penalty=1.0,
            no_repeat_ngram_size=3,
            early_stopping=True
        )
        reason = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
    
    correct = '✓' if pred_label == true_label else '✗'
    print(f'Example {idx} {correct}')
    print(f'Text: {text[:150]}...')
    print(f'True: {true_label} | Predicted: {pred_label}')
    print(f'Reasoning: {reason}')
    print('-'*80 + '\n')

In [None]:
# 13) Zip Artifacts (Optional)
import shutil
ZIP_PATH = WORK_DIR / 'unified-bart-joint-enhanced.zip'
if ZIP_PATH.exists():
    ZIP_PATH.unlink()
shutil.make_archive(str(ZIP_PATH.with_suffix('')), 'zip', str(OUTPUT_DIR))
print(f'✓ Zipped model to: {ZIP_PATH}')