
# Tabular → Insight Text: E-A-D (Encoder → Aggregator → Decoder) Pipeline

This notebook builds a full pipeline to convert tabular data (rows × features) into natural-language **business insights**.

**Architecture**  
- **E — Encoder (TabTransformer-style):** Produces a per-row embedding from categorical + continuous features.  
- **A — Aggregator (Attention Pooling):** Collapses *N* row embeddings into a single **table embedding**.  
- **D — Decoder (GPT-2 with Soft Prompting):** Maps the table embedding to a small set of **soft prompt tokens** that condition a decoder-only LLM (GPT-2) to generate insight text.

> The notebook uses a **synthetic dataset** where each example is a small table (many rows) plus a templated ground-truth insight. This keeps training fully supervised and easy to iterate.



## 1. Setup

You'll need a recent PyTorch and Transformers. If you're in Colab or a fresh environment, run the cell below. If you're offline or already installed, you can skip it.


In [1]:

# If needed, uncomment:
# !pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# !pip install -q transformers accelerate datasets

import os, math, random, json, sys, time
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Hugging Face
from transformers import AutoTokenizer, AutoModelForCausalLM, get_linear_schedule_with_warmup

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)
seed = 42
random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)


  from .autonotebook import tqdm as notebook_tqdm


Device: cpu



## 2. Synthetic Data: Tables + Ground-Truth Insights

We synthesize many **mini-tables** (each is a separate training example). For each table, we bake in a couple of clear patterns, then render a short **ground-truth insight** string describing those patterns.  
This lets us train end-to-end without needing a curated labeled corpus.


In [None]:

# --- Synthetic schema ---
CATEG_COLS = [
    ("country", 12),  # e.g., 12 countries
    ("segment", 4),   # SMB / Mid / Enterprise / Consumer
    ("product", 6),   # product lines
]
NUM_CONT = 6  # numerical features, e.g., balance, age, tenure, etc.

# We'll generate a binary column we don't feed to the model directly (churn/fraud)
# but we use it to derive target insight text for supervision.
# Each table is its own training example.
@dataclass
class TableConfig:
    n_rows_min: int = 80
    n_rows_max: int = 150
    p_base_churn: float = 0.15

