# Envis Insight Engine v2 — Improved Training

Changes from v1 baseline:

| Improvement | Rationale |
|-------------|-----------|
| **Unfreeze last 3 FinBERT layers** | v1 only trained 6.1% of params — adapters alone can't capture task-specific signal |
| **Differential learning rate** | 1e-5 for BERT, 1e-4 for new heads — prevents catastrophic forgetting |
| **Unified distress head (4-class)** | Sub-indicators never co-occur and only appear when distress=1 → single classification: none/self_blame/avoidance/secrecy |
| **Log-transform timing delay** | Raw delay is heavily skewed (mean 142h, median 61h, std 153) → log1p reduces std to 1.19 |
| **Class weights for framing** | 57.6% supportive vs 2.3% urgent — model predicts majority class without weights |
| **Pos weights for binary tasks** | tension 30%, goal_risk 39% — rebalance gradients |
| **Huber loss for delay** | Robust to remaining outliers after log transform |
| **Cross-attention fusion** | Transaction↔text attention instead of simple gating |
| **Mixed precision (fp16)** | ~2x faster on A100, no quality loss |
| **Gradient accumulation** | Effective batch 32 from actual batch 16 |
| **Reduced MAX_TEXT=128, MAX_TRANS=20** | p95 word count is 86, p95 transaction count is 18 — less padding waste |

**Requirements:** Colab Pro with A100 GPU · `financial_data_8k.csv`

## 1. Setup

In [None]:
!nvidia-smi

In [None]:
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
!pip install -q transformers pandas scikit-learn tqdm tensorboard

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from transformers import AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    roc_auc_score, accuracy_score, f1_score,
    mean_absolute_error, confusion_matrix, precision_recall_curve,
    classification_report
)
from tqdm.auto import tqdm
import json, time, os, math
from datetime import datetime
from dataclasses import dataclass
from typing import Dict, Optional

print(f"PyTorch {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    try:
        print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    except AttributeError:
        print("(memory info not available in this PyTorch version)")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## 2. Load Data

In [None]:
from google.colab import files
uploaded = files.upload()  # Select financial_data_8k.csv

In [None]:
df = pd.read_csv('financial_data_8k.csv')
print(f"Records: {len(df):,}  |  Columns: {len(df.columns)}")

# ── Constants ──
FRAMING_CLASSES = ['supportive', 'direct', 'celebratory', 'gentle', 'urgent']
URGENCY_CLASSES = ['immediate', 'soon', 'can_wait']
DISTRESS_CLASSES = ['none', 'self_blame', 'avoidance', 'secrecy']
ROLE_TYPES      = ['partner_1', 'partner_2', 'child_1', 'child_2', 'child_3',
                   'parent_1', 'parent_2', 'head_of_household', 'grandparent', 'other']
AGE_BRACKETS    = ['18-25', '25-35', '35-45', '45-55', '55-65', '65+']
INCOME_BRACKETS = ['low', 'medium', 'high', 'variable', 'unknown']

# ── Encode labels ──
framing_to_idx = {f: i for i, f in enumerate(FRAMING_CLASSES)}
urgency_to_idx = {u: i for i, u in enumerate(URGENCY_CLASSES)}
df['framing_idx'] = df['framing'].map(framing_to_idx)
df['urgency_idx'] = df['timing_urgency'].map(urgency_to_idx)

# Unified distress class: sub-indicators never co-occur, only appear when distress=1
# 0=none, 1=self_blame, 2=avoidance, 3=secrecy
def encode_distress_class(row):
    if row['self_blame'] == 1: return 1
    if row['avoidance'] == 1:  return 2
    if row['secrecy'] == 1:    return 3
    return 0
df['distress_class'] = df.apply(encode_distress_class, axis=1)

# Log-transform timing delay
df['timing_delay_log'] = np.log1p(df['timing_delay_hours'])

# ── Split ──
train_df, temp_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df['distress'])
val_df, test_df   = train_test_split(temp_df, test_size=0.5, random_state=42, stratify=temp_df['distress'])
print(f"Train: {len(train_df):,}  Val: {len(val_df):,}  Test: {len(test_df):,}")
print(f"\nDistress classes: {df['distress_class'].value_counts().sort_index().to_dict()}")
print(f"Framing: {df['framing'].value_counts().to_dict()}")
print(f"Urgency: {df['timing_urgency'].value_counts().to_dict()}")

## 3. Class Weights

In [None]:
n = len(train_df)

