In [1]:
# -*- coding: utf-8 -*-
"""
Integrated script for MALLET-based topic modeling with per-file Fisher’s Exact
filtering and stopword exclusion. Supports UTF-8 encoding for French accents.
Outputs:
  - Topic keys
  - Per-document topic distributions
  - Heatmap visualization (PDF + JPG)
  - Excel workbook with:
      • Topic tokens
      • Token counts per document
      • Topic probabilities per document
  - Separate Excel workbook with top-document titles per topic
"""

import os
import csv
import time
from pathlib import Path
from collections import Counter

import pandas as pd
from scipy.stats import fisher_exact
import little_mallet_wrapper
from openpyxl import Workbook
from openpyxl.utils.dataframe import dataframe_to_rows

# Path to your MALLET binary (adjust if needed)
path_to_mallet = os.path.expanduser("~/mallet-2.0.8/bin/mallet")


def tokenize_file(filepath):
    """Load a UTF-8 text file and split on whitespace (preserves accents)."""
    with open(filepath, "r", encoding="utf-8") as f:
        return f.read().split()


def list_txt_files(directory):
    """Return a sorted list of all .txt filenames in a directory."""
    return sorted(fn for fn in os.listdir(directory) if fn.endswith(".txt"))


def list_csv_files(directory):
    """
    Recursively find all .csv files under a directory, skipping
    any .ipynb_checkpoints folders.
    """
    csv_paths = []
    for root, dirs, files in os.walk(directory):
        if ".ipynb_checkpoints" in dirs:
            dirs.remove(".ipynb_checkpoints")
        for fn in files:
            if fn.lower().endswith(".csv"):
                csv_paths.append(os.path.join(root, fn))
    return csv_paths


def choose_directory(prompt):
    """
    Display a numbered list of immediate subdirectories (plus current dir)
    and return the full path selected by the user.
    """
    cwd = os.getcwd()
    base = Path(cwd).name
    subdirs = [
        d for d in os.listdir(cwd)
        if os.path.isdir(os.path.join(cwd, d)) and d != ".ipynb_checkpoints"
    ]
    options = [(os.path.join(cwd, d), f"{base}/{d}") for d in subdirs]
    options.append((cwd, base))

    print(prompt)
    for i, (_, label) in enumerate(options, start=1):
        print(f"{i}. {label}")

    while True:
        try:
            choice = int(input("Enter number: ").strip())
            if 1 <= choice <= len(options):
                return options[choice - 1][0]
        except ValueError:
            pass
        print("Invalid choice, please try again.")


def choose_files(filenames):
    """
    Let user pick one or more filenames by:
      - indices ("2" or "1,3")
      - index ranges ("1-4")
      - prefix matching ("report")
      - or the keyword "all" to select every file.
    Returns a sorted, unique list of chosen filenames.
    """
    print("Available files:")
    for i, fn in enumerate(filenames, start=1):
        print(f"{i}. {fn}")
    choice = input("Select files (indices, ranges, prefix, or 'all'): ").strip().lower()

    if choice == "all":
        return filenames.copy()

    selected = []
    for part in choice.split(","):
        part = part.strip()
        if part == "all":
            return filenames.copy()
        if "-" in part:
            a, b = map(int, part.split("-"))
            selected.extend(filenames[a - 1:b])
        elif part.isdigit():
            selected.append(filenames[int(part) - 1])
        else:
            selected.extend(fn for fn in filenames if fn.startswith(part))
    return sorted(set(selected))


def choose_csv_file(csv_paths):
    """
    Prompt the user to select one CSV from a list of paths.
    Returns the chosen filepath.
    """
    cwd = os.getcwd()
    base = Path(cwd).name
    print("Select your stopwords CSV:")
    for i, full in enumerate(csv_paths, start=1):
        rel = os.path.relpath(full, cwd)
        print(f"{i}. {base}/{rel}")

    while True:
        try:
            choice = int(input("Enter number: ").strip())
            if 1 <= choice <= len(csv_paths):
                return csv_paths[choice - 1]
        except ValueError:
            pass
        print("Invalid choice, please try again.")


def read_stopwords(filepath):
    """Load stopwords from a CSV, splitting on commas and trimming whitespace."""
    sw = []
    with open(filepath, "r", encoding="utf-8") as f:
        for row in csv.reader(f):
            for cell in row:
                sw.extend(cell.split(","))
    return [w.strip() for w in sw if w.strip()]


