# Information-Theoretic Analysis of Protein Sequence Representations

This notebook compares the information content between traditional amino acid sequences and FoldTree2 discrete structure representations. The goal is to quantify whether our discrete structural alphabet successfully captures and preserves structural information that complements or enhances sequence-based representations.

## Key Analyses

- **Sequence-level**: Direct comparison of entropy rates between AA and DSR sequences
- **K-mer analysis**: Information content in local sequence contexts using k-mer frequency distributions
- **Family-level MSAs**: Positional entropy analysis and mutual information between aligned positions
- **Cross-validation**: Robust entropy rate estimation using backoff smoothing and k-order Markov models

The analyses will reveal whether the FoldTree2 representation provides a more compressed encoding of structural information compared to amino acid sequences alone.

In [None]:
import numpy as np

def entropy(p): 
    p = p[p>0]
    return -(p*np.log2(p)).sum()

def empirical_probs(tokens, alphabet):
    counts = np.bincount(tokens, minlength=len(alphabet)).astype(float)
    return (counts + 1.0/len(alphabet)) / (counts.sum() + 1.0)

def mutual_info(x, y, ax, ay):
    # x,y are integer-encoded; ax,ay are alphabet sizes
    px = empirical_probs(x, np.arange(ax))
    py = empirical_probs(y, np.arange(ay))
    joint = np.zeros((ax, ay))
    for xi, yi in zip(x, y):
        joint[xi, yi] += 1
    joint = (joint + 1.0/(ax*ay)) / (len(x) + 1.0)
    mi = 0.0
    for i in range(ax):
        for j in range(ay):
            if joint[i,j] > 0:
                mi += joint[i,j]*np.log2(joint[i,j]/(px[i]*py[j]))
    return mi

In [None]:
def ensure_same_shape(msa1: List[str], msa2: List[str]):
    assert len(msa1) == len(msa2), "MSA row count mismatch between AA and DSR."
    L1 = len(msa1[0])
    for s in msa1:
        assert len(s) == L1, "AA MSA rows must have equal length."
    for s in msa2:
        assert len(s) == L1, "DSR MSA must have same number of columns as AA MSA."

# ------------------------- Sequence reweighting ------------------------

def seq_identity(a: str, b: str) -> float:
    # pairwise identity over non-gap positions shared by both
    matches = 0
    comps = 0
    for x, y in zip(a, b):
        if x == '-' or y == '-':
            continue
        comps += 1
        if x == y:
            matches += 1
    if comps == 0:
        return 0.0
    return matches / comps

def reweight_sequences(msa: List[str], thresh: float = 0.8) -> np.ndarray:
    """
    Henikoff-like simple sequence reweighting by identity threshold.
    w_i = 1 / |{j: pid(i,j) >= thresh}|
    """
    n = len(msa)
    w = np.zeros(n, dtype=float)
    for i in range(n):
        c = 0
        for j in range(n):
            if seq_identity(msa[i], msa[j]) >= thresh:
                c += 1
        w[i] = 1.0 / max(1, c)
    # normalize so sum weights ~ n (convention)
    w *= (n / w.sum()) if w.sum() > 0 else 1.0
    return w

# ----------------------- Per-position entropy -------------------------

def shannon_entropy_from_counts(counts: np.ndarray, pseudocount: float = 0.0) -> float:
    """
    counts: array of size A (weighted counts over alphabet)
    returns entropy in bits
    """
    A = counts.shape[0]
    total = counts.sum()
    p = (counts + pseudocount) / (total + pseudocount * A)
    p = p[p > 0]
    return float(-(p * np.log2(p)).sum())

