
# Tabular → Insight Text: E-A-D (Encoder → Aggregator → Decoder) Pipeline

This notebook builds a full pipeline to convert tabular data (rows × features) into natural-language **business insights**.

**Architecture**  
- **E — Encoder (TabTransformer):** Produces a per-row embedding from categorical + continuous features.  
- **A — Aggregator (Attention Pooling):** Collapses *N* row embeddings into a single **table embedding**.  
- **D — Decoder (GPT-2 with Soft Prompting):** Maps the table embedding to a small set of **soft prompt tokens** that condition a decoder-only LLM (GPT-2) to generate insight text.

> The notebook uses a **synthetic dataset** where each example is a small table (many rows) plus a templated ground-truth insight. This keeps training fully supervised and easy to iterate.



## 1. Setup

You'll need a recent PyTorch and Transformers. If you're in Colab or a fresh environment, run the cell below. If you're offline or already installed, you can skip it.


In [1]:
!pip install tab_transformer_pytorch

Collecting tab_transformer_pytorch
  Downloading tab_transformer_pytorch-0.4.2-py3-none-any.whl.metadata (914 bytes)
Collecting hyper-connections>=0.1.15 (from tab_transformer_pytorch)
  Downloading hyper_connections-0.2.1-py3-none-any.whl.metadata (6.0 kB)
Downloading tab_transformer_pytorch-0.4.2-py3-none-any.whl (7.2 kB)
Downloading hyper_connections-0.2.1-py3-none-any.whl (16 kB)
Installing collected packages: hyper-connections, tab_transformer_pytorch
Successfully installed hyper-connections-0.2.1 tab_transformer_pytorch-0.4.2


In [15]:
import os, math, random, json, sys, time
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional

from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer

import re

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tab_transformer_pytorch import TabTransformer

# Hugging Face
from transformers import AutoTokenizer, AutoModelForCausalLM, get_linear_schedule_with_warmup

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)
seed = 42
random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

Device: cuda



## 2. Synthetic Data: Tables + Ground-Truth Insights

We synthesize many **mini-tables** (each is a separate training example). For each table, we bake in a couple of clear patterns, then render a short **ground-truth insight** string describing those patterns.  
This lets us train end-to-end without needing a curated labeled corpus.


In [3]:
# --- Synthetic schema ---
CATEG_COLS = [
    ("country", 12),  # e.g., 12 countries
    ("segment", 4),   # SMB / Mid / Enterprise / Consumer
    ("product", 6),   # product lines
]
NUM_CONT = 6  # numerical features, e.g., balance, age, tenure, etc.

# We'll generate a binary column we don't feed to the model directly (churn/fraud)
# but we use it to derive target insight text for supervision.
# Each table is its own training example.
@dataclass
class TableConfig:
    n_rows_min: int = 80
    n_rows_max: int = 150
    p_base_churn: float = 0.15

