In [2]:
# io
import os
from pathlib import Path
from typing import Iterable
from collections import Counter
import string
from tqdm import tqdm
from more_itertools import chunked
import math
# standard
import dotenv
import simplejson as json
import numpy as np
import polars as pl
import pytorch_lightning as L
import torch
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from loguru import logger
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset
# justatom
from justatom.processing.sample import Sample, SampleBasket
from justatom.logging.wandb import WandbLogger
from justatom.tooling import stl
from justatom.modeling.mask import ILanguageModel
from justatom.processing import IProcessor, ITokenizer, igniset, INFERProcessor, ContrastiveProcessor
from justatom.processing.loader import NamedDataLoader
from justatom.running.m1 import M1LMRunner

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
def source_from_dataset(dataset_name_or_path, **props):
    from justatom.storing.dataset import API as DatasetApi
    import polars as pl

    maybe_df_or_iter = DatasetApi.named(dataset_name_or_path).iterator(**props)
    if isinstance(maybe_df_or_iter, pl.DataFrame):
        pl_data = maybe_df_or_iter
    else:
        dataset = list(maybe_df_or_iter)
        pl_data = pl.from_dicts(dataset)
    return pl_data

In [None]:
def maybe_cuda_or_mps():
    if torch.cuda.is_available():
        return "cuda:0"
    elif torch.has_mps:
        return "mps"
    else:
        return "cpu"

In [None]:
pl_data = source_from_dataset(Path(os.getcwd()) / ".data" / "polaroids.ai.data.all.json").select(["content", "queries", "chunk_id", "keywords_or_phrases"]).explode("queries").filter(pl.col("queries") != None).sample(shuffle=True, fraction=1.0)
js_data = pl_data.to_dicts()

In [None]:
len(pl_data)

In [None]:
pl_data.select(["queries", "content"]).head()

In [None]:
pl_data.select("content").unique()

In [None]:
js_data[0]

In [None]:
model_name_or_path = "intfloat/multilingual-e5-base"
tokenizer = ITokenizer.from_pretrained(model_name_or_path)
# processor = INFERProcessor(tokenizer=tokenizer, max_seq_len=512)
lm_model = ILanguageModel.load(model_name_or_path=model_name_or_path)
processor = ContrastiveProcessor(tokenizer=tokenizer, max_seq_len=512, queries_field="queries")

In [None]:
processor.queries_field

In [None]:
processor.pos_queries_field

In [None]:
batch_size = 16
device = maybe_cuda_or_mps()

In [None]:
dataset, tensor_names, _, baskets = processor.dataset_from_dicts(js_data, return_baskets=True)

In [None]:
loader = NamedDataLoader(dataset=dataset, tensor_names=tensor_names, batch_size=batch_size)

In [None]:
baskets[3].samples[0].clear_text

In [None]:
next(iter(loader))

In [None]:
logger.info(next(iter(loader))['pos_input_ids'].shape) # content# batch_size x max_seq_len
logger.info(next(iter(loader))["input_ids"].shape) # queries # batch_size x max_seq_len

In [None]:
processor.tokenizer.decode(next(iter(loader))['input_ids'][1].squeeze(), skip_special_tokens=True, clean_up_tokenization_spaces=True)

In [None]:
lm_runner = M1LMRunner(
    model=lm_model,
    processor=processor,
    prediction_heads=[],
    device=device
).eval()

In [None]:
sum(p.numel() for p in lm_runner.model.eval().parameters() if p.requires_grad)

