In [None]:
from dotenv import load_dotenv

from pathlib import Path

import numpy as np
import pandas as pd
import os

from datetime import datetime

from tqdm import tqdm

import matplotlib.pyplot as plt
import matplotlib.transforms as transforms

from wordcloud import WordCloud

from telegram_quality_control.visualization import single_col_figure, double_col_figure
from telegram_quality_control.topics import Topics

from cmcrameri import cm

load_dotenv()

figure_style = "print"

plt.style.use('./resources/mpl_styles/default.mplstyle')

In [None]:
def load_model(language_code):

    scratch_folder = Path(os.environ["SCRATCH_FOLDER"]) / "topics"
    topic_folder = scratch_folder / f"{language_code}_1000000_messages_full"

    topic_model, topics, probs = Topics.load_from_file(topic_folder)

    topic_df = topic_model.get_topic_info()

    def parse_label(label):
        return " ".join(label.split("_")[1:])

    topic_df["pretty_label"] = topic_df["Name"].map(parse_label)

    return topic_df

In [None]:
topic_df = load_model("en")

topic_df

In [None]:
topic_frequency = topic_df[["Topic", "pretty_label", "Count"]]
topic_frequency["Fraction"] = topic_frequency["Count"] / topic_frequency["Count"].sum()
topic_frequency = topic_frequency.sort_values("Count", ascending=False).head(40)
topic_frequency.to_csv(f"data/topic_frequency_en.csv", index=False, sep="\t")
topic_frequency

In [None]:
def create_wordcloud(topic_df, width, height, seed, num_topics=500):
    topic_frequency = topic_df[topic_df["Topic"] != -1][["Count", "pretty_label"]]
    topic_frequency = topic_frequency.sort_values("Count", ascending=False).head(num_topics)
    topic_frequency = topic_frequency.set_index("pretty_label")
    topic_frequency = topic_frequency.to_dict()["Count"]

    def color_func(*args, **kwargs):
        cmap = plt.get_cmap("cmc.lipari")
        value = np.random.uniform(0, 0.8)
        color = cmap(value)
        return tuple(int(c * 255) for c in color[:3])

    wc = WordCloud(
        background_color="white",
        color_func=color_func,
        width=width,
        height=height,
        random_state=seed,
        prefer_horizontal=0.9,
        min_font_size=12,
        max_font_size=50,
        relative_scaling=0.5,
    ).generate_from_frequencies(topic_frequency)

    return wc


fig = double_col_figure(0.35)

languages = ["ru", "en"]
language_labels = {"fa": "Farsi", "ar": "Arabic", "en": "English", "ru": "Russian"}

ax = {}

for i, lang in enumerate(languages):
    ax[lang] = fig.add_subplot(1, 2, i + 1)

seeds = {"fa": 1, "ar": 0, "en": 2, "ru": 1}

bbox = ax["en"].get_window_extent().transformed(ax["en"].get_figure().dpi_scale_trans.inverted())
# width_pixel = int(np.floor(bbox.width * 300))
height_pixel = int(np.floor(bbox.height * 300))
width_pixel = int(height_pixel / 0.35 / 2)

for lang in languages:
    topic_df = load_model(lang)
    wc = create_wordcloud(topic_df, width_pixel, height_pixel, seeds[lang])
    ax[lang].imshow(wc)
    ax[lang].tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
    ax[lang].set_title(language_labels[lang], fontsize=10)
    ax[lang].set_axis_off()


fig.tight_layout()
fig.savefig(f"figures/wordcloud_{languages[0]}_{languages[1]}.png", pad_inches=0, bbox_inches=None)
fig