In [None]:
import os
from pathlib import Path
from typing import List, Dict, Optional, Union

import polars as pl
import simplejson as json
import numpy as np

from plotly.subplots import make_subplots
import plotly.graph_objects as go

from justatom.configuring.prime import Config
from justatom.running.cluster import IBTRunner, IHFWrapperBackend
from justatom.modeling.prime import HFDocEmbedder
from justatom.clustering.prime import IUMAPDimReducer
from justatom.viewing.prime import PlotlyGroupedBarChart

# from gir.storing.dataset import INTIDataset

from loguru import logger

In [None]:
def ignite_dataset(where) -> List[Dict]:
    docs = None
    with open(str(Path(where)), encoding="utf-8") as fin:
        docs = json.load(fin)
    return docs

In [None]:
docs = ignite_dataset(where=Path(os.getcwd()) / ".data" / "polaroids.ai.data.json")

In [None]:
docs_df = pl.from_dicts(docs)

In [None]:
docs_df.head()

In [None]:
docs_df = docs_df.filter(pl.col("query").is_not_null())

In [None]:
docs_df.shape[0]

In [None]:
docs_df = docs_df.with_columns([
    pl.lit("x").alias("group")
])

In [None]:
sample_col_a, sample_col_b = "query", "content"
group_col_a, group_col_b = "group", "title"

In [None]:
def maybe_cuda_or_mps():
    if torch.cuda.is_available():
        return "cuda:0"
    elif torch.has_mps:
        return "mps"
    else:
        return "cpu"

In [None]:
embedder = HFDocEmbedder(**Config.clustering.embedder.toDict(), device=maybe_cuda_or_mps())
transformers_backend = IHFWrapperBackend(embedder, **Config.clustering.transformers_backend.toDict())

In [None]:
topic_model = IBTRunner(**Config.clustering.bertopic, model=transformers_backend, verbose=True)

In [None]:
documents_1 = docs_df.select(sample_col_a).to_series().to_list()
documents_2 = docs_df.select(sample_col_b).to_series().to_list()

In [None]:
def dot_score_metric(x: np.ndarray, y: np.ndarray):
    return x @ y.T

embeddings_1 = list(embedder.encode(documents_1, verbose=True, batch_size=50))
embeddings_1 = np.vstack(embeddings_1)

embeddings_2 = list(embedder.encode(documents_2, verbose=True, batch_size=50))
embeddings_2 = np.vstack(embeddings_2)

distances = [dot_score_metric(x, y) for x, y in zip(embeddings_1, embeddings_2)]

In [None]:
documents_1[0]

In [None]:
documents_2[0]

In [None]:
print(sample_col_a)
print(sample_col_b)
print(" --- ")
print(group_col_a)
print(group_col_b)

In [None]:
print(docs_df.select(pl.col(group_col_a).unique()).shape[0])
print(docs_df.select(pl.col(group_col_b).unique()).shape[0])

In [None]:
def prepare_view(df, group_col_a:str, group_col_b:str):
    df_view = pl.DataFrame(schema={group_col_a: str, group_col_b:str}, ).join(df.select(group_col_a, group_col_b), on=[group_col_a, group_col_b], how="outer")
    df_view = df_view.with_columns([
        pl.col(group_col_a).str.slice(0, 50).alias(f"short_{group_col_a}"),
        pl.col(group_col_b).str.slice(0, 50).alias(f"s hort_{group_col_b}")
    ])
    return df_view

In [None]:
df_view = prepare_view(df=docs_df, group_col_a=group_col_a, group_col_b=group_col_b)

In [None]:
def counts_per_col(df, col):
    df_cut = df.with_row_count().with_columns([
        pl.count("row_nr").over(col).alias(f"counts_per_{col}"),
        pl.first("row_nr").over(col).alias("mask")
    ]).filter(pl.col("mask") == pl.col("row_nr"))
    return df_cut

In [None]:
df_view.head()

In [None]:
df_cut_per_group_a = counts_per_col(df=docs_df, col=group_col_a)
df_cut_per_group_b = counts_per_col(df=docs_df, col=group_col_b)

In [None]:
print(df_cut_per_group_a.shape[0])
df_cut_per_group_a.head()

In [None]:
print(df_cut_per_group_b.shape[0])
df_cut_per_group_b.head()

In [None]:
df_view = df_view.with_columns([
    pl.Series(distances).alias("distance")
])

In [None]:
len(distances)

In [None]:
chart = PlotlyGroupedBarChart(group_col_a=group_col_a, group_col_b=group_col_b, distance_col="distance", dist_threshold=0.80)
fig = chart.view(df_view)
fig.show()

In [None]:
chart.save("comparison.png")