In [None]:
class BiGAMMATrainer(nn.Module):
    
    def __init__(self, lm_runner, device: str = "cpu", stopsyms: str | None = None):
        super().__init__()
        self.gamma1 = nn.Parameter(torch.Tensor([0.5]).to(device), requires_grad=True)
        self.gamma2 = nn.Parameter(torch.Tensor([1.5]).to(device), requires_grad=True)
        self.sigmoid = nn.Sigmoid()
        self.runner = lm_runner.eval()
        self.device = device
        self.processor = lm_runner.processor
        self.stopsyms = "«»:\"'" if stopsyms is None else stopsyms
        
        for name, tensor in self.runner.named_parameters():
            tensor.requires_grad=False
        self.runner.training=False
        
        self.runner.to(device)
    
    def wrapper_for_keywords_or_content(self, js_doc, include_keywords: bool = False, include_explanation: bool = False, include_content: bool = False):
        if not include_content and not include_keywords and not include_explanation:
            raise ValueError(f"You selected [include_keywords=False][include_content=False][include_explanation=False]")
        keywords_or_phrases = js_doc.get("keywords_or_phrases", [])
        keywords_content: str = [js_doc['content']] if include_content else []
        if include_keywords and include_explanation:
            keywords_content += [
                kwp["keyword_or_phrase"].strip() + " " + kwp["explanation"].strip() for kwp in keywords_or_phrases
            ]
        elif include_keywords:
            keywords_content += [kwp["keyword_or_phrase"].strip() for kwp in keywords_or_phrases]
        else:
            keywords_content += [kwp["explanation"].strip() for kwp in keywords_or_phrases]
            keywords_content += "\n".join([kwp["explanation"].strip() for kwp in keywords_or_phrases])
        return keywords_content

    def _fn_inverse_idf_recall(self, query: str, keywords_or_phrases_or_content: list[str] | str, stopsyms: str | None = None, **props):
        stopsyms = stopsyms or self.stopsyms
        stopsyms = string.punctuation if stopsyms is None else stopsyms + string.punctuation
        if isinstance(keywords_or_phrases_or_content, list):
            k_words = Counter(stl.flatten_list(["".join([w for w in kwp.lower().strip() if w not in stopsyms]).split() for kwp in keywords_or_phrases_or_content]))
        else:
            k_words = Counter(["".join([ch for ch in w.lower().strip() if ch not in stopsyms]) for w in keywords_or_phrases_or_content.split()])
        q_words = "".join(w for w in query if w not in stopsyms).lower().strip().split()
        idf_recall = sum([1.0 / math.log(1 + k_words.get(w, 1)) for w in q_words if w in k_words]) / sum(
            [1.0 / math.log(1 + k_words.get(w, 1)) for w in q_words]
        )
        return idf_recall
    
    def forward(self, batch):
        batch = {k:batch[k].to(self.device) for k in batch}
        q_vecs, d_vecs = lm_runner(batch, average=True, norm=True)
        scores = q_vecs @ d_vecs.T
        R = torch.zeros((scores.shape[0], scores.shape[1]), device=self.device, requires_grad=False)
        with torch.no_grad():
            for i, q_tokens in enumerate(batch["input_ids"]):
                for j, d_tokens in enumerate(batch["pos_input_ids"]):
                    queries = self.processor.tokenizer.decode(q_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True)[len(self.processor.queries_prefix):].strip()
                    content = self.processor.tokenizer.decode(d_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True)[len(self.processor.pos_queries_prefix):].strip()
                    rank = self._fn_inverse_idf_recall(queries, content)
                    try:
                        R[i, j] = rank
                    except IndexError:
                        logger.info(f"Error @ batch for tokens=[{str(i)}, {str(j)}]")
                        return batch
        gamma1_ = self.sigmoid(self.gamma1)
        gamma2_ = self.sigmoid(self.gamma2)
        output = gamma1_ * scores + gamma2_ * R
        
        return output
    
    def train(self, loader: NamedDataLoader, optimizer, logger = None, n_epochs: int = 1):
        for epoch_idx, _ in enumerate(range(n_epochs)):
            for batch_idx, batch in tqdm(enumerate(loader)):
                output = self.forward(batch) # batch_size x batch_size
                labels = torch.arange(len(output), device=self.device)
                loss = F.cross_entropy(output, labels)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                if logger is not None:
                    logger.log_metrics({
                        "Loss": loss.item(),
                        "Gamma1": self.gamma1.item(),
                        "Gamma2": self.gamma2.item()
                    })
            _save_dir = Path(save_dir) / "BiGamma" / f"epoch{str(epoch_idx + 1)}"
            self.runner.save(_save_dir)
        _save_dir = Path(save_dir) / "BiGamma" / f"epoch{str(epoch_idx + 1)}"
        self.runner.save(_save_dir)
            
            