def inv_freq_weights(series, n_classes):
    counts = series.value_counts().sort_index()
    return torch.tensor([n / (n_classes * counts[i]) for i in range(n_classes)],
                        dtype=torch.float, device=device)

framing_w  = inv_freq_weights(train_df['framing_idx'], 5)
urgency_w  = inv_freq_weights(train_df['urgency_idx'], 3)
distress_w = inv_freq_weights(train_df['distress_class'], 4)

print("Framing weights:  ", {FRAMING_CLASSES[i]: f"{framing_w[i]:.2f}" for i in range(5)})
print("Urgency weights:  ", {URGENCY_CLASSES[i]: f"{urgency_w[i]:.2f}" for i in range(3)})
print("Distress weights: ", {DISTRESS_CLASSES[i]: f"{distress_w[i]:.2f}" for i in range(4)})

def pos_weight(col):
    pos = train_df[col].sum()
    return torch.tensor([n / pos - 1], dtype=torch.float, device=device)

tension_pw   = pos_weight('tension')
goal_risk_pw = pos_weight('goal_risk')
print(f"\nTension pos_weight:   {tension_pw.item():.2f}")
print(f"Goal-risk pos_weight: {goal_risk_pw.item():.2f}")

## 4. Dataset

Reduced `MAX_TEXT=128` (p95 word count = 86) and `MAX_TRANS=20` (p95 = 18).
Less padding → faster per-batch and lower memory.

In [None]:
tokenizer = AutoTokenizer.from_pretrained('ProsusAI/finbert')

MAX_TEXT    = 128
MAX_TRANS   = 20
MAX_MEMBERS = 5
NODE_DIM    = len(ROLE_TYPES) + len(AGE_BRACKETS) + len(INCOME_BRACKETS)  # 21

def parse_json(val):
    return json.loads(val) if isinstance(val, str) else val

def pad_seq(seq, max_len, pad=0):
    seq = list(seq)[:max_len]
    return seq + [pad] * (max_len - len(seq))

def encode_nodes(roles, ages, incomes):
    feats = []
    for i in range(min(len(roles), MAX_MEMBERS)):
        role_oh = [1.0 if roles[i] == r else 0.0 for r in ROLE_TYPES]
        age_oh  = [1.0 if ages[i] == a else 0.0 for a in AGE_BRACKETS]
        inc_oh  = [1.0 if incomes[i] == b else 0.0 for b in INCOME_BRACKETS]
        feats.append(role_oh + age_oh + inc_oh)
    while len(feats) < MAX_MEMBERS:
        feats.append([0.0] * NODE_DIM)
    return feats[:MAX_MEMBERS]

def build_edges(edge_list, roles):
    role_idx = {r: i for i, r in enumerate(roles[:MAX_MEMBERS])}
    src, tgt = [], []
    for edge in edge_list:
        if len(edge) >= 2 and edge[0] in role_idx and edge[1] in role_idx:
            s, t = role_idx[edge[0]], role_idx[edge[1]]
            src.extend([s, t])
            tgt.extend([t, s])
    if not src:
        return torch.zeros(2, 0, dtype=torch.long)
    return torch.tensor([src, tgt], dtype=torch.long)


class EnvisDataset(Dataset):
    def __init__(self, dataframe, tokenizer):
        self.data = dataframe.reset_index(drop=True)
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        enc = self.tokenizer(
            row['text'], truncation=True, max_length=MAX_TEXT,
            padding='max_length', return_tensors='pt'
        )
        amounts    = parse_json(row['transaction_amounts'])
        categories = parse_json(row['transaction_categories'])
        merchants  = parse_json(row['transaction_merchants'])
        days       = parse_json(row['transaction_day_of_month'])
        months     = parse_json(row['transaction_month'])
        n_trans    = min(len(amounts), MAX_TRANS)
        mask       = [1.0] * n_trans + [0.0] * (MAX_TRANS - n_trans)
        roles   = parse_json(row['member_roles'])
        ages    = parse_json(row['member_age_brackets'])
        incomes = parse_json(row['member_income_brackets'])
        edges   = parse_json(row['edge_list'])

        return {
            'input_ids':       enc['input_ids'].squeeze(),
            'attention_mask':  enc['attention_mask'].squeeze(),
            'amounts':         torch.tensor(pad_seq(amounts, MAX_TRANS), dtype=torch.float),
            'categories':      torch.tensor(pad_seq(categories, MAX_TRANS), dtype=torch.long),
            'merchants':       torch.tensor(pad_seq(merchants, MAX_TRANS), dtype=torch.long),
            'days':            torch.tensor(pad_seq(days, MAX_TRANS, 1), dtype=torch.long),
            'months':          torch.tensor(pad_seq(months, MAX_TRANS, 1), dtype=torch.long),
            'trans_mask':      torch.tensor(mask[:MAX_TRANS], dtype=torch.float),
            'node_features':   torch.tensor(encode_nodes(roles, ages, incomes), dtype=torch.float),
            'edge_index':      build_edges(edges, roles),
            'n_members':       torch.tensor(min(len(roles), MAX_MEMBERS), dtype=torch.long),
            'distress_class':  torch.tensor(row['distress_class'], dtype=torch.long),
            'distress_binary': torch.tensor(row['distress'], dtype=torch.float),
            'framing':         torch.tensor(row['framing_idx'], dtype=torch.long),
            'tension':         torch.tensor(row['tension'], dtype=torch.float),
            'goal_risk':       torch.tensor(row['goal_risk'], dtype=torch.float),
            'timing_delay':    torch.tensor(row['timing_delay_log'], dtype=torch.float),
            'timing_urgency':  torch.tensor(row['urgency_idx'], dtype=torch.long),
        }


