In [None]:
import os
import pickle
import numpy as np
import pandas as pd
from collections import defaultdict
from tqdm import tqdm
from scipy.sparse import load_npz
from sklearn.preprocessing import normalize
import gc

# =============================================================================
# 1. CARGAR DATOS
# =============================================================================
print("=" * 80)
print("CARGANDO DATOS")
print("=" * 80)

df_pairs_unique = pd.read_pickle("data/df_pairs_unique.pkl")
author2idx = np.load("data/author_to_idx.npy", allow_pickle=True).item()

# Filtrar autores con >= 2 colaboraciones
author_counts = df_pairs_unique[['pair_min', 'pair_max']].stack().value_counts()
eligible_authors = set(author_counts[author_counts >= 2].index)

df_filtered = df_pairs_unique[
    df_pairs_unique['pair_min'].isin(eligible_authors) &
    df_pairs_unique['pair_max'].isin(eligible_authors)
].reset_index(drop=True)

# =============================================================================
# 2. CARGAR MODELO CONTENT-BASED (OPTIMIZADO)
# =============================================================================
models_dir = "data"

# Carga y conversión inmediata a float32 para ahorrar 50% de RAM
print("Cargando matriz de conceptos...")
author_matrix = load_npz(os.path.join(models_dir, 'author_concept_matrix.npz')).astype(np.float32)

# NORMALIZACIÓN PREVIA: Crucial para usar dot product como similitud de coseno
print("Normalizando matriz para similitud de coseno rápida...")
author_matrix = normalize(author_matrix, norm='l2', axis=1)

cb_author_ids = np.load(os.path.join(models_dir, 'cb_author_ids.npy'), allow_pickle=True)
cb_author_to_idx = {aid: i for i, aid in enumerate(cb_author_ids)}

try:
    author_work_counts = np.load(os.path.join(models_dir, 'cb_author_work_counts.npy')).astype(np.float32)
except FileNotFoundError:
    author_work_counts = np.ones(len(cb_author_ids), dtype=np.float32)

# =============================================================================
# 3. FUNCIONES DE MÉTRICAS Y PROCESAMIENTO POR LOTES
# =============================================================================

def calculate_metrics_batched(target_author_ids, gt, author_matrix, work_counts, C, topk=20, batch_size=500):
    """
    Calcula Recall y NDCG procesando autores en bloques para no saturar la RAM.
    """
    total_recall = 0.0
    total_ndcg = 0.0
    valid_users = 0

    # Mapear IDs de autores a sus índices en la matriz
    target_indices = [cb_author_to_idx[aid] for aid in target_author_ids if aid in cb_author_to_idx]

    for i in tqdm(range(0, len(target_indices), batch_size), desc=f"Evaluando C={C}"):
        batch_idx = target_indices[i : i + batch_size]

        # Similitud de coseno masiva: (Batch_Size x Total_Autores)
        # Esto es lo que causaba el crash, ahora está controlado por batch_size
        sims_batch = author_matrix[batch_idx].dot(author_matrix.T).toarray()

        for j, u_idx in enumerate(batch_idx):
            u_id = cb_author_ids[u_idx]
            if u_id not in gt: continue

            rel_set = gt[u_id]
            sims = sims_batch[j]

            # Bayesian Smoothing Vectorizado
            m = sims.mean()
            sims = (C * m + work_counts * sims) / (C + work_counts)

            # Excluir al mismo autor
            sims[u_idx] = -1.0

            # Obtener Top-K usando argpartition (O(n) vs O(n log n))
            top_indices = np.argpartition(-sims, topk)[:topk]
            # Ordenar solo esos 20
            top_indices = top_indices[np.argsort(-sims[top_indices])]

            recs = cb_author_ids[top_indices]
            hits = [1 if r in rel_set else 0 for r in recs]

            # Acumular métricas
            total_recall += sum(hits) / len(rel_set)

            dcg = sum(h / np.log2(idx + 2) for idx, h in enumerate(hits))
            idcg = sum(1 / np.log2(idx + 2) for idx in range(min(len(rel_set), topk)))
            total_ndcg += (dcg / idcg) if idcg > 0 else 0

            valid_users += 1

        # Liberar memoria del batch explícitamente
        del sims_batch
        if i % 5000 == 0: gc.collect()

    return (total_recall / valid_users, total_ndcg / valid_users) if valid_users > 0 else (0, 0)

# =============================================================================
# 4. SPLIT Y PREPARACIÓN
# =============================================================================
def build_gt(df):
    gt = defaultdict(set)
    for r in df.itertuples():
        gt[r.pair_min].add(r.pair_max)
        gt[r.pair_max].add(r.pair_min)
    return gt

def triple_loo_split(df, seed=42):
    rng = np.random.default_rng(seed)
    adj = defaultdict(list)
    for i, r in enumerate(df.itertuples()):
        adj[r.pair_min].append(i)
        adj[r.pair_max].append(i)

    test_idx, val_idx, used = set(), set(), set()
    for _, idxs in adj.items():
        idxs = [i for i in idxs if i not in used]
        if len(idxs) >= 3:
            rng.shuffle(idxs)
            test_idx.add(idxs[0]); val_idx.add(idxs[1])
            used.update(idxs[:2])
        elif len(idxs) == 2:
            test_idx.add(idxs[0]); used.add(idxs[0])

    all_idx = set(range(len(df)))
    train_idx = sorted(all_idx - test_idx - val_idx)
    return (df.iloc[train_idx].reset_index(drop=True),
            df.iloc[list(val_idx)].reset_index(drop=True),
            df.iloc[list(test_idx)].reset_index(drop=True))

print("\nGenerando splits...")
df_train, df_val, df_test = triple_loo_split(df_filtered)
gt_val = build_gt(df_val)
gt_test = build_gt(df_test)

# =============================================================================
# 5. TUNING (20,000 autores sample)
# =============================================================================
#rng = np.random.default_rng(42)
#val_authors_all = [a for a in gt_val.keys() if a in cb_author_to_idx]
#val_sample = rng.choice(val_authors_all, size=min(20000, len(val_authors_all)), replace=False)

#Cs = [1, 10, 20, 50, 100]
#best_C, best_ndcg = None, -1

#print(f"\n--- TUNING EN VALIDACIÓN ({len(val_sample)} autores) ---")
#for C in Cs:
#    rec, ndcg = calculate_metrics_batched(val_sample, gt_val, author_matrix, author_work_counts, C)
#    print(f"C={C:5.1f} | Recall@20={rec:.4f} | NDCG@20={ndcg:.4f}")
#    if ndcg > best_ndcg:
#        best_ndcg, best_C = ndcg, C

# =============================================================================
# 6. EVALUACIÓN FINAL FULL (518k autores)
# =============================================================================
print("\n" + "=" * 80)
print(f"EVALUACIÓN FINAL EN TEST (FULL: {len(gt_test)} autores)")
print("=" * 80)

best_C = 1

# Limpiar memoria antes del proceso largo
gc.collect()

test_authors_full = list(gt_test.keys())
final_rec, final_ndcg = calculate_metrics_batched(
    test_authors_full, gt_test, author_matrix, author_work_counts, best_C, batch_size=8000
)

print("\n" + "#" * 40)
print("RESULTADOS FINALES")
print(f"Mejor C: {best_C}")
print(f"Recall@20: {final_rec:.4f}")
print(f"NDCG@20 : {final_ndcg:.4f}")
print("#" * 40)