In [1]:
# index_predictor.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from collections import defaultdict
from typing import List

# ======== 1) Setup small word list and index (group by first letter) ========
WORDS = [
    "log", "loss", "let", "length", "const", "class",
    "return", "function", "console", "float", "for", "if", "else"
]
# normalize to lowercase
WORDS = [w.lower() for w in WORDS]
WORD2IDX = {w: i for i, w in enumerate(WORDS)}
IDX2WORD = {i: w for w, i in WORD2IDX.items()}

# build index: first_letter -> list of word indices
INDEX = defaultdict(list)
for w, idx in WORD2IDX.items():
    INDEX[w[0]].append(idx)

print("Index (first letter -> words):")
for k in sorted(INDEX.keys()):
    print(f"  '{k}': {[IDX2WORD[i] for i in INDEX[k]]}")

# ======== 2) Character vocabulary ========
CHARS = sorted({c for w in WORDS for c in w})
PAD = "<PAD>"
CHARS = [PAD] + CHARS
CHAR2IDX = {c: i for i, c in enumerate(CHARS)}
IDX2CHAR = {i: c for c, i in CHAR2IDX.items()}
NUM_CHARS = len(CHARS)
NUM_WORDS = len(WORDS)

# max prefix length for padding/training
MAX_PREFIX = max(len(w) - 1 for w in WORDS)  # we never include full word as prefix target
MAX_PREFIX = max(1, MAX_PREFIX)  # at least 1

# ======== 3) Small dataset of (prefix -> target word) ========
examples = []
for w in WORDS:
    wi = WORD2IDX[w]
    # produce prefixes of length 1..len(w)-1 (we want to predict full word from partial input)
    for L in range(1, max(2, min(len(w), MAX_PREFIX + 1))):  # at least 1
        prefix = w[:L]
        # convert to indices and pad to MAX_PREFIX
        prefix_idxs = [CHAR2IDX[ch] for ch in prefix]
        pad = [CHAR2IDX[PAD]] * (MAX_PREFIX - len(prefix_idxs))
        examples.append((prefix_idxs + pad, len(prefix), wi))  # store actual prefix length too

print(f"\nPrepared {len(examples)} training examples (prefixes padded to length {MAX_PREFIX}).")

# simple randomize
random.shuffle(examples)

# ======== 4) Tiny model: char embeddings -> pooled -> MLP -> logits over words ========
class PrefixWordPredictor(nn.Module):
    def __init__(self, num_chars, emb_dim, hidden_dim, num_words):
        super().__init__()
        self.emb = nn.Embedding(num_chars, emb_dim, padding_idx=0)
        self.fc1 = nn.Linear(emb_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_words)

    def forward(self, prefix_idxs, prefix_lens, allowed_mask):
        """
        prefix_idxs: (B, MAX_PREFIX) LongTensor
        prefix_lens: (B,) actual prefix lengths (>=1)
        allowed_mask: (B, NUM_WORDS) bool mask where allowed_mask[b, j] = True means this word j is allowed (same first letter)
        """
        emb = self.emb(prefix_idxs)  # (B, L, E)
        # simple pooling: average across valid positions (ignore pads)
        # create mask for valid char positions (pad index = 0)
        valid = (prefix_idxs != 0).float().unsqueeze(-1)  # (B, L, 1)
        summed = (emb * valid).sum(dim=1)  # (B, E)
        lengths = valid.sum(dim=1).clamp(min=1)  # (B, 1)
        pooled = summed / lengths  # (B, E)

        h = F.relu(self.fc1(pooled))  # (B, H)
        logits = self.fc2(h)  # (B, NUM_WORDS)

        # mask logits so words not in the index (first-letter bucket) are impossible
        # allowed_mask: bool tensor -> convert to logits mask
        very_neg = -1e9
        logits = logits + (~allowed_mask).float() * very_neg
        return logits