def generate_one_table(cfg: TableConfig):
    n_rows = random.randint(cfg.n_rows_min, cfg.n_rows_max)
    # Choose one "hot" country that has higher churn and higher balances on average
    hot_country = random.randint(0, CATEG_COLS[0][1] - 1)
    hot_segment = random.randint(0, CATEG_COLS[1][1] - 1)

    # Continuous feature generative params
    # balance ~ higher for hot_country; age, tenure, etc.
    base_balance_mu = 1000.0
    base_balance_sigma = 400.0
    balance_boost = random.uniform(300, 800)  # hot country boost

    # Base churn odds + boosts
    p_churn_base = cfg.p_base_churn
    p_churn_hot_country_boost = random.uniform(0.10, 0.25)
    p_churn_hot_segment_boost = random.uniform(0.05, 0.15)

    x_categ = torch.zeros(n_rows, len(CATEG_COLS), dtype=torch.long)
    x_cont  = torch.zeros(n_rows, NUM_CONT, dtype=torch.float32)
    churn   = torch.zeros(n_rows, dtype=torch.long)

    for i in range(n_rows):
        c_country = random.randint(0, CATEG_COLS[0][1] - 1)
        c_segment = random.randint(0, CATEG_COLS[1][1] - 1)
        c_product = random.randint(0, CATEG_COLS[2][1] - 1)
        x_categ[i] = torch.tensor([c_country, c_segment, c_product])

        # Continuous features (toy):
        # 0: balance, 1: age, 2: tenure, 3: tx_freq, 4: complaints, 5: income
        balance_mu = base_balance_mu + (balance_boost if c_country == hot_country else 0.0)
        balance = random.gauss(balance_mu, base_balance_sigma)
        age = max(18, random.gauss(42, 12))
        tenure = max(0.0, random.gauss(3.5, 2.0))
        tx_freq = max(0.0, random.gauss(8, 3))
        complaints = max(0.0, random.gauss(0.6, 0.7))
        income = max(0.0, random.gauss(48_000, 15_000))

        x_cont[i] = torch.tensor([balance, age, tenure, tx_freq, complaints, income], dtype=torch.float32)

        # Churn probability with boosts for hot country + hot segment
        p = p_churn_base
        if c_country == hot_country:
            p += p_churn_hot_country_boost
        if c_segment == hot_segment:
            p += p_churn_hot_segment_boost
        churn[i] = 1 if random.random() < p else 0

    # Derive summary stats for ground-truth insight
    # Country-level churn: compare hot_country vs rest
    mask_hot = (x_categ[:,0] == hot_country)
    churn_hot = churn[mask_hot].float().mean().item() if mask_hot.any() else 0.0
    churn_rest = churn[~mask_hot].float().mean().item() if (~mask_hot).any() else 0.0

    avg_balance_hot = x_cont[mask_hot,0].mean().item() if mask_hot.any() else 0.0
    avg_balance_rest = x_cont[~mask_hot,0].mean().item() if (~mask_hot).any() else 0.0

    # Segment-level churn for the chosen hot segment
    mask_seg = (x_categ[:,1] == hot_segment)
    churn_seg = churn[mask_seg].float().mean().item() if mask_seg.any() else 0.0
    churn_not_seg = churn[~mask_seg].float().mean().item() if (~mask_seg).any() else 0.0

    # Build human-readable labels from IDs
    def name_country(i): return f"Country_{i}"
    def name_seg(i): return ["SMB","Mid","Enterprise","Consumer"][i] if CATEG_COLS[1][1]==4 and i<4 else f"Segment_{i}"

    hot_country_name = name_country(hot_country)
    hot_segment_name = name_seg(hot_segment)

    # Ground-truth insight text (templated)
    insight = (
        f"Customers from {hot_country_name} show elevated churn ({churn_hot:.1%} vs {churn_rest:.1%} elsewhere). "
        f"Average balances are higher in {hot_country_name} (${avg_balance_hot:,.0f} vs ${avg_balance_rest:,.0f}). "
        f"The {hot_segment_name} segment also churns more ({churn_seg:.1%} vs {churn_not_seg:.1%}). "
        f"Prioritize retention for {hot_country_name} and {hot_segment_name} customers with higher balances."
    )

    return {
        "x_categ": x_categ,
        "x_cont": x_cont,
        "row_mask": torch.ones(n_rows, dtype=torch.bool),  # all valid rows
        "insight": insight,
    }

class TableToTextDataset(Dataset):
    def __init__(self, n_tables=800, cfg: TableConfig=TableConfig()):
        self.samples = [generate_one_table(cfg) for _ in range(n_tables)]

    def __len__(self): return len(self.samples)
    def __getitem__(self, idx): return self.samples[idx]

train_ds = TableToTextDataset(n_tables=800)
val_ds   = TableToTextDataset(n_tables=160)
print("Train/Val sizes:", len(train_ds), len(val_ds))
print("Sample insight:", train_ds[0]["insight"][:160], "...")

Train/Val sizes: 800 160
Sample insight: Customers from Country_0 show elevated churn (18.2% vs 13.3% elsewhere). Average balances are higher in Country_0 ($1,553 vs $1,000). The Enterprise segment als ...



## 3. Tokenizer & DataLoader (variable rows per table)

- We pad **rows** per table within a batch and carry a boolean `row_mask` for attention pooling.  
- We tokenize the target insight text and pad to the max text length in the batch.  
- We set `pad_token = eos_token` for GPT-2 to simplify handling.


In [4]:
# GPT-2 tokenizer
tokenizer_name = "gpt2"  # you can switch to "distilgpt2" for a smaller model
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

MAX_TEXT_TOKENS = 160  # cap for training (keep moderate)