def collate_fn(batch):
    out = {}
    for k in batch[0]:
        out[k] = [b[k] for b in batch] if k == 'edge_index' else torch.stack([b[k] for b in batch])
    return out

BATCH_SIZE = 16
train_dataset = EnvisDataset(train_df, tokenizer)
val_dataset   = EnvisDataset(val_df, tokenizer)
test_dataset  = EnvisDataset(test_df, tokenizer)
train_loader = DataLoader(train_dataset, BATCH_SIZE, shuffle=True,  collate_fn=collate_fn, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_dataset,   BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_dataset,  BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers=2, pin_memory=True)

print(f"Batches — Train: {len(train_loader)}  Val: {len(val_loader)}  Test: {len(test_loader)}")
s = train_dataset[0]
print("\nSample shapes:")
for k, v in s.items():
    if isinstance(v, torch.Tensor): print(f"  {k:20s} {str(v.shape):15s} {v.dtype}")

## 5. Model v2

- **FinBERT layers 9–11 unfrozen** with lower LR
- **Unified 4-class distress head** replaces separate binary distress + 3 sub-indicator heads
- **Cross-attention fusion** — transaction↔text bidirectional attention before gating
- **Residual prediction heads** for better gradient flow
- **Raw logits** everywhere — loss functions handle sigmoid/softmax

In [None]:
@dataclass
class ModelConfig:
    transaction_vocab_size: int = 8001
    transaction_embedding_dim: int = 256
    transaction_num_layers: int = 4
    transaction_num_heads: int = 8
    transaction_ff_dim: int = 1024
    num_amount_buckets: int = 13
    amount_embedding_dim: int = 32
    num_categories: int = 120
    category_embedding_dim: int = 64
    text_model_name: str = 'ProsusAI/finbert'
    text_embedding_dim: int = 768
    adapter_dim: int = 64
    unfreeze_layers: int = 3
    node_feature_dim: int = NODE_DIM
    household_hidden_dim: int = 64
    household_num_layers: int = 3
    household_num_heads: int = 4
    fusion_dim: int = 512
    dropout: float = 0.1
    num_distress_classes: int = 4
    num_framing_classes: int = 5
    num_urgency_classes: int = 3


# ─── Amount Encoder (log-scale buckets) ───

class AmountEncoder(nn.Module):
    BOUNDS = [0, 1, 2, 5, 10, 20, 50, 100, 200, 500, 1000, 2000]
    def __init__(self, cfg):
        super().__init__()
        self.emb = nn.Embedding(cfg.num_amount_buckets, cfg.amount_embedding_dim)
    def forward(self, amounts):
        buckets = torch.zeros_like(amounts, dtype=torch.long)
        for i, b in enumerate(self.BOUNDS):
            buckets = torch.where(amounts >= b, torch.tensor(i, device=amounts.device), buckets)
        return self.emb(buckets.clamp(max=len(self.BOUNDS)))


# ─── Temporal Encoder ───

class TemporalEncoder(nn.Module):
    def __init__(self, max_seq=MAX_TRANS, pos_dim=32):
        super().__init__()
        self.dom_emb   = nn.Embedding(32, 16)
        self.month_emb = nn.Embedding(13, 16)
        self.pos_emb   = nn.Embedding(max_seq, pos_dim)
        self.out_dim   = 64

    def forward(self, day_of_month, month, positions):
        return torch.cat([
            self.dom_emb(day_of_month), self.month_emb(month), self.pos_emb(positions)
        ], dim=-1)