def generate_one_table(cfg: TableConfig):
    n_rows = random.randint(cfg.n_rows_min, cfg.n_rows_max)
    # Choose one "hot" country that has higher churn and higher balances on average
    hot_country = random.randint(0, CATEG_COLS[0][1] - 1)
    hot_segment = random.randint(0, CATEG_COLS[1][1] - 1)

    # Continuous feature generative params
    # balance ~ higher for hot_country; age, tenure, etc.
    base_balance_mu = 1000.0
    base_balance_sigma = 400.0
    balance_boost = random.uniform(300, 800)  # hot country boost

    # Base churn odds + boosts
    p_churn_base = cfg.p_base_churn
    p_churn_hot_country_boost = random.uniform(0.10, 0.25)
    p_churn_hot_segment_boost = random.uniform(0.05, 0.15)

    x_categ = torch.zeros(n_rows, len(CATEG_COLS), dtype=torch.long)
    x_cont  = torch.zeros(n_rows, NUM_CONT, dtype=torch.float32)
    churn   = torch.zeros(n_rows, dtype=torch.long)

    for i in range(n_rows):
        c_country = random.randint(0, CATEG_COLS[0][1] - 1)
        c_segment = random.randint(0, CATEG_COLS[1][1] - 1)
        c_product = random.randint(0, CATEG_COLS[2][1] - 1)
        x_categ[i] = torch.tensor([c_country, c_segment, c_product])

        # Continuous features (toy):
        # 0: balance, 1: age, 2: tenure, 3: tx_freq, 4: complaints, 5: income
        balance_mu = base_balance_mu + (balance_boost if c_country == hot_country else 0.0)
        balance = random.gauss(balance_mu, base_balance_sigma)
        age = max(18, random.gauss(42, 12))
        tenure = max(0.0, random.gauss(3.5, 2.0))
        tx_freq = max(0.0, random.gauss(8, 3))
        complaints = max(0.0, random.gauss(0.6, 0.7))
        income = max(0.0, random.gauss(48_000, 15_000))

        x_cont[i] = torch.tensor([balance, age, tenure, tx_freq, complaints, income], dtype=torch.float32)

        # Churn probability with boosts for hot country + hot segment
        p = p_churn_base
        if c_country == hot_country:
            p += p_churn_hot_country_boost
        if c_segment == hot_segment:
            p += p_churn_hot_segment_boost
        churn[i] = 1 if random.random() < p else 0

    # Derive summary stats for ground-truth insight
    # Country-level churn: compare hot_country vs rest
    mask_hot = (x_categ[:,0] == hot_country)
    churn_hot = churn[mask_hot].float().mean().item() if mask_hot.any() else 0.0
    churn_rest = churn[~mask_hot].float().mean().item() if (~mask_hot).any() else 0.0

    avg_balance_hot = x_cont[mask_hot,0].mean().item() if mask_hot.any() else 0.0
    avg_balance_rest = x_cont[~mask_hot,0].mean().item() if (~mask_hot).any() else 0.0

    # Segment-level churn for the chosen hot segment
    mask_seg = (x_categ[:,1] == hot_segment)
    churn_seg = churn[mask_seg].float().mean().item() if mask_seg.any() else 0.0
    churn_not_seg = churn[~mask_seg].float().mean().item() if (~mask_seg).any() else 0.0

    # Build human-readable labels from IDs
    def name_country(i): return f"Country_{i}"
    def name_seg(i): return ["SMB","Mid","Enterprise","Consumer"][i] if CATEG_COLS[1][1]==4 and i<4 else f"Segment_{i}"

    hot_country_name = name_country(hot_country)
    hot_segment_name = name_seg(hot_segment)

    # Ground-truth insight text (templated)
    insight = (
        f"Customers from {hot_country_name} show elevated churn ({churn_hot:.1%} vs {churn_rest:.1%} elsewhere). "
        f"Average balances are higher in {hot_country_name} (${avg_balance_hot:,.0f} vs ${avg_balance_rest:,.0f}). "
        f"The {hot_segment_name} segment also churns more ({churn_seg:.1%} vs {churn_not_seg:.1%}). "
        f"Prioritize retention for {hot_country_name} and {hot_segment_name} customers with higher balances."
    )

    return {
        "x_categ": x_categ,
        "x_cont": x_cont,
        "row_mask": torch.ones(n_rows, dtype=torch.bool),  # all valid rows
        "insight": insight,
    }

class TableToTextDataset(Dataset):
    def __init__(self, n_tables=800, cfg: TableConfig=TableConfig()):
        self.samples = [generate_one_table(cfg) for _ in range(n_tables)]

    def __len__(self): return len(self.samples)
    def __getitem__(self, idx): return self.samples[idx]

train_ds = TableToTextDataset(n_tables=800)
val_ds   = TableToTextDataset(n_tables=160)
print("Train/Val sizes:", len(train_ds), len(val_ds))
print("Sample insight:", train_ds[0]["insight"][:160], "...")


Train/Val sizes: 800 160
Sample insight: Customers from Country_0 show elevated churn (18.2% vs 13.3% elsewhere). Average balances are higher in Country_0 ($1,553 vs $1,000). The Enterprise segment als ...



## 3. Tokenizer & DataLoader (variable rows per table)

- We pad **rows** per table within a batch and carry a boolean `row_mask` for attention pooling.  
- We tokenize the target insight text and pad to the max text length in the batch.  
- We set `pad_token = eos_token` for GPT-2 to simplify handling.


In [6]:

# GPT-2 tokenizer
tokenizer_name = "gpt2"  # you can switch to "distilgpt2" for a smaller model
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

MAX_TEXT_TOKENS = 160  # cap for training (keep moderate)

def collate_fn(batch):
    # Compute max rows & pad
    max_rows = max(b["x_categ"].shape[0] for b in batch)
    B = len(batch)
    Cc = len(CATEG_COLS)
    Nc = NUM_CONT

    pad_categ = torch.zeros(B, max_rows, Cc, dtype=torch.long)
    pad_cont  = torch.zeros(B, max_rows, Nc, dtype=torch.float32)
    row_mask  = torch.zeros(B, max_rows, dtype=torch.bool)

    insights = [b["insight"] for b in batch]
    for i, b in enumerate(batch):
        n = b["x_categ"].shape[0]
        pad_categ[i, :n] = b["x_categ"]
        pad_cont[i, :n]  = b["x_cont"]
        row_mask[i, :n]  = True

    # Tokenize and pad targets
    tok = tokenizer(
        insights,
        padding=True,
        truncation=True,
        max_length=MAX_TEXT_TOKENS,
        return_tensors="pt"
    )
    input_ids = tok["input_ids"]
    attention_mask_text = tok["attention_mask"]  # 1 for real tokens, 0 for pads

    return {
        "x_categ": pad_categ,
        "x_cont": pad_cont,
        "row_mask": row_mask,
        "input_ids": input_ids,
        "attention_mask_text": attention_mask_text,
    }

