In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sklearn.metrics.pairwise import cosine_distances
from sklearn.preprocessing import normalize

In [2]:
class MLPEncoder(nn.Module):
    
    def __init__(self, vocab_size, num_topic, hidden_dim, dropout):
        super().__init__()

        self.fc11 = nn.Linear(vocab_size, hidden_dim)
        self.fc12 = nn.Linear(hidden_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, num_topic)
        self.fc22 = nn.Linear(hidden_dim, num_topic)

        self.fc1_drop = nn.Dropout(dropout)

        self.mean_bn = nn.BatchNorm1d(num_topic, affine=True)
        self.mean_bn.weight.requires_grad = False
        self.logvar_bn = nn.BatchNorm1d(num_topic, affine=True)
        self.logvar_bn.weight.requires_grad = False

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mu + (eps * std)
        else:
            return mu

    def forward(self, x):
        e1 = F.softplus(self.fc11(x))
        e1 = F.softplus(self.fc12(e1))
        e1 = self.fc1_drop(e1)
        mu = self.mean_bn(self.fc21(e1))
        logvar = self.logvar_bn(self.fc22(e1))
        z = self.reparameterize(mu, logvar)
        return z, mu, logvar

In [3]:
class RPS_XTM(nn.Module):

    def __init__(self, input_size, vocab_size_en, vocab_size_cn,
                 num_topics, DCL_weight, temperature, en_units=200, dropout=0.1):
        
        super().__init__()
        
        self.DCL_weight = DCL_weight
        self.num_topics = num_topics
        self.temperature = temperature
        
        self.encoder = MLPEncoder(input_size, num_topics, en_units, dropout)
        # self.encoder_en = MLPEncoder(input_size, num_topics, en_units, dropout)
        # self.encoder_cn = MLPEncoder(input_size, num_topics, en_units, dropout)
        self.z_drop = nn.Dropout(dropout)

        self.a = 1 * np.ones((1, int(num_topics))).astype(np.float32)
        self.mu2 = nn.Parameter(torch.as_tensor((np.log(self.a).T - np.mean(np.log(self.a), 1)).T), requires_grad=False)
        self.var2 = nn.Parameter(torch.as_tensor((((1.0 / self.a) * (1 - (2.0 / num_topics))).T + (1.0 / (num_topics * num_topics)) * np.sum(1.0 / self.a, 1)).T), requires_grad=False)

        self.decoder_bn_en = nn.BatchNorm1d(vocab_size_en, affine=True)
        self.decoder_bn_en.weight.requires_grad = False
        self.decoder_bn_cn = nn.BatchNorm1d(vocab_size_cn, affine=True)
        self.decoder_bn_cn.weight.requires_grad = False

        self.phi_en = nn.Parameter(nn.init.xavier_uniform_(torch.empty((num_topics, vocab_size_en))))
        self.phi_cn = nn.Parameter(nn.init.xavier_uniform_(torch.empty((num_topics, vocab_size_en))))
        
    def get_beta(self):
        beta_en = self.phi_en
        beta_cn = self.phi_cn
        return beta_en, beta_cn
    
    
    def get_latent_vector(self, x):
        
        z, mu, logvar = self.encoder(x)

        if self.training:
            return z, mu, logvar
        else:
            return mu
        
    
    '''
    def get_latent_vector_en(self, x):
        z, mu, logvar = self.encoder_en(x)

        if self.training:
            return z, mu, logvar
        else:
            return mu
        
    def get_latent_vector_cn(self, x):
        z, mu, logvar = self.encoder_cn(x)

        if self.training:
            return z, mu, logvar
        else:
            return mu
    '''
        
    def get_theta(self, x):
        theta = F.softmax(x, dim=1)
        theta = self.z_drop(theta)
        return theta

    def decode(self, theta, beta, lang):
        bn = getattr(self, f'decoder_bn_{lang}')
        d1 = F.softmax(bn(torch.matmul(theta, beta)), dim=1)
        return d1
    
    def forward(self, x_en, x_cn, x_en_bow, x_cn_bow, labels_en, labels_cn, labels_c2e, labels_e2c):
                
        z_en, mu_en, logvar_en = self.get_latent_vector(x_en)
        z_cn, mu_cn, logvar_cn = self.get_latent_vector(x_cn)
        
        dcl_loss = 0.
        
        dcl_loss_e2c = self.compute_dcl_loss(z_en, z_cn, labels_en, labels_e2c)
        dcl_loss_c2e = self.compute_dcl_loss(z_cn, z_en, labels_cn, labels_c2e)
        
        dcl_loss = dcl_loss_e2c + dcl_loss_c2e
        # dcl_loss = dcl_loss_e2c
        # dcl_loss = dcl_loss_c2e
        dcl_loss = self.DCL_weight * dcl_loss
        
        theta_en = self.get_theta(z_en)
        theta_cn = self.get_theta(z_cn)

        beta_en, beta_cn = self.get_beta()

        TM_loss = 0.

        x_recon_en = self.decode(theta_en, beta_en, lang='en')
        x_recon_cn = self.decode(theta_cn, beta_cn, lang='cn')
        loss_en = self.compute_loss_TM(x_recon_en, x_en_bow, mu_en, logvar_en)
        loss_cn = self.compute_loss_TM(x_recon_cn, x_cn_bow, mu_cn, logvar_cn)

        TM_loss = loss_en + loss_cn

        total_loss = TM_loss + dcl_loss
        
        rst_dict = {
            'topic_modeling_loss': TM_loss,
            'contrastive_loss': dcl_loss,
            'total_loss': total_loss
        }

        return rst_dict

    def compute_loss_TM(self, recon_x, x, mu, logvar):
        var = logvar.exp()
        var_division = var / self.var2
        diff = mu - self.mu2
        diff_term = diff * diff / self.var2
        logvar_division = self.var2.log() - logvar
        KLD = 0.5 * ((var_division + diff_term + logvar_division).sum(1) - self.num_topics)

        RECON = -(x * (recon_x + 1e-10).log()).sum(1)

        LOSS = (RECON + KLD).mean()
        return LOSS
    
    
    def compute_dcl_loss(self, z_en, z_cn, labels_en, labels_e2c):
        batch_size, embedding_dim = z_en.size()

        # Initialize prototypes for each label
        unique_labels = torch.unique(torch.cat((labels_en, labels_e2c)))
        prototypes_en = torch.zeros((len(unique_labels), embedding_dim), device=z_en.device)
        prototypes_cn = torch.zeros((len(unique_labels), embedding_dim), device=z_cn.device)

        # Compute prototypes for English and Chinese embeddings
        for i, label in enumerate(unique_labels):
            en_mask = (labels_en == label).unsqueeze(-1)  # Mask for English documents with label
            cn_mask = (labels_e2c == label).unsqueeze(-1)  # Mask for Chinese documents with label

            if en_mask.any():
                prototypes_en[i] = (z_en * en_mask).sum(dim=0) / (en_mask.sum() + 1e-8)  # Avoid division by zero
            if cn_mask.any():
                prototypes_cn[i] = (z_cn * cn_mask).sum(dim=0) / (cn_mask.sum() + 1e-8)  # Avoid division by zero

        # Compute anchor-positive similarities
        logits = torch.mm(prototypes_en, prototypes_cn.t())  # Similarity matrix between English and Chinese prototypes
        logits /= self.temperature  # Apply temperature scaling

        # Normalize prototypes
        logits /= torch.norm(prototypes_en, dim=1, keepdim=True) + 1e-8  # Avoid division by zero
        logits /= (torch.norm(prototypes_cn, dim=1, keepdim=True).t() + 1e-8)  # Avoid division by zero

        # Create positive mask
        positive_mask = torch.eye(len(unique_labels), device=z_en.device)

        # Compute InfoNCE loss
        numerator = torch.exp(logits) * positive_mask
        denominator = torch.exp(torch.clamp(logits, min=-1e4, max=1e4))  # Clip logits to avoid large values

        # Avoid divide-by-zero
        loss = -torch.log((numerator.sum(dim=1) + 1e-8) / (denominator.sum(dim=1) + 1e-8))

        return loss.mean()

