## 1) Cargamos el modelo lstm

In [1]:
#@title Subir archivo "lstm_rs_v3.pt"
try:
    from google.colab import files
    uploaded = files.upload()
except Exception:
    pass

Saving lstm_rs_v3.pt to lstm_rs_v3.pt


In [2]:
#@title Recreamos la arquitectura ChordLSTM y cargamos el modelo
import torch, torch.nn as nn

class ChordLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, num_layers, dropout, tie_weights=False):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.LSTM(embedding_dim, hidden_size, num_layers=num_layers,
                           batch_first=True, dropout=dropout if num_layers>1 else 0.0)
        self.dropout = nn.Dropout(dropout)
        self.tie_weights = tie_weights
        if tie_weights:
            self.proj = nn.Linear(hidden_size, embedding_dim, bias=False) if hidden_size != embedding_dim else nn.Identity()
            self.decoder = nn.Linear(embedding_dim, vocab_size, bias=False)
            self.decoder.weight = self.emb.weight
        else:
            self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x):
        e = self.emb(x)                 # (B, T, E)
        o, _ = self.rnn(e)              # (B, T, H)
        h = self.dropout(o[:, -1, :])   # (B, H)
        return self.decoder(self.proj(h)) if self.tie_weights else self.fc(h)  # (B, V)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_best_checkpoint(path):
    try:
        ckpt = torch.load(path, map_location="cpu")  # intenta modo seguro (weights_only=True)
    except Exception:
        ckpt = torch.load(path, map_location="cpu", weights_only=False)  # fallback
    cfg  = ckpt["config"]
    model = ChordLSTM(
        vocab_size=len(ckpt["stoi"]),
        embedding_dim=cfg["embedding_dim"],
        hidden_size=cfg["hidden_size"],
        num_layers=cfg["num_layers"],
        dropout=cfg["dropout"],
        tie_weights=cfg.get("tie_weights", False),
    )
    model.load_state_dict(ckpt["model_state"], strict=True)
    model.to(device).eval()
    return model, ckpt["stoi"], ckpt["itos"], cfg, ckpt.get("metrics_test", None)

model, stoi, itos, cfg, metrics = load_best_checkpoint("/content/lstm_rs_v3.pt")
print("Cargado. Métricas (test) guardadas:", metrics)


Cargado. Métricas (test) guardadas: {'loss': 2.252060778257626, 'ppl': 9.50730811622195, 'Top@1': 0.45036420395421434, 'Top@3': 0.6936524453818116, 'Top@5': 0.7761706556043317, 'MRR': 0.5956792587544246}


In [3]:
#@title Definicion de `predict_next()` usando el chekpoint empaquetado
import torch.nn.functional as F

@torch.inference_mode()
def predict_next(model, stoi, itos, context_tokens, seq_len, k=5):
    unk_id = stoi.get("<unk>")
    if unk_id is None:
      raise KeyError("El vocabulario no contiene '<unk>'.")
    bos_id = stoi.get("<bos>")

    ids = [stoi.get(t, unk_id) for t in context_tokens]
    if len(ids) < seq_len and bos_id is not None:
        ids = [bos_id] * (seq_len - len(ids)) + ids
    else:
        ids = ids[-seq_len:]

    x = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0)
    logits = model(x)                           # (1, V)
    probs  = F.softmax(logits[0], dim=-1)       # (V,)
    topk   = torch.topk(probs, k)
    return [(itos[i.item()], float(p.item())) for i,p in zip(topk.indices, topk.values)]


## 2) Cargamos el modelo kn

In [4]:
#@title Subir archivo "best_kn_model.pkl"
try:
    from google.colab import files
    uploaded = files.upload()
except Exception:
    pass

Saving best_kn_model.pkl to best_kn_model.pkl


In [5]:
#@title Recreamos la arquitectura del modelo: KN interpolado genérico (orden N)
from collections import Counter, defaultdict
from functools import lru_cache


