In [None]:
from sqlalchemy import select, create_engine, func
import pandas as pd
from pathlib import Path
import os

from datetime import datetime

import pickle

from datasketch import HyperLogLogPlusPlus

from telegram_data_models import MessageTextContent, Message
from telegram_quality_control.db import get_conn_string

from dotenv import load_dotenv

load_dotenv()

In [None]:
test_run = False

if test_run:
    chunk_size = 10_000
else:
    chunk_size = 100_000

scratch_folder = Path(os.environ.get("SCRATCH_FOLDER")) / "duplicate_analysis"
scratch_folder.mkdir(parents=True, exist_ok=True)
print(f"Scratch folder: {scratch_folder}")

data_folder = Path(os.environ.get("OUTPUT_FOLDER"))
print(f"Data folder: {data_folder}")

In [None]:
# Database setup
db_url = get_conn_string(".env")

content_table = MessageTextContent.__table__
message_table = Message.__table__

In [None]:
def process_chunk(docs, hyperlog, length_distribution):
    """
    Count the estimate of the number of unique docs and the length distribution.

    Update the values of top_counter and hl
    l.
    """
    for line in docs:
        if not line or len(line) == 0:
            continue

        # Update unique estimate
        hyperlog.update(line.encode())

        # Update length distribution
        doc_length = len(line)
        if doc_length not in length_distribution.keys():
            length_distribution[doc_length] = 0
        else:
            length_distribution[doc_length] += 1

    return hyperlog


def download_chunk(db_conn_string, num_messages):
    engine = create_engine(db_conn_string)

    last_message_id = 0

    while True:
        with engine.connect() as conn:
            sql = (
                select(
                    content_table.c.message_id,
                    func.coalesce(content_table.c.text, content_table.c.caption).label('content'),
                )
                .where(content_table.c.message_id > last_message_id)
                .order_by(content_table.c.message_id)
                .limit(num_messages)
            )

            result = pd.read_sql_query(sql, conn, index_col="message_id")

            sql = (
                select(message_table.c.id)
                .where(message_table.c.id.in_(result.index.tolist()))
                .where(message_table.c.forward_date.is_(None))
            )

            forward_dates = pd.read_sql_query(sql, conn, index_col="id")
            # keep only non-forwarded messages
            result = result[result.index.isin(forward_dates.index)]

        if result.empty:
            break

        last_message_id = int(result.index.max())

        yield result

In [None]:
hyperlog = HyperLogLogPlusPlus()
length_distribution = {}

message_stats = pd.read_sql_query("SELECT MAX(message_id) AS max FROM message_content", db_url)

max_message_id = int(message_stats["max"].iloc[0])
total_rows = 5505989600

print(f"Max message id: {max_message_id}")
print(f"Total rows: {total_rows}")

num_rows = 0
average_download_time = []
average_process_time = []

generator = download_chunk(db_url, chunk_size)

tic = datetime.now()

for i, content_df in enumerate(generator):
    tac = datetime.now()
    num_rows += len(content_df)
    process_chunk(content_df["content"], hyperlog, length_distribution)
    toc = datetime.now()

    download_time = (tac - tic).total_seconds()
    process_time = (toc - tac).total_seconds()

    average_download_time.append(download_time)
    average_process_time.append(process_time)

    if test_run and i > 1000:
        break
    if i % 10 == 0:
        current_id = content_df.index.max()
        frac_ids = current_id / max_message_id * 100
        frac_messages = num_rows / total_rows * 100
        average_download_time = sum(average_download_time) / len(average_download_time)
        average_process_time = sum(average_process_time) / len(average_process_time)
        print(
            f"[{datetime.now().strftime('%H:%M:%S')}] chunk {i}, \t{frac_ids:.4f}% of ids, \t{frac_messages:.4f}% of messages, \t{average_download_time:.2f} s. download, \t{average_process_time:.2f} s. process"
        )
        average_download_time = []
        average_process_time = []

    if i % 1000 == 0 and i > 0:
        # checkpoint results
        current_id = content_df.index.max()
        print(f"Checkpointing at message id {current_id}")
        with open(scratch_folder / f"hyperlog_{current_id}.pkl", 'wb') as f:
            pickle.dump(hyperlog, f)

        with open(scratch_folder / f"length_distribution_{current_id}.pkl", 'wb') as f:
            pickle.dump(length_distribution, f)

    tic = datetime.now()

unique_estimate = hyperlog.count()
fraction = unique_estimate / num_rows

print(f"Unique fraction: {fraction:.2f}")

with open(data_folder / "unique_estimate.txt", 'w') as f:
    f.write(f"unique_estimate: {unique_estimate}\n")
    f.write(f"total messages: {num_rows}\n")
    f.write(f"fraction: {fraction}")

length_dist_df = pd.DataFrame.from_dict(length_distribution, orient='index', columns=['count'])
length_dist_df.index.name = 'length'
length_dist_df = length_dist_df.sort_index()
length_dist_df.to_csv(data_folder / "message_length.csv")