def per_position_entropy(
    msa: List[str],
    alphabet: List[str],
    weights: np.ndarray,
    occupancy_threshold: float = 0.7,
    pseudocount: float = None
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Returns (entropy_per_position, valid_mask) arrays of length L
    """
    A = len(alphabet)
    alpha_index = {c: i for i, c in enumerate(alphabet)}
    L = len(msa[0])
    n = len(msa)

    if pseudocount is None:
        pseudocount = 1.0 / A

    ent = np.full(L, np.nan, dtype=float)
    mask = np.zeros(L, dtype=bool)

    # occupancy per column (weighted fraction non-gap)
    occ = np.zeros(L, dtype=float)
    for i, seq in enumerate(msa):
        g = np.fromiter((c != '-' for c in seq), dtype=bool, count=L)
        occ += weights[i] * g.astype(float)
    occ /= weights.sum()

    for t in range(L):
        if occ[t] < occupancy_threshold:
            continue
        counts = np.zeros(A, dtype=float)
        for i, seq in enumerate(msa):
            ch = seq[t]
            if ch == '-':
                continue
            if ch not in alpha_index:
                # skip unknowns
                continue
            counts[alpha_index[ch]] += weights[i]
        if counts.sum() <= 0:
            continue
        ent[t] = shannon_entropy_from_counts(counts, pseudocount=pseudocount)
        mask[t] = True

    return ent, mask

In [None]:
#note. try both alignment using foldmason, ft2 and regular mafft...
# the aligner can also be an argument for a particular character model

#

## K-mer Frequency Analysis for Fold Discrimination

This experiment analyzes the discriminative power of FoldTree2 (DSR) versus amino acid representations by examining k-mer frequency distributions. The goal is to determine which alphabet better distinguishes between different protein folds.

### Methodology

- **K-mer extraction**: Compute frequency distributions of subsequences of length k for both AA and DSR sequences
- **Within-family distances**: Calculate Jensen-Shannon divergences between k-mer distributions of sequences within the same fold family
- **Between-family distances**: Measure k-mer distribution differences across distinct fold families
- **Discrimination analysis**: Compare the separation between within-family (similar folds) and between-family (different folds) distance distributions

### Key Questions

1. **Fold specificity**: Does the FoldTree2 alphabet capture fold-specific sequence patterns more effectively than amino acid sequences?
2. **Optimal k-mer length**: What subsequence length provides the best discrimination for each representation?
3. **Distribution separation**: Which alphabet shows clearer separation between intra-fold similarity and inter-fold dissimilarity?

The analysis will reveal whether structural alphabets provide enhanced discriminative power for protein fold classification compared to traditional sequence-based representations.

In [None]:
from itertools import product
def kmer_freqs(seq, k, alpha_map, A):
    idx = [alpha_map[c] for c in seq if c in alpha_map]
    if len(idx) < k: return np.ones(A**k)/(A**k)
    counts = np.zeros(A**k)
    base = A**np.arange(k)[::-1]
    for t in range(len(idx)-k+1):
        code = 0
        for j in range(k):
            code = code*A + idx[t+j]
        counts[code]+=1
    p = counts + 1e-9
    p /= p.sum()
    return p

def jsd(p, q):
    m = 0.5*(p+q)
    def KL(a,b): 
        mask = (a>0)
        return (a[mask]*np.log2(a[mask]/b[mask])).sum()
    return 0.5*(KL(p,m)+KL(q,m))

In [None]:
from typing import List, Tuple, Dict
import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score

# ---------------------------- k-mer utils ------------------------------

def build_alpha_index(alphabet: str) -> Dict[str, int]:
    return {c:i for i,c in enumerate(alphabet)}

def kmers_counts(seq: str, k: int, alpha_idx: Dict[str,int], A: int) -> np.ndarray:
    """Return counts vector of length A^k for overlapping k-mers in seq (skip k-mer if any unseen char)."""
    L = len(seq)
    if L < k or k == 0:
        return np.zeros(A**max(k,1), dtype=np.float64)
    v = np.zeros(A**k, dtype=np.float64)
    code = -1
    for i, ch in enumerate(seq):
        if ch not in alpha_idx:
            code = -1
        else:
            x = alpha_idx[ch]
            if code == -1:
                if i >= k-1:
                    ok = True
                    code_tmp = 0
                    for j in range(i-k+1, i+1):
                        c2 = seq[j]
                        if c2 not in alpha_idx:
                            ok = False; break
                        code_tmp = code_tmp * A + alpha_idx[c2]
                    if ok:
                        code = code_tmp
                        v[code] += 1.0
            else:
                code = (code % (A**(k-1))) * A + x
                v[code] += 1.0
    return v

def kmer_prob(seq: str, k: int, alphabet: str, pseudocount: float = 1e-9) -> np.ndarray:
    A = len(alphabet)
    idx = build_alpha_index(alphabet)
    c = kmers_counts(seq, k, idx, A)
    total = c.sum()
    if total == 0:
        # no valid k-mers: return uniform tiny distribution
        p = np.ones(A**k, dtype=np.float64)
        p /= p.sum()
        return p
    p = (c + pseudocount) / (total + pseudocount * c.shape[0])
    return p

def build_feature_matrix(fasta: List[Tuple[str,str]], alphabet: str, k: int, pseudocount: float):
    ids = [name for name,_ in fasta]
    seqs = [seq.upper() for _,seq in fasta]
    P = np.vstack([kmer_prob(s, k, alphabet, pseudocount=pseudocount) for s in seqs])
    return ids, P


In [None]:

# --------------------------- KMeans + eval -----------------------------

def kmeans_cluster(P: np.ndarray, K: int, n_init: int = 20, max_iter: int = 300, random_state: int = 0):
    # KMeans on probability vectors (Euclidean). For probability geometry, you can also sqrt-transform (Hellinger).
    model = KMeans(n_clusters=K, n_init=n_init, max_iter=max_iter, random_state=random_state)
    labels = model.fit_predict(P)
    return labels, model

def run_rep(
    fasta_path: str, labels_map: Dict[str,str],
    alphabet: str, k: int, pseudocount: float,
    target_clusters: int, random_state: int, n_init: int, max_iter: int
):
    fasta = read_fasta(fasta_path)
    ids, P = build_feature_matrix(fasta, alphabet, k, pseudocount)

    # align labels and filter
    y = []
    keep = []
    for i, sid in enumerate(ids):
        if sid in labels_map:
            y.append(labels_map[sid])
            keep.append(i)
    if not keep:
        raise ValueError("No IDs from FASTA matched labels.tsv")
    ids = [ids[i] for i in keep]
    P = P[keep]
    y = np.array(y)

    uniq = {lab:i for i,lab in enumerate(sorted(set(y)))}
    y_int = np.array([uniq[lab] for lab in y], dtype=int)
    K = target_clusters or len(uniq)

    # KMeans
    labels_pred, model = kmeans_cluster(
        P, K=K, n_init=n_init, max_iter=max_iter, random_state=random_state
    )

    ari = adjusted_rand_score(y_int, labels_pred)
    nmi = normalized_mutual_info_score(y_int, labels_pred)

    return {
        "ids": ids,
        "P": P,
        "y_int": y_int,
        "y_str": y,
        "labels_pred": labels_pred,
        "K": K,
        "ari": ari,
        "nmi": nmi,
    }

## Entropy Rate Estimation with k-order Markov Models

This experiment estimates the global entropy rate using k-order Markov models across multiple protein families. The analysis compares two different sequence representations:

- **AA sequences**: Traditional amino acid sequences using the 20-letter alphabet
- **DSR sequences**: Discrete structure representation using a K-token alphabet

### Methodology

The experiment uses a **backoff smoothing approach** with cross-validation to estimate entropy rates:

1. **k-order Markov modeling**: Models conditional probabilities P(x|context) where context length = k
2. **Backoff smoothing**: Handles sparse data by interpolating between different order models (k-gram → (k-1)-gram → ... → unigram)
3. **5-fold cross-validation**: Splits sequences by family to avoid overfitting
4. **Additive smoothing**: Regularizes maximum likelihood estimates with parameter α

### Aggregation Strategy

Results are aggregated at two levels:
- **Macro-averaging**: Equal weight per family (family-centric view)
- **Micro-averaging**: Weight by total tokens (sequence-centric view)

This allows comparison of structural vs. sequence-based entropy rates across different context lengths (k=0,1,2,3,4).

In [None]:
# 1) Prepare your data as dict: family -> list of sequences (strings)
families_AA  = {"PF00001": ["MKT...", "MSS..."], "PF00002": [...], ...}
families_DSR = {"PF00001": ["QAB...", "QAA..."], "PF00002": [...], ...}

# 2) Define alphabets
alphabet_AA  = list("ACDEFGHIKLMNPQRSTVWY")       # or include 'X' if you keep it
alphabet_DSR = [chr(i) for i in range(65, 65+K)]   # e.g., 'A'.. for K tokens, or your actual token set

# 3) Run
ks = (0,1,2,3,4)
aa_res, aa_agg   = run_entropy_over_families(families_AA,  alphabet_AA,  k_values=ks, alpha=0.1, delta=None, folds=5)
dsr_res, dsr_agg = run_entropy_over_families(families_DSR, alphabet_DSR, k_values=ks, alpha=0.1, delta=None, folds=5)

# 4) Compare and plot:
#   - aa_agg[k]['macro'] vs dsr_agg[k]['macro']
#   - per-family deltas: {fam: dsr_res[k][fam]-aa_res[k][fam]}

In [None]:
from collections import defaultdict, Counter
import random, math
from typing import List, Dict, Tuple

# ---------- utils
def tokenize(seq: str, alpha_map: Dict[str,int]) -> List[int]:
    return [alpha_map[c] for c in seq if c in alpha_map]

def k_context(stream: List[int], k: int):
    # yields (context_tuple, symbol) skipping first k
    if k == 0:
        for x in stream: yield (), x
    else:
        ctx = []
        for x in stream:
            ctx.append(x)
            if len(ctx) > k:
                yield tuple(ctx[-k-1:-1]), x

def add_counts(counts, stream: List[int], k: int):
    for ctx, x in k_context(stream, k):
        counts[ctx][x] += 1

def build_counts(seqs: List[List[int]], k: int, A: int):
    counts = defaultdict(lambda: Counter())
    total_tokens = 0
    for s in seqs:
        total_tokens += max(0, len(s)-k)
        add_counts(counts, s, k)
    return counts, total_tokens

# ---------- smoothed conditional with simple backoff
class BackoffKModel:
    def __init__(self, counts_k_list, A: int, alpha=0.1, delta=None):
        """
        counts_k_list: list where idx j holds counts for order j (0..k)
        A: alphabet size
        alpha: additive smoothing for MLE
        delta: backoff strength; if None, set to A (alphabet size)
        """
        self.counts = counts_k_list
        self.A = A
        self.alpha = alpha
        self.delta = delta if delta is not None else A

    def p_cond(self, ctx: Tuple[int,...], x: int) -> float:
        k = len(ctx)
        return self._p_k(k, ctx, x)

    def _p_k(self, k: int, ctx: Tuple[int,...], x: int) -> float:
        # base: unigram (order 0)
        if k == 0:
            cnts0 = self.counts[0][()]
            num = cnts0.get(x, 0) + self.alpha
            den = sum(cnts0.values()) + self.alpha * self.A
            return num / den

        # order-k MLE with smoothing
        cnts_k = self.counts[k][ctx]
        num = cnts_k.get(x, 0) + self.alpha
        den = sum(cnts_k.values()) + self.alpha * self.A
        p_mle = num / den

        # backoff weight
        gamma = self.delta / (self.delta + sum(cnts_k.values()))
        # suffix context
        suffix = ctx[1:]
        return (1 - gamma) * p_mle + gamma * self._p_k(k-1, suffix, x)

# ---------- cross-entropy (held-out)
def cross_entropy_bits(model: BackoffKModel, seqs: List[List[int]], k: int) -> float:
    tot_logloss = 0.0
    tot_tokens = 0
    for s in seqs:
        # iterate tokens with contexts; boundaries reset by per-seq processing
        for ctx, x in k_context(s, k):
            p = model.p_cond(ctx, x)
            tot_logloss += -math.log2(max(p, 1e-300))
            tot_tokens += 1
    return tot_logloss / max(1, tot_tokens)

# ---------- 5-fold CV by sequence
def entropy_rate_cv(seqs: List[List[int]], A: int, k: int, alpha=0.1, delta=None, folds=5, seed=0):
    random.Random(seed).shuffle(seqs)
    if len(seqs) < folds: folds = max(2, len(seqs))
    fold_size = math.ceil(len(seqs)/folds)
    losses = []
    for f in range(folds):
        test = seqs[f*fold_size:(f+1)*fold_size]
        train = seqs[:f*fold_size] + seqs[(f+1)*fold_size:]
        # build counts for orders 0..k
        counts_k_list = []
        for j in range(k+1):
            counts_j, _ = build_counts(train, j, A)
            counts_k_list.append(counts_j)
        model = BackoffKModel(counts_k_list, A, alpha=alpha, delta=delta)
        H = cross_entropy_bits(model, test, k)
        losses.append((H, sum(max(0, len(s)-k) for s in test)))
    # micro-average over folds
    num = sum(H*n for H, n in losses)
    den = sum(n for _, n in losses) or 1
    return num/den

# ---------- run over families and ks
def run_entropy_over_families(
    families: Dict[str, List[str]],
    alphabet: List[str],
    k_values=(0,1,2,3,4),
    alpha=0.1, delta=None, folds=5, seed=0
):
    alpha_map = {c:i for i,c in enumerate(alphabet)}
    A = len(alphabet)

    # tokenize
    fam_tok = {
        fam: [tokenize(s, alpha_map) for s in seqs if len(tokenize(s, alpha_map))>0]
        for fam, seqs in families.items()
    }

    results = {k:{} for k in k_values}
    sizes   = {fam: sum(max(0,len(s)-max(k_values)) for s in seqs) for fam, seqs in fam_tok.items()}

    for k in k_values:
        fam_H = {}
        for fam, seqs in fam_tok.items():
            if len(seqs)==0: continue
            Hk = entropy_rate_cv(seqs, A, k, alpha=alpha, delta=delta, folds=folds, seed=seed)
            fam_H[fam] = Hk
        results[k] = fam_H

    # aggregates
    aggregates = {}
    for k in k_values:
        fam_H = results[k]
        fam_list = list(fam_H.items())
        if not fam_list:
            aggregates[k] = dict(macro=None, micro=None, n_families=0)
            continue
        macro = sum(h for _,h in fam_list)/len(fam_list)
        # micro weight by total tokens (approximate using lengths at this k)
        weights = {fam: sum(max(0, len(s)-k) for s in fam_tok[fam]) for fam,_ in fam_list}
        num = sum(fam_H[fam]*weights[fam] for fam,_ in fam_list)
        den = sum(weights.values()) or 1
        micro = num/den
        aggregates[k] = dict(macro=macro, micro=micro, n_families=len(fam_list))
    return results, aggregates


## Cross-Representation Information Analysis

This experiment investigates the **mutual information** between amino acid (AA) and discrete structure representation (DSR) sequences by training probabilistic mappings in both directions using local sequence windows.

### Approach

- **Windowed prediction**: Train neural networks to predict target tokens from source context windows (e.g., predict AA from 7-token DSR window)
- **Bidirectional mapping**: Learn both DSR→AA and AA→DSR predictors to estimate cross-entropies
- **Information bounds**: Derive mutual information lower bounds using H(target) - H(target|source_window)

### Key Questions

1. **Complementarity**: How much structural information is captured in DSR that's not present in AA sequences?
2. **Redundancy**: What fraction of sequence information is already encoded in structural representations?
3. **Context dependence**: How does prediction accuracy vary with window size and MSA position?

The analysis will reveal whether the two representations contain overlapping or complementary information about protein structure and function.

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Cross-representation information (Strategy 3)
- Learn small probabilistic mappers DSR→AA and AA→DSR using windowed tokens
- Estimate per-position and global H(AA), H(DSR), and conditional cross-entropies
- Derive MI lower bounds: I_hat = H - H_hat(cond)
"""

import argparse, math, random
from collections import Counter, defaultdict
from pathlib import Path
from typing import List, Tuple, Dict

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# ----------------------------- FASTA utils -----------------------------

def read_fasta(path: str) -> List[Tuple[str,str]]:
    seqs = []
    name = None
    buf = []
    with open(path, 'r') as f:
        for line in f:
            line = line.strip()
            if not line: 
                continue
            if line.startswith('>'):
                if name is not None:
                    seqs.append((name, ''.join(buf)))
                name = line[1:].strip()
                buf = []
            else:
                buf.append(line)
    if name is not None:
        seqs.append((name, ''.join(buf)))
    return seqs

def ensure_same_shape(msa1: List[str], msa2: List[str]):
    assert len(msa1) == len(msa2), "MSA row count mismatch between AA and DSR."
    L1 = len(msa1[0])
    for s in msa1:
        assert len(s) == L1, "AA MSA rows must have equal length."
    for s in msa2:
        assert len(s) == L1, "DSR MSA must have same number of columns as AA MSA."

# ------------------------- Sequence reweighting ------------------------

def seq_identity(a: str, b: str) -> float:
    matches = 0
    comps = 0
    for x, y in zip(a, b):
        if x == '-' or y == '-':
            continue
        comps += 1
        if x == y:
            matches += 1
    if comps == 0: return 0.0
    return matches / comps

def reweight_sequences(msa: List[str], thresh: float = 0.8) -> np.ndarray:
    n = len(msa)
    w = np.zeros(n, dtype=float)
    for i in range(n):
        c = 0
        for j in range(n):
            if seq_identity(msa[i], msa[j]) >= thresh:
                c += 1
        w[i] = 1.0 / max(1, c)
    if w.sum() > 0:
        w *= (n / w.sum())
    return w

# ----------------------- Entropy (empirical) ---------------------------

def shannon_entropy_from_counts(counts: np.ndarray, pseudocount: float = 0.0) -> float:
    A = counts.shape[0]
    total = counts.sum()
    p = (counts + pseudocount) / (total + pseudocount * A)
    p = p[p > 0]
    return float(-(p * np.log2(p)).sum())

def per_position_entropy(msa: List[str], alphabet: List[str], weights: np.ndarray,
                         occupancy_threshold: float = 0.7,
                         pseudocount: float = None) -> Tuple[np.ndarray, np.ndarray]:
    A = len(alphabet)
    if pseudocount is None:
        pseudocount = 1.0 / A
    alpha_index = {c:i for i,c in enumerate(alphabet)}
    L = len(msa[0])
    ent = np.full(L, np.nan, dtype=float)
    mask = np.zeros(L, dtype=bool)

    occ = np.zeros(L, dtype=float)
    for i, seq in enumerate(msa):
        g = np.fromiter((c != '-' for c in seq), dtype=bool, count=L)
        occ += weights[i] * g.astype(float)
    occ /= max(1e-9, weights.sum())

    for t in range(L):
        if occ[t] < occupancy_threshold: continue
        counts = np.zeros(A, dtype=float)
        for i, seq in enumerate(msa):
            ch = seq[t]
            if ch == '-' or ch not in alpha_index: continue
            counts[alpha_index[ch]] += weights[i]
        if counts.sum() <= 0: continue
        ent[t] = shannon_entropy_from_counts(counts, pseudocount=pseudocount)
        mask[t] = True
    return ent, mask

# ------------------------ Windowed samples -----------------------------

def build_window_samples(
    src_msa: List[str], tgt_msa: List[str], weights: np.ndarray,
    src_alpha: List[str], tgt_alpha: List[str],
    pos_mask: np.ndarray, win: int
):
    """
    Build samples to predict target token at t from a src window around t.
    - Requires no gaps in the src window or target at t.
    - pos_mask picks columns to consider (e.g., occupancy intersection).
    Returns X (one-hot per-window), y (int labels), w (sample weights), pos_idx (column indices).
    """
    assert win % 2 == 1, "Window must be odd size."
    half = win // 2
    A_src, A_tgt = len(src_alpha), len(tgt_alpha)
    src_idx = {c:i for i,c in enumerate(src_alpha)}
    tgt_idx = {c:i for i,c in enumerate(tgt_alpha)}

    L = len(src_msa[0])
    n = len(src_msa)

    feats = []
    labels = []
    sw = []
    pos_idx = []

    # precompute valid (non-gap and in alphabet) masks
    src_valid = np.array([[ (c!='-' and c in src_idx) for c in row] for row in src_msa], dtype=bool)
    tgt_valid = np.array([[ (c!='-' and c in tgt_idx) for c in row] for row in tgt_msa], dtype=bool)

    for i in range(n):
        w_i = weights[i]
        src_row = src_msa[i]
        tgt_row = tgt_msa[i]
        for t in range(half, L-half):
            if not pos_mask[t]: continue
            if not tgt_valid[i, t]: continue
            # window must be fully valid in src
            if not src_valid[i, t-half:t+half+1].all(): continue

            # one-hot encode window: concat per-position one-hots
            v = np.zeros((win, A_src), dtype=np.float32)
            for j, col in enumerate(range(t-half, t+half+1)):
                s = src_row[col]
                v[j, src_idx[s]] = 1.0
            feats.append(v.reshape(-1))  # (win*A_src,)
            labels.append(tgt_idx[tgt_row[t]])
            sw.append(w_i)
            pos_idx.append(t)

    if len(feats) == 0:
        return (np.zeros((0, win*A_src), dtype=np.float32),
                np.zeros((0,), dtype=np.int64),
                np.zeros((0,), dtype=np.float32),
                np.zeros((0,), dtype=np.int32))
    X = np.stack(feats, axis=0)
    y = np.array(labels, dtype=np.int64)
    w = np.array(sw, dtype=np.float32)
    pos_idx = np.array(pos_idx, dtype=np.int32)
    return X, y, w, pos_idx

# --------------------------- Torch model --------------------------------

class SoftmaxLinear(nn.Module):
    def __init__(self, D_in: int, C_out: int, bias=True):
        super().__init__()
        self.lin = nn.Linear(D_in, C_out, bias=bias)
    def forward(self, x):
        return self.lin(x)  # logits

class NumpyDataset(Dataset):
    def __init__(self, X, y, w):
        self.X = torch.from_numpy(X)
        self.y = torch.from_numpy(y)
        self.w = torch.from_numpy(w)
    def __len__(self):
        return self.X.shape[0]
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], self.w[idx]

