[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/antoniotrapote/chord-prediction-tfm/blob/main/anexos/notebooks/04_reranking/01_reranking.ipynb)
[![View on GitHub](https://img.shields.io/badge/View_on-GitHub-black?logo=github)](https://github.com/antoniotrapote/chord-prediction-tfm/blob/main/anexos/notebooks/04_reranking/01_reranking.ipynb)

# Re-ranking

En este notebook se diseñaron las funciones necesarias para implementar un sistema de re-ranking de las predicciones de acordes generadas por nuestro modelo, mejorando así la experiencia del usuario y la calidad de las sugerencias musicales.

Objetivos:
1. Penalizar repeticiones
2. Establecer un filtrado: "free", "diatonic", "functional_plus"


## 1) Utilidades de filtrado
- _has_accidental() - detecta acordes fuera de la escala diatónica
- _is_secondary() - detecta dominantes secundarios y sustitutos
- _is_diatonic() - detecta acordes diatónicos (incluye excepción bVII en modo menor)

In [1]:
def _has_accidental(roman: str) -> bool:
    # Consideramos alteraciones en el grado (no en tensiones)
    # Ej: 'bII', '#iv', 'bVII7', etc.
    return roman.startswith("b") or roman.startswith("#")

In [2]:
test = ["bii", "#IV", "V/VI", "Vsub/ii", "V7", "vi", "viio"]
print({t: _has_accidental(t) for t in test})

{'bii': True, '#IV': True, 'V/VI': False, 'Vsub/ii': False, 'V7': False, 'vi': False, 'viio': False}


In [3]:
def _is_secondary(roman: str) -> bool:
    """
    True solo para dominantes secundarios y sustitutos por tritono
    codificados como 'V/XX' o 'Vsub/XX'.
    """
    # Normalizamos espacios accidentales
    r = roman.strip()
    return r.startswith("V/") or r.startswith("Vsub/")

In [4]:
test = ["V/VI", "V/IV", "Vsub/V", "V7", "v"]
print({t: _is_secondary(t) for t in test})

{'V/VI': True, 'V/IV': True, 'Vsub/V': True, 'V7': False, 'v': False}


In [5]:
DEGREES_MAJOR = {"I","ii","iii","IV","V", "V7","vi", "viio", "viiø"}
DEGREES_MINOR = {"i","iiø","III","iv","v", "V", "V7", "VI", "vii", "viio"}

BVII_MINOR = {"bVII", "bVII7"} #Para consderar diatónico el 7º grado de la menor natural

def _is_diatonic(roman: str, mode: str) -> bool:
    # Excepción: en menor, bVII y bVII7 pasan como diatónicos (antes de filtrar accidentales)
    if mode == "minor" and roman in BVII_MINOR:
        return True
    
    # Filtro: sin alteraciones (#/b) ni secundarios ('/').
    if _has_accidental(roman) or ("/" in roman):
        return False
    
    # Comparación directa con set
    return roman in (DEGREES_MAJOR if mode == "major" else DEGREES_MINOR)

In [6]:
test = ["i", "iiø", "III", "iv", "v", "V", "V7", "VI", "vii", "viio", "bVII7", "bVII", "#iv", "V/vi", "Vsub/ii", "V7"]
print({t: _is_diatonic(t, "minor") for t in test})

{'i': True, 'iiø': True, 'III': True, 'iv': True, 'v': True, 'V': True, 'V7': True, 'VI': True, 'vii': True, 'viio': True, 'bVII7': True, 'bVII': True, '#iv': False, 'V/vi': False, 'Vsub/ii': False}


In [7]:
test = ["I", "ii", "iii", "IV", "V", "V7", "vi", "viiø", "viio", "bVII7", "#iv", "V/vi", "Vsub/ii", "V7", "viio"]
print({t: _is_diatonic(t, "major") for t in test})

{'I': True, 'ii': True, 'iii': True, 'IV': True, 'V': True, 'V7': True, 'vi': True, 'viiø': True, 'viio': True, 'bVII7': False, '#iv': False, 'V/vi': False, 'Vsub/ii': False}


## 2) Reranking

In [8]:
def allow_in_functional_plus(roman: str, mode: str) -> bool:
    """
    Funcional+: diatónicos + (V/xx, Vsub/xx).
    """
    return _is_diatonic(roman, mode) or _is_secondary(roman)

In [9]:
test = ["I", "ii", "iii", "IV", "v", "V", "V7", "vi", "viiø", "viio", "V/vi", "Vsub/ii",  "bVII", "#iv"]
print({t: allow_in_functional_plus(t, "major") for t in test})

{'I': True, 'ii': True, 'iii': True, 'IV': True, 'v': False, 'V': True, 'V7': True, 'vi': True, 'viiø': True, 'viio': True, 'V/vi': True, 'Vsub/ii': True, 'bVII': False, '#iv': False}


In [10]:
test = ["i", "iiø", "III", "iv", "v", "V", "V7", "VI", "vii", "viio", "bVII7", "bVII", "V/vi", "Vsub/ii", "V7", "#iv", "bii"]
print({t: allow_in_functional_plus(t, "minor") for t in test})

{'i': True, 'iiø': True, 'III': True, 'iv': True, 'v': True, 'V': True, 'V7': True, 'VI': True, 'vii': True, 'viio': True, 'bVII7': True, 'bVII': True, 'V/vi': True, 'Vsub/ii': True, '#iv': False, 'bii': False}


In [11]:
from typing import List, Literal, Tuple

def rerank(
    candidates: List[Tuple[str, float]],
    recent_context: List[str],
    mode: str,
    filter_mode: Literal["free","diatonic","functional_plus"] = "free",
    alpha_repeat: float = 0.25,   # penalización por repetición [0..1]
    rep_window: int = 2,          # ventana de repetición
    beta_filter: float = 0.15,    # atenuación en vez de knockout si deseas filtro "blando"
    hard_filter: bool = True      # True = elimina (prob=0); False = downweight
) -> List[Tuple[str,float]]:
    last_window = recent_context[-rep_window:] if rep_window > 0 else []
    scored = []
    for tok, p in candidates:
        score = p

        # Anti-repeat: penaliza si ya aparece en la ventana reciente
        if tok in last_window and alpha_repeat > 0:
            score *= (1.0 - alpha_repeat)

        # Filtro por modo
        allowed = True
        if filter_mode == "diatonic":
            allowed = _is_diatonic(tok, mode)
        elif filter_mode == "functional_plus":
            allowed = allow_in_functional_plus(tok, mode)

        if not allowed:
            if hard_filter:
                score = 0.0
            else:
                score *= beta_filter

        scored.append((tok, max(0.0, score)))

    # Renormalizamos para preservar una distribución interpretable
    total = sum(s for _, s in scored)
    if total > 0:
        scored = [(t, s/total) for t, s in scored]
    # Reordenamos
    scored.sort(key=lambda x: x[1], reverse=True)
    return scored

In [12]:
#@title Testeamos el re-ranking

# --- Utilidad de impresión ---
def show(title, out):
    print(f"\n{title}")
    print("-"*len(title))
    s = sum(p for _, p in out)
    for t, p in out:
        print(f"{t:8s}  {p:0.3f}")
    print(f"sum = {s:0.3f}")

# --- Candidatos de prueba ---
candidates = [
    ("I",     0.30),
    ("V7",    0.25),
    ("ii",    0.15),
    ("V/ii",  0.12),
    ("bVII7", 0.10),
    ("bII",   0.08),
]

# --- Escenarios de prueba ---

# 1) Modo mayor, contexto reciente con 'V7' para ver anti-repeat
context_M = ["ii", "V7", "I"]
print("=== TEST: MODE=MAJOR ===")
show("Base / free", rerank(candidates, context_M, mode="major",
                           filter_mode="free", alpha_repeat=0.0))
show("Anti-repeat (rep_window=2, alpha=0.5)", rerank(candidates, context_M, mode="major",
                           filter_mode="free", alpha_repeat=0.5, rep_window=2))
show("Diatonic (hard_filter=True)", rerank(candidates, context_M, mode="major",
                           filter_mode="diatonic", alpha_repeat=0.0, hard_filter=True))
show("Functional+ (hard_filter=False, beta=0.2)", rerank(candidates, context_M, mode="major",
                           filter_mode="functional_plus", alpha_repeat=0.0, hard_filter=False, beta_filter=0.2))

# 2) Modo menor, donde bVII7 es considerado diatónico por excepción
context_m = ["iv", "V7", "i"]
print("\n=== TEST: MODE=MINOR ===")
show("Base / free", rerank(candidates, context_m, mode="minor",
                           filter_mode="free", alpha_repeat=0.0))
show("Diatonic (minor, bVII7 permitido)", rerank(candidates, context_m, mode="minor",
                           filter_mode="diatonic", alpha_repeat=0.0, hard_filter=True))
show("Functional+ (permitirá también V/ii)", rerank(candidates, context_m, mode="minor",
                           filter_mode="functional_plus", alpha_repeat=0.0, hard_filter=True))

# 3) Caso extremo: todo filtrado (ej. mayor + diatonic con candidatos solo secundarios/alterados)
cands_extremos = [("V/ii", 0.6), ("bII", 0.4)]
out_ext = rerank(cands_extremos, context_M, mode="major", filter_mode="diatonic", hard_filter=True)
show("Extremo (todo fuera): mayor + diatonic", out_ext)
if sum(p for _, p in out_ext) == 0.0:
    print(">> Aviso: todos los candidatos fueron filtrados (prob=0).")

=== TEST: MODE=MAJOR ===

Base / free
-----------
I         0.300
V7        0.250
ii        0.150
V/ii      0.120
bVII7     0.100
bII       0.080
sum = 1.000

Anti-repeat (rep_window=2, alpha=0.5)
-------------------------------------
I         0.207
ii        0.207
V7        0.172
V/ii      0.166
bVII7     0.138
bII       0.110
sum = 1.000

Diatonic (hard_filter=True)
---------------------------
I         0.429
V7        0.357
ii        0.214
V/ii      0.000
bVII7     0.000
bII       0.000
sum = 1.000

Functional+ (hard_filter=False, beta=0.2)
-----------------------------------------
I         0.350
V7        0.292
ii        0.175
V/ii      0.140
bVII7     0.023
bII       0.019
sum = 1.000

=== TEST: MODE=MINOR ===

Base / free
-----------
I         0.300
V7        0.250
ii        0.150
V/ii      0.120
bVII7     0.100
bII       0.080
sum = 1.000

Diatonic (minor, bVII7 permitido)
---------------------------------
V7        0.714
bVII7     0.286
I         0.000
ii        0.000
V/ii   