In [2]:
import pickle
import torch
import pandas as pd
from tqdm import tqdm
import numpy as np
import os
import sys
sys.path.append('TNTM/Code/TNTM')
from octis.dataset.dataset import Dataset
import TNTM_SentenceTransformer
import TNTM_inference

def load_dataset(num_docs=600):
    dataset = Dataset()
    dataset.fetch_dataset("20NewsGroup")
    corpus = dataset.get_corpus()[:num_docs]
    vocab = sorted(dataset._Dataset__vocabulary)
    print(f"Usando {len(corpus)} documentos")
    print(f"Tamaño del vocabulario: {len(vocab)}")
    return corpus, vocab

def load_word_embeddings(file_path):
    word_df = pickle.load(open(file_path, "rb"))
    
    def embeddings_to_tensor(emb_list):
        if isinstance(emb_list, list):
            return torch.stack(emb_list)
        elif isinstance(emb_list, torch.Tensor):
            return emb_list
        else:
            raise ValueError(f"Tipo inesperado: {type(emb_list)}")
    
    word_df['emb_tensor'] = word_df['embedding'].apply(embeddings_to_tensor)
    
    word_to_emb = dict(zip(word_df.index, word_df['emb_tensor']))
    embedding_dim = word_df['emb_tensor'].iloc[0].shape[0]
    
    return word_to_emb, embedding_dim

def compute_word_embeddings(vocab, word_to_emb, embedding_dim):
    word_embeddings = torch.stack([word_to_emb.get(w, torch.zeros(embedding_dim)) for w in vocab])
    print(f"word_embeddings shape: {word_embeddings.shape}")
    return word_embeddings

def compute_document_embeddings(corpus, word_to_emb, embedding_dim):
    print("Calculando embeddings de documentos...")
    document_embeddings_list = []
    for doc in tqdm(corpus, desc="Document embeddings"):
        if len(doc) == 0:
            doc_emb = torch.zeros(embedding_dim)
        else:
            doc_embs = [word_to_emb.get(w, torch.zeros(embedding_dim)) for w in doc]
            doc_emb = torch.stack(doc_embs).mean(dim=0)
        document_embeddings_list.append(doc_emb)
    
    document_embeddings = torch.stack(document_embeddings_list)
    print(f"document_embeddings shape: {document_embeddings.shape}")
    return document_embeddings

def split_dataset(n_docs):
    indices = np.arange(n_docs)
    np.random.seed(42)
    np.random.shuffle(indices)
    
    n_train = int(0.8 * n_docs)
    n_val = int(0.1 * n_docs)
    n_test = n_docs - n_train - n_val
    
    train_idx = indices[:n_train]
    val_idx = indices[n_train:n_train + n_val]
    test_idx = indices[n_train + n_val:]
    
    print(f"Split: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}")
    return train_idx, val_idx, test_idx

def custom_collate(batch):
    document_embs, bow_tens = zip(*batch)
    document_embs = torch.stack(document_embs)
    bow_tens_dense = [t.to_dense() if t.is_sparse else t for t in bow_tens]
    bow_tens = torch.stack(bow_tens_dense)
    return document_embs, bow_tens

def patch_train_test_split():
    def patched_train_test_split(dataset, train_frac, val_frac, batch_size):
        tot_len = len(dataset)
        train_len = int(tot_len * train_frac)
        val_len = int(tot_len * val_frac)
        test_len = tot_len - train_len - val_len
        train, val, test = torch.utils.data.random_split(dataset, [train_len, val_len, test_len])
        
        train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True, collate_fn=custom_collate)
        val_loader = torch.utils.data.DataLoader(val, batch_size=batch_size, shuffle=True, collate_fn=custom_collate)
        test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=True, collate_fn=custom_collate)
        
        return train_loader, val_loader, test_loader
    
    TNTM_inference.train_test_split = patched_train_test_split

def train_model(tntm, corpus_train_val, vocab, word_embeddings, document_embeddings_train_val):
    print("Iniciando entrenamiento...")
    result = tntm.fit(
        corpus=corpus_train_val,
        vocab=vocab,
        word_embeddings=word_embeddings,
        document_embeddings=document_embeddings_train_val
    )
    print("¡Entrenamiento completado con éxito!")
    print(f"Resultados guardados en: {tntm.save_path}")
    return result