# ======== 5) Helpers: make allowed_mask per example, encode prefixes, predict function ========
def make_allowed_mask(prefixs_chars: List[str]):
    """Given a list of prefix strings (real prefixes, not padded), produce allowed word masks."""
    masks = []
    for pre in prefixs_chars:
        first_letter = pre[0]  # we always ensure prefix length >=1
        allowed_indices = INDEX[first_letter]
        mask = torch.zeros(NUM_WORDS, dtype=torch.bool)
        mask[allowed_indices] = True
        masks.append(mask)
    return torch.stack(masks, dim=0)  # (B, NUM_WORDS)

def encode_prefix(prefix: str):
    idxs = [CHAR2IDX[ch] for ch in prefix]
    pad = [CHAR2IDX[PAD]] * (MAX_PREFIX - len(idxs))
    return idxs + pad

def pretty_predictions(logits, topk=5):
    probs = F.softmax(logits, dim=-1)
    topv, topi = probs.topk(min(topk, logits.size(-1)), dim=-1)
    outs = []
    for row_vals, row_idx in zip(topv.tolist(), topi.tolist()):
        outs.append([(IDX2WORD[i], p) for i, p in zip(row_idx, row_vals)])
    return outs  # list per batch: list of (word, prob)

# ======== 6) Training (tiny) ========
device = torch.device("cpu")
model = PrefixWordPredictor(num_chars=NUM_CHARS, emb_dim=16, hidden_dim=32, num_words=NUM_WORDS).to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.02)
loss_fn = nn.CrossEntropyLoss(reduction='mean')

BATCH = 8
EPOCHS = 150

# prepare tensors for quick training
Xs = torch.tensor([ex[0] for ex in examples], dtype=torch.long)
prefix_lens = torch.tensor([ex[1] for ex in examples], dtype=torch.long)
Ys = torch.tensor([ex[2] for ex in examples], dtype=torch.long)

for epoch in range(EPOCHS):
    perm = torch.randperm(len(Xs))
    total_loss = 0.0
    for i in range(0, len(Xs), BATCH):
        idx = perm[i:i+BATCH]
        xb = Xs[idx].to(device)
        yb = Ys[idx].to(device)
        # actual prefix strings for mask
        prefix_strs = []
        for ii in idx:
            # compute actual string ignoring PAD
            raw = Xs[ii].tolist()
            s = ''.join([IDX2CHAR[c] for c in raw if IDX2CHAR[c] != PAD])
            prefix_strs.append(s)
        allowed_mask = make_allowed_mask(prefix_strs).to(device)
        logits = model(xb, prefix_lens[idx].to(device), allowed_mask)
        loss = loss_fn(logits, yb)
        opt.zero_grad()
        loss.backward()
        opt.step()
        total_loss += loss.item() * xb.size(0)
    if (epoch+1) % 50 == 0 or epoch == 0:
        print(f"Epoch {epoch+1}/{EPOCHS} - avg loss: {total_loss/len(Xs):.4f}")

print("Training finished.\n")

# ======== 7) Interactive-like demo: simulate typing a word and measuring errors ========
def simulate_typing(true_word: str):
    print(f"Simulating typing for true word: '{true_word}'")
    true_idx = WORD2IDX[true_word]
    prefix = ""
    prev_probs = None
    for step in range(1, len(true_word)):  # reveal letters one by one (we stop before full word)
        prefix = true_word[:step]
        enc = torch.tensor([encode_prefix(prefix)], dtype=torch.long).to(device)
        allowed_mask = make_allowed_mask([prefix]).to(device)
        logits = model(enc, torch.tensor([len(prefix)]), allowed_mask)
        probs = F.softmax(logits, dim=-1).squeeze(0).detach().cpu()
        top = torch.topk(probs, k=min(6, NUM_WORDS))
        # metrics
        p_true = probs[true_idx].item()
        nll = -torch.log(torch.clamp(probs[true_idx], min=1e-9)).item()
        # index error (simple): 1 - probability assigned to true word among allowed words
        allowed_indices = INDEX[prefix[0]]
        allowed_mass = probs[allowed_indices].sum().item()
        index_error = 1.0 - p_true/allowed_mass if allowed_mass > 0 else 1.0
        print(f"\nPrefix '{prefix}' -> top candidates (probabilities):")
        for w_i, p in zip(top.indices.tolist(), top.values.tolist()):
            print(f"   {IDX2WORD[w_i]:10s} : {p:.3f}")
        print(f"  -> p(true='{true_word}') = {p_true:.4f}, NLL = {nll:.3f}")
        print(f"  -> allowed mass (same-first-letter) = {allowed_mass:.3f}, index_error = {index_error:.3f}")
        if prev_probs is not None:
            # compute KL from previous distribution to current distribution (only over allowed words)
            prev = prev_probs.clone()
            cur = probs.clone()
            # restrict to allowed indices for stable KL
            allowed = torch.tensor(INDEX[prefix[0]])  # candidate indices
            prev_sub = prev[allowed]
            cur_sub = cur[allowed]
            # avoid zeros
            prev_sub = torch.clamp(prev_sub, 1e-9, 1.0)
            cur_sub = torch.clamp(cur_sub, 1e-9, 1.0)
            kl = (prev_sub * (prev_sub.log() - cur_sub.log())).sum().item()
            print(f"  KL(prev->cur) over allowed words = {kl:.4f}")
        prev_probs = probs