def collate_fn(batch):
    # Compute max rows & pad
    max_rows = max(b["x_categ"].shape[0] for b in batch)
    B = len(batch)
    Cc = len(CATEG_COLS)
    Nc = NUM_CONT

    pad_categ = torch.zeros(B, max_rows, Cc, dtype=torch.long)
    pad_cont  = torch.zeros(B, max_rows, Nc, dtype=torch.float32)
    row_mask  = torch.zeros(B, max_rows, dtype=torch.bool)

    insights = [b["insight"] for b in batch]
    for i, b in enumerate(batch):
        n = b["x_categ"].shape[0]
        pad_categ[i, :n] = b["x_categ"]
        pad_cont[i, :n]  = b["x_cont"]
        row_mask[i, :n]  = True

    # Tokenize and pad targets
    tok = tokenizer(
        insights,
        padding=True,
        truncation=True,
        max_length=MAX_TEXT_TOKENS,
        return_tensors="pt"
    )
    input_ids = tok["input_ids"]
    attention_mask_text = tok["attention_mask"]  # 1 for real tokens, 0 for pads

    return {
        "x_categ": pad_categ,
        "x_cont": pad_cont,
        "row_mask": row_mask,
        "input_ids": input_ids,
        "attention_mask_text": attention_mask_text,
    }

train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_ds,   batch_size=4, shuffle=False, collate_fn=collate_fn)

batch = next(iter(train_loader))
print("Batch shapes:")
for k,v in batch.items():
    print(" ", k, tuple(v.shape))

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Batch shapes:
  x_categ (4, 133, 3)
  x_cont (4, 133, 6)
  row_mask (4, 133)
  input_ids (4, 74)
  attention_mask_text (4, 74)



## 4. Model — E: Row Encoder (TabTransformer)

- TabTransformer package
- Here, we just demostrate how we can call the TabTranformer function


In [None]:
# Define encoder using the package
encoder = TabTransformer(
    categories = [12, 4, 6],   # cardinalities of categorical features
    num_continuous = 6,        # number of continuous features
    dim = 128,                 # embedding dimension
    depth = 6,                 # number of transformer layers
    heads = 8,                 # number of attention heads
    attn_dropout = 0.1,
    ff_dropout = 0.1
).to(device)


## 5. Model — A: Attention Pooling over Rows

A simple attention mechanism over rows to produce a fixed-size **table embedding**.


In [5]:
class AttentionPooling(nn.Module):
    def __init__(self, dim, hidden=64, dropout=0.1):
        super().__init__()
        self.att = nn.Sequential(
            nn.Linear(dim, hidden), nn.Tanh(),
            nn.Dropout(dropout),
            nn.Linear(hidden, 1)
        )

    def forward(self, x, mask):
        # x: [B, R, D], mask: [B, R]
        scores = self.att(x).squeeze(-1)  # [B, R] - Fixed: was self.attn, now self.att
        scores = scores.masked_fill(~mask, -1e9)  # mask out padded rows
        weights = torch.softmax(scores, dim=1)  # [B, R]
        pooled = (weights.unsqueeze(-1) * x).sum(dim=1)  # [B, D]
        return pooled


## 6. Model — D: GPT-2 with Soft Prompting

We map the table embedding into **K learnable soft tokens** in GPT-2's hidden space.  
During training we concatenate these soft tokens with the token embeddings of the target text and train with standard next-token LM loss (labels = text tokens, soft tokens masked with `-100`).  
During generation we feed only the soft tokens and let GPT-2 produce text.


In [6]:

class SoftPromptProjector(nn.Module):
    def __init__(self, in_dim, gpt_hidden, K=20, hidden=256, dropout=0.1):
        super().__init__()
        self.K = K
        self.out_dim = gpt_hidden
        self.net = nn.Sequential(
            nn.LayerNorm(in_dim),
            nn.Linear(in_dim, hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, K * gpt_hidden)
        )

    def forward(self, table_emb):  # [B, in_dim]
        B = table_emb.size(0)
        x = self.net(table_emb)                # [B, K*H]
        x = x.view(B, self.K, self.out_dim)    # [B, K, H]
        return x