def get_fishers(word, freq_dict, rate_dict, alternative="greater"):
    """
    Perform Fisher’s Exact Test on one token:
        [[observed, total-observed],
         [expected, total-expected]]
    Returns the p-value.
    """
    observed = freq_dict.get(word, 0)
    total = sum(freq_dict.values())
    expected = round(rate_dict.get(word, 0) * total)
    table = [
        [observed, total - observed],
        [expected, total - expected]
    ]
    _, pval = fisher_exact(table, alternative=alternative)
    return pval


def calculate_rate_dictionary(rate_files, rate_dir):
    """
    Build a background rate dictionary from reference documents.
    Returns a mapping { token: relative_frequency }.
    """
    counter = Counter()
    total_tokens = 0
    for fn in rate_files:
        tokens = tokenize_file(os.path.join(rate_dir, fn))
        counter.update(tokens)
        total_tokens += len(tokens)
    return {tok: cnt / total_tokens for tok, cnt in counter.items()}


def prepare_training_data(files, directory, stopwords, rate_dict, alpha):
    """
    For each file:
      1. Tokenize and count every token.
      2. Exclude stopwords and tokens with Fisher p-value ≥ alpha.
      3. Collect filtered document text and raw token counts.
    Prints per-file progress with elapsed time.
    Returns:
      - docs: list of filtered document strings for MALLET input
      - distributions: list of Counter objects of raw token counts
    """
    docs = []
    distributions = []
    total = len(files)
    overall_start = time.time()

    for idx, fn in enumerate(files, start=1):
        file_start = time.time()
        print(f"[{idx}/{total}] ⏳ Processing '{fn}'... ", end="", flush=True)

        path = os.path.join(directory, fn)
        tokens = tokenize_file(path)
        freq = Counter(tokens)

        filtered = [
            w for w in tokens
            if w not in stopwords and get_fishers(w, freq, rate_dict) < alpha
        ]
        docs.append(" ".join(filtered))
        distributions.append(freq)

        elapsed = time.time() - file_start
        kept = len(filtered)
        before = len(tokens)
        pct = (kept / before * 100) if before else 0
        print(f"done in {elapsed:.1f}s – kept {kept}/{before} tokens ({pct:.1f}%).")

    total_elapsed = time.time() - overall_start
    print(f"[Done] Prepared {total} documents in {total_elapsed:.1f}s.\n")
    return docs, distributions


def train_topic_model(training_docs, num_topics, output_dir, num_top_words):
    """
    Train a MALLET model via little_mallet_wrapper, specifying how many
    top words MALLET should output per topic.
    Returns:
      - topics: list of token lists (topic keys)
      - doc_topics: list of probability lists (per-document topic distributions)
    """
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    little_mallet_wrapper.quick_train_topic_model(
        path_to_mallet,
        output_dir,
        num_topics,
        training_docs,
        num_top_words=num_top_words
    )
    key_file = f"{output_dir}/mallet.topic_keys.{num_topics}"
    topics = little_mallet_wrapper.load_topic_keys(key_file)

    dist_file = f"{output_dir}/mallet.doc_topics.{num_topics}"
    doc_topics = little_mallet_wrapper.load_document_topics(dist_file)
    return topics, doc_topics


def save_results_to_excel(excel_path, topics, token_distributions, doc_topics, files):
    """
    Create an Excel workbook with:
      1) 'Topics' sheet: one row per topic (Topic#, top tokens...)
      2) One sheet per document: raw token counts
      3) 'DocTopicDist' sheet: topic probabilities per document
    """
    wb = Workbook()
    # -- Topics sheet --
    ws0 = wb.active
    ws0.title = "Topics"
    for idx, topic in enumerate(topics):
        ws0.append([f"Topic {idx}"] + topic)

    # -- Token counts per document --
    for fn, dist in zip(files, token_distributions):
        sheet = wb.create_sheet(title=Path(fn).stem)
        df = pd.DataFrame.from_dict(dist, orient="index", columns=["count"])
        for row in dataframe_to_rows(df, index=True, header=True):
            sheet.append(row)

    # -- Document-topic probabilities --
    ws3 = wb.create_sheet(title="DocTopicDist")
    header = ["Document"] + [f"Topic{t}" for t in range(len(topics))]
    ws3.append(header)
    for fn, probs in zip(files, doc_topics):
        ws3.append([Path(fn).stem] + [round(p, 4) for p in probs])

    wb.save(excel_path)