In [4]:
from torch.utils.data import TensorDataset, DataLoader
import torch

def create_dataloader_separate(sbert_doc_embeddings_en, bow_en, labels_en, labels_c2e,
                               sbert_doc_embeddings_cn, bow_cn, labels_cn, labels_e2c,
                               batch_size_en, batch_size_cn):
    # Convert English data to tensors
    embeddings_en = torch.tensor(sbert_doc_embeddings_en, dtype=torch.float32)
    bow_en_tensor = torch.tensor(bow_en, dtype=torch.float32)
    labels_en_tensor = torch.tensor(labels_en, dtype=torch.float32)
    labels_c2e = torch.tensor(labels_c2e, dtype=torch.float32)
    dataset_en = TensorDataset(embeddings_en, bow_en_tensor, labels_en_tensor, labels_c2e)
    dataloader_en = DataLoader(dataset_en, batch_size=batch_size_en, shuffle=True)

    # Convert Chinese data to tensors
    embeddings_cn = torch.tensor(sbert_doc_embeddings_cn, dtype=torch.float32)
    bow_cn_tensor = torch.tensor(bow_cn, dtype=torch.float32)
    labels_cn_tensor = torch.tensor(labels_cn, dtype=torch.float32)
    labels_e2c = torch.tensor(labels_e2c, dtype=torch.float32)
    dataset_cn = TensorDataset(embeddings_cn, bow_cn_tensor, labels_cn_tensor, labels_e2c)
    dataloader_cn = DataLoader(dataset_cn, batch_size=batch_size_cn, shuffle=True)

    return dataloader_en, dataloader_cn