def save_test_set(save_dir, corpus, document_embeddings, test_idx):
    test_data = {
        "corpus": [corpus[i] for i in test_idx],
        "document_embeddings": document_embeddings[test_idx],
        "original_indices": test_idx.tolist()
    }
    with open(os.path.join(save_dir, "test_set_60docs.pickle"), "wb") as f:
        pickle.dump(test_data, f)
    print("Conjunto de test guardado para evaluación posterior.")


  from .autonotebook import tqdm as notebook_tqdm


current device: cpu
current device: cpu


In [5]:
embeddings_file = "TNTM/Data/DataResults_BERT/cleaned_embedding_df_20ng_BERT.pickle"
save_dir = "TNTM/Data/example/20_topics_600docs"
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, "model.pt")

corpus, vocab = load_dataset(num_docs=600)

word_to_emb, embedding_dim = load_word_embeddings(embeddings_file)
word_embeddings = compute_word_embeddings(vocab, word_to_emb, embedding_dim)

document_embeddings = compute_document_embeddings(corpus, word_to_emb, embedding_dim)

Usando 600 documentos
Tamaño del vocabulario: 1612
word_embeddings shape: torch.Size([1612, 768])
Calculando embeddings de documentos...


Document embeddings: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 600/600 [00:00<00:00, 5448.27it/s]

document_embeddings shape: torch.Size([600, 768])





In [6]:
train_idx, val_idx, test_idx = split_dataset(len(corpus))

train_val_idx = np.concatenate((train_idx, val_idx))
corpus_train_val = [corpus[i] for i in train_val_idx]
document_embeddings_train_val = document_embeddings[torch.tensor(train_val_idx)]

total_train_val = len(train_idx) + len(val_idx)
validation_set_size = 1 - ((len(train_idx) - 0.5) / total_train_val)
print(f"validation_set_size ajustado: {validation_set_size:.6f} (para evitar test vacío interno)")

tntm = TNTM_SentenceTransformer.TNTM_SentenceTransformer(
    n_topics=20,
    save_path=save_path,
    enc_lr=1e-3,
    dec_lr=1e-3,
    validation_set_size=validation_set_size
)

patch_train_test_split()

result = train_model(tntm, corpus_train_val, vocab, word_embeddings, document_embeddings_train_val)

save_test_set(save_dir, corpus, document_embeddings, test_idx)

Split: Train=480, Val=60, Test=60
validation_set_size ajustado: 0.112037 (para evitar test vacío interno)
Iniciando entrenamiento...


  mus_init_ten = torch.tensor(mus_init).to(self.device)
  L_lower_init_ten = torch.tensor(L_lower_init).to(self.device)
  log_diag_init_ten = torch.tensor(log_diag_init).to(self.device)


Epoch nr 0: mean_train_loss = -537.1713256835938, mean_train_nl = -539.144287109375, mean_train_kld = 1.9729321002960205, elapsed time: 0.5833618640899658
Epoch nr 0: median_train_loss = -484.56451416015625, median_train_nl = -486.8547058105469, median_train_kld = 1.7464098930358887, elapsed time: 0.5833618640899658
Epoch nr 0: mean_val_loss = -441.6573486328125, mean_val_nl = -445.99737548828125, mean_val_kld = 4.339963912963867
Epoch nr 0: median_val_loss = -439.7626953125, median_val_nl = -444.1026306152344, median_val_kld = 4.339963912963867
gradient norm: mean: 671.1299402849223, median: 581.520677780878, max: 1004.8846705383926


Epoch nr 1: mean_train_loss = -548.2445678710938, mean_train_nl = -554.32958984375, mean_train_kld = 6.085070610046387, elapsed time: 0.27857446670532227
Epoch nr 1: median_train_loss = -550.9362182617188, median_train_nl = -557.143310546875, median_train_kld = 6.207151412963867, elapsed time: 0.27857446670532227
Epoch nr 1: mean_val_loss = -448.17523193