# ─── Transaction Encoder (4-layer Transformer) ───

class TransactionEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.amount_enc = AmountEncoder(cfg)
        self.cat_emb    = nn.Embedding(cfg.num_categories, cfg.category_embedding_dim)
        self.merch_emb  = nn.Embedding(cfg.transaction_vocab_size, cfg.category_embedding_dim)
        self.temp_enc   = TemporalEncoder()
        in_dim = cfg.amount_embedding_dim + cfg.category_embedding_dim * 2 + self.temp_enc.out_dim
        self.input_proj = nn.Linear(in_dim, cfg.transaction_embedding_dim)
        layer = nn.TransformerEncoderLayer(
            d_model=cfg.transaction_embedding_dim, nhead=cfg.transaction_num_heads,
            dim_feedforward=cfg.transaction_ff_dim, dropout=cfg.dropout, batch_first=True,
        )
        self.transformer = nn.TransformerEncoder(layer, num_layers=cfg.transaction_num_layers)

    def forward(self, amounts, categories, merchants, days, months, mask=None):
        B, S = amounts.shape
        pos = torch.arange(S, device=amounts.device).unsqueeze(0).expand(B, -1)
        combined = torch.cat([
            self.amount_enc(amounts), self.cat_emb(categories),
            self.merch_emb(merchants), self.temp_enc(days, months, pos),
        ], dim=-1)
        projected = self.input_proj(combined)
        pad_mask = ~mask.bool() if mask is not None else None
        return self.transformer(projected, src_key_padding_mask=pad_mask)


# ─── Adapter (Houlsby et al., 2019) ───

class Adapter(nn.Module):
    def __init__(self, dim, bottleneck):
        super().__init__()
        self.down = nn.Linear(dim, bottleneck)
        self.up   = nn.Linear(bottleneck, dim)
        self.act  = nn.GELU()
    def forward(self, x):
        return x + self.up(self.act(self.down(x)))


# ─── Text Encoder (FinBERT, last 3 layers unfrozen + adapters) ───

class TextEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.bert = AutoModel.from_pretrained(cfg.text_model_name)
        for p in self.bert.parameters():
            p.requires_grad = False
        n_layers = len(self.bert.encoder.layer)
        for i in range(n_layers - cfg.unfreeze_layers, n_layers):
            for p in self.bert.encoder.layer[i].parameters():
                p.requires_grad = True
        self.adapters = nn.ModuleList([
            Adapter(cfg.text_embedding_dim, cfg.adapter_dim) for _ in range(n_layers)
        ])

    def forward(self, input_ids, attention_mask):
        out = self.bert(input_ids=input_ids, attention_mask=attention_mask,
                        output_hidden_states=True)
        cls = out.hidden_states[-1][:, 0, :]
        for i, adapter in enumerate(self.adapters):
            layer_cls = out.hidden_states[i + 1][:, 0, :]
            cls = cls + 0.1 * adapter(layer_cls)
        return cls, out.last_hidden_state


# ─── GAT Layer ───

class GATLayer(nn.Module):
    def __init__(self, in_dim, out_dim, heads, dropout):
        super().__init__()
        self.heads, self.out_dim = heads, out_dim
        self.W = nn.Linear(in_dim, out_dim * heads, bias=False)
        self.a = nn.Parameter(torch.zeros(heads, 2 * out_dim))
        nn.init.xavier_uniform_(self.a)
        self.leaky_relu = nn.LeakyReLU(0.2)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, edge_index):
        h = self.W(x).view(-1, self.heads, self.out_dim)
        if edge_index.shape[1] == 0:
            return h.mean(dim=1)
        src, tgt = edge_index
        alpha = torch.cat([h[src], h[tgt]], dim=-1)
        alpha = self.leaky_relu((alpha * self.a).sum(dim=-1))
        alpha = self.dropout(F.softmax(alpha, dim=0))
        out = torch.zeros_like(h)
        out.index_add_(0, tgt, alpha.unsqueeze(-1) * h[src])
        return out.mean(dim=1)


# ─── Household Encoder (3-layer GAT) ───

class HouseholdEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.input_proj = nn.Linear(cfg.node_feature_dim, cfg.household_hidden_dim)
        self.layers = nn.ModuleList([
            GATLayer(cfg.household_hidden_dim, cfg.household_hidden_dim,
                     cfg.household_num_heads, cfg.dropout)
            for _ in range(cfg.household_num_layers)
        ])
        self.pool_attn = nn.Linear(cfg.household_hidden_dim, 1)

    def forward(self, node_features, edge_index, n_members):
        x = self.input_proj(node_features[:n_members])
        for layer in self.layers:
            x = F.elu(layer(x, edge_index))
        weights = F.softmax(self.pool_attn(x), dim=0)
        return (weights * x).sum(dim=0)


