In [None]:
from sqlalchemy import select, create_engine, func, tablesample
from sqlalchemy.orm import aliased

import os
import shutil
import numpy as np
import pandas as pd
from pathlib import Path
import random
from datetime import datetime
import dask.dataframe as dd
from dask.distributed import LocalCluster

from telegram_data_models import Message, Chat, MessageTextContent, Entity
from telegram_quality_control.chat_language import ChatLanguage
from telegram_quality_control.db import get_conn_string
from telegram_quality_control.cleaning import batch_clean_text
from telegram_quality_control.topics import Embeddings, Topics

from dotenv import load_dotenv

load_dotenv(".env")

In [None]:
# params

test_run = False

language = "english"
language_code = "en"
lang_score = 0.8

embedding_model = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"

if test_run:
    num_messages = 100
    min_cluster_size = 3
    folder_tag = f"{language_code}_test"
else:
    num_messages = 1_000_000
    if language_code == "en":
        # For English, min_cluster_size has to be a bit larger, because otherwise we run into this bug:
        # https://github.com/rapidsai/cuml/issues/3568#issuecomment-788316039
        min_cluster_size = 25
    else:
        min_cluster_size = 20
    folder_tag = f"{language_code}_{num_messages}_messages_full"

min_message_length = 50

min_samples = min_cluster_size

scratch_folder = Path(os.environ.get("SCRATCH_FOLDER"))
print(f"Scratch folder: {scratch_folder}")
scratch_folder.mkdir(parents=True, exist_ok=True)

data_folder = Path(os.environ.get("OUTPUT_FOLDER"))
print(f"Data folder: {data_folder}")
data_folder.mkdir(parents=True, exist_ok=True)

message_folder = scratch_folder / f"clean_messages" / folder_tag
message_folder.mkdir(parents=True, exist_ok=True)

## Get data


In [None]:
db_url = get_conn_string(".env")

message_table = Message.__table__
language_table = ChatLanguage.__table__
content_table = MessageTextContent.__table__

In [None]:
# get message ids first
if language_code == "ru":
    frac = 0.5  # 0.5%
elif language_code == "fa" or language_code == "ar":
    frac = 1  # 1%
elif language_code == "en":
    frac = 2
else:
    frac = 4
message_sample = aliased(Message, tablesample(Message, func.bernoulli(frac)))
# message_sample = aliased(Message, tablesample(Message, func.system(1)))

sql = (
    select(message_sample.id)
    .join(ChatLanguage, ChatLanguage.chat_id == message_sample.chat_id)
    .join(MessageTextContent, MessageTextContent.message_id == message_sample.id)
    # filter for language
    .where(ChatLanguage.lang == language_code)
    .where(ChatLanguage.score > 0.8)
    # filter out short documents
    .where(
        (func.length(MessageTextContent.text) > min_message_length)
        | (func.length(MessageTextContent.caption) > min_message_length)
    )
)

if test_run:
    # do not really randomly shuffle
    sql = sql.limit(num_messages * 10)

message_ids = pd.read_sql(sql, db_url, index_col="id")

print(f"Got {len(message_ids)} messages")

# sample twice as many messages, so that there will be enough left after cleaning
if len(message_ids) > num_messages * 2:
    message_ids = message_ids.sample(n=num_messages * 2)

message_ids.to_csv(message_folder / "message_ids.csv")

In [None]:
message_ids = pd.read_csv(message_folder / "message_ids.csv")
message_ids = message_ids["id"].to_list()
message_ids.sort()

sql = select(
    MessageTextContent.message_id,
    func.coalesce(MessageTextContent.text, MessageTextContent.caption).label('content'),
).where(MessageTextContent.message_id.in_(message_ids))

engine = create_engine(db_url)
with engine.connect() as conn:
    messages = pd.read_sql(sql, conn)

print(f"[{datetime.now().strftime("%H:%M:%S")}] Extracted {len(messages)} messages")

# filter out system messages
system_messages = [
    "This message couldn't be displayed on your device due to copyright infringement.",
    "This channel can’t be displayed because it violated Telegram's Terms of Service.",
    "This channel can’t be displayed because it violated local laws.",
]

messages = messages[~messages[["content"]].isin(system_messages).any(axis=1)]

print(
    f"[{datetime.now().strftime("%H:%M:%S")}] Filtered out system messages, left with {len(messages)} messages"
)

messages

In [None]:
messages["clean_text"] = batch_clean_text(messages["content"])

messages = messages[messages["clean_text"].str.len() >= min_message_length]

print(
    f"[{datetime.now().strftime("%H:%M:%S")}] Cleaned messages and removed too short ones, left with {len(messages)} messages"
)

if len(messages) > num_messages:
    messages = messages.sample(n=num_messages)
    print(f"[{datetime.now().strftime("%H:%M:%S")}] Sampled to {len(messages)} messages")
else:
    print(
        f"[{datetime.now().strftime("%H:%M:%S")}] Number of messages already below {num_messages}"
    )

# save messages
messages.to_parquet(message_folder / "messages.parquet")

In [None]:
messages = pd.read_parquet(message_folder / "messages.parquet")

docs = messages["clean_text"].tolist()

print(f"[{datetime.now().strftime("%H:%M:%S")}] Loaded {len(docs)} documents")

embedding_cache = Embeddings(folder_tag, embedding_model=embedding_model)
embeddings = embedding_cache.create(docs)
embedding_cache.save()

print(f"[{datetime.now().strftime("%H:%M:%S")}] Calculated embeddings")

topic_cache = Topics(
    folder_tag, text_language=language, min_samples=min_samples, min_cluster_size=min_cluster_size
)
topic_model, topics, probs = topic_cache.create(docs, embeddings, embedding_cache.embedding_model)
topic_cache.save()

print(f"[{datetime.now().strftime("%H:%M:%S")}] Finished classifying topics")

In [None]:
n_topics = len(topic_model.get_topic_info())
print(f"Number of discovered topics: {n_topics}")

# Print top-20 topics with their counts
topic_info = topic_model.get_topic_info()
top_20 = topic_info.head(20)

not_classified_count = topic_info['Count'].values[0]
fraction_unclassified = not_classified_count / num_messages * 100
print(f"Fraction of unclassified documents: {fraction_unclassified:.2f}%")

print("\nTop-20 Topics:")
for _, row in top_20.iterrows():
    print(f"Topic {row['Name']}: {row['Count']} documents")