In [None]:
import os
from pathlib import Path
from typing import List, Dict

import pandas as pd
import simplejson as json
import numpy as np
from itertools import chain
from loguru import logger

import polars as pl
import torch

from justatom.modeling.mask import ILanguageModel

from justatom.configuring.prime import Config
from justatom.running.cluster import IBTRunner, IHFWrapperBackend
from justatom.modeling.prime import DocEmbedder
from justatom.clustering.prime import IUMAPDimReducer
from justatom.viewing.prime import PlotlyScatterChart

import altair as alt

In [None]:
def source_from_dataset(dataset_name_or_path, **props):
    from justatom.storing.dataset import API as DatasetApi
    import polars as pl

    maybe_df_or_iter = DatasetApi.named(dataset_name_or_path).iterator(**props)
    if isinstance(maybe_df_or_iter, pl.DataFrame):
        pl_data = maybe_df_or_iter
    else:
        dataset = list(maybe_df_or_iter)
        pl_data = pl.from_dicts(dataset)
    return pl_data

In [None]:
pl_docs = source_from_dataset(Path(os.getcwd()) / ".data" / "polaroids.ai.data.all.in.one.json")

In [None]:
sub_sections = [
    "Гладиатор",
    "451 градус по Фаренгейту",
    "Гарри Поттер и Узник Азкабана",
    "Гарри Поттер и философский камень",
    "Цветы для Элджернона",
    "Гарри Поттер и Дары Смерти",
    "Ведьмак",
    "Сойка-пересмешница",
    "Голодные игры",
    "Голодные игры: И вспыхнет пламя"
]

In [None]:
pl_sub_docs = pl_docs.filter(pl.col("title").is_in(sub_sections))

In [None]:
logger.info(f"There are S=[{pl_sub_docs.shape[0]}] / [{pl_docs.shape[0]}] subset of documents selected for clustering")

In [None]:
content_col = "content"
title_col = "title"

In [None]:
js_titles = pl_sub_docs.select(title_col).unique().to_series().to_list()
js_sub_docs = pl_sub_docs.to_dicts()

In [None]:
js_docs = [di[content_col] for di in js_sub_docs]
js_labels = [di[title_col] for di in js_sub_docs]

In [None]:
def maybe_cuda_or_mps():
    if torch.cuda.is_available():
        return "cuda:0"
    elif torch.has_mps:
        return "mps"
    else:
        return "cpu"

In [None]:
device = maybe_cuda_or_mps()
logger.info(f"Using device {device}")

In [None]:
model_name_or_path = "intfloat/multilingual-e5-base" 

In [None]:
from justatom.processing.mask import IProcessor
from justatom.processing.prime import INFERProcessor, TripletProcessor
from justatom.processing import ITokenizer

In [None]:
tokenizer = ITokenizer.from_pretrained(model_name_or_path)
processor = INFERProcessor(tokenizer=tokenizer, max_seq_len=512, prefix="query:")

In [None]:
lm_model = ILanguageModel.load(model_name_or_path)

In [None]:
embedder = DocEmbedder(model=lm_model, processor=processor, device=device)
backend_wrapper = IHFWrapperBackend(embedder, **Config.clustering.transformers_backend.toDict())

In [None]:
bt_runner = IBTRunner(**Config.clustering.bertopic, model=backend_wrapper, verbose=True)

In [None]:
embeddings = list(chain.from_iterable(embedder.encode(js_docs, verbose=True, batch_size=4)))
topics, probs = bt_runner.fit_transform(docs=js_docs)

In [None]:
reducer = IUMAPDimReducer(**Config.clustering.umap.toDict())
points = reducer.fit_transform(embeddings)

In [None]:
def prepare2d(docs, topics, labels, reduced_embeddings):
    assert reduced_embeddings.shape[1] == 2, f"Embeddings shape mismatch Exptected 2D, got {embeddings.shape[1]}D"
    COLS_MAPPING=dict(
        column_0="text",
        column_1="topic",
        column_2="label",
        column_3="x",
        column_4="y"
    )
    pl_view = pl.from_dicts(zip(docs, topics, labels, reduced_embeddings[:, 0], reduced_embeddings[:, 1]))
    pl_view = pl_view.rename(COLS_MAPPING)
    return pl_view

In [None]:
pl_view = prepare2d(docs=js_docs, topics=js_labels, labels=js_labels, reduced_embeddings=points)

In [None]:
from justatom.viewing.prime import PlotlyScatterChart

In [None]:
chart = PlotlyScatterChart().view(pl_view, label_to_view="Вселенная")

In [None]:
chart.show()

In [None]:
chart.write_image(f"clustering_model=[e5]_dataset=[universe].png", engine='kaleido', scale=2)

In [None]:
pl_sub_docs.shape