# ─── Cross-Attention Fusion ───

class CrossAttentionFusion(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        d = cfg.fusion_dim
        self.trans_proj = nn.Linear(cfg.transaction_embedding_dim, d)
        self.text_proj  = nn.Linear(cfg.text_embedding_dim, d)
        self.house_proj = nn.Linear(cfg.household_hidden_dim, d)
        self.trans_text_attn = nn.MultiheadAttention(d, 8, dropout=cfg.dropout, batch_first=True)
        self.text_house_attn = nn.MultiheadAttention(d, 4, dropout=cfg.dropout, batch_first=True)
        self.gate = nn.Sequential(
            nn.Linear(d * 3, d), nn.ReLU(), nn.Dropout(cfg.dropout),
            nn.Linear(d, 3), nn.Softmax(dim=-1),
        )
        self.out_proj = nn.Linear(d * 3, d)
        self.norm = nn.LayerNorm(d)
        self.drop = nn.Dropout(cfg.dropout)

    def forward(self, trans_seq, text_cls, text_tokens, house_emb):
        t_proj = self.trans_proj(trans_seq)
        x_proj = self.text_proj(text_tokens)
        x_cls  = self.text_proj(text_cls)
        h      = self.house_proj(house_emb)

        t_query = t_proj.mean(dim=1, keepdim=True)
        t_cross, _ = self.trans_text_attn(t_query, x_proj, x_proj)
        t_enh = t_query.squeeze(1) + t_cross.squeeze(1)

        h_kv = h.unsqueeze(1)
        x_cross, _ = self.text_house_attn(x_cls.unsqueeze(1), h_kv, h_kv)
        x_enh = x_cls + x_cross.squeeze(1)

        cat = torch.cat([t_enh, x_enh, h], dim=-1)
        g = self.gate(cat)
        gated = torch.cat([g[:,0:1]*t_enh, g[:,1:2]*x_enh, g[:,2:3]*h], dim=-1)
        return self.norm(self.drop(self.out_proj(gated)))


# ─── Residual Prediction Head ───

class ResHead(nn.Module):
    def __init__(self, in_d, out_d, hidden=256):
        super().__init__()
        self.fc1 = nn.Linear(in_d, hidden)
        self.fc2 = nn.Linear(hidden, hidden)
        self.fc3 = nn.Linear(hidden, out_d)
        self.skip = nn.Linear(in_d, hidden) if in_d != hidden else nn.Identity()
        self.drop = nn.Dropout(0.1)
    def forward(self, x):
        h = F.relu(self.fc1(x))
        h = self.drop(h)
        h = F.relu(self.fc2(h) + self.skip(x))
        h = self.drop(h)
        return self.fc3(h)


# ─── Envis v2 ───

class EnvisV2(nn.Module):
    def __init__(self, cfg=None):
        super().__init__()
        self.cfg = cfg or ModelConfig()
        c = self.cfg
        self.trans_enc  = TransactionEncoder(c)
        self.text_enc   = TextEncoder(c)
        self.house_enc  = HouseholdEncoder(c)
        self.fusion     = CrossAttentionFusion(c)

        self.distress_head  = ResHead(c.fusion_dim, c.num_distress_classes)
        self.framing_head   = ResHead(c.fusion_dim, c.num_framing_classes)
        self.urgency_head   = ResHead(c.fusion_dim, c.num_urgency_classes)
        self.tension_head   = ResHead(c.fusion_dim, 1)
        self.goal_risk_head = ResHead(c.fusion_dim, 1)
        self.delay_head     = ResHead(c.fusion_dim, 1)
        self.log_vars = nn.Parameter(torch.zeros(6))

    def forward(self, batch):
        B = batch['input_ids'].shape[0]
        trans_seq = self.trans_enc(
            batch['amounts'], batch['categories'], batch['merchants'],
            batch['days'], batch['months'], batch['trans_mask'],
        )
        text_cls, text_tokens = self.text_enc(batch['input_ids'], batch['attention_mask'])
        house_list = []
        for i in range(B):
            h = self.house_enc(
                batch['node_features'][i],
                batch['edge_index'][i].to(batch['input_ids'].device),
                batch['n_members'][i].item(),
            )
            house_list.append(h)
        house_emb = torch.stack(house_list)
        fused = self.fusion(trans_seq, text_cls, text_tokens, house_emb)

        return {
            'distress':       self.distress_head(fused),
            'framing':        self.framing_head(fused),
            'timing_urgency': self.urgency_head(fused),
            'tension':        self.tension_head(fused).squeeze(-1),
            'goal_risk':      self.goal_risk_head(fused).squeeze(-1),
            'timing_delay':   self.delay_head(fused).squeeze(-1),
            'log_vars':       self.log_vars,
        }


cfg = ModelConfig()
model = EnvisV2(cfg).to(device)

total     = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
bert_unfrozen = sum(p.numel() for n, p in model.text_enc.bert.named_parameters() if p.requires_grad)
print(f"Total parameters:     {total:,}")
print(f"Trainable:            {trainable:,} ({100*trainable/total:.1f}%)")
print(f"  BERT unfrozen:      {bert_unfrozen:,} (last {cfg.unfreeze_layers} layers)")
print(f"  New components:     {trainable - bert_unfrozen:,}")
print(f"Frozen:               {total - trainable:,}")

## 6. Loss & Evaluation

- Weighted `CrossEntropyLoss` for distress 4-class, framing, urgency
- `BCEWithLogitsLoss` with `pos_weight` for tension, goal_risk
- `HuberLoss` for log-scale timing delay

In [None]:
ce_distress = nn.CrossEntropyLoss(weight=distress_w)
ce_framing  = nn.CrossEntropyLoss(weight=framing_w)
ce_urgency  = nn.CrossEntropyLoss(weight=urgency_w)
bce_tension = nn.BCEWithLogitsLoss(pos_weight=tension_pw)
bce_goal    = nn.BCEWithLogitsLoss(pos_weight=goal_risk_pw)
huber_delay = nn.HuberLoss(delta=1.0)


def compute_loss(preds, batch):
    lv = preds['log_vars']
    losses = {
        'distress':  ce_distress(preds['distress'], batch['distress_class']),
        'framing':   ce_framing(preds['framing'], batch['framing']),
        'urgency':   ce_urgency(preds['timing_urgency'], batch['timing_urgency']),
        'tension':   bce_tension(preds['tension'], batch['tension']),
        'goal_risk': bce_goal(preds['goal_risk'], batch['goal_risk']),
        'delay':     huber_delay(preds['timing_delay'], batch['timing_delay']),
    }
    total = sum(torch.exp(-lv[i]) * v + lv[i] for i, v in enumerate(losses.values()))
    return total, {k: v.item() for k, v in losses.items()}


@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    keys = ['distress_class', 'distress_prob', 'framing', 'tension',
            'goal_risk', 'timing_delay', 'timing_urgency']
    P = {k: [] for k in keys}
    L = {k: [] for k in keys}
    total_loss = 0

    for batch in loader:
        b = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}
        with autocast():
            preds = model(b)
            loss, _ = compute_loss(preds, b)
        total_loss += loss.item()

        d_probs = F.softmax(preds['distress'], dim=1)
        P['distress_class'].extend(d_probs.argmax(1).cpu().numpy())
        L['distress_class'].extend(batch['distress_class'].numpy())
        P['distress_prob'].extend((1 - d_probs[:, 0]).cpu().numpy())
        L['distress_prob'].extend(batch['distress_binary'].numpy())

        P['framing'].extend(preds['framing'].argmax(1).cpu().numpy())
        L['framing'].extend(batch['framing'].numpy())
        P['timing_urgency'].extend(preds['timing_urgency'].argmax(1).cpu().numpy())
        L['timing_urgency'].extend(batch['timing_urgency'].numpy())
        P['tension'].extend(torch.sigmoid(preds['tension']).cpu().numpy())
        L['tension'].extend(batch['tension'].numpy())
        P['goal_risk'].extend(torch.sigmoid(preds['goal_risk']).cpu().numpy())
        L['goal_risk'].extend(batch['goal_risk'].numpy())
        P['timing_delay'].extend(torch.expm1(preds['timing_delay']).cpu().numpy())
        L['timing_delay'].extend(torch.expm1(batch['timing_delay']).cpu().numpy())

    m = {
        'loss':              total_loss / len(loader),
        'distress_auc':      roc_auc_score(L['distress_prob'], P['distress_prob']),
        'distress_4cls_acc': accuracy_score(L['distress_class'], P['distress_class']),
        'distress_4cls_f1':  f1_score(L['distress_class'], P['distress_class'], average='macro'),
        'framing_acc':       accuracy_score(L['framing'], P['framing']),
        'framing_f1':        f1_score(L['framing'], P['framing'], average='macro'),
        'tension_auc':       roc_auc_score(L['tension'], P['tension']),
        'goal_risk_auc':     roc_auc_score(L['goal_risk'], P['goal_risk']),
        'delay_mae_hours':   mean_absolute_error(L['timing_delay'], P['timing_delay']),
        'urgency_acc':       accuracy_score(L['timing_urgency'], P['timing_urgency']),
    }
    return m, P, L