class GAMMATrainer(nn.Module):
    
    def __init__(self, lm_runner, device: str = "cpu", stopsyms: str | None = None):
        super().__init__()
        self.gamma = nn.Parameter(torch.Tensor([0.5]).to(device), requires_grad=True)
        self.sigmoid = nn.Sigmoid()
        self.runner = lm_runner.eval()
        self.device = device
        self.processor = lm_runner.processor
        self.stopsyms = "«»:\"'" if stopsyms is None else stopsyms
        
        for name, tensor in self.runner.named_parameters():
            tensor.requires_grad=False
        self.runner.training=False
        
        self.runner.to(device)
    
    def wrapper_for_keywords_or_content(self, js_doc, include_keywords: bool = False, include_explanation: bool = False, include_content: bool = False):
        if not include_content and not include_keywords and not include_explanation:
            raise ValueError(f"You selected [include_keywords=False][include_content=False][include_explanation=False]")
        keywords_or_phrases = js_doc.get("keywords_or_phrases", [])
        keywords_content: str = [js_doc['content']] if include_content else []
        if include_keywords and include_explanation:
            keywords_content += [
                kwp["keyword_or_phrase"].strip() + " " + kwp["explanation"].strip() for kwp in keywords_or_phrases
            ]
        elif include_keywords:
            keywords_content += [kwp["keyword_or_phrase"].strip() for kwp in keywords_or_phrases]
        else:
            keywords_content += [kwp["explanation"].strip() for kwp in keywords_or_phrases]
            keywords_content += "\n".join([kwp["explanation"].strip() for kwp in keywords_or_phrases])
        return keywords_content
    
    def _fn_inverse_idf_recall(self, query: str, keywords_or_phrases_or_content: list[str] | str, stopsyms: str | None = None, **props):
        stopsyms = stopsyms or self.stopsyms
        stopsyms = string.punctuation if stopsyms is None else stopsyms + string.punctuation
        if isinstance(keywords_or_phrases_or_content, list):
            k_words = Counter(stl.flatten_list(["".join([w for w in kwp.lower().strip() if w not in stopsyms]).split() for kwp in keywords_or_phrases_or_content]))
        else:
            k_words = Counter(["".join([ch for ch in w.lower().strip() if ch not in stopsyms]) for w in keywords_or_phrases_or_content.split()])
        q_words = "".join(w for w in query if w not in stopsyms).lower().strip().split()
        idf_recall = sum([1.0 / math.log(1 + k_words.get(w, 1)) for w in q_words if w in k_words]) / sum(
            [1.0 / math.log(1 + k_words.get(w, 1)) for w in q_words]
        )
        return idf_recall
    
    def forward(self, batch):
        batch = {k:batch[k].to(self.device) for k in batch}
        q_vecs, d_vecs = lm_runner(batch, average=True, norm=True)
        scores = q_vecs @ d_vecs.T
        R = torch.zeros((scores.shape[0], scores.shape[1]), device=self.device, requires_grad=False)
        with torch.no_grad():
            for i, q_tokens in enumerate(batch["input_ids"]):
                for j, d_tokens in enumerate(batch["pos_input_ids"]):
                    queries = self.processor.tokenizer.decode(q_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True)[len(self.processor.queries_prefix):].strip()
                    content = self.processor.tokenizer.decode(d_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True)[len(self.processor.pos_queries_prefix):].strip()
                    rank = self._fn_inverse_idf_recall(queries, content)
                    try:
                        R[i, j] = rank
                    except IndexError:
                        logger.info(f"Error @ batch for tokens=[{str(i)}, {str(j)}]")
                        return batch
        gamma_ = self.sigmoid(self.gamma)
        output = gamma_ * scores + (1 - gamma_) * R

        return output

    def train(self, loader: NamedDataLoader, optimizer, logger = None, n_epochs: int = 1, save_dir: str | Path = None):
        for epoch_idx, _ in enumerate(range(n_epochs)):
            for batch_idx, batch in tqdm(enumerate(loader)):
                output = self.forward(batch) # batch_size x batch_size
                labels = torch.arange(len(output), device=self.device)
                loss = F.cross_entropy(output, labels)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                if logger is not None:
                    logger.log_metrics({
                        "Loss": loss.item(),
                        "Gamma": self.gamma.item(),
                    })
            _save_dir = Path(save_dir) / "Gamma" / f"epoch{str(epoch_idx + 1)}"
            self.runner.save(_save_dir)
        _save_dir = Path(save_dir) / "Gamma" / f"epoch{str(epoch_idx + 1)}"
        self.runner.save(_save_dir)

In [None]:
trainer = BiGAMMATrainer(lm_runner=lm_runner, device=device)

In [None]:
optimizer = optim.AdamW([trainer.gamma1, trainer.gamma2])

wb_logger = WandbLogger(project="justatom.ai", name="BiGamma AdamW descent")

In [None]:
trainer.train(loader, optimizer=optimizer, logger=wb_logger, n_epochs=2, save_dir=Path(os.getcwd()) / "weights")

In [None]:
trainer = GAMMATrainer(lm_runner = lm_runner, device=device)

In [None]:
optimizer = optim.AdamW([trainer.gamma])
wb_logger = WandbLogger(project="justatom.ai", name="Gamma AdamW descent")

In [None]:
trainer.train(loader, optimizer=optimizer, logger=wb_logger, n_epochs=2, save_dir=Path(os.getcwd()) / "weights")

In [None]:
wb_logger.close_log()