In [5]:
def train_RPSXTM(model, dataloader_en, dataloader_cn, optimizer, num_epochs=500, device='cpu'):

    model.to(device)
    model.train()

    for epoch in range(num_epochs):
        epoch_loss = 0.0

        # Zip the two dataloaders to iterate over them simultaneously
        for (x_en, x_en_bow, labels_en, labels_c2e), (x_cn, x_cn_bow, labels_cn, labels_e2c) in zip(dataloader_en, dataloader_cn):
            # Move data to the specified device
            x_en, x_en_bow, labels_en, labels_c2e = x_en.to(device), x_en_bow.to(device), labels_en.to(device), labels_c2e.to(device)
            x_cn, x_cn_bow, labels_cn, labels_e2c = x_cn.to(device), x_cn_bow.to(device), labels_cn.to(device), labels_e2c.to(device)

            # Forward pass
            outputs = model(x_en, x_cn, x_en_bow, x_cn_bow, labels_en, labels_cn, labels_c2e, labels_e2c)

            # Handle potential keys in the outputs
            tm_loss = outputs.get('topic_modeling_loss', torch.tensor(0.0, device=device))
            dcl_loss = outputs.get('contrastive_loss', torch.tensor(0.0, device=device))
            total_loss = outputs.get('total_loss', tm_loss + dcl_loss)

            # Backward pass and optimization
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            epoch_loss += total_loss.item()

        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss / len(dataloader_en)}")

    return model

In [6]:
import time
import torch

def train_RPSXTM(model, dataloader_en, dataloader_cn, optimizer, num_epochs=500, device='cpu'):
    model.to(device)
    model.train()

    for epoch in range(num_epochs):
        start_time = time.time()  # 에폭 시작 시간 기록
        epoch_loss = 0.0

        # Zip the two dataloaders to iterate over them simultaneously
        for (x_en, x_en_bow, labels_en, labels_c2e), (x_cn, x_cn_bow, labels_cn, labels_e2c) in zip(dataloader_en, dataloader_cn):
            # Move data to the specified device
            x_en = x_en.to(device)
            x_en_bow = x_en_bow.to(device)
            labels_en = labels_en.to(device)
            labels_c2e = labels_c2e.to(device)
            x_cn = x_cn.to(device)
            x_cn_bow = x_cn_bow.to(device)
            labels_cn = labels_cn.to(device)
            labels_e2c = labels_e2c.to(device)

            # Forward pass
            outputs = model(x_en, x_cn, x_en_bow, x_cn_bow, labels_en, labels_cn, labels_c2e, labels_e2c)

            # Handle potential keys in the outputs
            tm_loss = outputs.get('topic_modeling_loss', torch.tensor(0.0, device=device))
            dcl_loss = outputs.get('contrastive_loss', torch.tensor(0.0, device=device))
            total_loss = outputs.get('total_loss', tm_loss + dcl_loss)

            # Backward pass and optimization
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            epoch_loss += total_loss.item()

        end_time = time.time()  # 에폭 종료 시간 기록
        epoch_time = end_time - start_time

        print(f"Epoch {epoch + 1}/{num_epochs}, "
              f"Loss: {epoch_loss / len(dataloader_en):.4f}, "
              f"Time: {epoch_time:.2f} sec")

    return model