train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_ds,   batch_size=4, shuffle=False, collate_fn=collate_fn)

batch = next(iter(train_loader))
print("Batch shapes:")
for k,v in batch.items():
    print(" ", k, tuple(v.shape))


Batch shapes:
  x_categ (4, 133, 3)
  x_cont (4, 133, 6)
  row_mask (4, 133)
  input_ids (4, 74)
  attention_mask_text (4, 74)



## 4. Model — E: Row Encoder (TabTransformer-style)

- Each categorical column has its **own embedding**.  
- We treat a row's categorical values as a **token sequence** and pass them through a small Transformer encoder.  
- Continuous features go through an MLP and are concatenated with the categorical representation.  
- Finally projected to a compact **row embedding**.


In [None]:

class CatEmbeddingBlock(nn.Module):
    def __init__(self, cardinalities: List[int], d_cat: int):
        super().__init__()
        self.embeds = nn.ModuleList([nn.Embedding(n, d_cat) for (_, n) in CATEG_COLS])
        for emb in self.embeds:
            nn.init.normal_(emb.weight, std=0.02)

    def forward(self, x_categ):  # [B, R, Cc]
        B,R,Cc = x_categ.shape
        embs = []
        for j in range(Cc):
            embs.append(self.embeds[j](x_categ[:,:,j]))  # [B, R, d_cat]
        # Stack into [B, R, Cc, d_cat] then view as sequence over Cc per row
        cat_tok = torch.stack(embs, dim=2)  # [B, R, Cc, d_cat]
        return cat_tok

class RowEncoder(nn.Module):
    def __init__(self,
                 d_cat: int = 32,
                 n_heads: int = 4,
                 n_layers: int = 2,
                 d_ff: int = 128,
                 d_cont: int = 64,
                 row_emb_dim: int = 128,
                 dropout: float = 0.1):
        super().__init__()
        self.Cc = len(CATEG_COLS)
        self.Nc = NUM_CONT
        self.d_cat = d_cat
        self.cat_block = CatEmbeddingBlock([n for _,n in CATEG_COLS], d_cat)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_cat, nhead=n_heads, dim_feedforward=d_ff,
            dropout=dropout, batch_first=True, activation="gelu"
        )
        self.cat_encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers)
        self.cat_ln = nn.LayerNorm(d_cat)

        self.cont_mlp = nn.Sequential(
            nn.LayerNorm(self.Nc),
            nn.Linear(self.Nc, d_cont),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_cont, d_cont),
            nn.GELU(),
        )

        self.fuse = nn.Sequential(
            nn.Linear(d_cat + d_cont, row_emb_dim),
            nn.GELU(),
            nn.LayerNorm(row_emb_dim),
        )

    def forward(self, x_categ, x_cont):  # x_categ:[B,R,Cc], x_cont:[B,R,Nc]
        B,R,Cc = x_categ.shape
        cat_tok = self.cat_block(x_categ)  # [B, R, Cc, d_cat]
        # Process cat sequence per row: we reshape to merge batch and rows
        cat_tok = cat_tok.view(B*R, Cc, self.d_cat)  # [B*R, Cc, d_cat]
        cat_enc = self.cat_encoder(cat_tok)          # [B*R, Cc, d_cat]
        cat_enc = self.cat_ln(cat_enc.mean(dim=1))   # [B*R, d_cat] mean over Cc tokens

        cont_enc = self.cont_mlp(x_cont.view(B*R, -1))  # [B*R, d_cont]

        row = torch.cat([cat_enc, cont_enc], dim=-1)    # [B*R, d_cat+d_cont]
        row = self.fuse(row)                            # [B*R, row_emb_dim]
        row = row.view(B, R, -1)                        # [B, R, row_emb_dim]
        return row


