In [1]:
import pickle
import torch
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import os
import sys
sys.path.append('TNTM/Code/TNTM')
from octis.dataset.dataset import Dataset
import TNTM_SentenceTransformer
import TNTM_inference

def load_dataset(num_docs=1000):
    dataset = Dataset()
    dataset.fetch_dataset("20NewsGroup")
    corpus = dataset.get_corpus()[:num_docs]
    vocab = sorted(dataset._Dataset__vocabulary)
    print(f"Usando {len(corpus)} documentos")
    print(f"Tamaño del vocabulario: {len(vocab)}")
    return corpus, vocab

def load_word_embeddings(file_path):
    word_df = pickle.load(open(file_path, "rb"))
  
    def embeddings_to_tensor(emb_list):
        if isinstance(emb_list, list):
            return torch.stack(emb_list)
        elif isinstance(emb_list, torch.Tensor):
            return emb_list
        else:
            raise ValueError(f"Tipo inesperado: {type(emb_list)}")
  
    word_df['emb_tensor'] = word_df['embedding'].apply(embeddings_to_tensor)
  
    word_to_emb = dict(zip(word_df.index, word_df['emb_tensor']))
    embedding_dim = word_df['emb_tensor'].iloc[0].shape[0]
  
    return word_to_emb, embedding_dim

def compute_word_embeddings(vocab, word_to_emb, embedding_dim):
    word_embeddings = torch.stack([word_to_emb.get(w, torch.zeros(embedding_dim)) for w in vocab])
    print(f"word_embeddings shape: {word_embeddings.shape}")
    return word_embeddings

def compute_document_embeddings(corpus, word_to_emb, embedding_dim):
    print("Calculando embeddings de documentos...")
    document_embeddings_list = []
    for doc in tqdm(corpus, desc="Document embeddings"):
        if len(doc) == 0:
            doc_emb = torch.zeros(embedding_dim)
        else:
            doc_embs = [word_to_emb.get(w, torch.zeros(embedding_dim)) for w in doc]
            doc_emb = torch.stack(doc_embs).mean(dim=0)
        document_embeddings_list.append(doc_emb)
  
    document_embeddings = torch.stack(document_embeddings_list)
    print(f"document_embeddings shape: {document_embeddings.shape}")
    return document_embeddings

def split_dataset(n_docs):
    indices = np.arange(n_docs)
    np.random.seed(42)
    np.random.shuffle(indices)
  
    n_train = int(0.7 * n_docs)  # 700
    n_val = int(0.2 * n_docs)    # 200
    n_test = n_docs - n_train - n_val  # 100
  
    train_idx = indices[:n_train]
    val_idx = indices[n_train:n_train + n_val]
    test_idx = indices[n_train + n_val:]
  
    print(f"Split: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}")
    return train_idx, val_idx, test_idx

def custom_collate(batch):
    document_embs, bow_tens = zip(*batch)
    document_embs = torch.stack(document_embs)
    bow_tens_dense = [t.to_dense() if t.is_sparse else t for t in bow_tens]
    bow_tens = torch.stack(bow_tens_dense)
    return document_embs, bow_tens

def patch_train_test_split():
    def patched_train_test_split(dataset, train_frac, val_frac, batch_size):
        tot_len = len(dataset)
        train_len = int(tot_len * train_frac)
        val_len = int(tot_len * val_frac)
        test_len = tot_len - train_len - val_len
        train, val, test = torch.utils.data.random_split(dataset, [train_len, val_len, test_len])
      
        train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True, collate_fn=custom_collate)
        val_loader = torch.utils.data.DataLoader(val, batch_size=batch_size, shuffle=True, collate_fn=custom_collate)
        test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=True, collate_fn=custom_collate)
      
        return train_loader, val_loader, test_loader
  
    TNTM_inference.train_test_split = patched_train_test_split

def showTopicsTNTM():
    for i in range(topic_sorted_words.shape[0]):
        top_words = topic_sorted_words[i][:10].tolist()
        print(f"Topic {i}: {' '.join(top_words)}")


  from .autonotebook import tqdm as notebook_tqdm


current device: cpu
current device: cpu


In [2]:
save_dir = "TNTM/Data/example/10_topics_1000docs"
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, "model.pt")

embeddings_file = "TNTM/Data/DataResults_BERT/cleaned_embedding_df_20ng_BERT.pickle"

corpus, vocab = load_dataset(num_docs=1000)
word_to_emb, embedding_dim = load_word_embeddings(embeddings_file)
word_embeddings = compute_word_embeddings(vocab, word_to_emb, embedding_dim)
document_embeddings = compute_document_embeddings(corpus, word_to_emb, embedding_dim)

train_idx, val_idx, test_idx = split_dataset(len(corpus))

train_val_idx = np.concatenate((train_idx, val_idx))
corpus_train_val = [corpus[i] for i in train_val_idx]
document_embeddings_train_val = document_embeddings[torch.tensor(train_val_idx)]

patch_train_test_split()

print("Iniciando entrenamiento...")
tntm = TNTM_SentenceTransformer.TNTM_SentenceTransformer(
    n_topics=10,
    save_path=save_path,
    enc_lr=1e-3,
    dec_lr=1e-3,
    validation_set_size=0.222,
    n_epochs=500,
    n_epochs_early_stopping=50,
    early_stopping=True,
    n_topwords=10
)

topic_sorted_words, topic_probs = tntm.fit(
    corpus=corpus_train_val,
    vocab=vocab,
    word_embeddings=word_embeddings,
    document_embeddings=document_embeddings_train_val
)

print("¡Entrenamiento completado con éxito!")

print("\n=== Tópicos TNTM ===")
showTopicsTNTM()

Usando 1000 documentos
Tamaño del vocabulario: 1612
word_embeddings shape: torch.Size([1612, 768])
Calculando embeddings de documentos...


Document embeddings: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 4154.96it/s]


document_embeddings shape: torch.Size([1000, 768])
Split: Train=700, Val=200, Test=100
Iniciando entrenamiento...


  mus_init_ten = torch.tensor(mus_init).to(self.device)
  L_lower_init_ten = torch.tensor(L_lower_init).to(self.device)
  log_diag_init_ten = torch.tensor(log_diag_init).to(self.device)


Epoch nr 0: mean_train_loss = -425.5240783691406, mean_train_nl = -427.563232421875, mean_train_kld = 2.039156436920166, elapsed time: 0.19446420669555664
Epoch nr 0: median_train_loss = -410.65826416015625, median_train_nl = -412.8607177734375, median_train_kld = 1.8102588653564453, elapsed time: 0.19446420669555664
Epoch nr 0: mean_val_loss = -469.6174621582031, mean_val_nl = -474.1396484375, mean_val_kld = 4.522181510925293
Epoch nr 0: median_val_loss = -409.9067077636719, median_val_nl = -414.4252624511719, median_val_kld = 4.518590927124023
gradient norm: mean: 693.3844873919537, median: 664.0408136145145, max: 840.6415926179606


Epoch nr 1: mean_train_loss = -437.5760803222656, mean_train_nl = -443.0464172363281, mean_train_kld = 5.470337390899658, elapsed time: 0.18983721733093262
Epoch nr 1: median_train_loss = -389.94854736328125, median_train_nl = -395.95001220703125, median_train_kld = 5.601262092590332, elapsed time: 0.18983721733093262
Epoch nr 1: mean_val_loss = -425.644