In [1]:
# io
import os
from pathlib import Path
from typing import Iterable
from collections import Counter
import string
from tqdm import tqdm
from more_itertools import chunked
import math
# standard
import dotenv
import simplejson as json
import numpy as np
import polars as pl
import pytorch_lightning as L
import torch
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from loguru import logger
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset
# justatom
from justatom.logging.io import CSVLogger
from justatom.processing.sample import Sample, SampleBasket
from justatom.logging.wandb import WandbLogger
from justatom.tooling import stl
from justatom.modeling.mask import ILanguageModel
from justatom.processing import IProcessor, ITokenizer, igniset, RuntimeProcessor, TrainWithContrastiveProcessor
from justatom.processing.loader import NamedDataLoader
from justatom.running.encoders import EncoderRunner

In [2]:
from justatom.tooling.dataset import DatasetRecordAdapter

In [3]:
def maybe_cuda_or_mps():
    if torch.cuda.is_available():
        return "cuda:0"
    elif torch.has_mps:
        return "mps"
    else:
        return "cpu"

In [4]:
dataset_path = Path(os.getcwd()) / ".data" / "polaroids.ai.data.json"
num_samples = 100  # debug subset size

adapter = DatasetRecordAdapter.from_source(
    dataset_path,
    content_col="content",
    queries_col="queries",
    chunk_id_col="chunk_id",
    keywords_col="keywords_or_phrases",
    preserve_all_fields=False,
 )

docs = list(adapter.iterator())
docs_df = pl.from_dicts(docs)

pl_data = (
    docs_df
    .with_columns(
        queries=pl.col("meta").struct.field("labels"),
        chunk_id=pl.col("id"),
        keywords_or_phrases=pl.col("meta").struct.field("keywords_or_phrases"),
    )
    .select(["content", "queries", "chunk_id", "keywords_or_phrases"] )
    .explode("queries")
    .filter(
        pl.col("queries").is_not_null()
        & (pl.col("queries").cast(pl.Utf8).str.len_chars() > 0)
    )
    .with_columns(queries=pl.col("queries").cast(pl.Utf8))
    .sample(shuffle=True, fraction=1.0)
    .head(num_samples)
 )

js_data = pl_data.to_dicts()

In [5]:
len(pl_data)

100

In [6]:
pl_data.select(["queries", "content"]).head()

queries,content
str,str
"""Hey fellow nerds, can anyone e…","""(Grima): ""The wizard had three…"
"""In the universe of 'Harry Pott…","""(Hagrid): ""Alright. Let's go.""…"
"""Hey gamers! In the Harry Potte…","""No one on Privet Drive had eve…"
"""What is the feared weapon used…","""Lupin's eyebrows crept up in s…"
"""What were the dynamics and out…","""Adelaide Ivanovna, immediately…"


In [7]:
pl_data.select("content").unique().shape

(99, 1)

In [8]:
js_data[0]