class KNInterpolatedNGram:
    """
    Modelo de lenguaje basado en n-gramas con suavizado Kneser-Ney.
    """

    def __init__(self, order=3, discount=0.75, unk_threshold=1):
        assert order >= 1
        self.N = order
        self.D = discount
        self.unk_threshold = unk_threshold
        self.vocab = set()
        self.counts = {n: Counter()
                       for n in range(1, self.N+1)}      # n-gram counts
        self.context_totals = {n: Counter()
                               for n in range(1, self.N)}  # c(context)
        self.unique_continuations = {n: Counter()
                                     # N1+(context •)
                                     for n in range(1, self.N)}
        self.continuation_counts_unigram = Counter()  # N1+(• w)
        self.total_unique_bigrams = 0
        self._rank_cache = {}
        self._prob_cache = {}
        self.fitted = False

    def _add_bounds(self, seq):
        """
        Añade marcas de comienzo y fin. "<s>" y "</s>"
        """
        return ["<s>"]*(self.N-1) + seq + ["</s>"]

    def fit(self, sequences):
        """
        Ajusta el modelo a las secuencias de entrenamiento.
        """
        token_counts = Counter(t for seq in sequences for t in seq)
        vocab = set([t for t, c in token_counts.items()
                    if c > self.unk_threshold])
        vocab.update({"<s>", "</s>", "<unk>"})
        self.vocab = vocab

        def map_unk(seq):
            return [t if t in vocab else "<unk>" for t in seq]

        for seq in sequences:
            s = self._add_bounds(map_unk(seq))
            for i in range(len(s)):
                for n in range(1, self.N+1):
                    if i-n+1 < 0:
                        continue
                    ngram = tuple(s[i-n+1:i+1])
                    self.counts[n][ngram] += 1

        # context totals + unique continuations
        for n in range(2, self.N+1):
            seen = defaultdict(set)
            for ngram, c in self.counts[n].items():
                ctx, w = ngram[:-1], ngram[-1]
                self.context_totals[n-1][ctx] += c
                seen[ctx].add(w)
            for ctx, ws in seen.items():
                self.unique_continuations[n-1][ctx] = len(ws)

        # unigram continuation counts
        left_contexts = defaultdict(set)
        for (w1, w2) in self.counts[2].keys():
            left_contexts[w2].add(w1)
        self.continuation_counts_unigram = Counter(
            {w: len(ctxs) for w, ctxs in left_contexts.items()})
        self.total_unique_bigrams = len(self.counts[2])
        self.fitted = True

    @lru_cache(maxsize=None)
    def _p_cont_unigram(self, w):
        if self.total_unique_bigrams == 0:
            return 1.0 / max(1, len(self.vocab))
        return self.continuation_counts_unigram.get(w, 0) / self.total_unique_bigrams

    def _lambda(self, ctx):
        m = len(ctx)
        if m == 0:
            return 1.0
        cont_types = self.unique_continuations[m].get(ctx, 0)
        total = self.context_totals[m].get(ctx, 0)
        if total == 0:
            return 1.0
        return (self.D * cont_types) / total

    def _base(self, ctx, w):
        m = len(ctx)
        if m == 0:
            return self._p_cont_unigram(w)
        total = self.context_totals[m].get(ctx, 0)
        c = self.counts[m+1].get(tuple(list(ctx)+[w]), 0)
        if total == 0:
            return 0.0
        return max(c - self.D, 0) / total

    def prob(self, ctx, w):
        key = (ctx, w)
        if key in self._prob_cache:
            return self._prob_cache[key]
        m = len(ctx)
        if m == 0:
            p = self._p_cont_unigram(w)
        else:
            p = self._base(ctx, w) + self._lambda(ctx) * self.prob(ctx[1:], w)
        self._prob_cache[key] = p
        return p

    def predict_ranking(self, history):
        # mapeo a <unk> interno para usar directamente evaluate_next_token_ranking(...)
        hist = ["<s>"]*(self.N-1) + \
            [t if t in self.vocab else "<unk>" for t in history]
        ctx = tuple(hist[-(self.N-1):]) if self.N > 1 else tuple()
        if ctx in self._rank_cache:
            return self._rank_cache[ctx]
        cands = [w for w in self.vocab if w not in {"<s>"}]
        scores = [(w, self.prob(ctx, w)) for w in cands]
        scores.sort(key=lambda x: x[1], reverse=True)
        self._rank_cache[ctx] = scores
        return scores


import joblib
best_kn = joblib.load('/content/best_kn_model.pkl')


In [6]:
#@title definimos `topk_next()`
def topk_next(model, context, k=5, exclude_special=True):
    """Devuelve las k mejores sugerencias (token, prob) dado el contexto."""
    ranking = model.predict_ranking(context)
    if exclude_special:
        ranking = [(w,p) for (w,p) in ranking if w not in {"<s>", "</s>", "<unk>"}]
    return ranking[:k]

## 3) Comparativa cualitativa