print("Loss and eval defined ✓")

## 7. Training

- **Differential LR**: BERT 1e-5, new components 1e-4
- **Mixed precision** fp16
- **Gradient accumulation** 2 steps → effective batch 32
- **Early stopping** patience 5

In [None]:
EPOCHS       = 25
LR_BERT      = 1e-5
LR_NEW       = 1e-4
WEIGHT_DECAY = 0.01
WARMUP_STEPS = 500
PATIENCE     = 5
ACCUM_STEPS  = 2

bert_ids = set(id(p) for p in model.text_enc.bert.parameters() if p.requires_grad)
bert_params = [p for p in model.text_enc.bert.parameters() if p.requires_grad]
new_params  = [p for p in model.parameters() if p.requires_grad and id(p) not in bert_ids]

optimizer = torch.optim.AdamW([
    {'params': bert_params, 'lr': LR_BERT},
    {'params': new_params,  'lr': LR_NEW},
], weight_decay=WEIGHT_DECAY)

total_steps = (len(train_loader) * EPOCHS) // ACCUM_STEPS
scheduler = get_cosine_schedule_with_warmup(optimizer, WARMUP_STEPS, total_steps)
scaler = GradScaler()

print(f"BERT trainable (lr={LR_BERT}): {sum(p.numel() for p in bert_params):,}")
print(f"New params (lr={LR_NEW}):      {sum(p.numel() for p in new_params):,}")
print(f"Effective batch size:           {BATCH_SIZE * ACCUM_STEPS}")
print("=" * 70)