# test on two words with same first letter 'l': 'log' and 'loss'
simulate_typing("log")
simulate_typing("loss")

# ======== 8) Example: use only first-letter to restrict candidates (index behavior) ========
def predict_from_single_letter(letter: str, topk=6):
    prefix = letter
    enc = torch.tensor([encode_prefix(prefix)], dtype=torch.long).to(device)
    allowed_mask = make_allowed_mask([prefix]).to(device)
    logits = model(enc, torch.tensor([1]), allowed_mask)
    outs = pretty_predictions(logits, topk=topk)[0]
    print(f"\nGiven only first letter '{letter}': top candidates:")
    for w, p in outs:
        print(f"  {w:10s} : {p:.3f}")

predict_from_single_letter("l")
predict_from_single_letter("r")


Index (first letter -> words):
  'c': ['const', 'class', 'console']
  'e': ['else']
  'f': ['function', 'float', 'for']
  'i': ['if']
  'l': ['log', 'loss', 'let', 'length']
  'r': ['return']

Prepared 48 training examples (prefixes padded to length 7).
Epoch 1/150 - avg loss: 1.0516
Epoch 50/150 - avg loss: 0.4251
Epoch 100/150 - avg loss: 0.4130
Epoch 150/150 - avg loss: 0.4127
Training finished.

Simulating typing for true word: 'log'

Prefix 'l' -> top candidates (probabilities):
   log        : 0.284
   loss       : 0.244
   let        : 0.236
   length     : 0.235
   const      : 0.000
   class      : 0.000
  -> p(true='log') = 0.2844, NLL = 1.257
  -> allowed mass (same-first-letter) = 1.000, index_error = 0.716

Prefix 'lo' -> top candidates (probabilities):
   log        : 0.589
   loss       : 0.408
   length     : 0.002
   let        : 0.001
   const      : 0.000
   class      : 0.000
  -> p(true='log') = 0.5889, NLL = 0.529
  -> allowed mass (same-first-letter) = 1.000, ind

In [None]:
# interactive_index_predictor.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from collections import defaultdict

# ---------------------------
# 1) Small word list + index
# ---------------------------
WORDS = ["log", "loss", "let", "length", "const", "class",
         "return", "function", "console", "float", "for", "if", "else"]
WORDS = [w.lower() for w in WORDS]
WORD2IDX = {w: i for i, w in enumerate(WORDS)}
IDX2WORD = {i: w for w, i in WORD2IDX.items()}

INDEX = defaultdict(list)
for w, idx in WORD2IDX.items():
    INDEX[w[0]].append(idx)

# ---------------------------
# 2) char vocabulary + padding
# ---------------------------
PAD = "<PAD>"
CHARS = sorted({c for w in WORDS for c in w})
CHARS = [PAD] + CHARS
CHAR2IDX = {c: i for i, c in enumerate(CHARS)}
IDX2CHAR = {i: c for c, i in CHAR2IDX.items()}
NUM_CHARS = len(CHARS)
NUM_WORDS = len(WORDS)
MAX_PREFIX = max(1, max(len(w) - 1 for w in WORDS))