## 5. Model — A: Attention Pooling over Rows

A simple attention mechanism over rows to produce a fixed-size **table embedding**.


In [8]:

class AttentionPooling(nn.Module):
    def __init__(self, dim, hidden=64, dropout=0.1):
        super().__init__()
        self.att = nn.Sequential(
            nn.Linear(dim, hidden), nn.Tanh(),
            nn.Dropout(dropout),
            nn.Linear(hidden, 1)
        )

    def forward(self, x, row_mask=None):  # x:[B,R,D]
        scores = self.att(x).squeeze(-1)  # [B,R]
        if row_mask is not None:
            scores = scores.masked_fill(~row_mask, float('-inf'))
        w = torch.softmax(scores, dim=1)  # [B,R]
        pooled = torch.bmm(w.unsqueeze(1), x).squeeze(1)  # [B,D]
        return pooled, w



## 6. Model — D: GPT-2 with Soft Prompting

We map the table embedding into **K learnable soft tokens** in GPT-2's hidden space.  
During training we concatenate these soft tokens with the token embeddings of the target text and train with standard next-token LM loss (labels = text tokens, soft tokens masked with `-100`).  
During generation we feed only the soft tokens and let GPT-2 produce text.


In [None]:

class SoftPromptProjector(nn.Module):
    def __init__(self, in_dim, gpt_hidden, K=20, hidden=256, dropout=0.1):
        super().__init__()
        self.K = K
        self.out_dim = gpt_hidden
        self.net = nn.Sequential(
            nn.LayerNorm(in_dim),
            nn.Linear(in_dim, hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, K * gpt_hidden)
        )

    def forward(self, table_emb):  # [B, in_dim]
        B = table_emb.size(0)
        x = self.net(table_emb)                # [B, K*H]
        x = x.view(B, self.K, self.out_dim)    # [B, K, H]
        return x

class InsightDecoder(nn.Module):
    def __init__(self, model_name="gpt2", freeze_gpt2=False):
        super().__init__()
        self.gpt2 = AutoModelForCausalLM.from_pretrained(model_name)
        self.hidden = self.gpt2.config.n_embd
        if freeze_gpt2:
            for p in self.gpt2.parameters():
                p.requires_grad = False

    def forward_train(self, soft_embeds, input_ids, attention_mask_text):
        # soft_embeds: [B, K, H]
        # input_ids: [B, T], attention_mask_text: [B, T] (1 for real tokens)
        B, K, H = soft_embeds.shape
        T = input_ids.size(1)

        # Get word embeddings for tokens
        tok_embeds = self.gpt2.transformer.wte(input_ids)  # [B, T, H]

        inputs_embeds = torch.cat([soft_embeds, tok_embeds], dim=1)  # [B, K+T, H]

        # Build attention mask: soft tokens are all visible (1s), then text mask
        attn_mask = torch.cat([torch.ones(B, K, dtype=attention_mask_text.dtype, device=input_ids.device),
                               attention_mask_text], dim=1)  # [B, K+T]

        # Labels: -100 for soft tokens, then actual token IDs
        labels = torch.cat([torch.full((B, K), -100, dtype=torch.long, device=input_ids.device),
                            input_ids], dim=1)  # [B, K+T]

        out = self.gpt2(inputs_embeds=inputs_embeds, attention_mask=attn_mask, labels=labels)
        return out  # .loss, .logits

    @torch.no_grad()
    def generate_from_soft(self, soft_embeds, max_new_tokens=100, do_sample=True, temperature=0.8, top_p=0.95):
        B, K, H = soft_embeds.shape
        # Start generation with only the soft prompt
        attn_mask = torch.ones(B, K, dtype=torch.long, device=soft_embeds.device)
        gen_ids = self.gpt2.generate(
            inputs_embeds=soft_embeds,
            attention_mask=attn_mask,
            max_new_tokens=max_new_tokens,
            do_sample=do_sample,
            temperature=temperature,
            top_p=top_p,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
        # gen_ids will be token IDs *for the generated continuation only* in some HF versions.
        # To be safe, decode last max_new_tokens tokens.
        text = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
        return text


## 7. Full Model Wiring

`Table2TextModel` wraps the encoder, aggregator, projector, and decoder, exposing `forward()` for training and a `generate()` utility.


In [None]:
class Table2TextModel(nn.Module):
    def __init__(self,
                 row_emb_dim=128,
                 gpt_model_name="gpt2",
                 freeze_gpt2=True,
                 soft_K=20):
        super().__init__()
        self.encoder = RowEncoder(row_emb_dim=row_emb_dim)
        self.aggregator = AttentionPooling(dim=row_emb_dim)
        # We instantiate decoder first to know hidden size
        self.decoder = InsightDecoder(model_name=gpt_model_name, freeze_gpt2=freeze_gpt2)
        gpt_hidden = self.decoder.hidden
        self.projector = SoftPromptProjector(in_dim=row_emb_dim, gpt_hidden=gpt_hidden, K=soft_K)

    def forward(self, x_categ, x_cont, row_mask, input_ids, attention_mask_text):
        # E
        row_embs = self.encoder(x_categ, x_cont)  # [B,R,D]
        # A
        table_emb, _ = self.aggregator(row_embs, row_mask=row_mask)  # [B,D]
        # Soft prompt
        soft_embeds = self.projector(table_emb)  # [B,K,H]
        # D
        out = self.decoder.forward_train(soft_embeds, input_ids, attention_mask_text)
        return out

    @torch.no_grad()
    def generate(self, x_categ, x_cont, row_mask, max_new_tokens=100, **gen_kw):
        row_embs = self.encoder(x_categ, x_cont)
        table_emb, _ = self.aggregator(row_embs, row_mask=row_mask)
        soft = self.projector(table_emb)
        texts = self.decoder.generate_from_soft(soft, max_new_tokens=max_new_tokens, **gen_kw)
        return texts


## 8. Training

We keep GPT-2 **frozen** by default and only train:
- Row encoder (TabTransformer-style)
- Attention pooling aggregator
- Soft prompt projector

> You can unfreeze GPT-2 later for better quality if you have GPU memory.


In [11]:
LEARNING_RATE = 3e-4
EPOCHS = 3
GRAD_CLIP = 1.0
WARMUP_STEPS = 100
LOG_EVERY = 20

model = Table2TextModel(gpt_model_name=tokenizer_name, freeze_gpt2=True, soft_K=20).to(device)

# Only train unfrozen params
optim = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=LEARNING_RATE, weight_decay=0.01)