best_val_loss = float('inf')
patience_ctr  = 0
training_log  = []

for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0
    t0 = time.time()
    optimizer.zero_grad()

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    for step, batch in enumerate(pbar):
        b = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}

        with autocast():
            preds = model(b)
            loss, task_losses = compute_loss(preds, b)
            loss = loss / ACCUM_STEPS

        scaler.scale(loss).backward()

        if (step + 1) % ACCUM_STEPS == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            scheduler.step()

        epoch_loss += loss.item() * ACCUM_STEPS
        pbar.set_postfix(loss=f"{loss.item()*ACCUM_STEPS:.4f}")

    epoch_time = time.time() - t0
    val_m, _, _ = evaluate(model, val_loader)

    entry = {
        'epoch':            epoch + 1,
        'train_loss':       round(epoch_loss / len(train_loader), 4),
        'val_loss':         round(val_m['loss'], 4),
        'distress_auc':     round(val_m['distress_auc'], 4),
        'distress_4cls_f1': round(val_m['distress_4cls_f1'], 4),
        'framing_acc':      round(val_m['framing_acc'], 4),
        'framing_f1':       round(val_m['framing_f1'], 4),
        'tension_auc':      round(val_m['tension_auc'], 4),
        'goal_risk_auc':    round(val_m['goal_risk_auc'], 4),
        'delay_mae':        round(val_m['delay_mae_hours'], 1),
        'urgency_acc':      round(val_m['urgency_acc'], 4),
        'time_sec':         round(epoch_time, 1),
    }
    training_log.append(entry)

    print(f"\nEpoch {epoch+1}: train={entry['train_loss']:.4f} val={entry['val_loss']:.4f} "
          f"d_auc={entry['distress_auc']:.4f} d4_f1={entry['distress_4cls_f1']:.4f} "
          f"f_acc={entry['framing_acc']:.4f} f_f1={entry['framing_f1']:.4f}")
    print(f"  t_auc={entry['tension_auc']:.4f} gr_auc={entry['goal_risk_auc']:.4f} "
          f"delay={entry['delay_mae']:.1f}h urg_acc={entry['urgency_acc']:.4f} "
          f"time={epoch_time:.0f}s")

    if val_m['loss'] < best_val_loss:
        best_val_loss = val_m['loss']
        patience_ctr = 0
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_m['loss'],
            'val_metrics': val_m,
            'config': cfg.__dict__,
        }, 'best_model_v2.pt')
        print("  ✓ Best model saved")
    else:
        patience_ctr += 1
        print(f"  No improvement ({patience_ctr}/{PATIENCE})")

    if patience_ctr >= PATIENCE:
        print(f"\nEarly stopping at epoch {epoch+1}")
        break