class InsightDecoder(nn.Module):
    def __init__(self, model_name="gpt2", freeze_gpt2=False):
        super().__init__()
        self.gpt2 = AutoModelForCausalLM.from_pretrained(model_name)
        self.hidden = self.gpt2.config.n_embd
        if freeze_gpt2:
            for p in self.gpt2.parameters():
                p.requires_grad = False

    def forward_train(self, soft_embeds, input_ids, attention_mask_text):
        # soft_embeds: [B, K, H]
        # input_ids: [B, T], attention_mask_text: [B, T] (1 for real tokens)
        B, K, H = soft_embeds.shape
        T = input_ids.size(1)

        # Get word embeddings for tokens
        tok_embeds = self.gpt2.transformer.wte(input_ids)  # [B, T, H]

        inputs_embeds = torch.cat([soft_embeds, tok_embeds], dim=1)  # [B, K+T, H]

        # Build attention mask: soft tokens are all visible (1s), then text mask
        attn_mask = torch.cat([torch.ones(B, K, dtype=attention_mask_text.dtype, device=input_ids.device),
                               attention_mask_text], dim=1)  # [B, K+T]

        # Labels: -100 for soft tokens, then actual token IDs
        labels = torch.cat([torch.full((B, K), -100, dtype=torch.long, device=input_ids.device),
                            input_ids], dim=1)  # [B, K+T]

        out = self.gpt2(inputs_embeds=inputs_embeds, attention_mask=attn_mask, labels=labels)
        return out  # .loss, .logits

    @torch.no_grad()
    def generate_from_soft(self, soft_embeds, max_new_tokens=100, do_sample=True, temperature=0.8, top_p=0.95):
        B, K, H = soft_embeds.shape
        # Start generation with only the soft prompt
        attn_mask = torch.ones(B, K, dtype=torch.long, device=soft_embeds.device)
        gen_ids = self.gpt2.generate(
            inputs_embeds=soft_embeds,
            attention_mask=attn_mask,
            max_new_tokens=max_new_tokens,
            do_sample=do_sample,
            temperature=temperature,
            top_p=top_p,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
        # gen_ids will be token IDs *for the generated continuation only* in some HF versions.
        # To be safe, decode last max_new_tokens tokens.
        text = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
        return text


## 7. Full Model Wiring

`Table2TextModel` wraps the encoder, aggregator, projector, and decoder, exposing `forward()` for training and a `generate()` utility.


In [7]:
class Tab2TextModel(nn.Module):
    def __init__(self, num_categories: List[int], num_continuous: int, gpt_model_name="gpt2",
                 freeze_gpt2=False, soft_K=20, d_token=64, depth=4, heads=8):
        super().__init__()
        self.encoder = TabTransformer(
            categories=num_categories,
            num_continuous=num_continuous,
            dim=d_token,
            depth=depth,
            heads=heads,
            attn_dropout=0.1,
            ff_dropout=0.1,
            dim_out=d_token
        )

        row_emb_dim = d_token
        self.aggregator = AttentionPooling(dim=row_emb_dim)

        self.decoder = InsightDecoder(model_name=gpt_model_name, freeze_gpt2=freeze_gpt2)
        gpt_hidden = self.decoder.hidden
        self.projector = SoftPromptProjector(in_dim=row_emb_dim, gpt_hidden=gpt_hidden, K=soft_K)

    def forward(self, x_categ, x_cont, row_mask, input_ids, attention_mask_text):
        B, R, C = x_categ.shape
        B, R, N = x_cont.shape

        # Reshape to process all rows at once: (B*R, C) and (B*R, N)
        x_categ_flat = x_categ.view(-1, C)  # [B*R, C]
        x_cont_flat = x_cont.view(-1, N)    # [B*R, N]

        # Encode all rows
        row_embs_flat = self.encoder(x_categ_flat, x_cont_flat)  # [B*R, D]

        # Reshape back to [B, R, D]
        D = row_embs_flat.shape[-1]
        row_embs = row_embs_flat.view(B, R, D)

        # Apply row mask to zero out padded rows
        row_embs = row_embs * row_mask.unsqueeze(-1).float()

        # Aggregate rows to table embedding
        table_emb = self.aggregator(row_embs, row_mask)  # [B, D]

        # Project into K soft prompts
        soft_embeds = self.projector(table_emb)  # [B, K, H]

        # Train GPT2 with soft prompts
        out = self.decoder.forward_train(soft_embeds, input_ids, attention_mask_text)
        return out

    @torch.no_grad()
    def generate(self, x_categ, x_cont, row_mask, max_new_tokens=100, do_sample=True, temperature=0.8, top_p=0.95):
        B, R, C = x_categ.shape
        B, R, N = x_cont.shape

        # Reshape and encode
        x_categ_flat = x_categ.view(-1, C)
        x_cont_flat = x_cont.view(-1, N)
        row_embs_flat = self.encoder(x_categ_flat, x_cont_flat)

        # Reshape back and apply mask
        D = row_embs_flat.shape[-1]
        row_embs = row_embs_flat.view(B, R, D)
        row_embs = row_embs * row_mask.unsqueeze(-1).float()

        # Aggregate and generate
        table_emb = self.aggregator(row_embs, row_mask)
        soft_embeds = self.projector(table_emb)
        return self.decoder.generate_from_soft(soft_embeds, max_new_tokens=max_new_tokens,
                                               do_sample=do_sample, temperature=temperature, top_p=top_p)


## 8. Training

We keep GPT-2 **frozen** by default and only train:
- TabTransformer encoder
- Attention pooling aggregator
- Soft prompt projector

> You can unfreeze GPT-2 later for better quality if you have GPU memory.


In [8]:
LEARNING_RATE = 3e-4
EPOCHS = 3
GRAD_CLIP = 1.0
WARMUP_STEPS = 100
LOG_EVERY = 20

num_categories = [size for _, size in CATEG_COLS]
num_continuous = NUM_CONT

model = Tab2TextModel(num_categories, num_continuous, gpt_model_name=tokenizer_name, freeze_gpt2=False, soft_K=20).to(device)

# Only train unfrozen params
optim = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=LEARNING_RATE, weight_decay=0.01)

