In [1]:
import pandas as pd
import os
import sys
import logging
import pathlib
import random
from itertools import product
from typing import List, Optional, Union

import gensim.downloader as api
import numpy as np

In [2]:

class TopicSelector(object):
    """
    Class to select topics from different topic models.
    """

    def __init__(
        self,
        wmd_model: str = 'word2vec-google-news-300',
        logger: Optional[logging.Logger] = None,
        config_path: pathlib.Path = pathlib.Path("config/config.yaml")
    ) -> None:
        """
        Initialize the TopicSelector class.

        Parameters
        ----------
        logger : logging.Logger, optional
            Logger object to log activity.
        path_logs : pathlib.Path, optional
            Path for saving logs.
        """
        self._wmd_model = wmd_model

        return

    def _get_wmd(self, from_: Union[str, List[str]], to_: Union[str, List[str]], n_words=10) -> float:
        """
        Calculate the Word Mover's Distance between two sentences.

        Parameters
        ----------
        from_ : Union[str, List[str]]
            The source sentence.
        to_ : Union[str, List[str]]
            The target sentence.
        n_words : int
            The number of words to consider in the sentences to calculate the WMD.
        """
        if isinstance(from_, str):
            from_ = from_.split()

        if isinstance(to_, str):
            to_ = to_.split()

        if n_words < len(from_):
            from_ = from_[:n_words]
        if n_words < len(to_):
            to_ = to_[:n_words]

        return self._wmd_model.wmdistance(from_, to_)

    def _get_wmd_mat(self, models: list) -> np.ndarray:
        """Calculate inter-topic distance based topic words using Word Mover's Distance.

        Parameters
        ----------
        models : list
            A list containing two sublists the models. Each sublits is a list of topics, each topic represented as a list of words.

        Returns
        -------
        np.ndarray
            A matrix of Word Mover's Distance between topics from two models.
        """

        if len(models) != 2:
            raise ValueError(
                "models must contain exactly two sublists/arrays.")

        num_topics_first_model = len(models[0])
        num_topics_second_model = len(models[1])
        wmd_sims = np.zeros((num_topics_first_model, num_topics_second_model))

        for k_idx, k in enumerate(models[0]):
            for k__idx, k_ in enumerate(models[1]):
                wmd_sims[k_idx, k__idx] = self._get_wmd(k, k_)

        return wmd_sims

    def iterative_matching(self, models, N, remove_topic_ids=None, seed=2357_11):
        """
        Performs an iterative pairing process between the topics of multiple models.

        Parameters
        ----------
        models : list
            A list containing two sublists the models. Each sublits is a list of topics, each topic represented as a list of words.
        N : int
            Number of matches to find.

        Returns
        -------
        list of list of tuple
            List of lists with the N matches found. Each match is a list of tuples, where each tuple contains the model index and the topic index.
        """
        random.seed(seed)
       
        if remove_topic_ids is not None:
            modified_models = []
            id_mappings = []  # To store mappings for each model
            for i_model, model in enumerate(models):
                # Create a mapping for this model
                mapping = {}
                new_model = []
                new_topic_id = 0
                for i, topic in enumerate(model):
                    if i not in remove_topic_ids[i_model]:
                        new_model.append(topic)
                        # Map new topic ID to the original
                        mapping[new_topic_id] = i
                        new_topic_id += 1
                modified_models.append(new_model)
                id_mappings.append(mapping)  # Store the mapping for this model
            models = modified_models
        else:
            id_mappings = [
                {i: i for i in range(len(model))}
                for model in models
            ]
        # Case 1: Single Model - Return single-topic matches
        if len(models) == 1:

            model = models[0]
            num_topics = len(model)

            if N > num_topics:
                raise ValueError("N must be less than or equal to the number of topics in the model.")

            selected_topics = random.sample(range(num_topics), N)  # Randomly select N topics

            matches = [[(0, topic_id)] for topic_id in selected_topics]  # Wrap each topic in a list

            sampled_matches_original = [[(0, id_mappings[0][topic_id])] for topic_id in selected_topics]

        # Case 2: Multiple Models - Perform cross-model matching
        else:
           
            self._wmd_model = api.load(self._wmd_model)
           
            dists = {}
            for modelA, modelB in product(range(len(models)), range(len(models))):
                dists[(modelA, modelB)] = self._get_wmd_mat(
                    [models[modelA], models[modelB]])

            matches = []  # Matches with filtered topic IDs

            assert (all(N <= len(m) for m in models))
            while len(matches) < min(len(m) for m in models):
                for seed_model in range(len(models)):
                    # Calculate the mean distance to all other models
                    min_dists, min_dists_indices = [], []
                    for other_model in range(len(models)):
                        if seed_model == other_model:
                            min_dists_indices.append((seed_model, None))
                            continue
                        distsAB = dists[(seed_model, other_model)]
                        # Get the minimum distance for each topic in the seed model to the other model
                        min_dists.append(distsAB.min(1))
                        min_dists_indices.append((other_model, distsAB.argmin(1)))
                    mean_min_dists = np.mean(min_dists, axis=0)
                    seed_model_topic = np.argmin(mean_min_dists)
                    seed_model_matches = [
                        (model_idx, indices[seed_model_topic]) if model_idx != seed_model else (
                            model_idx, seed_model_topic)
                        for model_idx, indices in min_dists_indices
                    ]
                    matches.append(seed_model_matches)

                    # Remove the matched topics from the distance matrix
                    for modelA, modelA_topic in seed_model_matches:
                        for modelB in range(len(models)):
                            if modelA != modelB:
                                dists[(modelA, modelB)][modelA_topic, :] = np.inf
                                dists[(modelB, modelA)][:, modelA_topic] = np.inf

            sampled_matches = random.sample(matches, N)

            # Map the sampled matches to their original topic IDs (sampled_matches are just positions in the betas matrix)
            sampled_matches_original = [
                [
                    (model_idx, id_mappings[model_idx][topic_id]
                    if topic_id is not None else None)
                    for model_idx, topic_id in match
                ]
                for match in sampled_matches
            ]

            # Output the sampled matches in both forms
            print("Sampled Matches (Position IDs):", sampled_matches)
            print("Sampled Matches (Original Topic IDs):", sampled_matches_original)

        return sampled_matches_original

