# A notebook to parse URLs from messages in the database


In [None]:
from sqlalchemy import select, create_engine, Table, Column, Integer, ForeignKey, func, cast

from telegram_data_models import Message, Chat, MessageTextContent, Entity
from telegram_quality_control.chat_language import ChatLanguage

from itertools import batched

import multiprocessing as mp
import dask.dataframe as dd
from dask.distributed import LocalCluster

import os
import shutil
import numpy as np
import pandas as pd
from pathlib import Path
from functools import partial
from tqdm.notebook import tqdm

from telegram_quality_control.urls import load_rating_resources, batch_rate_urls
from telegram_quality_control.db import get_conn_string

from dotenv import load_dotenv

load_dotenv(".env")

In [None]:
lang = 'en'
lang_score = 0.8

test_run = False

num_workers = 1  # number of parallel workers to use
if test_run:
    chunk_size = 100  # number of entities to process in one chunk
else:
    chunk_size = 10_000

scratch_folder = Path(os.environ.get("SCRATCH_FOLDER")) / "urls"
print(f"Scratch folder: {scratch_folder}")
scratch_folder.mkdir(parents=True, exist_ok=True)

data_folder = Path(os.environ.get("OUTPUT_FOLDER"))
print(f"Data folder: {data_folder}")
data_folder.mkdir(parents=True, exist_ok=True)

In [None]:
db_url = get_conn_string()

chat_table = Chat.__table__
message_table = Message.__table__
content_table = MessageTextContent.__table__
language_table = ChatLanguage.__table__
entity_table = Entity.__table__

In [None]:
# pre-load entity ids
entity_id_path = scratch_folder / "entity_ids.csv"

if entity_id_path.exists():
    print("Loading entity ids from cache")

    if test_run:
        entity_ids = pd.read_csv(entity_id_path, usecols=["id"], dtype=int, nrows=chunk_size * 100)
    else:
        entity_ids = pd.read_csv(entity_id_path, usecols=["id"], dtype=int)

else:
    print("Collecting the entity ids")
    sql = (
        select(
            entity_table.c.id,
            language_table.c.chat_id,
        )
        .join(message_table, message_table.c.id == entity_table.c.message_id)
        .join(language_table, language_table.c.chat_id == message_table.c.chat_id)
        .where(language_table.c.lang == lang)
        .where(language_table.c.score > 0.8)
    )

    entity_ids = pd.read_sql_query(sql, db_url)

    entity_ids.to_csv(scratch_folder / "entity_ids.csv")

entity_ids = entity_ids["id"].tolist()
entity_ids.sort()
print(entity_ids[:20])

In [None]:
# load already parsed entities
if (scratch_folder / "finished_entity_ids.parquet").exists():
    finished_entity_ids = pd.read_parquet(scratch_folder / "finished_entity_ids.parquet")[
        "entity_id"
    ].tolist()

    print(f"Already finished: {len(finished_entity_ids)} entities")

    entity_ids = list(set(entity_ids) - set(finished_entity_ids))

    print(f"Remaining: {len(entity_ids)} entities")

In [None]:
entity_id_chunks = list(batched(entity_ids, n=chunk_size))

num_chunks = len(entity_id_chunks)
print(f"Number of chunks: {num_chunks}")

In [None]:
def process_chunk(url_df, original_rating, updated_rating):
    """
    Parse the URLs from all messages from the chats in `chat_ids`.
    """

    if len(url_df) == 0:
        return None

    result_original = batch_rate_urls(url_df, *original_rating, col="url")

    result_updated = batch_rate_urls(url_df, *updated_rating, col="url")

    matched_url_df = result_original.merge(
        result_updated[["reliability"]],
        left_index=True,
        right_index=True,
        how="left",
        suffixes=("_original", "_updated"),
    )

    return matched_url_df


def download_chunk(entity_ids, db_conn_string):
    sql = (
        select(
            Message.id.label("message_id"),
            Entity.id,
            Entity.type,
            Entity.url,
            func.substring(MessageTextContent.text, Entity.offset + 1, Entity.length).label(
                'entity_text'
            ),
        )
        .join(Entity, Entity.message_id == Message.id)
        .join(MessageTextContent, MessageTextContent.message_id == Entity.message_id)
        .where(Entity.type.in_(["MessageEntityType.URL", "MessageEntityType.TEXT_LINK"]))
        .where(Entity.id.in_(entity_ids))
    )

    engine = create_engine(db_conn_string)
    with engine.connect() as conn:
        url_df = pd.read_sql_query(sql, conn, index_col="id")

    url_df["type"] = url_df["type"].replace(
        {"MessageEntityType.URL": "url", "MessageEntityType.TEXT_LINK": "text_link"}
    )

    # The URL is stored in the 'url' column for TEXT_LINK and in the 'entity_text' column for URL
    # This merges both columns.
    url_df["url"] = url_df["entity_text"].where(url_df["type"] == "url", url_df["url"])
    url_df = url_df.drop(columns=["entity_text", "type"])

    return url_df


def append_to_parquet(df, path):
    """
    Save the dataframe to a parquet file. If the file already exists, append the data.
    """
    assert path.suffix == ".parquet", "Path must be a parquet file"
    if path.exists():
        df.to_parquet(path, engine="fastparquet", append=True)
    else:
        path.parent.mkdir(parents=True, exist_ok=True)
        df.to_parquet(path, engine="fastparquet")

In [None]:
from datetime import datetime

original_rating = load_rating_resources(version="original")
updated_rating = load_rating_resources(version="updated")

cache_files = {
    "urls": scratch_folder / "urls.parquet",
    "matched_urls": scratch_folder / "matched_urls.parquet",
    "entity_ids": scratch_folder / "finished_entity_ids.parquet",
}

# remove existing files
# for file in cache_files.values():
#     file.unlink(missing_ok=True)

for i, entity_ids in enumerate(entity_id_chunks):
    print(f"[{datetime.now().strftime("%H:%M:%S")}] chunk {i}/{num_chunks}")
    url_df = download_chunk(entity_ids, db_url)
    matched_urls = process_chunk(url_df, original_rating, updated_rating)

    if len(url_df) > 0:
        append_to_parquet(url_df, cache_files["urls"])
        append_to_parquet(matched_urls, cache_files["matched_urls"])

    # save the entity_ids
    append_to_parquet(pd.DataFrame(entity_ids, columns=["entity_id"]), cache_files["entity_ids"])


for file in cache_files.values():
    shutil.move(file, data_folder / file.name)