In [None]:
import os
from itertools import chain
from pathlib import Path

import polars as pl
import simplejson as json
import torch
from loguru import logger

from justatom.clustering.prime import IUMAPDimReducer
from justatom.configuring.prime import Config
from justatom.modeling.mask import ILanguageModel
from justatom.modeling.prime import DocEmbedder
from justatom.running.cluster import IBTRunner, IHFWrapperBackend
from justatom.viewing.prime import PlotlyScatterChart

In [None]:
def ignite_dataset(where, mask: str = None) -> list[dict]:
    docs = None
    with open(str(Path(where)), encoding="utf-8") as fin:
        docs = json.load(fin)
    if mask:
        return docs[mask]
    return docs

In [None]:
docs = ignite_dataset(where=Path(os.getcwd()) / ".data" / "polaroids.ai.data.json")

In [None]:
documents = [di["content"] for di in docs]
labels = [di["title"] for di in docs]

In [None]:
def maybe_cuda_or_mps():
    if torch.cuda.is_available():
        return "cuda:0"
    elif torch.has_mps:
        return "mps"
    else:
        return "cpu"

In [None]:
device = maybe_cuda_or_mps()
logger.info(f"Using device {device}")

In [None]:
# model_name_or_path = Path(os.getcwd()) / "weights" / "polaroids.ai-bs=128-margin=0.4"
model_name_or_path = "intfloat/multilingual-e5-base"

In [None]:
from justatom.processing import ITokenizer
from justatom.processing.prime import TripletProcessor

In [None]:
tokenizer = ITokenizer.from_pretrained("intfloat/multilingual-e5-base")
processor = TripletProcessor(tokenizer=tokenizer, max_seq_len=512)

In [None]:
lm_model = ILanguageModel.load(model_name_or_path)

In [None]:
embedder = DocEmbedder(model=lm_model, processor=processor, device=device)
backend_wrapper = IHFWrapperBackend(
    embedder, **Config.clustering.transformers_backend.toDict()
)

In [None]:
bt_runner = IBTRunner(**Config.clustering.bertopic, model=backend_wrapper, verbose=True)

In [None]:
embeddings = list(
    chain.from_iterable(embedder.encode(documents, verbose=True, batch_size=4))
)
topics, probs = bt_runner.fit_transform(docs=documents)

In [None]:
reducer = IUMAPDimReducer(**Config.clustering.umap.toDict())
points = reducer.fit_transform(embeddings)

In [None]:
def prepare2d(docs, topics, labels, reduced_embeddings):
    assert (
        reduced_embeddings.shape[1] == 2
    ), f"Embeddings shape mismatch Exptected 2D, got {embeddings.shape[1]}D"
    COLS_MAPPING = dict(
        column_0="text", column_1="topic", column_2="label", column_3="x", column_4="y"
    )
    pl_view = pl.from_dicts(
        zip(
            docs,
            topics,
            labels,
            reduced_embeddings[:, 0],
            reduced_embeddings[:, 1],
            strict=False,
        )
    )
    pl_view = pl_view.rename(COLS_MAPPING)
    return pl_view

In [None]:
pl_view = prepare2d(
    docs=documents, topics=topics, labels=labels, reduced_embeddings=points
)

In [None]:
chart = PlotlyScatterChart().view(pl_view, label_to_view="Вселенная")

In [None]:
chart

In [20]:
pl_view.head()

text,topic,label,x,y
str,i64,str,f64,f64
"""В реалисте вер…",4,"""Братья Карамаз…",10.780663,8.398505
"""Жизнь — это ми…",4,"""Человек в футл…",10.056022,6.705342
"""Нет, не так. К…",4,"""Метро 2033""",10.17023,6.336634
"""Станьте солнце…",4,"""Преступление и…",10.817473,6.70468
"""Когда-то он бы…",4,"""Джон Уик 3""",9.281586,12.412477