# Simple linear warmup schedule
num_training_steps = EPOCHS * math.ceil(len(train_loader))
sched = get_linear_schedule_with_warmup(optim, num_warmup_steps=WARMUP_STEPS, num_training_steps=num_training_steps)

def run_epoch(loader, train=True):
    model.train(train)
    total, n = 0.0, 0
    for it, batch in enumerate(loader):
        x_categ = batch["x_categ"].to(device)
        x_cont  = batch["x_cont"].to(device)
        row_mask= batch["row_mask"].to(device)
        input_ids = batch["input_ids"].to(device)
        attention_mask_text = batch["attention_mask_text"].to(device)

        out = model(x_categ, x_cont, row_mask, input_ids, attention_mask_text)
        loss = out.loss

        if train:
            optim.zero_grad(set_to_none=True)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
            optim.step()
            sched.step()

        total += loss.item() * x_categ.size(0)
        n += x_categ.size(0)

        if train and (it+1) % LOG_EVERY == 0:
            print(f"iter {it+1:4d}/{len(loader)} | loss {loss.item():.3f}")
    return total / max(1,n)

best_val = float("inf")
for ep in range(1, EPOCHS+1):
    t0 = time.time()
    tr_loss = run_epoch(train_loader, train=True)
    val_loss = run_epoch(val_loader, train=False)
    dt = time.time()-t0
    print(f"Epoch {ep}: train {tr_loss:.3f} | val {val_loss:.3f}  ({dt:.1f}s)")
    if val_loss < best_val:
        best_val = val_loss
        os.makedirs("checkpoints", exist_ok=True)
        torch.save(model.state_dict(), "checkpoints/tab2text_best.pt")
        print("  Saved best checkpoint")