In [3]:
def parse_topic_file(filepath, topn=5):
    """
    Parses a topic file with bilingual topic format.

    Returns:
        models: list of two models, each is a list of topics (each topic is a list of keywords)
    """
    models = {0: [], 1: []}
    with open(filepath, 'r', encoding='utf-8') as f:
        lines = f.readlines()

    topic_buffer = {0: None, 1: None}
    current_topic = -1

    for line in lines:
        line = line.strip()
        if not line:
            continue

        # New topic line (e.g., "1    0.02")
        if line[0].isdigit() and '\t' not in line and ' ' not in line.split()[0]:
            current_topic += 1
            topic_buffer = {0: None, 1: None}
            continue

        parts = line.split()
        if len(parts) < 4:
            continue

        model_idx = int(parts[0])
        keywords = parts[3:]

        if topn:
            keywords = keywords[:topn]

        topic_buffer[model_idx] = keywords

        # Once both 0 and 1 are populated, we push them to their respective models
        if all(topic_buffer.values()):
            for idx in (0, 1):
                models[idx].append(topic_buffer[idx])
            topic_buffer = {0: None, 1: None}

    return [models[0], models[1]]


In [4]:
def parse_ZS_file(filepath, topn=None):
    """
    Parses a simple topic file where each line contains:
    <topic_id> <keywords...>

    Returns:
        models: list with a single model (list of topic keyword lists)
    """
    model = []

    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            parts = line.strip().split()
            if not parts:
                continue
            keywords = parts[1:]  # skip topic index
            if topn:
                keywords = keywords[:topn]
            model.append(keywords)

    return [model]


In [7]:
zs_top_path = '/export/usuarios_ml4ds/ammesa/ZS_results/en_2025-06-04_segmented_dataset.parquet.gzip/n_topics_9/ZS_output/topics.txt'
m_trans_path = '/export/usuarios_ml4ds/ammesa/mallet_folder/en_2025-06-04_segm_trans/n_topics_50/mallet_output/topickeys.txt'
m_match_path = '/export/usuarios_ml4ds/ammesa/mallet_folder/en_2025_06_05_matched/n_topics_50/mallet_output/topickeys.txt'

In [8]:
tpcs_trans = parse_topic_file(m_trans_path, 10)
tpcs_match = parse_topic_file(m_match_path, 10)
tpcs_ZS = parse_ZS_file(zs_top_path, 10)

In [9]:
tpcs_ZS

[[['albany',
   'gray',
   'baltimore',
   'church',
   'anglicanism',
   'albany_new',
   'school',
   'alexandria',
   'anglican',
   'bell'],
  ['ley',
   'congreso',
   'presidente',
   'voto',
   'constitución',
   'estados_unidos',
   'coolidge',
   'bill_clinton',
   'artículo_constitución',
   'vicepresidente'],
  ['states_constitution',
   'article_united',
   'amendment',
   'constitution',
   'supreme_court',
   'president',
   'clause',
   'court',
   'congress',
   'senate'],
  ['american_foxhound',
   'hound',
   'breed',
   'augustus_saint',
   'gaudens',
   'american_women',
   'historical_society',
   'archive',
   'anti_clericalism',
   'state_quarter'],
  ['británicos',
   'batalla',
   'fuerte',
   'cochrane',
   'mando',
   'ejército',
   'hombre',
   'ataque',
   'británico',
   'tropas'],
  ['espiridón',
   'presidente_república',
   'arbitraje_derecho',
   'espiridón_tremitunte',
   'theodoros_deligiannis',
   'político_griego',
   'seigneur',
   'laicos',
   't

In [10]:
models = [tpcs_trans[0], tpcs_match[0], tpcs_ZS[0]]

In [11]:
tp = TopicSelector()

tp.iterative_matching(
    models=models,
    N = 6
)


Sampled Matches (Position IDs): [[(0, 10), (1, 9), (2, 6)], [(0, 31), (1, 48), (2, 0)], [(0, 22), (1, 38), (2, 8)], [(0, 6), (1, 25), (2, 7)], [(0, 32), (1, 10), (2, 1)], [(0, 14), (1, 13), (2, 4)]]
Sampled Matches (Original Topic IDs): [[(0, 10), (1, 9), (2, 6)], [(0, 31), (1, 48), (2, 0)], [(0, 22), (1, 38), (2, 8)], [(0, 6), (1, 25), (2, 7)], [(0, 32), (1, 10), (2, 1)], [(0, 14), (1, 13), (2, 4)]]


[[(0, 10), (1, 9), (2, 6)],
 [(0, 31), (1, 48), (2, 0)],
 [(0, 22), (1, 38), (2, 8)],
 [(0, 6), (1, 25), (2, 7)],
 [(0, 32), (1, 10), (2, 1)],
 [(0, 14), (1, 13), (2, 4)]]