In [9]:
# @title Comparativa cualitativa KN vs LSTM

import pandas as pd

K = 5                     # top-k
SEQ_LEN = cfg["seq_len"]  # del checkpoint LSTM

def count_hits(cands, acceptable):
    if not acceptable:
        return None
    s = set(acceptable)
    return sum(1 for c in cands if c in s)

def first_hit_rank(cands, acceptable):
    if not acceptable:
        return None
    s = set(acceptable)
    for i, c in enumerate(cands, start=1):
        if c in s:
            return i
    return None

def eval_case(context, acceptable=None, note=""):
    # KN
    kn_preds = topk_next(best_kn, context, k=K)            # [('I', 0.21), ...]
    # LSTM
    lstm_preds = predict_next(model, stoi, itos, context, seq_len=SEQ_LEN, k=K)

    kn_top   = [c for c,_ in kn_preds]
    lstm_top = [c for c,_ in lstm_preds]

    row = {
        "context": " ".join(context),
        "note": note,
        "acc_targets": sorted(list(acceptable)) if acceptable else None,
        "kn_top": kn_top,
        "lstm_top": lstm_top,
    }

    # hit@k y rank (si hay acc_targets)
    row["KN_hit@k"]   = any(c in acceptable for c in kn_top)   if acceptable else None
    row["LSTM_hit@k"] = any(c in acceptable for c in lstm_top) if acceptable else None
    row["KN_rank"]    = first_hit_rank(kn_top, acceptable)
    row["LSTM_rank"]  = first_hit_rank(lstm_top, acceptable)

    # “plaus_sum” es simplemente el nº de aciertos dentro del top-k
    row["KN_plaus_sum"]   = count_hits(kn_top, acceptable)
    row["LSTM_plaus_sum"] = count_hits(lstm_top, acceptable)

    return row

# ---------- Suites de prueba (ROMANOS) ----------
SUITES = {
    "Modo mayor": [
        (["ii","V7"], {"I"}, "Cadencia auténtica"),
        (["I","V"], {"vi","IV","I"}, "Patrones pop"),
        (["I","IV"], {"I","V","ii"}, "Regreso a I o preparación a V"),
        (["IV","iv"], {"I"}, "Intercambio modal (iv→I)"),
        (["I","bVII"], {"I","IV","V"}, "Rock bVII"),
        (["I","V"], {"vi"}, "Cadencia punk I–V–vi–IV"),
        (["I","V","vi"], {"IV"}, "Cadencia punk I–V–vi–IV"),
    ],
    "Modo menor": [
        (["iiø","V7"], {"i"}, "iiø–V7–i"),
        (["i","iv"], {"V7", "v"}, "Preparar dominante"),
        (["VI","iiø"], {"V7"}, "VI–iiø–V7"),
    ],
    "Secundarios y sustitutos": [
        (["i","V/IV"], {"iv"}, "V/IV→iv"),
        (["I","V/IV"], {"IV"}, "V/IV→IV"),
        (["I","Vsub/V"], {"V7","V"}, "Sust. tritono a V"),
        (["I","biio"], {"ii"}, "bii°→ii (vecindad cromática)"),
        (["I","V/ii"], {"ii"}, "Cadena de dominantes"),
    ],
    "Cold-start y rarezas": [
        ([], {"I","i","ii","iiø"}, "Comienzo sin contexto"),
        (["i"], {"iiø","iv","V7","vi"}, "Arranque menor"),
        (["I"], {"ii","IV","V7","VI"}, "Arranque mayor"),
    ],
}

# ---------- Ejecutar ----------
rows = []
for suite, cases in SUITES.items():
    for ctx, acceptable, note in cases:
        rows.append({"suite": suite, **eval_case(ctx, acceptable, note)})

df = pd.DataFrame(rows)

def summarize(group):
    return pd.Series({
        "cases": len(group),
        "KN_hits": group["KN_hit@k"].fillna(False).sum() if "KN_hit@k" in group else None,
        "LSTM_hits": group["LSTM_hit@k"].fillna(False).sum() if "LSTM_hit@k" in group else None,
        "KN_mean_rank": group["KN_rank"].dropna().mean() if group["KN_rank"].notna().any() else None,
        "LSTM_mean_rank": group["LSTM_rank"].dropna().mean() if group["LSTM_rank"].notna().any() else None,
        # promedio de aciertos dentro del top-k
        "KN_plaus_sum_avg": group["KN_plaus_sum"].dropna().mean(),
        "LSTM_plaus_sum_avg": group["LSTM_plaus_sum"].dropna().mean(),
    })