def split_by_sequence_indices(n_seq: int, seed=0, ratios=(0.7, 0.15, 0.15)):
    idx = list(range(n_seq))
    random.Random(seed).shuffle(idx)
    n_train = int(ratios[0]*n_seq)
    n_val   = int(ratios[1]*n_seq)
    train_ids = set(idx[:n_train])
    val_ids   = set(idx[n_train:n_train+n_val])
    test_ids  = set(idx[n_train+n_val:])
    return train_ids, val_ids, test_ids

def mask_samples_by_seqpos(seq_pos_of_sample: List[int], seq_ids_set: set):
    return np.array([sp in seq_ids_set for sp in seq_pos_of_sample], dtype=bool)

def train_softmax(
    X_train, y_train, w_train,
    X_val,   y_val,   w_val,
    D_in, C_out, lr=1e-2, epochs=20, bs=4096, weight_decay=1e-4, device="cpu"
):
    model = SoftmaxLinear(D_in, C_out).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    ds_tr = NumpyDataset(X_train, y_train, w_train)
    dl_tr = DataLoader(ds_tr, batch_size=bs, shuffle=True, drop_last=False)

    def eval_ce(X, y, w):
        if X.shape[0] == 0: return float('nan')
        with torch.no_grad():
            X_t = torch.from_numpy(X).to(device)
            y_t = torch.from_numpy(y).to(device)
            w_t = torch.from_numpy(w).to(device)
            logits = model(X_t)
            ce = nn.functional.cross_entropy(logits, y_t, reduction='none')
            return float((ce * w_t).sum().item() / max(1e-9, w_t.sum().item()))

    best_val = float('inf')
    best_state = None
    patience, bad = 4, 0

    for ep in range(epochs):
        model.train()
        for xb, yb, wb in dl_tr:
            xb, yb, wb = xb.to(device), yb.to(device), wb.to(device)
            logits = model(xb)
            ce = nn.functional.cross_entropy(logits, yb, reduction='none')
            loss = (ce * wb).sum() / (wb.sum() + 1e-9)
            opt.zero_grad()
            loss.backward()
            opt.step()

        val_ce = eval_ce(X_val, y_val, w_val)
        if val_ce < best_val - 1e-5:
            best_val = val_ce
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            bad = 0
        else:
            bad += 1
            if bad >= patience:
                break

    if best_state is not None:
        model.load_state_dict(best_state)
    return model

def predict_log_probs(model: nn.Module, X: np.ndarray, bs=8192, device="cpu") -> np.ndarray:
    if X.shape[0] == 0: return np.zeros((0, model.lin.out_features), dtype=np.float32)
    model.eval()
    out = []
    with torch.no_grad():
        for i in range(0, X.shape[0], bs):
            xb = torch.from_numpy(X[i:i+bs]).to(device)
            logits = model(xb)
            logp = nn.functional.log_softmax(logits, dim=-1)
            out.append(logp.detach().cpu().numpy())
    return np.vstack(out)