In [None]:
from telegram_data_models import Message, MessageTextContent
from telegram_quality_control.chat_language import ChatLanguage
from telegram_quality_control.db import get_conn_string

from datasketch import HyperLogLogPlusPlus
from collections import Counter

from itertools import batched

import multiprocessing as mp

import numpy as np

import tqdm

from sqlalchemy import select, create_engine, func
import pandas as pd
from pathlib import Path
import os
import json
import pickle
from datetime import datetime
import csv

from dotenv import load_dotenv

load_dotenv()

In [None]:
test_run = False

num_messages = 10_000  # number of messages to download in one go

num_most_common = 1000

max_counter_size = 50_000_000
# max_counter_size = 100_000

lang = "en"

scratch_folder = Path(os.environ.get("SCRATCH_FOLDER"))
print(f"Scratch folder: {scratch_folder}")

data_folder = Path(os.environ.get("OUTPUT_FOLDER"))
print(f"Data folder: {data_folder}")

n_values = [1, 3, 10]  # n-grams to compute

In [None]:
# Database setup
db_url = get_conn_string(".env")

message_table = Message.__table__
content_table = MessageTextContent.__table__
language_table = ChatLanguage.__table__

In [None]:
# pre-load message ids
message_id_path = scratch_folder / "message_ids_english.csv"

if message_id_path.exists():
    print("Loading message ids from cache")

    if test_run:
        message_ids = pd.read_csv(
            message_id_path, usecols=["id"], dtype=int, nrows=num_messages * 100
        )
    else:
        message_ids = pd.read_csv(message_id_path, usecols=["id"], dtype=int)

else:
    print("Collecting the message ids")
    sql = (
        select(
            message_table.c.id,
        )
        .join(content_table, message_table.c.id == content_table.c.message_id)
        # filter for non-empty text or caption
        .where((content_table.c.text.isnot(None) | content_table.c.caption.isnot(None)))
        .join(language_table, language_table.c.chat_id == message_table.c.chat_id)
        .where(language_table.c.lang == lang)
        .where(language_table.c.score > 0.8)
    )

    message_ids = pd.read_sql_query(sql, db_url)

    message_ids.to_csv(message_id_path, index=False)

message_ids = message_ids["id"].tolist()
message_ids.sort()
print(message_ids[:20])

In [None]:
# check cached files
last_message_id = -1
cache_file = None
for file in scratch_folder.glob(f"ngram_counters_*.pkl"):
    print(f"Found cached file: {file}")
    last_id = int(file.stem.split("_")[-1])
    if last_id > last_message_id:
        last_message_id = last_id
        cache_file = file

print(f"Last processed message id: {last_message_id}")
if last_message_id != -1:
    message_ids = [mid for mid in message_ids if mid > last_message_id]
    print(f"Remaining messages to process: {len(message_ids)}")

In [None]:
message_id_chunks = list(batched(message_ids, n=num_messages))

num_chunks = len(message_id_chunks)
print(f"Number of chunks: {num_chunks}")

In [None]:
def process_chunk(content_df, ngram_counters):
    """
    Count n-grams from messages and update the counters.
    """

    messages = content_df["content"].tolist()

    for message in messages:
        words = message.strip().lower().split()
        for n in ngram_counters.keys():
            # Generate n-grams
            if len(words) < n:
                continue

            for i in range(len(words) - n + 1):
                ngram = tuple(words[i : i + n])
                ngram_counters[n][ngram] += 1

    # Prune if the counters become too large
    for n in n_values:
        if len(ngram_counters[n]) > max_counter_size:
            ngram_counters[n] = Counter(dict(ngram_counters[n].most_common(max_counter_size // 2)))

    return ngram_counters


def download_messages(db_conn_string, message_ids):
    engine = create_engine(db_conn_string)

    with engine.connect() as conn:
        sql = select(
            content_table.c.message_id,
            func.coalesce(content_table.c.text, content_table.c.caption).label('content'),
        ).where(content_table.c.message_id.in_(message_ids))

        result = pd.read_sql_query(sql, conn, index_col="message_id")
    return result

In [None]:
if cache_file is not None:
    print(f"Loading ngram counters from cache file: {cache_file}")
    with open(cache_file, 'rb') as f:
        ngram_counters = pickle.load(f)
else:
    print("Initializing new ngram counters")
    ngram_counters = {n: Counter() for n in n_values}

total_messages = len(message_ids)

num_rows = 0
average_download_time = []
average_process_time = []

for i, message_ids in enumerate(message_id_chunks):
    tic = datetime.now()
    chunk = download_messages(db_url, message_ids)
    tac = datetime.now()
    ngram_counters = process_chunk(chunk, ngram_counters)
    toe = datetime.now()
    average_download_time.append((tac - tic).total_seconds())
    average_process_time.append((toe - tac).total_seconds())

    if test_run and i > 1000:
        break
    if i % 10 == 0:
        frac = i * num_messages / total_messages * 100
        average_download_time = sum(average_download_time) / len(average_download_time)
        average_process_time = sum(average_process_time) / len(average_process_time)
        print(
            f"[{datetime.now().strftime('%H:%M:%S')}] chunk {i}, \t{frac:.4f}% messages \t{average_download_time:.2f} s. download, \t{average_process_time:.2f} s. process, \t{len(ngram_counters[n_values[-1]])} 10-grams"
        )
        average_download_time = []
        average_process_time = []

    if i % 1000 == 0 and i > 0:
        # checkpoint results
        current_id = int(max(message_ids))
        print(f"Checkpointing at message id {current_id}")
        with open(scratch_folder / f"ngram_counters_{current_id}.pkl", 'wb') as f:
            pickle.dump(ngram_counters, f)

        # delete previous checkpoints except for the second-to-last one
        second_to_last_id = -1
        for file in scratch_folder.glob(f"ngram_counters_*.pkl"):
            last_id = int(file.stem.split("_")[-1])
            if last_id < current_id and last_id > second_to_last_id:
                second_to_last_id = last_id

        if second_to_last_id != -1:
            for file in scratch_folder.glob(f"ngram_counters_*.pkl"):
                last_id = int(file.stem.split("_")[-1])
                if last_id != current_id and last_id != second_to_last_id:
                    print(f"Deleting old checkpoint file: {file}")
                    file.unlink()

most_common = {}
for n in n_values:
    # O(N log k) complexity where N = length of the counter, k = num_most_common
    most_common[n] = ngram_counters[n].most_common(num_most_common)

print("Calculated most common")

with open(data_folder / 'ngram_counters.pkl', 'wb') as f:
    pickle.dump(ngram_counters, f)

print("Pickle dump counters")

# Convert Counter objects to serializable format
for n, counter in most_common.items():
    json_counter = {}
    for ngram, count in counter:
        # Convert tuple n-grams to space-separated strings
        if isinstance(ngram, tuple):
            key = ' '.join(ngram)
        else:
            key = ngram
        json_counter[key] = count

    with open(data_folder / f'most_common_{n}-grams.json', 'w', encoding='utf-8') as f:
        json.dump(json_counter, f, ensure_ascii=False, indent=2)