In [1]:
import sys
!{sys.executable} -m pip install torch --index-url https://download.pytorch.org/whl/cu121

Looking in indexes: https://download.pytorch.org/whl/cu121


In [2]:
import sys
!{sys.executable} -m pip install scikit-learn

Collecting scikit-learn
  Downloading scikit_learn-1.7.2-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (9.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.7/9.7 MB[0m [31m36.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting scipy>=1.8.0
  Downloading scipy-1.15.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (37.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m37.7/37.7 MB[0m [31m72.6 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting joblib>=1.2.0
  Downloading joblib-1.5.3-py3-none-any.whl (309 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m309.1/309.1 KB[0m [31m133.5 MB/s[0m eta [36m0:00:00[0m
Collecting threadpoolctl>=3.1.0
  Downloading threadpoolctl-3.6.0-py3-none-any.whl (18 kB)
Installing collected packages: threadpoolctl, scipy, joblib, scikit-learn
Successfully installed joblib-1.5.3 scikit-learn-1.7.2 scipy-1.15.3 threadpoolctl-3.6.0


In [4]:
!wget -q https://raw.githubusercontent.com/GwenTsang/tests/refs/heads/main/all_Flaubert.txt
!wget -q https://raw.githubusercontent.com/GwenTsang/tests/refs/heads/main/french.txt
!wget -q https://raw.githubusercontent.com/GwenTsang/tests/refs/heads/main/l.txt

In [5]:
#@title Setup & Training

import os
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

import re
import unicodedata
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Tuple, Sequence

import numpy as np
import torch
import torch.nn as nn

CONFIG = {
    "max_pair_dist": 280, "max_depth": 2, "n_paragraphs": 3000, "min_char_freq": 5,
    "lowercase": False, "fold_diacritics": True, "digit_map": "none", "collapse_whitespace": True,
    "hidden_size": 512, "embedding_dim": 128, "n_layers": 1,
    "dropout": 0.17723111503731978, "forget_bias": 0.673403567309223,
    "seq_length": 256, "batch_size": 128, "lr": 0.0008848537716447663,
    "weight_decay": 2.5444743948884783e-06, "grad_clip": 5.0, "epochs": 40,
    "probe_alpha": 10.0, "probe_subsample_every": 1, "probe_max_paras_per_lang": 200,
    "probe_max_len": 400, "probe_exclude_parens": True, "probe_test_size": 0.2, "seed": 10,
}
TRAIN_FILES = ["all_Flaubert.txt", "french.txt", "l.txt"]
EVAL_FR_PATH, EVAL_EN_PATH = "all_Flaubert.txt", "french.txt"
UNK_CHAR = "\u0000"

@dataclass(frozen=True)
class ParagraphInfo:
    text: str; has_pair: bool; balanced: bool; max_pair_dist: int; max_depth: int

@dataclass(frozen=True)
class TextNormConfig:
    lowercase: bool; fold_diacritics: bool; digit_map: str; collapse_whitespace: bool

def _analyze_parentheses(par: str) -> Tuple[bool, bool, int, int]:
    has_pair = "(" in par and ")" in par
    stack, max_dist, depth, max_depth, balanced = [], 0, 0, 0, True
    for i, ch in enumerate(par):
        if ch == "(":
            stack.append(i); depth += 1; max_depth = max(max_depth, depth)
        elif ch == ")":
            if not stack: balanced = False; break
            max_dist = max(max_dist, i - stack.pop()); depth = max(0, depth - 1)
    if stack: balanced = False
    return has_pair, balanced, max_dist, max_depth

def load_paragraphs(files: Sequence[Path]) -> List[str]:
    paragraphs = []
    for fp in files:
        txt = fp.read_text(encoding="utf-8", errors="ignore")
        paragraphs.extend([p.strip() for p in re.split(r"\n\s*\n", txt) if p.strip()])
    return paragraphs

def build_paragraph_infos(files: Sequence[Path]) -> List[ParagraphInfo]:
    return [ParagraphInfo(p, *_analyze_parentheses(p)) for p in load_paragraphs(files)]

def normalize_text(s: str, cfg: TextNormConfig) -> str:
    if cfg.lowercase: s = s.lower()
    if cfg.fold_diacritics:
        s = unicodedata.normalize("NFKD", s)
        s = "".join(ch for ch in s if not unicodedata.combining(ch))
    if cfg.digit_map != "none": s = re.sub(r"\d", "#" if cfg.digit_map == "hash" else "0", s)
    if cfg.collapse_whitespace: s = re.sub(r"[ \t]+", " ", s); s = re.sub(r"\n{3,}", "\n\n", s)
    return s

def make_training_text(infos, *, max_pair_dist, max_depth, n_paragraphs, seed, norm):
    candidates = [pi.text for pi in infos if pi.has_pair and pi.balanced and pi.max_pair_dist <= max_pair_dist and pi.max_depth <= max_depth]
    if not candidates: return ""
    rng = np.random.default_rng(seed)
    chosen = [candidates[i] for i in rng.choice(len(candidates), size=min(n_paragraphs, len(candidates)), replace=False)]
    return "\n\n".join([normalize_text(p, norm) for p in chosen])

def build_vocab_and_encode(text: str, *, min_char_freq: int):
    freq = {}
    for ch in text: freq[ch] = freq.get(ch, 0) + 1
    for ch in (UNK_CHAR, "(", ")", "\n", " "): freq.setdefault(ch, 10**9)
    chars = sorted(set([ch for ch, c in freq.items() if c >= min_char_freq] + [UNK_CHAR, "(", ")"]))
    char2int = {ch: i for i, ch in enumerate(chars)}
    unk = char2int[UNK_CHAR]
    return char2int, np.fromiter((char2int.get(ch, unk) for ch in text), dtype=np.int64)

class CharLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, n_layers, dropout):
        super().__init__()
        self.vocab_size, self.hidden_size, self.n_layers = vocab_size, hidden_size, n_layers
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_size, n_layers, dropout=dropout if n_layers > 1 else 0.0, batch_first=True)
        self.drop = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def init_hidden(self, batch_size, device):
        return (torch.zeros(self.n_layers, batch_size, self.hidden_size, device=device),
                torch.zeros(self.n_layers, batch_size, self.hidden_size, device=device))

    def forward(self, x, hc):
        out, (h, c) = self.lstm(self.embedding(x), hc)
        return self.fc(self.drop(out)), (h, c)

def init_forget_gate_bias(lstm, forget_bias):
    with torch.no_grad():
        for layer in range(lstm.num_layers):
            for name in (f"bias_ih_l{layer}", f"bias_hh_l{layer}"):
                getattr(lstm, name)[lstm.hidden_size:2*lstm.hidden_size].fill_(forget_bias)

torch.manual_seed(CONFIG["seed"]); np.random.seed(CONFIG["seed"])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

norm = TextNormConfig(CONFIG["lowercase"], CONFIG["fold_diacritics"], CONFIG["digit_map"], CONFIG["collapse_whitespace"])
infos = build_paragraph_infos([Path(f) for f in TRAIN_FILES])
print(f"Loaded {len(infos)} paragraphs")

text = make_training_text(infos, max_pair_dist=CONFIG["max_pair_dist"], max_depth=CONFIG["max_depth"],
                          n_paragraphs=CONFIG["n_paragraphs"], seed=CONFIG["seed"], norm=norm)
print(f"Training text: {len(text):,} chars")

char2int, encoded = build_vocab_and_encode(text, min_char_freq=CONFIG["min_char_freq"])
vocab_size = len(char2int)
print(f"Vocab size: {vocab_size}")

encoded = torch.tensor(encoded, dtype=torch.long, device=device)
n_tokens = encoded.numel()

model = CharLSTM(vocab_size, CONFIG["embedding_dim"], CONFIG["hidden_size"], CONFIG["n_layers"], CONFIG["dropout"]).to(device)
init_forget_gate_bias(model.lstm, CONFIG["forget_bias"])
model.lstm.flatten_parameters()
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG["lr"], weight_decay=CONFIG["weight_decay"])
criterion = nn.CrossEntropyLoss()
scaler = torch.amp.GradScaler("cuda", enabled=device.type == "cuda")
B, T = CONFIG["batch_size"], CONFIG["seq_length"]
rng = torch.Generator(device=device); rng.manual_seed(CONFIG["seed"])

model.train()
for epoch in range(CONFIG["epochs"]):
    off = int(torch.randint(0, T, (1,), generator=rng, device=device).item())
    n_batches = (n_tokens - off - 1) // (B * T)
    if n_batches < 1: continue
    data = encoded[off: off + n_batches * B * T + 1]
    mat, mat_y = data[:-1].reshape(B, -1), data[1:].reshape(B, -1)
    hc = model.init_hidden(B, device)
    epoch_loss = 0.0
    for step in range(mat.size(1) // T):
        x, y = mat[:, step*T:(step+1)*T], mat_y[:, step*T:(step+1)*T]
        hc = (hc[0].detach(), hc[1].detach())
        optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast("cuda", dtype=torch.float16, enabled=device.type == "cuda"):
            logits, hc = model(x, hc)
            loss = criterion(logits.reshape(-1, vocab_size), y.reshape(-1))
        scaler.scale(loss).backward()
        if CONFIG["grad_clip"] > 0: scaler.unscale_(optimizer); nn.utils.clip_grad_norm_(model.parameters(), CONFIG["grad_clip"])
        scaler.step(optimizer); scaler.update()
        epoch_loss += loss.item()
    print(f"Epoch {epoch+1}/{CONFIG['epochs']} - Loss: {epoch_loss / (mat.size(1) // T):.4f}")
print("Training complete!")

Device: cuda
Loaded 23644 paragraphs
Training text: 3,640,437 chars
Vocab size: 115
Parameters: 1,388,531
Epoch 1/40 - Loss: 2.5326
Epoch 2/40 - Loss: 1.9547
Epoch 3/40 - Loss: 1.7454
Epoch 4/40 - Loss: 1.6105
Epoch 5/40 - Loss: 1.5205
Epoch 6/40 - Loss: 1.4586
Epoch 7/40 - Loss: 1.4130
Epoch 8/40 - Loss: 1.3772
Epoch 9/40 - Loss: 1.3479
Epoch 10/40 - Loss: 1.3229
Epoch 11/40 - Loss: 1.3014
Epoch 12/40 - Loss: 1.2814
Epoch 13/40 - Loss: 1.2640
Epoch 14/40 - Loss: 1.2479
Epoch 15/40 - Loss: 1.2339
Epoch 16/40 - Loss: 1.2208
Epoch 17/40 - Loss: 1.2087
Epoch 18/40 - Loss: 1.1979
Epoch 19/40 - Loss: 1.1877
Epoch 20/40 - Loss: 1.1782
Epoch 21/40 - Loss: 1.1691
Epoch 22/40 - Loss: 1.1611
Epoch 23/40 - Loss: 1.1536
Epoch 24/40 - Loss: 1.1468
Epoch 25/40 - Loss: 1.1400
Epoch 26/40 - Loss: 1.1335
Epoch 27/40 - Loss: 1.1272
Epoch 28/40 - Loss: 1.1219
Epoch 29/40 - Loss: 1.1164
Epoch 30/40 - Loss: 1.1112
Epoch 31/40 - Loss: 1.1062
Epoch 32/40 - Loss: 1.1019
Epoch 33/40 - Loss: 1.0973
Epoch 34/40 

In [6]:
#@title Evaluation

from sklearn.model_selection import train_test_split

def inside_after_reading(s: str) -> np.ndarray:
    d, y = 0, np.zeros(len(s), dtype=np.int64)
    for t, ch in enumerate(s):
        if ch == "(": d += 1
        elif ch == ")": d = max(0, d - 1)
        y[t] = 1 if d > 0 else 0
    return y

@torch.no_grad()
def cell_states_last_layer(model, s, char2int, device, max_len):
    if len(s) > max_len: s = s[:max_len]
    unk = char2int[UNK_CHAR]
    x = torch.tensor([char2int.get(ch, unk) for ch in s], device=device, dtype=torch.long).unsqueeze(0)
    h, c = model.init_hidden(1, device)
    model.lstm.flatten_parameters()
    Cs = []
    for t in range(x.size(1)):
        _, (h, c) = model.lstm(model.embedding(x[:, t:t+1]), (h, c))
        Cs.append(c[-1, 0].float().cpu().numpy())
    return np.stack(Cs, axis=0).astype(np.float32)

def build_probe_dataset(model, char2int, device, fr_path, en_path, cfg, norm):
    rng = np.random.default_rng(cfg["seed"])
    model.eval()
    def get_paras(path):
        txt = path.read_text(encoding="utf-8", errors="ignore")
        ps = [p.strip() for p in re.split(r"\n\s*\n", txt) if p.strip() and "(" in p and ")" in p]
        if len(ps) > cfg["probe_max_paras_per_lang"]:
            ps = [ps[i] for i in rng.choice(len(ps), size=cfg["probe_max_paras_per_lang"], replace=False)]
        return ps
    Xs, ys, langs = [], [], []
    for lang, path in [("fr", fr_path), ("en", en_path)]:
        for p in get_paras(path):
            p = normalize_text(p, norm)
            if len(p) > cfg["probe_max_len"]:
                start = int(rng.integers(0, len(p) - cfg["probe_max_len"] + 1))
                p = p[start:start + cfg["probe_max_len"]]
            y = inside_after_reading(p)
            X = cell_states_last_layer(model, p, char2int, device, cfg["probe_max_len"])
            mask = np.ones(len(p), dtype=bool)
            if cfg["probe_exclude_parens"]: mask &= np.array([c not in "()" for c in p])
            if cfg["probe_subsample_every"] > 1:
                take = np.zeros(len(p), dtype=bool); take[::cfg["probe_subsample_every"]] = True; mask &= take
            if mask.sum() >= 10:
                Xs.append(X[mask]); ys.append(y[mask]); langs.append(np.full(mask.sum(), lang, dtype=object))
    X, y = np.concatenate(Xs), np.concatenate(ys).astype(np.float32)
    return train_test_split(X, y, test_size=cfg["probe_test_size"], random_state=cfg["seed"], stratify=np.concatenate(langs))

def ridge_r2_local(X_train, y_train, X_test, y_test, alpha):
    Xtr, Xte = X_train.astype(np.float64), X_test.astype(np.float64)
    ytr, yte = y_train.astype(np.float64), y_test.astype(np.float64)
    x_mean, y_mean = Xtr.mean(0), ytr.mean()
    Xtr_c, ytr_c = Xtr - x_mean, ytr - y_mean
    w = (Xtr_c.T @ ytr_c) / (np.sum(Xtr_c**2, 0) + alpha)
    Xte_c, a = Xte - x_mean, yte - y_mean
    sse = np.sum(a**2) - 2*(w * (Xte_c.T @ a)) + (w**2) * np.sum(Xte_c**2, 0)
    ss_tot = np.sum((yte - yte.mean())**2)
    r2 = 1 - sse / ss_tot if ss_tot > 1e-12 else np.zeros_like(sse)
    return float(r2.max()), int(np.argmax(r2))

print("Building probe dataset...")
X_train, X_test, y_train, y_test = build_probe_dataset(model, char2int, device, Path(EVAL_FR_PATH), Path(EVAL_EN_PATH), CONFIG, norm)
print(f"Probe: {X_train.shape[0]} train, {X_test.shape[0]} test")

r2_local, best_neuron = ridge_r2_local(X_train, y_train, X_test, y_test, CONFIG["probe_alpha"])
print(f"\n{'='*50}\nRESULTS\n{'='*50}")
print(f"Best R² (local): {r2_local:.4f}")
print(f"Best neuron index: {best_neuron}\n{'='*50}")

Building probe dataset...
Probe: 57448 train, 14362 test

RESULTS
Best R² (local): 0.7363
Best neuron index: 250


In [7]:
# Saving

checkpoint = {
    "model_state_dict": model.state_dict(),
    "char2int": char2int,
    "config": CONFIG,
    "vocab_size": len(char2int),
    "r2_local": r2_local,
    "best_neuron": best_neuron,
}

torch.save(checkpoint, "best_lstm_model.pt")
print("Saved: best_lstm_model.pt")
print(f"  Vocab size: {len(char2int)}")
print(f"  R² score: {r2_local:.4f}")
print(f"  Best neuron: {best_neuron}")

Saved: best_lstm_model.pt
  Vocab size: 115
  R² score: 0.7363
  Best neuron: 250