In [7]:
# 사용 예제 on ECNews
if __name__ == "__main__":
    # SBERT 문서 임베딩 및 BoW 표현 로드
    doc_embeddings_en_path = "/users/seung-won/documents/datasets/Amazon_Review/AR_sbert_doc_embeddings_en.npz"
    doc_embeddings_cn_path = "/users/seung-won/documents/datasets/Amazon_Review/AR_sbert_doc_embeddings_cn.npz"
    
    # doc_embeddings_en_path = "/users/seung-won/documents/datasets/Amazon_Review/AR_sbert_doc_embeddings_en.npz"
    # doc_embeddings_cn_path = "/users/seung-won/documents/datasets/Amazon_Review/AR_sbert_doc_embeddings_cn.npz"
    
    # bow_embeddings_en_path = "/users/seung-won/documents/datasets/ECNews/ECN_bow_embeddings_en.npy"
    # bow_embeddings_cn_path = "/users/seung-won/documents/datasets/ECNews/ECN_bow_embeddings_cn.npy"
    
    bow_embeddings_en_path = "/users/seung-won/documents/datasets/Amazon_Review/AR_bow_embeddings_en.npy"
    bow_embeddings_cn_path = "/users/seung-won/documents/datasets/Amazon_Review/AR_bow_embeddings_cn.npy"

    labels_en_path = "/users/seung-won/documents/TPL_method/data/Amazon_Review/XLM_labels_en_50.npy"
    labels_cn_path = "/users/seung-won/documents/TPL_method/data/Amazon_Review/XLM_labels_cn_50.npy"
    
    # labels_en_path = "/users/seung-won/documents/TPL_method/data/Amazon_Review/k=50/labels_en.npy"
    # labels_cn_path = "/users/seung-won/documents/TPL_method/data/Amazon_Review/k=50/labels_cn.npy"

    
    # labels_c2e_path = "/users/seung-won/documents/RPS/data/ECNews/labels_c2e_70.npy"
    # labels_e2c_path = "/users/seung-won/documents/RPS/data/ECNews/labels_e2c_70.npy"
    
    labels_c2e_path = "/users/seung-won/documents/RPS/data/Amazon_Review/XLM_labels_c2e_30.npy"
    labels_e2c_path = "/users/seung-won/documents/RPS/data/Amazon_Review/XLM_labels_e2c_30.npy"

    sbert_doc_embeddings_en = np.load(doc_embeddings_en_path)  # English SBERT embeddings
    sbert_doc_embeddings_cn = np.load(doc_embeddings_cn_path)  # Chinese SBERT embeddings
    
    sbert_doc_embeddings_en = sbert_doc_embeddings_en['embeddings']
    sbert_doc_embeddings_cn = sbert_doc_embeddings_cn['embeddings']
    
    bow_en = np.load(bow_embeddings_en_path)  # English BoW embeddings
    bow_cn = np.load(bow_embeddings_cn_path)  # Chinese BoW embeddings
    
    labels_en = np.load(labels_en_path)
    labels_cn = np.load(labels_cn_path)
    
    labels_c2e = np.load(labels_c2e_path)
    labels_e2c = np.load(labels_e2c_path)
    

    # 모델 초기화
    input_size = sbert_doc_embeddings_en.shape[1]
    vocab_size_en = bow_en.shape[1]
    vocab_size_cn = bow_cn.shape[1]
    num_topics = 20
    DCL_weight = 1
    temperature = 0.1
    
    model = RPS_XTM(input_size=input_size, vocab_size_en=vocab_size_en, vocab_size_cn=vocab_size_cn,
                    num_topics=num_topics, DCL_weight=DCL_weight,
                    temperature=temperature, en_units=200, dropout=0.1)

    # Optimizer 정의
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    # DataLoader 생성
    batch_size_en = 512
    batch_size_cn = 512
    
    dataloader_en, dataloader_cn = create_dataloader_separate(sbert_doc_embeddings_en, bow_en, labels_en,
                                                              labels_c2e, sbert_doc_embeddings_cn, bow_cn,
                                                              labels_cn, labels_e2c,
                                                              batch_size_en, batch_size_cn)

    # 모델 학습
    trained_model = train_RPSXTM(model, dataloader_en, dataloader_cn, optimizer, num_epochs=500, device='cpu')