print("\nTraining complete!")

## 8. Test Evaluation

In [None]:
ckpt = torch.load('best_model_v2.pt', weights_only=False)
model.load_state_dict(ckpt['model_state_dict'])
print(f"Loaded best model from epoch {ckpt['epoch']} (val_loss={ckpt['val_loss']:.4f})")

test_m, test_P, test_L = evaluate(model, test_loader)

print("\n" + "=" * 60)
print("TEST RESULTS — v2")
print("=" * 60)
print(f"  Distress AUC (binary):     {test_m['distress_auc']:.4f}")
print(f"  Distress 4-class Acc:      {test_m['distress_4cls_acc']:.4f}")
print(f"  Distress 4-class F1:       {test_m['distress_4cls_f1']:.4f}")
print(f"  Framing Accuracy:          {test_m['framing_acc']:.4f}")
print(f"  Framing F1 (macro):        {test_m['framing_f1']:.4f}")
print(f"  Tension AUC:               {test_m['tension_auc']:.4f}")
print(f"  Goal-Risk AUC:             {test_m['goal_risk_auc']:.4f}")
print(f"  Timing Delay MAE:          {test_m['delay_mae_hours']:.1f} hours")
print(f"  Timing Urgency Acc:        {test_m['urgency_acc']:.4f}")
print("=" * 60)

print("\n--- Distress 4-class report ---")
print(classification_report(
    test_L['distress_class'], test_P['distress_class'],
    target_names=['none', 'self_blame', 'avoidance', 'secrecy']
))

print("--- Framing report ---")
print(classification_report(
    test_L['framing'], test_P['framing'],
    target_names=FRAMING_CLASSES
))

# Binary distress threshold
prec, rec, thr = precision_recall_curve(test_L['distress_prob'], test_P['distress_prob'])
f1s = 2 * prec * rec / (prec + rec + 1e-8)
best_thr = float(thr[np.argmax(f1s)])
print(f"Binary distress threshold: {best_thr:.3f}, best F1: {np.max(f1s):.3f}")

d_bin_pred = (np.array(test_P['distress_prob']) >= best_thr).astype(int)
d_bin_true = np.array(test_L['distress_prob']).astype(int)
print("\nDistress binary confusion matrix:")
print(confusion_matrix(d_bin_true, d_bin_pred))
print("\nFraming confusion matrix:")
print(confusion_matrix(test_L['framing'], test_P['framing']))

## 9. Save & Download

In [None]:
log_df = pd.DataFrame(training_log)
log_df.to_csv('training_log_v2.csv', index=False)

results = {
    'run_id':             f"envis_v2_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
    'model_version':      'v2',
    'completed':          datetime.now().isoformat(),
    'epochs_trained':     len(training_log),
    'best_epoch':         ckpt['epoch'],
    'v2_improvements': [
        'Unfrozen last 3 FinBERT layers + differential LR',
        'Unified distress/sub-indicator 4-class head',
        'Log1p timing delay + Huber loss',
        'Class-weighted CE for framing/urgency/distress',
        'BCEWithLogitsLoss + pos_weight for tension/goal_risk',
        'Cross-attention fusion',
        'Residual prediction heads',
        'Mixed precision fp16 + gradient accumulation',
        'Reduced MAX_TEXT=128, MAX_TRANS=20',
    ],
    'config':             cfg.__dict__,
    'test_metrics':       {k: round(v, 4) if isinstance(v, float) else v for k, v in test_m.items()},
    'distress_threshold': round(best_thr, 3),
    'confusion_matrices': {
        'distress_4class': confusion_matrix(test_L['distress_class'], test_P['distress_class']).tolist(),
        'distress_binary': confusion_matrix(d_bin_true, d_bin_pred).tolist(),
        'framing': confusion_matrix(test_L['framing'], test_P['framing']).tolist(),
    },
    'training_log':       training_log,
}
with open('evaluation_results_v2.json', 'w') as f:
    json.dump(results, f, indent=2)

print("Saved:")
print(f"  best_model_v2.pt            ({os.path.getsize('best_model_v2.pt')/1e6:.1f} MB)")
print(f"  training_log_v2.csv         ({len(training_log)} epochs)")
print(f"  evaluation_results_v2.json")

In [None]:
from google.colab import files
files.download('best_model_v2.pt')
files.download('training_log_v2.csv')
files.download('evaluation_results_v2.json')