def save_top_titles_excel(xlsx_path, topics, training_docs, doc_topics, doc_titles, n_docs):
    """
    Save the top n_docs document titles per topic into an Excel file,
    one sheet per topic named 'Topic{#}'.
    """
    wb = Workbook()
    for t_idx in range(len(topics)):
        ws = wb.create_sheet(title=f"Topic{t_idx}")
        # header
        ws.append(["Probability", "Document Title"])
        # fetch top docs
        for prob, doc in little_mallet_wrapper.get_top_docs(
            training_docs, doc_topics, t_idx, n=n_docs
        ):
            title = doc_titles.get(doc, Path(doc).stem)
            ws.append([round(prob, 4), title])
    wb.save(xlsx_path)


def export_heatmap(files, doc_topics, topics, output_dir):
    """
    Generate and save a topic-by-document heatmap:
      - PDF at categories_by_topics.pdf
      - JPG at heatmap.jpg
    Row labels come from each filename (without extension).
    """
    labels = [Path(fn).stem for fn in files]
    pdf_out = os.path.join(output_dir, "categories_by_topics.pdf")

    fig = little_mallet_wrapper.plot_categories_by_topics_heatmap(
        labels,
        doc_topics,
        topics,
        pdf_out,
        target_labels=labels,
        dim=(13, 9)
    )
    fig.savefig(os.path.join(output_dir, "heatmap.jpg"),
                format="jpg",
                dpi=300)


def main():
    # 1) Choose and load stopwords CSV
    stop_csv = choose_csv_file(list_csv_files(os.getcwd()))
    stopwords = read_stopwords(stop_csv)

    # 2) Build background rate dictionary
    rate_dir = choose_directory("Select reference text directory:")
    rate_files = choose_files(list_txt_files(rate_dir))
    rate_dict = calculate_rate_dictionary(rate_files, rate_dir)

    # 3) Select target files
    target_dir = choose_directory("Select target text directory:")
    target_files = choose_files(list_txt_files(target_dir))

    # 4) Fisher’s Exact threshold
    alpha = float(input("Enter Fisher’s Exact alpha threshold (e.g. 0.05): ").strip())

    # 5) Prepare filtered training data
    training_docs, token_distributions = prepare_training_data(
        target_files, target_dir, stopwords, rate_dict, alpha
    )

    # 6) Number of topics & top-words
    num_topics = int(input("Enter number of topics to generate: ").strip())
    num_top_words = int(input(
        "Enter how many top tokens MALLET should output per topic: ").strip()
    )

    # 7) How many top documents per topic
    n_top_docs = int(input("Enter number of top documents per topic: ").strip())

    # 8) Output folder
    out_sub = input("Enter name for output folder: ").strip()
    output_dir = os.path.join(os.getcwd(), out_sub)

    # 9) Train model and load topic keys + doc-topic distributions
    topics, doc_topics = train_topic_model(
        training_docs, num_topics, output_dir, num_top_words
    )

    # 10) Save combined results to Excel
    excel_results = os.path.join(output_dir, "topic_model_results.xlsx")
    save_results_to_excel(
        excel_results, topics, token_distributions, doc_topics, target_files
    )

    # 11) Generate and save heatmap
    export_heatmap(target_files, doc_topics, topics, output_dir)

    # 12) Save top titles per topic to separate workbook
    doc_titles = {doc: Path(fn).stem for doc, fn in zip(training_docs, target_files)}
    top_titles_path = os.path.join(output_dir, "top_titles.xlsx")
    save_top_titles_excel(
        top_titles_path, topics, training_docs, doc_topics, doc_titles, n_top_docs
    )

    print("✅ Topic modeling pipeline completed successfully!")


if __name__ == "__main__":
    main()

Select your stopwords CSV:
1. lemmatized/stop_words.csv


Enter number:  1


Select reference text directory:
1. lemmatized/10Topics
2. lemmatized


Enter number:  2


Available files:
1. Bodin_stemmed.txt
2. L'Hospital_stemmed.txt


Select files (indices, ranges, or prefix text):  1-2


Select target text directory:
1. lemmatized/10Topics
2. lemmatized


Enter number:  2


Available files:
1. Bodin_stemmed.txt
2. L'Hospital_stemmed.txt


Select files (indices, ranges, or prefix text):  1-2
Enter Fisher’s Exact alpha threshold (e.g. 0.05):  0.075


KeyboardInterrupt: 