Epoch 1/500, Loss: 698.5715, Time: 1.72 sec
Epoch 2/500, Loss: 702.4260, Time: 1.59 sec
Epoch 3/500, Loss: 699.6259, Time: 1.57 sec
Epoch 4/500, Loss: 695.0632, Time: 1.50 sec
Epoch 5/500, Loss: 696.4761, Time: 1.50 sec
Epoch 6/500, Loss: 692.0394, Time: 1.52 sec
Epoch 7/500, Loss: 688.4872, Time: 1.66 sec
Epoch 8/500, Loss: 686.8682, Time: 1.57 sec
Epoch 9/500, Loss: 681.0563, Time: 1.54 sec
Epoch 10/500, Loss: 674.9122, Time: 1.58 sec
Epoch 11/500, Loss: 673.5864, Time: 1.60 sec
Epoch 12/500, Loss: 669.7195, Time: 1.73 sec
Epoch 13/500, Loss: 666.6424, Time: 1.57 sec
Epoch 14/500, Loss: 664.6134, Time: 1.71 sec
Epoch 15/500, Loss: 661.0383, Time: 1.53 sec
Epoch 16/500, Loss: 655.8762, Time: 1.57 sec
Epoch 17/500, Loss: 653.8022, Time: 1.60 sec
Epoch 18/500, Loss: 652.6873, Time: 1.56 sec
Epoch 19/500, Loss: 651.5659, Time: 1.64 sec
Epoch 20/500, Loss: 648.3586, Time: 1.59 sec
Epoch 21/500, Loss: 643.9933, Time: 1.57 sec
Epoch 22/500, Loss: 645.2629, Time: 1.57 sec
Epoch 23/500, Loss:

Epoch 182/500, Loss: 574.1979, Time: 1.66 sec
Epoch 183/500, Loss: 575.2690, Time: 1.60 sec
Epoch 184/500, Loss: 573.1657, Time: 1.63 sec
Epoch 185/500, Loss: 574.2065, Time: 1.67 sec
Epoch 186/500, Loss: 574.3806, Time: 1.69 sec
Epoch 187/500, Loss: 574.8213, Time: 1.62 sec
Epoch 188/500, Loss: 574.8298, Time: 1.67 sec
Epoch 189/500, Loss: 575.5063, Time: 1.62 sec
Epoch 190/500, Loss: 574.4574, Time: 1.68 sec
Epoch 191/500, Loss: 572.1597, Time: 1.66 sec
Epoch 192/500, Loss: 574.4739, Time: 1.62 sec
Epoch 193/500, Loss: 572.7712, Time: 1.61 sec
Epoch 194/500, Loss: 572.0812, Time: 1.61 sec
Epoch 195/500, Loss: 573.4064, Time: 1.67 sec
Epoch 196/500, Loss: 573.1377, Time: 1.61 sec
Epoch 197/500, Loss: 572.2444, Time: 1.61 sec
Epoch 198/500, Loss: 571.2118, Time: 1.66 sec
Epoch 199/500, Loss: 572.0771, Time: 1.70 sec
Epoch 200/500, Loss: 569.8692, Time: 1.66 sec
Epoch 201/500, Loss: 574.9997, Time: 1.60 sec
Epoch 202/500, Loss: 572.7728, Time: 1.61 sec
Epoch 203/500, Loss: 570.8668, Tim

Epoch 361/500, Loss: 558.7311, Time: 1.61 sec
Epoch 362/500, Loss: 557.2483, Time: 1.61 sec
Epoch 363/500, Loss: 556.3551, Time: 1.62 sec
Epoch 364/500, Loss: 556.5291, Time: 1.61 sec
Epoch 365/500, Loss: 557.9483, Time: 1.67 sec
Epoch 366/500, Loss: 557.2259, Time: 1.61 sec
Epoch 367/500, Loss: 556.6408, Time: 1.66 sec
Epoch 368/500, Loss: 558.5637, Time: 1.66 sec
Epoch 369/500, Loss: 555.5112, Time: 1.65 sec
Epoch 370/500, Loss: 554.4832, Time: 1.73 sec
Epoch 371/500, Loss: 554.9550, Time: 1.62 sec
Epoch 372/500, Loss: 555.8599, Time: 1.63 sec
Epoch 373/500, Loss: 555.1079, Time: 1.62 sec
Epoch 374/500, Loss: 555.5967, Time: 1.67 sec
Epoch 375/500, Loss: 557.5265, Time: 1.61 sec
Epoch 376/500, Loss: 557.9375, Time: 1.68 sec
Epoch 377/500, Loss: 559.7440, Time: 1.61 sec
Epoch 378/500, Loss: 555.5316, Time: 1.66 sec
Epoch 379/500, Loss: 556.1356, Time: 1.66 sec
Epoch 380/500, Loss: 555.4844, Time: 1.61 sec
Epoch 381/500, Loss: 554.5845, Time: 1.61 sec
Epoch 382/500, Loss: 556.4739, Tim