# Simple linear warmup schedule
num_training_steps = EPOCHS * math.ceil(len(train_loader))
sched = get_linear_schedule_with_warmup(optim, num_warmup_steps=WARMUP_STEPS, num_training_steps=num_training_steps)

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [9]:
def run_epoch(loader, train=True):
    model.train(train)
    total, n = 0.0, 0
    for it, batch in enumerate(loader):
        x_categ = batch["x_categ"].to(device)
        x_cont  = batch["x_cont"].to(device)
        row_mask= batch["row_mask"].to(device)
        input_ids = batch["input_ids"].to(device)
        attention_mask_text = batch["attention_mask_text"].to(device)

        out = model(x_categ, x_cont, row_mask, input_ids, attention_mask_text)
        loss = out.loss

        if train:
            optim.zero_grad(set_to_none=True)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
            optim.step()
            sched.step()

        total += loss.item() * x_categ.size(0)
        n += x_categ.size(0)

        if train and (it+1) % LOG_EVERY == 0:
            print(f"iter {it+1:4d}/{len(loader)} | loss {loss.item():.3f}")
    return total / max(1,n)

best_val = float("inf")
for ep in range(1, EPOCHS+1):
    t0 = time.time()
    tr_loss = run_epoch(train_loader, train=True)
    val_loss = run_epoch(val_loader, train=False)
    dt = time.time()-t0
    print(f"Epoch {ep}: train {tr_loss:.3f} | val {val_loss:.3f}  ({dt:.1f}s)")
    if val_loss < best_val:
        best_val = val_loss
        os.makedirs("checkpoints", exist_ok=True)
        torch.save(model.state_dict(), "checkpoints/tab2text_best.pt")
        print("  Saved best checkpoint")