{'content': '(Grima): "The wizard had three companions. An Elf, a dwarf, and a human".\n(Saruman): "You smell like a horse. Was the man from Gondor?"\n(Grima): "No. From the north. I think, one of the Ranger of the North. Dressed poorly, and also... he wears a strange ring. Two snakes with emerald eyes. One consuming the other, crowned with golden flowers".\n(Saruman): "The ring of Barahir. Gandalf the Grey thinks that he found Isildur\'s Heir - the lost king of Gondor. What a fool. Their line ended many years ago. But it doesn\'t matter. The world of men will fall anyway and it will start with Edoras".',
 'queries': "Hey fellow nerds, can anyone explain why Saruman dismisses the possibility of a surviving line from Gondor's kings despite the clues about the ring of Barahir presented by Grima? Let's dissect the lore! #MiddleEarthMysteries",
 'chunk_id': '918d8c4a-4e0c-5d95-ac21-07746b59a465',
 'keywords_or_phrases': [{'keyword_or_phrase': 'Ranger of the North',
   'explanation': "A gro

In [9]:
model_name_or_path = "intfloat/multilingual-e5-base"
tokenizer = ITokenizer.from_pretrained(model_name_or_path)
# processor = RuntimeProcessor(tokenizer=tokenizer, max_seq_len=512)
lm_model = ILanguageModel.load(model_name_or_path=model_name_or_path)
processor = TrainWithContrastiveProcessor(tokenizer=tokenizer, max_seq_len=512, queries_field="queries")

[32m2026-02-17 07:45:35.176[0m | [1mINFO    [0m | [36mjustatom.modeling.mask[0m:[36mload[0m:[36m149[0m - [1mLoading from huggingface hub via "intfloat/multilingual-e5-base"[0m


In [10]:
processor.queries_field

'queries'

In [11]:
processor.pos_queries_field

'content'

In [12]:
batch_size = 4
device = maybe_cuda_or_mps()

In [13]:
dataset, tensor_names, _, baskets = processor.dataset_from_dicts(js_data, return_baskets=True)

In [14]:
loader = NamedDataLoader(dataset=dataset, tensor_names=tensor_names, batch_size=batch_size)

In [15]:
baskets[3].samples[0].clear_text

["query: What is the feared weapon used by dementors to destroy their victims in the 'Harry Potter and the Prisoner of Azkaban'?",
 "passage: Lupin's eyebrows crept up in surprise. 'Ron and Hermione brought me from Hogsmeade,' Harry lied without blinking an eye. 'Ah,' Lupin drawled. But there was still a moment of suspicion in his look. 'Well, let's toast to the victory of Gryffindor over Slytherin! But, of course, as a teacher, I shouldn't prefer any house,' he hastily added. They were drinking lemonade in silence, but Harry had one question on his tongue. 'What's under a dementor's hood?' he finally blurted out. Professor Lupin, detaching himself from the bottle, frowned. 'You see, those few who know it are not able to tell about it. The thing is that dementors pull back the hood only to use their final and most frightening weapon... 'What weapon?' 'It's called the Dementor’s Kiss,' Lupin said, grimacing. 'Dementors use it on those they want to completely destroy. I think they have s

In [17]:
logger.info(next(iter(loader))['pos_input_ids'].shape) # content# batch_size x max_seq_len
logger.info(next(iter(loader))["input_ids"].shape) # queries # batch_size x max_seq_len

[32m2026-02-17 07:45:42.345[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mtorch.Size([4, 512])[0m
[32m2026-02-17 07:45:42.346[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mtorch.Size([4, 512])[0m


In [18]:
processor.tokenizer.decode(next(iter(loader))['input_ids'][1].squeeze(), skip_special_tokens=True, clean_up_tokenization_spaces=True)

"query: In the universe of 'Harry Potter and the Philosopher's Stone', what discovery did Hagrid make a few weeks prior to taking Harry and others into the Forbidden Forest?"

In [19]:
lm_runner = EncoderRunner(
    model=lm_model,
    processor=processor,
    prediction_heads=[],
    device=device
).eval()

[32m2026-02-17 07:45:47.866[0m | [1mINFO    [0m | [36mjustatom.running.encoders[0m:[36mto[0m:[36m36[0m - [1mMoving to device cuda:0[0m


In [20]:
sum(p.numel() for p in lm_runner.model.eval().parameters() if p.requires_grad)

278043648

In [21]:
class _BaseGammaTrainer(nn.Module):
    def __init__(self, lm_runner, device: str = "cpu", stopsyms: str | None = None, freeze_encoder: bool = True):
        super().__init__()
        self.runner = lm_runner
        self.device = device
        self.processor = lm_runner.processor
        self.freeze_encoder = freeze_encoder
        self.stopsyms = "«»:\"'" if stopsyms is None else stopsyms
        self._configure_encoder()
        self.runner.to(device)

    def _configure_encoder(self):
        if self.freeze_encoder:
            self.runner.eval()
            for tensor in self.runner.parameters():
                tensor.requires_grad = False
        else:
            self.runner.train()
            for tensor in self.runner.parameters():
                tensor.requires_grad = True

    def _fn_inverse_idf_recall(self, query: str, keywords_or_phrases_or_content: list[str] | str, stopsyms: str | None = None):
        stopsyms = stopsyms or self.stopsyms
        stopsyms = string.punctuation if stopsyms is None else stopsyms + string.punctuation
        if isinstance(keywords_or_phrases_or_content, list):
            k_words = Counter(
                stl.flatten_list([
                    "".join([w for w in kwp.lower().strip() if w not in stopsyms]).split()
                    for kwp in keywords_or_phrases_or_content
                ])
            )
        else:
            k_words = Counter([
                "".join([ch for ch in w.lower().strip() if ch not in stopsyms])
                for w in keywords_or_phrases_or_content.split()
            ])
        q_words = "".join(w for w in query if w not in stopsyms).lower().strip().split()
        idf_recall = sum([1.0 / math.log(1 + k_words.get(w, 1)) for w in q_words if w in k_words]) / sum(
            [1.0 / math.log(1 + k_words.get(w, 1)) for w in q_words]
        )
        return idf_recall

    def _encode(self, batch):
        if self.freeze_encoder:
            with torch.no_grad():
                return self.runner(batch, average=True, norm=True)
        return self.runner(batch, average=True, norm=True)

    def _build_rank_matrix(self, batch, shape):
        rank_matrix = torch.zeros(shape, device=self.device, requires_grad=False)
        with torch.no_grad():
            for i, q_tokens in enumerate(batch["input_ids"]):
                for j, d_tokens in enumerate(batch["pos_input_ids"]):
                    queries = self.processor.tokenizer.decode(
                        q_tokens,
                        skip_special_tokens=True,
                        clean_up_tokenization_spaces=True,
                    )[len(self.processor.queries_prefix):].strip()
                    content = self.processor.tokenizer.decode(
                        d_tokens,
                        skip_special_tokens=True,
                        clean_up_tokenization_spaces=True,
                    )[len(self.processor.pos_queries_prefix):].strip()
                    rank_matrix[i, j] = self._fn_inverse_idf_recall(queries, content)
        return rank_matrix

    def _grad_norm(self, parameters) -> float:
        grads = [p.grad.detach().float().norm(2) for p in parameters if p.grad is not None]
        if not grads:
            return 0.0
        return torch.stack(grads).norm(2).item()

    def gamma_parameters(self) -> list[nn.Parameter]:
        raise NotImplementedError

    def gamma_metrics(self) -> dict[str, float]:
        raise NotImplementedError

    def gamma_grad_metrics(self) -> dict[str, float]:
        raise NotImplementedError

    def mix_scores(self, scores, rank_matrix):
        raise NotImplementedError

    def build_optimizer(self, lr_gamma: float = 1e-2, lr_encoder: float = 2e-5, weight_decay: float = 0.01):
        param_groups = [{"params": self.gamma_parameters(), "lr": lr_gamma, "weight_decay": 0.0}]
        if not self.freeze_encoder:
            encoder_params = [p for p in self.runner.parameters() if p.requires_grad]
            if encoder_params:
                param_groups.append({
                    "params": encoder_params,
                    "lr": lr_encoder,
                    "weight_decay": weight_decay,
                })
        return optim.AdamW(param_groups)

    def forward(self, batch):
        batch = {k: batch[k].to(self.device) for k in batch}
        q_vecs, d_vecs = self._encode(batch)
        scores = q_vecs @ d_vecs.T
        rank_matrix = self._build_rank_matrix(batch, scores.shape)
        return self.mix_scores(scores, rank_matrix)

    def train(self, loader: NamedDataLoader, optimizer, logger=None, n_epochs: int = 1, save_dir: str | Path = None):
        for epoch_idx in range(n_epochs):
            for _, batch in tqdm(enumerate(loader)):
                output = self.forward(batch)
                labels = torch.arange(len(output), device=self.device)
                loss = F.cross_entropy(output, labels)
                optimizer.zero_grad()
                loss.backward()

                grad_metrics = self.gamma_grad_metrics()
                grad_metrics["Grad_Norm_model"] = self._grad_norm(self.runner.parameters())

                optimizer.step()
                if logger is not None:
                    logger.log_metrics({
                        "Loss": loss.item(),
                        "FreezeEncoder": int(self.freeze_encoder),
                        **self.gamma_metrics(),
                        **grad_metrics,
                    })
            if save_dir is not None:
                save_path = Path(save_dir) / self.save_subdir / f"epoch{str(epoch_idx + 1)}"
                self.runner.save(save_path)


class BiGAMMATrainer(_BaseGammaTrainer):
    save_subdir = "BiGamma"

    def __init__(self, lm_runner, device: str = "cpu", stopsyms: str | None = None, freeze_encoder: bool = True):
        super().__init__(lm_runner=lm_runner, device=device, stopsyms=stopsyms, freeze_encoder=freeze_encoder)
        self.gamma1 = nn.Parameter(torch.tensor([0.5], device=device), requires_grad=True)
        self.gamma2 = nn.Parameter(torch.tensor([1.5], device=device), requires_grad=True)
        self.sigmoid = nn.Sigmoid()

    def gamma_parameters(self) -> list[nn.Parameter]:
        return [self.gamma1, self.gamma2]

    def gamma_metrics(self) -> dict[str, float]:
        return {"Gamma1": self.gamma1.item(), "Gamma2": self.gamma2.item()}

    def gamma_grad_metrics(self) -> dict[str, float]:
        return {
            "Grad_norm_gamma1": self._grad_norm([self.gamma1]),
            "Grad_norm_gamma2": self._grad_norm([self.gamma2]),
        }

    def mix_scores(self, scores, rank_matrix):
        gamma1_ = self.sigmoid(self.gamma1)
        gamma2_ = self.sigmoid(self.gamma2)
        return gamma1_ * scores + gamma2_ * rank_matrix


class GAMMATrainer(_BaseGammaTrainer):
    save_subdir = "Gamma"

    def __init__(self, lm_runner, device: str = "cpu", stopsyms: str | None = None, freeze_encoder: bool = True):
        super().__init__(lm_runner=lm_runner, device=device, stopsyms=stopsyms, freeze_encoder=freeze_encoder)
        self.gamma = nn.Parameter(torch.tensor([0.5], device=device), requires_grad=True)
        self.sigmoid = nn.Sigmoid()

    def gamma_parameters(self) -> list[nn.Parameter]:
        return [self.gamma]

    def gamma_metrics(self) -> dict[str, float]:
        return {"Gamma": self.gamma.item()}

    def gamma_grad_metrics(self) -> dict[str, float]:
        return {
            "Grad_norm_gamma1": self._grad_norm([self.gamma]),
            "Grad_norm_gamma2": 0.0,
        }

    def mix_scores(self, scores, rank_matrix):
        gamma_ = self.sigmoid(self.gamma)
        return gamma_ * scores + (1 - gamma_) * rank_matrix

In [22]:
experiment_freeze_encoder = True  # True: tune only gammas | False: end2end encoder+gammas
trainer = BiGAMMATrainer(
    lm_runner=lm_runner,
    device=device,
    freeze_encoder=experiment_freeze_encoder,
 )

[32m2026-02-17 07:46:01.072[0m | [1mINFO    [0m | [36mjustatom.running.encoders[0m:[36mto[0m:[36m36[0m - [1mMoving to device cuda:0[0m


In [23]:
optimizer = trainer.build_optimizer(
    lr_gamma=1e-2,
    lr_encoder=2e-5,
    weight_decay=0.01,
 )

wb_logger = CSVLogger(Path(os.getcwd()) / "weights" / "debug_bi_gamma_metrics.csv")

In [24]:
import csv
import gc

def _to_gb(num_bytes: int) -> float:
    return num_bytes / (1024 ** 3)

def cuda_mem_report(tag: str, reset_peak: bool = False):
    if not torch.cuda.is_available():
        print(f"[{tag}] CUDA is not available")
        return
    if reset_peak:
        torch.cuda.reset_peak_memory_stats()
    allocated = torch.cuda.memory_allocated()
    reserved = torch.cuda.memory_reserved()
    max_allocated = torch.cuda.max_memory_allocated()
    max_reserved = torch.cuda.max_memory_reserved()
    print(
        f"[{tag}] alloc={_to_gb(allocated):.2f}GB | reserved={_to_gb(reserved):.2f}GB | "
        f"peak_alloc={_to_gb(max_allocated):.2f}GB | peak_reserved={_to_gb(max_reserved):.2f}GB"
    )

In [25]:
cuda_mem_report("BiGamma BEFORE", reset_peak=True)
trainer.train(loader, optimizer=optimizer, logger=wb_logger, n_epochs=1, save_dir=Path(os.getcwd()) / "weights")
cuda_mem_report("BiGamma AFTER", reset_peak=False)

[BiGamma BEFORE] alloc=1.04GB | reserved=1.09GB | peak_alloc=1.04GB | peak_reserved=1.09GB


25it [00:02, 11.31it/s]


[BiGamma AFTER] alloc=1.04GB | reserved=1.20GB | peak_alloc=1.12GB | peak_reserved=1.20GB


In [None]:
cuda_mem_report("After cleanup", reset_peak=False)

In [26]:
experiment_freeze_encoder = False  # compare against frozen setup
trainer = GAMMATrainer(
    lm_runner=lm_runner,
    device=device,
    freeze_encoder=experiment_freeze_encoder,
 )

[32m2026-02-17 07:49:01.849[0m | [1mINFO    [0m | [36mjustatom.running.encoders[0m:[36mto[0m:[36m36[0m - [1mMoving to device cuda:0[0m


In [None]:
optimizer = trainer.build_optimizer(
    lr_gamma=1e-2,
    lr_encoder=2e-5,
    weight_decay=0.01,
 )
wb_logger = CSVLogger(Path(os.getcwd()) / "weights" / "gamma_metrics.csv")

In [28]:
cuda_mem_report("Gamma BEFORE", reset_peak=True)
trainer.train(loader, optimizer=optimizer, logger=wb_logger, n_epochs=1, save_dir=Path(os.getcwd()) / "weights")

[Gamma BEFORE] alloc=1.04GB | reserved=1.20GB | peak_alloc=1.04GB | peak_reserved=1.20GB


25it [00:05,  4.87it/s]


In [None]:
cuda_mem_report("Gamma AFTER", reset_peak=False)

In [None]:
if wb_logger is not None:
    wb_logger.close_log()