`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.


iter   20/200 | loss 5.004
iter   40/200 | loss 4.170
iter   60/200 | loss 3.500
iter   80/200 | loss 3.019
iter  100/200 | loss 2.383
iter  120/200 | loss 1.816
iter  140/200 | loss 1.577
iter  160/200 | loss 1.254
iter  180/200 | loss 1.205
iter  200/200 | loss 1.269
Epoch 1: train 2.740 | val 0.765  (736.7s)
  Saved best checkpoint
iter   20/200 | loss 1.072
iter   40/200 | loss 1.086
iter   60/200 | loss 1.031
iter   80/200 | loss 0.923
iter  100/200 | loss 1.025
iter  120/200 | loss 1.030
iter  140/200 | loss 0.859
iter  160/200 | loss 0.889
iter  180/200 | loss 0.848
iter  200/200 | loss 0.848
Epoch 2: train 0.976 | val 0.622  (734.5s)
  Saved best checkpoint
iter   20/200 | loss 0.820
iter   40/200 | loss 0.895
iter   60/200 | loss 0.916
iter   80/200 | loss 0.796
iter  100/200 | loss 0.857
iter  120/200 | loss 0.803
iter  140/200 | loss 0.803
iter  160/200 | loss 0.760
iter  180/200 | loss 0.817
iter  200/200 | loss 0.729
Epoch 3: train 0.827 | val 0.607  (910.3s)
  Saved best 


## 9. Inference / Generation Demo

We take a validation table, encode → aggregate → project to soft tokens, and ask GPT-2 to **generate** an insight.


In [12]:

# Load best checkpoint (optional if continuing same session)
if os.path.exists("checkpoints/tab2text_best.pt"):
    _ = model.load_state_dict(torch.load("checkpoints/tab2text_best.pt", map_location=device), strict=False)
    print("Loaded best checkpoint.")

model.eval()
with torch.no_grad():
    batch = next(iter(val_loader))
    x_categ = batch["x_categ"][:2].to(device)
    x_cont  = batch["x_cont"][:2].to(device)
    row_mask= batch["row_mask"][:2].to(device)

    gen_texts = model.generate(x_categ, x_cont, row_mask, max_new_tokens=120, do_sample=True, temperature=0.8, top_p=0.95)
    for i, t in enumerate(gen_texts):
        print(f"---- Generated insight #{i} ----\n{t}\n")


Loaded best checkpoint.
---- Generated insight #0 ----
Customers from Country_5 show elevated churn (43.7% vs 19.1% elsewhere). Average balances are higher in Country_5 ($2,831 vs $1,822). The SMB segment also churns more (25.1% vs 10.9%). Prioritize retention for Country_5 and SMB customers with higher balances.

---- Generated insight #1 ----
Customers from Country_4 show elevated churn (27.3% vs 14.9% elsewhere). Average balances are higher in Country_4 ($6,739 vs $5,914). The Enterprise segment also churns more (20.9% vs 15.5%). Prioritize retention for Country_4 and Enterprise customers with higher balances.




## 10. (Optional) Unfreeze GPT-2 for End-to-End Fine-tuning

If you have the VRAM, you can unfreeze GPT-2 to improve quality. It’s slower and heavier, but can help.


In [None]:

# To unfreeze GPT-2, re-initialize model with freeze_gpt2=False, then re-run training.
# model = Table2TextModel(gpt_model_name=tokenizer_name, freeze_gpt2=False, soft_K=20).to(device)
# optim = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=1e-5, weight_decay=0.01)
# ... re-run the training loop with smaller LR for stability.



## 11. Next Steps & Extensions

- **Richer Encoders:** Try feature-wise masking, residual connections across row/feature blocks, or pretrained tabular encoders.  
- **Better Aggregation:** Multi-head attention pooling, Set Transformers, or learn multiple table tokens (a *set* of soft prompts).  
- **Multiple Insight Types:** Train with multi-target texts (e.g., churn, fraud, revenue) using task tokens.  
- **Scaling:** Use `distilgpt2` or larger GPT-2 variants; add gradient checkpointing; train on real tables with curated insight texts.  
- **Evaluation:** ROUGE/BLEU for text, plus *factuality checks* comparing generated statements to table stats.



## 12. Environment Info


In [13]:

import torch, transformers, platform
print("torch:", torch.__version__)
print("transformers:", transformers.__version__)
print("python:", sys.version)
print("platform:", platform.platform())

torch: 2.8.0+cpu
transformers: 4.55.2
python: 3.13.2 (tags/v3.13.2:4f8bb39, Feb  4 2025, 15:23:48) [MSC v.1942 64 bit (AMD64)]
platform: Windows-11-10.0.26100-SP0