`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.


iter   20/200 | loss 1.518
iter   40/200 | loss 0.609
iter   60/200 | loss 0.548
iter   80/200 | loss 0.557
iter  100/200 | loss 0.549
iter  120/200 | loss 0.586
iter  140/200 | loss 0.524
iter  160/200 | loss 0.517
iter  180/200 | loss 0.524
iter  200/200 | loss 0.547
Epoch 1: train 0.874 | val 0.532  (49.8s)
  Saved best checkpoint
iter   20/200 | loss 0.508
iter   40/200 | loss 0.509
iter   60/200 | loss 0.478
iter   80/200 | loss 0.481
iter  100/200 | loss 0.455
iter  120/200 | loss 0.538
iter  140/200 | loss 0.503
iter  160/200 | loss 0.502
iter  180/200 | loss 0.497
iter  200/200 | loss 0.527
Epoch 2: train 0.506 | val 0.505  (37.8s)
  Saved best checkpoint
iter   20/200 | loss 0.446
iter   40/200 | loss 0.478
iter   60/200 | loss 0.493
iter   80/200 | loss 0.463
iter  100/200 | loss 0.478
iter  120/200 | loss 0.493
iter  140/200 | loss 0.464
iter  160/200 | loss 0.476
iter  180/200 | loss 0.499
iter  200/200 | loss 0.475
Epoch 3: train 0.476 | val 0.500  (39.5s)
  Saved best che


## 9. Inference / Generation Demo

We take a validation table, encode → aggregate → project to soft tokens, and ask GPT-2 to **generate** an insight.


In [10]:
# Load best checkpoint (optional if continuing same session)
if os.path.exists("checkpoints/tab2text_best.pt"):
    _ = model.load_state_dict(torch.load("checkpoints/tab2text_best.pt", map_location=device), strict=False)
    print("Loaded best checkpoint.")

model.eval()
with torch.no_grad():
    batch = next(iter(val_loader))
    x_categ = batch["x_categ"][:2].to(device)
    x_cont  = batch["x_cont"][:2].to(device)
    row_mask= batch["row_mask"][:2].to(device)

    gen_texts = model.generate(x_categ, x_cont, row_mask, max_new_tokens=120, do_sample=True, temperature=0.8, top_p=0.95)
    for i, t in enumerate(gen_texts):
        print(f"---- Generated insight #{i} ----\n{t}\n")


Loaded best checkpoint.
---- Generated insight #0 ----
Customers from Country_3 show elevated churn (33.3% vs 18.2% elsewhere). Average balances are higher in Country_3 ($1,723 vs $1,004). The Consumer segment also churns more (30.0% vs 18.1%). Prioritize retention for Country_3 and Consumer customers with higher balances.

---- Generated insight #1 ----
Customers from Country_2 show elevated churn (0.0% vs 23.6% elsewhere). Average balances are higher in Country_2 ($1,769 vs $1,054). The Enterprise segment also churns more (30.0% vs 25.1%). Prioritize retention for Country_2 and Enterprise customers with higher balances.



# 10. Preliminary Testing

Factual correctness, comparing generated insights with ground truth via a similarity score.

SentenceTransformer('all-MiniLM-L6-v2')

In [11]:
# Load embedding model for semantic similarity
sbert = SentenceTransformer('all-MiniLM-L6-v2')

def factual_correctness(preds, targets):
    pred_emb = sbert.encode(preds, convert_to_tensor=True)
    targ_emb = sbert.encode(targets, convert_to_tensor=True)
    sims = cosine_similarity(pred_emb.cpu(), targ_emb.cpu())
    return sims.diagonal().mean()

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [12]:
if os.path.exists("checkpoints/tab2text_best.pt"):
    _ = model.load_state_dict(torch.load("checkpoints/tab2text_best.pt", map_location=device), strict=False)
    print("Loaded best checkpoint.")

Loaded best checkpoint.


In [13]:
model.eval()
with torch.no_grad():
    # Take a validation batch
    batch = next(iter(val_loader))
    x_categ = batch["x_categ"][:4].to(device)
    x_cont  = batch["x_cont"][:4].to(device)
    row_mask= batch["row_mask"][:4].to(device)

    # Decode ground-truth insights from input_ids
    true_texts = tokenizer.batch_decode(
        batch["input_ids"][:4],
        skip_special_tokens=True
    )

    # Generate predictions from the model
    gen_texts = model.generate(
        x_categ, x_cont, row_mask,
        max_new_tokens=120,
        do_sample=True,
        temperature=0.8,
        top_p=0.95
    )

    # Print side-by-side
    for i, (pred, truth) in enumerate(zip(gen_texts, true_texts)):
        print(f"---- Pair #{i} ----")
        print(f"Predicted:    {pred}")
        print(f"Ground truth: {truth}\n")

    # Compute factual correctness score
    score = factual_correctness(gen_texts, true_texts)
    print(f"\nAverage factual correctness score: {score:.4f}")

---- Pair #0 ----
Predicted:    Customers from Country_9 show elevated churn (25.0% vs 12.5% elsewhere). Average balances are higher in Country_9 ($1,416 vs $1,025). The Enterprise segment also churns more (24.3% vs 9.6%). Prioritize retention for Country_9 and Enterprise customers with higher balances.
Ground truth: Customers from Country_0 show elevated churn (42.9% vs 17.8% elsewhere). Average balances are higher in Country_0 ($1,680 vs $1,027). The Enterprise segment also churns more (36.7% vs 15.4%). Prioritize retention for Country_0 and Enterprise customers with higher balances.

---- Pair #1 ----
Predicted:    Customers from Country_10 show elevated churn (40.0% vs 12.4% elsewhere). Average balances are higher in Country_10 ($1,744 vs $1,011). The Mid segment also churns more (25.0% vs 12.5%). Prioritize retention for Country_10 and Mid customers with higher balances.
Ground truth: Customers from Country_4 show elevated churn (0.0% vs 19.1% elsewhere). Average balances are high

In [19]:
# Simpler version focusing on the 3 most important variables
def simple_key_variable_accuracy(predictions: List[str], ground_truths: List[str]) -> Dict[str, float]:
    """Focus on the 3 most critical business variables"""

    def extract_key_vars(text: str) -> Dict:
        vars_dict = {}

        # 1. Hot Country (most important)
        country_match = re.search(r'Country_(\d+)', text)
        if country_match:
            vars_dict['country'] = int(country_match.group(1))

        # 2. Hot Segment (second most important)
        segment_match = re.search(r'(SMB|Mid|Enterprise|Consumer)', text)
        if segment_match:
            vars_dict['segment'] = segment_match.group(1)

        # 3. Country churn rate comparison (third most important)
        # Check if the model correctly identifies that hot country has higher churn
        churn_match = re.search(r'elevated churn \((\d+\.?\d*)% vs (\d+\.?\d*)%', text)
        if churn_match:
            hot_churn = float(churn_match.group(1))
            rest_churn = float(churn_match.group(2))
            vars_dict['churn_relationship_correct'] = hot_churn > rest_churn

        return vars_dict

    # Extract variables
    pred_vars = [extract_key_vars(pred) for pred in predictions]
    true_vars = [extract_key_vars(true) for true in ground_truths]

    # Calculate accuracies
    results = {}
    for key in ['country', 'segment', 'churn_relationship_correct']:
        correct = 0
        total = 0

        for pred, true in zip(pred_vars, true_vars):
            if key in true:
                total += 1
                if key in pred and pred[key] == true[key]:
                    correct += 1

        if total > 0:
            results[key] = correct / total

    # Overall accuracy (equal weights)
    scores = list(results.values())
    results['overall'] = sum(scores) / len(scores) if scores else 0.0

    return results

# Integration with your existing evaluation
def run_improved_evaluation():
    model.eval()
    all_predictions = []
    all_ground_truths = []

    with torch.no_grad():
        for batch in val_loader:
            x_categ = batch["x_categ"].to(device)
            x_cont = batch["x_cont"].to(device)
            row_mask = batch["row_mask"].to(device)

            true_texts = tokenizer.batch_decode(batch["input_ids"], skip_special_tokens=True)
            gen_texts = model.generate(x_categ, x_cont, row_mask, max_new_tokens=120)

            all_predictions.extend(gen_texts)
            all_ground_truths.extend(true_texts)

    # Calculate the new metrics
    key_var_results = simple_key_variable_accuracy(all_predictions, all_ground_truths)

    print("=== Key Variable Accuracy ===")
    print(f"Country Identification: {key_var_results.get('country', 0):.3f}")
    print(f"Segment Identification: {key_var_results.get('segment', 0):.3f}")
    print(f"Churn Relationship: {key_var_results.get('churn_relationship_correct', 0):.3f}")
    print(f"Overall Key Variables: {key_var_results.get('overall', 0):.3f}")

    return key_var_results

In [20]:
run_improved_evaluation()

=== Key Variable Accuracy ===
Country Identification: 0.075
Segment Identification: 0.231
Churn Relationship: 0.700
Overall Key Variables: 0.335


{'country': 0.075,
 'segment': 0.23125,
 'churn_relationship_correct': 0.7,
 'overall': 0.33541666666666664}

In [21]:
val_ds[0]["insight"]

'Customers from Country_0 show elevated churn (42.9% vs 17.8% elsewhere). Average balances are higher in Country_0 ($1,680 vs $1,027). The Enterprise segment also churns more (36.7% vs 15.4%). Prioritize retention for Country_0 and Enterprise customers with higher balances.'


## 11. Environment Info


In [None]:

import torch, transformers, platform
print("torch:", torch.__version__)
print("transformers:", transformers.__version__)
print("python:", sys.version)
print("platform:", platform.platform())

torch: 2.8.0+cpu
transformers: 4.55.2
python: 3.13.2 (tags/v3.13.2:4f8bb39, Feb  4 2025, 15:23:48) [MSC v.1942 64 bit (AMD64)]
platform: Windows-11-10.0.26100-SP0