# ---------------------------
# 3) dataset: prefixes -> word
# ---------------------------
examples = []
for w in WORDS:
    wi = WORD2IDX[w]
    # prefixes length 1..len(w)-1
    for L in range(1, len(w)):
        prefix = w[:L]
        idxs = [CHAR2IDX[ch] for ch in prefix]
        pad = [CHAR2IDX[PAD]] * (MAX_PREFIX - len(idxs))
        examples.append((idxs + pad, len(prefix), wi))
random.shuffle(examples)

Xs = torch.tensor([ex[0] for ex in examples], dtype=torch.long)
prefix_lens = torch.tensor([ex[1] for ex in examples], dtype=torch.long)
Ys = torch.tensor([ex[2] for ex in examples], dtype=torch.long)

# ---------------------------
# 4) tiny model
# ---------------------------
class PrefixWordPredictor(nn.Module):
    def __init__(self, num_chars, emb_dim, hidden_dim, num_words):
        super().__init__()
        self.emb = nn.Embedding(num_chars, emb_dim, padding_idx=0)
        self.fc1 = nn.Linear(emb_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_words)

    def forward(self, prefix_idxs, prefix_lens, allowed_mask):
        emb = self.emb(prefix_idxs)                   # (B, L, E)
        valid = (prefix_idxs != 0).float().unsqueeze(-1)
        summed = (emb * valid).sum(dim=1)
        lengths = valid.sum(dim=1).clamp(min=1)
        pooled = summed / lengths
        h = F.relu(self.fc1(pooled))
        logits = self.fc2(h)
        very_neg = -1e9
        logits = logits + (~allowed_mask).float() * very_neg
        return logits

device = torch.device("cpu")
model = PrefixWordPredictor(NUM_CHARS, emb_dim=16, hidden_dim=32, num_words=NUM_WORDS).to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.02)
loss_fn = nn.CrossEntropyLoss()

# quick training (tiny)
BATCH = 8
EPOCHS = 120
for epoch in range(EPOCHS):
    perm = torch.randperm(len(Xs))
    total = 0.0
    for i in range(0, len(Xs), BATCH):
        batch_idx = perm[i:i+BATCH]
        xb = Xs[batch_idx].to(device)
        yb = Ys[batch_idx].to(device)
        # build allowed_mask using first letter of each prefix (we know prefix length >=1)
        prefix_strings = []
        for bi in batch_idx:
            raw = Xs[bi].tolist()
            s = ''.join([IDX2CHAR[c] for c in raw if IDX2CHAR[c] != PAD])
            prefix_strings.append(s)
        masks = []
        for s in prefix_strings:
            m = torch.zeros(NUM_WORDS, dtype=torch.bool)
            m[INDEX[s[0]]] = True
            masks.append(m)
        allowed_mask = torch.stack(masks).to(device)
        logits = model(xb, prefix_lens[batch_idx].to(device), allowed_mask)
        loss = loss_fn(logits, yb)
        opt.zero_grad(); loss.backward(); opt.step()
        total += loss.item() * xb.size(0)
    if (epoch+1) % 40 == 0 or epoch == 0:
        print(f"[train] epoch {epoch+1}/{EPOCHS} avg loss {total/len(Xs):.4f}")

print("Training done.\n")

# ---------------------------
# 5) helper functions
# ---------------------------
def encode_prefix(prefix: str):
    idxs = [CHAR2IDX.get(c, 0) for c in prefix]
    pad = [CHAR2IDX[PAD]] * (MAX_PREFIX - len(idxs))
    return idxs + pad

def make_allowed_mask(prefix_list):
    masks = []
    for pre in prefix_list:
        m = torch.zeros(NUM_WORDS, dtype=torch.bool)
        if len(pre) == 0:
            m[:] = True  # allow all if empty (but our app requires >=1)
        else:
            m[INDEX[pre[0]]] = True
        masks.append(m)
    return torch.stack(masks, dim=0)

def pretty_print_preds(logits, topk=6):
    probs = F.softmax(logits, dim=-1).squeeze(0).detach().cpu()
    topv, topi = probs.topk(min(topk, probs.size(0)))
    out = [(IDX2WORD[i.item()], v.item()) for v,i in zip(topv, topi)]
    return out, probs