In [8]:
def save_top_words(beta, vocab, file_path, top_n=15):

    # Convert beta to numpy array
    beta_np = beta.detach().cpu().numpy()

    # Open file for writing
    with open(file_path, "w") as f:
        for topic_idx, topic_dist in enumerate(beta_np):
            # Get top N word indices for the topic
            top_word_indices = topic_dist.argsort()[-top_n:][::-1]
            # Map indices to words
            top_words = [vocab[idx] for idx in top_word_indices]
            # Write topic and words to file
            f.write(" ".join(top_words) + "\n")

# 사용 예제
if __name__ == "__main__":
    # 모델에서 beta_en, beta_cn 가져오기
    beta_en, beta_cn = trained_model.get_beta()

    # 영어와 중국어의 vocabulary 로드
    vocab_en = [line.strip() for line in open("/users/seung-won/documents/datasets/Amazon_Review/AR_vocab_en", encoding="utf-8").readlines()]
    vocab_cn = [line.strip() for line in open("/users/seung-won/documents/datasets/Amazon_Review/AR_vocab_cn", encoding="utf-8").readlines()]

    # 상위 15개 단어를 각각 txt 파일로 저장
    save_top_words(beta_en, vocab_en, "/users/seung-won/documents/topic_en.txt", top_n=15)
    save_top_words(beta_cn, vocab_cn, "/users/seung-won/documents/topic_cn.txt", top_n=15)

    print("Top words saved to 'top_words_en.txt' and 'top_words_cn.txt'")

Top words saved to 'top_words_en.txt' and 'top_words_cn.txt'


In [9]:
def create_dataloader_separate_fixed(x_en, x_cn, batch_size_en, batch_size_cn):
    # Convert English data to tensors
    x_en_tensor = torch.tensor(x_en, dtype=torch.float32)
    dataset_en = TensorDataset(x_en_tensor)
    dataloader_en = DataLoader(dataset_en, batch_size=batch_size_en, shuffle=False)

    # Convert Chinese data to tensors
    x_cn_tensor = torch.tensor(x_cn, dtype=torch.float32)
    dataset_cn = TensorDataset(x_cn_tensor)
    dataloader_cn = DataLoader(dataset_cn, batch_size=batch_size_cn, shuffle=False)

    return dataloader_en, dataloader_cn

def save_doc_topic_distributions(model, dataloader, output_file, device):

    model.eval()  # 평가 모드로 설정
    doc_topic_distributions = []

    with torch.no_grad():  # 그래디언트 비활성화
        for batch_idx, (x,) in enumerate(dataloader):  # 배치 데이터 로드
            x = x.to(device)
            # 모델에서 doc-topic 분포 예측
            mu = model.get_latent_vector(x)
            theta = model.get_theta(mu)  # 모델의 get_theta 함수 호출
            doc_topic_distributions.append(theta.cpu().numpy())

    # 모든 배치를 하나로 합치기
    doc_topic_distributions = np.concatenate(doc_topic_distributions, axis=0)

    # .npy 파일로 저장
    np.save(output_file, doc_topic_distributions)
    print(f"Doc-topic distributions saved to {output_file}")

    
# 사용 예시
# 학습된 모델과 DataLoader 준비
device = 'cpu'
output_file_en = "/users/seung-won/documents/AR_doc_topic_dist_en_20.npy"
output_file_cn = "/users/seung-won/documents/AR_doc_topic_dist_cn_20.npy"


# doc-topic 분포 저장
dataloader_en, dataloader_cn= create_dataloader_separate_fixed(sbert_doc_embeddings_en, sbert_doc_embeddings_cn,
                                                               batch_size_en=128, batch_size_cn=128)
save_doc_topic_distributions(model, dataloader_en, output_file_en, device)
save_doc_topic_distributions(model, dataloader_cn, output_file_cn, device)

Doc-topic distributions saved to /users/seung-won/documents/AR_doc_topic_dist_en_20.npy
Doc-topic distributions saved to /users/seung-won/documents/AR_doc_topic_dist_cn_20.npy