summary = df.groupby("suite").apply(summarize, include_groups=False).reset_index()

global_scores = pd.DataFrame([{
    "KN_plaus_avg": df["KN_plaus_sum"].dropna().mean(),
    "LSTM_plaus_avg": df["LSTM_plaus_sum"].dropna().mean(),
    "KN_hit_rate": df["KN_hit@k"].fillna(False).mean(),
    "LSTM_hit_rate": df["LSTM_hit@k"].fillna(False).mean(),
}]).round(3)

print("=== RESUMEN POR SUITE ===")
display(summary.round(3))
print("\n=== SCORES GLOBALES ===")
display(global_scores)

cols_show = ["suite","context","note","acc_targets",
             "kn_top","lstm_top","KN_rank","LSTM_rank",
             "KN_plaus_sum","LSTM_plaus_sum"]
display(df[cols_show])


=== RESUMEN POR SUITE ===


Unnamed: 0,suite,cases,KN_hits,LSTM_hits,KN_mean_rank,LSTM_mean_rank,KN_plaus_sum_avg,LSTM_plaus_sum_avg
0,Cold-start y rarezas,3.0,3.0,0.0,1.667,,3.0,0.0
1,Modo mayor,7.0,6.0,4.0,1.667,2.25,1.429,0.714
2,Modo menor,3.0,3.0,2.0,1.0,1.5,1.0,0.667
3,Secundarios y sustitutos,5.0,5.0,4.0,1.4,1.0,1.2,1.0



=== SCORES GLOBALES ===


Unnamed: 0,KN_plaus_avg,LSTM_plaus_avg,KN_hit_rate,LSTM_hit_rate
0,1.556,0.667,0.944,0.556


Unnamed: 0,suite,context,note,acc_targets,kn_top,lstm_top,KN_rank,LSTM_rank,KN_plaus_sum,LSTM_plaus_sum
0,Modo mayor,ii V7,Cadencia auténtica,[I],"[I, ii, iii, i, V7]","[V7, I, bii, Vsub/VII, #IV7]",1.0,2.0,1,1
1,Modo mayor,I V,Patrones pop,"[I, IV, vi]","[V, IV, I, vi, v]","[V, I, bii, II, vo]",2.0,2.0,3,1
2,Modo mayor,I IV,Regreso a I o preparación a V,"[I, V, ii]","[I, iii, V7, V, IV]","[IV, VII, iii, I, bIII]",1.0,4.0,2,1
3,Modo mayor,IV iv,Intercambio modal (iv→I),[I],"[I, bVII7, iii, iv, V]","[iii, iv, V7, bIII7, VII]",1.0,,1,0
4,Modo mayor,I bVII,Rock bVII,"[I, IV, V]","[I, IV, bVII, V/IV, V7]","[I, bii, IV, V7, vii]",1.0,1.0,2,2
5,Modo mayor,I V,Cadencia punk I–V–vi–IV,[vi],"[V, IV, I, vi, v]","[V, I, bii, II, vo]",4.0,,1,0
6,Modo mayor,I V vi,Cadencia punk I–V–vi–IV,[IV],"[II7, V, V/V, Vsub/II, vi]","[II, V/V, vii, V, II7]",,,0,0
7,Modo menor,iiø V7,iiø–V7–i,[i],"[i, I, V7, iiø, VI]","[V7, bii, vo, #IV7, I]",1.0,,1,0
8,Modo menor,i iv,Preparar dominante,"[V7, v]","[V7, i, iv, V/III, bVII7]","[iv, V7, bIII7, i, iii]",1.0,2.0,1,1
9,Modo menor,VI iiø,VI–iiø–V7,[V7],"[V7, iiø, i, I, bVI7]","[V7, vo, bii, iiø, bII7]",1.0,1.0,1,1


## Justificación de la elección del modelo Kneser-Ney

A pesar de que el LSTM ofrecía métricas ligeramente mejores en validación (Top@k, MRR), en las pruebas cualitativas con progresiones armónicas reales se observa que el modelo KN interpolado:

- Tiene una tasa de acierto mucho más alta (94% vs 56%).

- Ofrece sugerencias más plausibles y consistentes en todos los contextos (cadencias, arranques, dominantes secundarios).

- Es estable, explicable y determinista, lo que lo hace más adecuado para un sistema productivo.