# ---------------------------
# 6) interactive loop
# ---------------------------
def interactive_loop():
    print("Interactive index-first predictor.")
    print("Words in index:", WORDS)
    print("Type single or multiple letters (prefix). Type 'quit' to exit.")
    print("Optionally enter a ground-truth target word (must be in the word list) to compute errors.")
    target = input("Enter ground-truth target word (or press Enter to skip): ").strip().lower()
    if target == "":
        target = None
    elif target not in WORD2IDX:
        print("Target not in word list. Ignoring.")
        target = None
    print("---- start typing prefixes ----")
    while True:
        pre = input("prefix> ").strip().lower()
        if pre == "quit":
            break
        if len(pre) == 0:
            print("Please type at least the first letter.")
            continue
        if pre[0] not in INDEX:
            print(f"No words starting with '{pre[0]}'.")
            continue
        enc = torch.tensor([encode_prefix(pre)], dtype=torch.long).to(device)
        allowed = make_allowed_mask([pre]).to(device)
        logits = model(enc, torch.tensor([len(pre)]), allowed)
        preds, probs = pretty_print_preds(logits, topk=6)
        allowed_indices = INDEX[pre[0]]
        allowed_mass = probs[allowed_indices].sum().item()
        print(f"\nTop candidates for prefix '{pre}':")
        for w, p in preds:
            print(f"  {w:12s}  p = {p:.3f}")
        print(f"Allowed-bucket (first-letter='{pre[0]}') total mass: {allowed_mass:.3f}")
        if target:
            true_i = WORD2IDX[target]
            p_true = probs[true_i].item()
            nll = -torch.log(torch.clamp(probs[true_i], 1e-9)).item()
            index_error = 1.0 - (p_true / allowed_mass) if allowed_mass > 0 else 1.0
            rank = (probs > probs[true_i]).sum().item() + 1
            print(f"Metrics (true = '{target}'):")
            print(f"  p(true) = {p_true:.4f}, NLL = {nll:.3f}, rank = {rank}, index_error = {index_error:.3f}")
        print("\n---")

if __name__ == "__main__":
    interactive_loop()


[train] epoch 1/120 avg loss 1.0385
[train] epoch 40/120 avg loss 0.4339
[train] epoch 80/120 avg loss 0.4101
[train] epoch 120/120 avg loss 0.4059
Training done.

Interactive index-first predictor.
Words in index: ['log', 'loss', 'let', 'length', 'const', 'class', 'return', 'function', 'console', 'float', 'for', 'if', 'else']
Type single or multiple letters (prefix). Type 'quit' to exit.
Optionally enter a ground-truth target word (must be in the word list) to compute errors.


Enter ground-truth target word (or press Enter to skip):  l


Target not in word list. Ignoring.
---- start typing prefixes ----


prefix>  lo



Top candidates for prefix 'lo':
  loss          p = 0.551
  log           p = 0.448
  length        p = 0.001
  let           p = 0.000
  const         p = 0.000
  class         p = 0.000
Allowed-bucket (first-letter='l') total mass: 1.000

---


prefix>  r



Top candidates for prefix 'r':
  return        p = 1.000
  log           p = 0.000
  loss          p = 0.000
  let           p = 0.000
  length        p = 0.000
  const         p = 0.000
Allowed-bucket (first-letter='r') total mass: 1.000

---


prefix>  re



Top candidates for prefix 're':
  return        p = 1.000
  log           p = 0.000
  loss          p = 0.000
  let           p = 0.000
  length        p = 0.000
  const         p = 0.000
Allowed-bucket (first-letter='r') total mass: 1.000

---


prefix>  ro



Top candidates for prefix 'ro':
  return        p = 1.000
  log           p = 0.000
  loss          p = 0.000
  let           p = 0.000
  length        p = 0.000
  const         p = 0.000
Allowed-bucket (first-letter='r') total mass: 1.000

---


prefix>  l



Top candidates for prefix 'l':
  loss          p = 0.282
  log           p = 0.257
  length        p = 0.247
  let           p = 0.214
  const         p = 0.000
  class         p = 0.000
Allowed-bucket (first-letter='l') total mass: 1.000

---
