In [1]:
import os
import re
import sys
import json
import logging
import collections
import pandas as pd
import numpy as np

from typing import Any, Literal, TypedDict, Callable, cast, TypeVar, Generic

In [2]:
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import PrecisionRecallDisplay, precision_recall_curve, average_precision_score
from sklearn.manifold import TSNE
import scipy.cluster.hierarchy as hc
import scipy.spatial as sp

In [3]:
import time
from typing import Callable, Generic, TypeVar


KT = TypeVar('KT')
VT = TypeVar('VT')


MAX_RETRY = 10
RETRY_WAIT = 0.1


class LRU(Generic[KT, VT]):
    def __init__(
            self,
            max_items: int,
            soft_limit: int | None = None) -> None:
        self._values: dict[KT, VT] = {}
        self._times: dict[KT, float] = {}
        self._max_items = max_items
        self._soft_limit = (
            max(1, int(max_items * 0.9)) if soft_limit is None else soft_limit)
        assert self._max_items >= self._soft_limit

    def get(self, key: KT) -> VT | None:
        res = self._values.get(key)
        if res is not None:
            self._times[key] = time.monotonic()
        return res

    def set(self, key: KT, value: VT) -> None:
        self._values[key] = value
        self._times[key] = time.monotonic()
        self.gc()

    def clear_keys(self, prefix_match: Callable[[KT], bool]) -> None:
        for key in list(self._values.keys()):
            if prefix_match(key):
                self._values.pop(key, None)
                self._times.pop(key, None)

    def gc(self) -> None:
        retry = 0
        while len(self._values) > self._max_items:
            try:
                to_remove = sorted(
                    self._times.copy().items(),
                    key=lambda item: item[1])[:-self._soft_limit]
                for rm_item in to_remove:
                    key = rm_item[0]
                    self._values.pop(key, None)
                    self._times.pop(key, None)
            except RuntimeError:
                # dictionary changed size during iteration: try again
                if retry >= MAX_RETRY:
                    raise
                retry += 1
                if RETRY_WAIT > 0:
                    time.sleep(RETRY_WAIT)

In [4]:
import torch
from torch import nn

def get_device() -> torch.device:
    if torch.backends.mps.is_available():
        return torch.device("mps")
    if torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")

device = get_device()
device

device(type='mps')

In [5]:
from transformers import (
    DistilBertModel,
    DistilBertTokenizer,
    modeling_utils,
    get_scheduler,
)
from torch.optim import AdamW

In [6]:
from tqdm.auto import tqdm
import evaluate
import time

In [7]:
MODEL_FOLDER = "checkpoints"
os.makedirs(MODEL_FOLDER, exist_ok=True)

In [8]:
def batch_dot(batch_a: torch.Tensor, batch_b: torch.Tensor) -> torch.Tensor:
    batch_size = batch_a.shape[0]
    return torch.bmm(
        batch_a.reshape([batch_size, 1, -1]),
        batch_b.reshape([batch_size, -1, 1])).reshape([-1, 1])

In [9]:
TokenizedInput = TypedDict('TokenizedInput', {
    "text": list[str] | None,
    "input_ids": torch.Tensor,
    "attention_mask": torch.Tensor,
})

In [10]:
VERSION = 3
MODEL_BASE = {
    1: "distilbert-base-uncased",
    2: "distilbert-base-uncased",
    3: "distilbert-base-multilingual-cased",
}[VERSION]

In [11]:
def get_tokenizer() -> Callable[[list[str], bool], TokenizedInput]:
    tokenizer = DistilBertTokenizer.from_pretrained(MODEL_BASE)
    device = get_device()

    def tokens(texts: list[str], preserve_text: bool) -> TokenizedInput:
        res = tokenizer(
            texts, return_tensors="pt", padding=True, truncation=True)
        obj = {k: v.to(device) for k, v in res.items()}
        obj["text"] = texts if preserve_text else None
        return cast(TokenizedInput, obj)

    return tokens

In [12]:
AggType = Literal["cls", "mean"]
AGG_CLS: AggType = "cls"
AGG_MEAN: AggType = "mean"

In [13]:
class TagModel(nn.Module):
    def __init__(
            self,
            *,
            agg: AggType,
            ignore_pretrained_warning: bool = False) -> None:
        super().__init__()
        logger = modeling_utils.logger
        level = logger.getEffectiveLevel()
        try:
            if ignore_pretrained_warning:
                logger.setLevel(logging.ERROR)
            self._bert = DistilBertModel.from_pretrained(MODEL_BASE)
        finally:
            if ignore_pretrained_warning:
                logger.setLevel(level)
        self._agg = agg
        self._test_lru = None
        
        def cache_hook(*args: Any, **kwargs: Any) -> None:
            self.clear_cache()
        
        self.register_load_state_dict_post_hook(cache_hook)
        self.register_state_dict_pre_hook(cache_hook)

    def _get_agg(self, lhs: torch.Tensor) -> torch.Tensor:
        if self._agg == AGG_CLS:
            return lhs[:, 0]
        if self._agg == AGG_MEAN:
            return torch.mean(lhs, dim=1)
        raise ValueError(f"unknown aggregation: {self._agg}")

    def _embed(
            self,
            input_ids: torch.Tensor,
            attention_mask: torch.Tensor) -> torch.Tensor:
        outputs = self._bert(
            input_ids=input_ids, attention_mask=attention_mask)
        out = self._get_agg(outputs.last_hidden_state)
        return out
    
    def clear_cache(self) -> None:
        self._test_lru = None

    def forward(
            self,
            x: TokenizedInput) -> torch.Tensor:
        if self.training and self._test_lru is not None:
            self.clear_cache()
        if not self.training:
            if self._test_lru is None:
                self._test_lru = LRU(1000)
            if x["text"] is not None:
                res = []
                for row_ix, text in enumerate(x["text"]):
                    cache_res = self._test_lru.get(text)
                    if cache_res is None:
                        cache_res = self._embed(
                            input_ids=x["input_ids"][[row_ix]],
                            attention_mask=x["attention_mask"][[row_ix]]).detach()
                        self._test_lru.set(text, cache_res)
                    res.append(torch.clone(cache_res))
                return torch.vstack(res)
        return self._embed(
            input_ids=x["input_ids"],
            attention_mask=x["attention_mask"])


class TrainingHarness(nn.Module):
    def __init__(self, model: TagModel, use_cos: bool) -> None:
        super().__init__()
        self._model = model
        self._loss = nn.BCELoss()
        self._cos = nn.CosineSimilarity() if use_cos else None
        
    def _combine(self, left_embed: torch.Tensor, right_embed: torch.Tensor) -> torch.Tensor:
        if self._cos is None:
            # NOTE: torch.sigmoid would be a bad idea here
            return batch_dot(left_embed, right_embed)
        return self._cos(left_embed, right_embed).reshape([-1, 1])
    
    def get_model(self) -> TagModel:
        return self._model

    def forward(
            self,
            *,
            left: TokenizedInput,
            right: TokenizedInput,
            labels: torch.Tensor | None = None) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor:
        left_embed = self._model(left)
        right_embed = self._model(right)
        preds = self._combine(left_embed, right_embed)
        if labels is None:
            return preds
        probs = torch.hstack([1.0 - preds, preds])
        return preds, self._loss(probs, labels)

In [14]:
tokens = get_tokenizer()

In [15]:
def create_model(config) -> TagModel:
    return TagModel(agg=config["agg"], ignore_pretrained_warning=True)


def load_model(harness, model_fname):
    print(f"loading {model_fname}")
    with open(model_fname, "rb") as fin:
        harness.load_state_dict(torch.load(fin, map_location=device))


def compute(harness, df, *, preserve_text):
    lefts = tokens(df["left"].tolist(), preserve_text)
    rights = tokens(df["right"].tolist(), preserve_text)
    labels = torch.tensor(
        np.array([~df["is_correct"], df["is_correct"]]),
        dtype=torch.float32).T.to(device)
   
    preds, loss = harness(
        left=lefts,
        right=rights,
        labels=labels)
    return preds, loss


def run_training(*, prefix, num_epochs, train_size, train_df_gen, test_df, config, resume, skip_test):
    model = create_model(config)
    model.to(device)
    harness = TrainingHarness(model, use_cos=config["use_cos"])
    harness.to(device)

    try:
        batch_size_train = 3
        batch_size_test = 12
        optimizer = AdamW(harness.parameters(), lr=5e-5)

        num_training_steps = num_epochs * train_size // batch_size_train
        warmup = 1000
        lr_scheduler = get_scheduler(
            name="linear",
            optimizer=optimizer,
            num_warmup_steps=warmup,
            num_training_steps=num_training_steps - warmup)

        for epoch in range(num_epochs):
            print(f"epoch {epoch} train size {train_size} test size {test_df.shape[0]}")
            real_time = time.monotonic()

            model_fname = os.path.join(MODEL_FOLDER, f"{prefix}_model_{epoch}.pkl")

            acc_train = evaluate.load("accuracy")
            pre_train = evaluate.load("precision")
            rec_train = evaluate.load("recall")
            train_loss = []

            cur_skip_test = skip_test
            if os.path.exists(model_fname):
                load_model(harness, model_fname)
                train_loss.append(0)
                acc_train.add_batch(
                    predictions=[0, 1],
                    references=[0, 1])
                pre_train.add_batch(
                    predictions=[0, 1],
                    references=[0, 1])
                rec_train.add_batch(
                    predictions=[0, 1],
                    references=[0, 1])
            else:
                train_df = train_df_gen(train_size)
                model.train()
                harness.train()
                with tqdm(desc="train", total=train_df.shape[0]) as progress_bar:
                    for train_chunk in np.array_split(train_df, np.ceil(train_df.shape[0] / batch_size_train)):
                        preds, loss = compute(harness, train_chunk, preserve_text=False)
                        train_loss.append(loss.item())
                        loss.backward()

                        optimizer.step()
                        lr_scheduler.step()
                        optimizer.zero_grad()
                        progress_bar.update(train_chunk.shape[0])

                        labels = train_chunk["is_correct"].astype(int).to_numpy()
                        predictions = (preds > 0.5).to(int)
                        acc_train.add_batch(
                            predictions=predictions,
                            references=labels)
                        pre_train.add_batch(
                            predictions=predictions,
                            references=labels)
                        rec_train.add_batch(
                            predictions=predictions,
                            references=labels)

                torch.save(harness.state_dict(), model_fname)
                cur_skip_test = False

            if not cur_skip_test:
                acc_test = evaluate.load("accuracy")
                pre_test = evaluate.load("precision")
                rec_test = evaluate.load("recall")
                test_loss = []

                model.eval()
                harness.eval()
                with torch.no_grad():
                    with tqdm(desc="test", total=test_df.shape[0]) as progress_bar:
                        for test_chunk in np.array_split(test_df, np.ceil(test_df.shape[0] / batch_size_test)):
                            preds, loss = compute(harness, test_chunk, preserve_text=True)
                            test_loss.append(loss.item())

                            progress_bar.update(test_chunk.shape[0])
                            
                            labels = test_chunk["is_correct"].astype(int).to_numpy()
                            predictions = (preds > 0.5).to(int)
                            acc_test.add_batch(
                                predictions=predictions,
                                references=labels)
                            pre_test.add_batch(
                                predictions=predictions,
                                references=labels)
                            rec_test.add_batch(
                                predictions=predictions,
                                references=labels)

                stats = {
                    "epoch": int(epoch),
                    "train_acc": float(acc_train.compute()["accuracy"]),
                    "train_pre": float(pre_train.compute()["precision"]),
                    "train_rec": float(rec_train.compute()["recall"]),
                    "train_loss": float(np.mean(train_loss)),
                    "test_acc": float(acc_test.compute()["accuracy"]),
                    "test_pre": float(pre_test.compute()["precision"]),
                    "test_rec": float(rec_test.compute()["recall"]),
                    "test_loss": float(np.mean(test_loss)),
                    "time": 0.0,
                    "config": config,
                }

                print(
                    f"train[acc: {stats['train_acc']} "
                    f"pre: {stats['train_pre']} "
                    f"rec: {stats['train_rec']} "
                    f"loss: {stats['train_loss']}]")
                print(
                    f"test[acc: {stats['test_acc']} "
                    f"pre: {stats['test_pre']} "
                    f"rec: {stats['test_rec']} "
                    f"loss: {stats['test_loss']}]")
                stats["time"] = float((time.monotonic() - real_time) / 60.0)
                print(f"epoch time: {stats['time']:.2f}min")

                stats_fname = os.path.join(MODEL_FOLDER, f"{prefix}_stats_{epoch}.json")
                with open(stats_fname, "w") as fout:
                    print(json.dumps(stats, indent=4, sort_keys=True), file=fout)
    
    except KeyboardInterrupt:
        print("TERMINATED BY USER", file=sys.stderr)
    return harness

In [16]:
DF = pd.read_parquet("traintest.pq")
DF

Unnamed: 0,stage,id,db,title,text,tag_3d printing,tag_Acupuntura Urbana,tag_Bebidas Tradicionales,tag_Bicicleta,tag_Bicycle,...,tag_wood gas,tag_wood stove,tag_woodchips,tag_worker safety,tag_youth,tag_youth activism,tag_youth and unemployment,tag_youth empowerment,tag_youth informality,tag_zero waste
0,validation,5019,sm,The power of faith facing of the weakness of t...,The power of faith facing of the weakness of t...,False,False,False,False,False,...,False,False,False,False,False,False,False,True,False,False
1,validation,5164,sm,local three lines power agregetor,local three lines power agregetor\nDIALLO Thie...,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
2,test,5022,sm,Teflon REGULATOR,Teflon REGULATOR\nThe TEFLON or PTFE is stable...,False,False,False,False,False,...,False,False,False,False,False,False,True,False,False,False
3,train,2783,sm,Public Lights auto managed,Public Lights auto managed \nSolution developp...,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
4,test,5021,sm,ORGANIC WASTE Matanizer !,ORGANIC WASTE Matanizer !\nHere is a Solide wa...,False,False,False,False,False,...,False,False,False,False,False,False,False,True,False,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4248,validation,197,exp,Reducing the Use of Single-Use Plastic Bags in...,Reducing the Use of Single-Use Plastic Bags in...,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
4249,validation,362,exp,Using DPPD to identify greater indigenous poli...,Using DPPD to identify greater indigenous poli...,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
4250,validation,252,exp,"Local Convergence: Promoting Agile, Adaptive, ...","Local Convergence: Promoting Agile, Adaptive, ...",False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
4251,validation,241,exp,Marine Litter: Behavioral Insights Experiment ...,Marine Litter: Behavioral Insights Experiment ...,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False


In [17]:
FULL_TRAIN = DF[DF["stage"] == "train"]
FULL_TEST = DF[DF["stage"] == "test"]
FULL_REST = DF[(DF["stage"] != "train") & (DF["stage"] != "test")]
display(FULL_TRAIN)
display(FULL_TEST)
FULL_TRAIN.shape, FULL_TEST.shape, FULL_REST.shape

Unnamed: 0,stage,id,db,title,text,tag_3d printing,tag_Acupuntura Urbana,tag_Bebidas Tradicionales,tag_Bicicleta,tag_Bicycle,...,tag_wood gas,tag_wood stove,tag_woodchips,tag_worker safety,tag_youth,tag_youth activism,tag_youth and unemployment,tag_youth empowerment,tag_youth informality,tag_zero waste
3,train,2783,sm,Public Lights auto managed,Public Lights auto managed \nSolution developp...,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
6,train,4466,sm,"Certificates (birth, marriage): request and de...","Certificates (birth, marriage): request and de...",False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
7,train,5367,sm,Ndinda Gully Rehabilitation Project,Ndinda Gully Rehabilitation Project \nSmallhol...,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
9,train,4357,sm,Improved cook stove,Improved cook stove\nSileshi Abebe Alemayehu ...,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
10,train,5275,sm,[Reciclaje] Compra y Venta Raschell,[Reciclaje] Compra y Venta Raschell\nCONTACTO ...,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3175,train,5954,sm,A Syria geography teacher innovates a new mach...,A Syria geography teacher innovates a new mach...,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
3177,train,5295,sm,PYROLYSE DE PLASTIQUE (PDP),PYROLYSE DE PLASTIQUE (PDP) \nFatma SAGHIR...,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
3178,train,5782,sm,Citizen science on monitoring and mapping of t...,Citizen science on monitoring and mapping of t...,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
3180,train,5965,sm,Citizen Science on monitoring of invasive alie...,Citizen Science on monitoring of invasive alie...,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False


Unnamed: 0,stage,id,db,title,text,tag_3d printing,tag_Acupuntura Urbana,tag_Bebidas Tradicionales,tag_Bicicleta,tag_Bicycle,...,tag_wood gas,tag_wood stove,tag_woodchips,tag_worker safety,tag_youth,tag_youth activism,tag_youth and unemployment,tag_youth empowerment,tag_youth informality,tag_zero waste
2,test,5022,sm,Teflon REGULATOR,Teflon REGULATOR\nThe TEFLON or PTFE is stable...,False,False,False,False,False,...,False,False,False,False,False,False,True,False,False,False
4,test,5021,sm,ORGANIC WASTE Matanizer !,ORGANIC WASTE Matanizer !\nHere is a Solide wa...,False,False,False,False,False,...,False,False,False,False,False,False,False,True,False,False
8,test,4364,sm,Tikikil stove,Tikikil stove\nGIZ\nOffice contact GIZ Office ...,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
13,test,5777,sm,Green Tech Packaging Solutions,Green Tech Packaging Solutions \nDesign and de...,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
19,test,5827,sm,Origins: Personal Incubator,Origins: Personal Incubator\nWhat is Origins?\...,False,False,False,False,False,...,False,False,False,False,False,False,False,True,False,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3166,test,5759,sm,Smart O' ou jardin smart,Smart O' ou jardin smart\n\n\n&nbsp; &nbsp; &n...,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
3167,test,5608,sm,Kizha Intregrate farm_Race Taureau,Kizha Intregrate farm_Race Taureau\n\n\nkizha ...,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
3169,test,4690,sm,Safepad - sustainable sanitary napkin,Safepad - sustainable sanitary napkin\n\n\nSaf...,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
3173,test,4362,sm,Cow dung beehive briquette,Cow dung beehive briquette \n\n\nAstmamgn Amar...,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False


((1000, 888), (1000, 888), (2253, 888))

In [18]:
RNG = np.random.default_rng(42)

In [19]:
TAG_PREFIX = "tag_"
TAG_COLS = [col for col in DF.columns if col.startswith(TAG_PREFIX)]
TAGS = [col[len(TAG_PREFIX):] for col in TAG_COLS]

def get_tags(df: pd.DataFrame, ix: int) -> list[str]:
    cur = df.iloc[ix][TAG_COLS]
    return sorted(col[len(TAG_PREFIX):] for col in cur[cur].index)

TAGS[:10]

['3d printing',
 'Acupuntura Urbana',
 'Bebidas Tradicionales',
 'Bicicleta',
 'Bicycle',
 'Cadenas de Valor',
 'Coca',
 'Conectividad',
 'Construction',
 'Contactlessness']

In [20]:
def clean(text: str) -> str:
    return re.sub("[ \t]+", " ", re.sub("\n[ \t]+", "\n", re.sub("\n\n+", "\n", re.sub("\r", "\n", text.strip()))))

def build_df(full_df: pd.DataFrame, num_rows: int | None, *, prob_correct: float = 0.5) -> pd.DataFrame:
    df_builder = {
        "left": [],
        "right": [],
        "is_correct": [],
    }
    if num_rows is None:
        ixs = list(range(full_df.shape[0]))
    else:
        ixs = RNG.choice(list(range(full_df.shape[0])), num_rows, replace=True)
    for row_ix in ixs:
        cur_row = full_df.iloc[row_ix]
        cur_text = clean(f"{cur_row['title']}:\n{cur_row['text']}")
        cur_tags = get_tags(full_df, row_ix)
        if num_rows is None:
            for tag in TAGS:
                df_builder["left"].append(cur_text)
                df_builder["right"].append(tag)
                df_builder["is_correct"].append(bool(cur_row[f"{TAG_PREFIX}{tag}"]))
        else:
            df_builder["left"].append(cur_text)
            if cur_tags and RNG.random() < prob_correct:
                df_builder["right"].append(RNG.choice(cur_tags, 1)[0])
                df_builder["is_correct"].append(True)
            else:
                tag = RNG.choice(TAGS, 1)[0]
                df_builder["right"].append(tag)
                df_builder["is_correct"].append(bool(cur_row[f"{TAG_PREFIX}{tag}"]))
    return pd.DataFrame(df_builder, columns=["left", "right", "is_correct"])

In [21]:
TEST_DF = build_df(FULL_TEST, None).groupby("is_correct", group_keys=False).apply(lambda x: x.sample(min(x.shape[0], 7000)))
TEST_DF = TEST_DF.sort_values("left")
TEST_DF

Unnamed: 0,left,right,is_correct
595947,"""Ai Nono"" - Traditional practice of construct...",transfer of knowledge and technology,True
595511,"""Ai Nono"" - Traditional practice of construct...",fishery,True
595779,"""Ai Nono"" - Traditional practice of construct...",pumping,False
595707,"""Ai Nono"" - Traditional practice of construct...",organic fertilizer,False
596013,"""Ai Nono"" - Traditional practice of construct...",women-led solution,False
...,...,...,...
786930,“Books on the Road”:\n“Books on the Road”\nMee...,commons,True
787565,“Books on the Road”:\n“Books on the Road”\nMee...,trust,False
787630,“Books on the Road”:\n“Books on the Road”\nMee...,youth,True
786780,“Books on the Road”:\n“Books on the Road”\nMee...,Library,True


In [None]:
harness = run_training(
    prefix=f"v{VERSION}",
    num_epochs=6,
    train_size=12000,
    train_df_gen=lambda size: build_df(FULL_TRAIN, size, prob_correct=0.6),
    test_df=TEST_DF,
    config={
        "agg": "cls",
        "use_cos": True,
    },
    resume=True,
    skip_test=True)

epoch 0 train size 12000 test size 10000


train:   0%|          | 0/12000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train[acc: 0.6125 pre: 0.6098375451263538 rec: 0.9538396386222473 loss: 0.8007726018056274]
test[acc: 0.4755 pre: 0.34971214352657654 rec: 0.8706666666666667 loss: 0.797020734464236]
epoch time: 57.04min
epoch 1 train size 12000 test size 10000


train:   0%|          | 0/12000 [00:00<?, ?it/s]

test:   0%|          | 0/10000 [00:00<?, ?it/s]

train[acc: 0.65475 pre: 0.6489183048503409 rec: 0.9177144453758033 loss: 0.8212587606897578]
test[acc: 0.3 pre: 0.3 rec: 1.0 loss: 3.243213752762591]
epoch time: 53.89min
epoch 2 train size 12000 test size 10000


train:   0%|          | 0/12000 [00:00<?, ?it/s]

In [None]:
def compute_tag_matrix(harness: TrainingHarness) -> pd.DataFrame:
    plan = []
    for lix, left in enumerate(TAGS):
        for rix, right in enumerate(TAGS):
            if lix > rix:
                continue
            plan.append((lix, rix, left, right))
    
    res = np.zeros((len(TAGS), len(TAGS)))
    batch_size = 16
    model = harness.get_model()
    model.eval()
    harness.eval()
    with torch.no_grad():
        with tqdm(desc="tags", total=len(plan)) as progress_bar:
            for chunk in np.array_split(plan, np.ceil(len(plan) / batch_size)):
                lixs, rixs, left_strs, right_strs = tuple(zip(*chunk))
                lefts = tokens(left_strs, True)
                rights = tokens(right_strs, True)   
                preds = harness(
                    left=lefts,
                    right=rights)
                for lix, rix, val in zip(lixs, rixs, preds.ravel().tolist()):
                    res[int(lix), int(rix)] = val
                    res[int(rix), int(lix)] = val
                progress_bar.update(chunk.shape[0])
    return pd.DataFrame(res, columns=TAGS, index=TAGS)

In [None]:
if os.path.exists("tags.pq"):
    tag_matrix = pd.read_parquet("tags.pq")
else:
    tag_matrix = compute_tag_matrix(harness)
    tag_matrix.to_parquet("tags.pq")
tag_matrix

In [None]:
(1.0 - tag_matrix.to_numpy()).round(6)

In [None]:
tag_linkage = hc.linkage(sp.distance.squareform((1.0 - tag_matrix.to_numpy()).round(6)), method="average")

In [None]:
sns.clustermap(
    tag_matrix,
    row_linkage=tag_linkage,
    col_linkage=tag_linkage,
    cmap="OrRd",
    figsize=(300, 300),
    linewidth=0)
plt.savefig("dendrogram_large.png", bbox_inches="tight")
plt.close()

In [None]:
sns.clustermap(
    tag_matrix,
    row_linkage=tag_linkage,
    col_linkage=tag_linkage,
    cmap="OrRd",
    figsize=(15, 15),
    linewidth=0)
plt.savefig("dendrogram.png", bbox_inches="tight")

In [None]:
len(TAGS)

In [None]:
def get_preds(df: pd.DataFrame, harness: TrainingHarness) -> pd.DataFrame:
    batch_size = 16
    model = harness.get_model()
    model.eval()
    harness.eval()
    all_labels = []
    all_preds = []
    with torch.no_grad():
        with tqdm(desc="data", total=df.shape[0]) as progress_bar:
            for chunk in np.array_split(df, np.ceil(df.shape[0] / batch_size)):
                preds, _ = compute(harness, chunk, preserve_text=True)

                progress_bar.update(chunk.shape[0])
                all_preds.extend(preds.ravel().tolist())
                all_labels.extend(chunk["is_correct"].astype(int).to_list())
    return pd.DataFrame({"preds": all_preds, "labels": all_labels}, columns=["preds", "labels"])

In [None]:
preds = get_preds(TEST_DF, harness)
preds

In [None]:
preds["labels"].astype(bool).describe()

In [None]:
precision, recall, thresholds = precision_recall_curve(preds["labels"], preds["preds"], pos_label=1)
f1_scores = 2 * recall * precision / (recall + precision)
best_th_ix = np.argmax(f1_scores)
best_thresh = thresholds[best_th_ix]
average_precision = average_precision_score(preds["labels"], preds["preds"], pos_label=1)
display = PrecisionRecallDisplay(
    precision=precision,
    recall=recall,
    average_precision=average_precision,
    estimator_name="Model",
    pos_label=1)
display.plot(name="Model")
display.ax_.set_title("Test Data")
display.ax_.plot(recall[best_th_ix], precision[best_th_ix], "ro", label=f"f1max (th = {best_thresh:.2f})")
display.ax_.legend()
None

In [None]:
def add_embeds(
        df: pd.DataFrame,
        harness: TrainingHarness,
        name: str,
        embeds: list[np.ndarray],
        names: list[str]) -> None:
    batch_size = 8
    model = harness.get_model()
    model.eval()
    with torch.no_grad():
        with tqdm(desc=name, total=df.shape[0]) as progress_bar:
            for chunk in np.array_split(df, np.ceil(df.shape[0] / batch_size)):
                embed = model(tokens([clean(txt) for txt in chunk["text"].tolist()], False))
                embeds.append(embed.numpy(force=True))
                names.extend([name] * chunk.shape[0])
                
                progress_bar.update(chunk.shape[0])

In [None]:
def get_all_embeds(harness: TrainingHarness) -> tuple[np.ndarray, list[str]]:
    embeds = []
    names = []
    add_embeds(pd.DataFrame({"text": FULL_REST["title"].astype(str)}), harness, "title", embeds, names)
    add_embeds(
        pd.DataFrame({"text": FULL_REST["title"].astype(str) + ":\n" + FULL_REST["text"].astype(str)}),
        harness,
        "full",
        embeds,
        names)
    add_embeds(pd.DataFrame({"text": TAGS}), harness, "tags", embeds, names)
    return np.vstack(embeds), names

In [None]:
ALL_EMBEDS, ALL_NAMES = get_all_embeds(harness)
ALL_EMBEDS.shape[0], len(ALL_NAMES)

In [None]:
CMAP = {
    "title": "tab:blue",
    "tags": "tab:orange",
    "full": "tab:green",
}

In [None]:
if os.path.exists("embeds.pq"):
    ALL_EMBED_DF = pd.read_parquet("embeds.pq")
else:
    ALL_TSNE_EMBEDS = TSNE(
        n_components=2,
        learning_rate="auto",
        init="random",
        # perplexity=10,
        method="barnes_hut",
        random_state=42,
        metric="cosine",
        n_jobs=-1).fit_transform(ALL_EMBEDS)
    ALL_EMBED_DF = pd.DataFrame({
        "x": ALL_TSNE_EMBEDS[:, 0],
        "y": ALL_TSNE_EMBEDS[:, 1],
        "l": [CMAP[name] for name in ALL_NAMES],
        "cat": ALL_NAMES})
    ALL_EMBED_DF.to_parquet("embeds.pq")
ALL_EMBED_DF

In [None]:
sns.scatterplot(
    data=ALL_EMBED_DF,
    x="x",
    y="y",
    s=1.0,
    hue="cat",
    hue_order=["tags", "title", "full"])

In [None]:
def scatter_embed(all_embeds, all_names, filter_name):
    tsne_embed = TSNE(
        n_components=2,
        learning_rate="auto",
        init="random",
        # perplexity=10,
        method="barnes_hut",
        random_state=42,
        metric="cosine",
        n_jobs=-1).fit_transform(
            all_embeds[[
                ix
                for ix, name in enumerate(all_names)
                if name == filter_name
            ], :])
    embed_df = pd.DataFrame({
        "x": tsne_embed[:, 0],
        "y": tsne_embed[:, 1]})
    sns.scatterplot(
        data=embed_df,
        x="x",
        y="y",
        s=4.0).set(title=filter_name)

In [None]:
scatter_embed(ALL_EMBEDS, ALL_NAMES, "tags")

In [None]:
scatter_embed(ALL_EMBEDS, ALL_NAMES, "title")

In [None]:
scatter_embed(ALL_EMBEDS, ALL_NAMES, "full")

In [None]:
#  Using the embeddings
# ======================

# 1. remove (clusters of) duplicate tags
# 2. suggest tags for a given title or title + text
# 3. allow to add new tags and suggest documents which might also fit that tag
# 4. prompt for searching related documents

In [None]:
def add_db(x: pd.DataFrame) -> pd.DataFrame:
    embeds = []
    add_embeds(pd.DataFrame({"text": x["text"].astype(str)}), harness, "", embeds, [])
    all_embeds = np.vstack(embeds)
    x = x[[col for col in x.columns if not col.startswith(TAG_PREFIX) and col != "stage"]]
    x["embedding"] = all_embeds.tolist()
    return x

LOOKUP_DF = DF.groupby("db", group_keys=False).apply(add_db)
LOOKUP_DF

In [None]:
def from_embeds(df: pd.DataFrame) -> np.ndarray:
    return np.array(df["embedding"].to_list())

In [None]:
LOOKUP_EMBEDS = from_embeds(LOOKUP_DF)
LOOKUP_EMBEDS

In [None]:
LOOKUP_EMBEDS.shape

In [None]:
def single_embed(
        prompt: str,
        harness: TrainingHarness) -> np.ndarray:
    model = harness.get_model()
    model.eval()
    with torch.no_grad():
        embed = model(tokens([prompt], False))
        return embed.numpy(force=True)

In [None]:
def get_distances(embeds: np.ndarray, single: np.ndarray) -> np.ndarray:
    x = embeds
    y = single.ravel().reshape((-1, 1))
    return 1.0 - np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))

In [None]:
def search(
        df: pd.DataFrame,
        embeds: np.ndarray | None,
        prompt: str,
        harness: TrainingHarness) -> pd.DataFrame:
    if embeds is None:
        embeds = from_embeds(df)
    dists = get_distances(embeds, single_embed(prompt, harness)).ravel()
    res = df[["db", "title", "text"]].copy()
    res["dist"] = dists
    return res.sort_values("dist", ascending=True)

In [None]:
# "biogas from waste"
# "national innovation ecosystem and scalability"
# "social justice"
search_results = search(LOOKUP_DF, LOOKUP_EMBEDS, "biogas from waste", harness)
# search_results = search_results[search_results["db"] == "ap"]
search_results.head(5)

In [None]:
for s_ix in range(5):
    s_cur = search_results.iloc[s_ix]
    title = clean(s_cur["title"])
    bar_size = len(title) + 2
    print(f"┌{'─' * bar_size}┐")
    print(f"│ {title} │")
    print(f"└{'─' * bar_size}┘")
    print(clean(s_cur["text"]))
    print()

In [None]:
# Data Sources
# 1. direct query feedback (like/dislike)
# 2. recording query and last clicked link

In [None]:
# concrete clusters

In [None]:
tag_matrix

In [None]:
def get_tag_clusters():
    cols = tag_matrix.columns
    cids = hc.fcluster(tag_linkage, 0.08, criterion="distance")
    res = collections.defaultdict(list)
    for col, cid in zip(cols, cids):
        res[int(cid)].append(f"{col}")
    return dict(res)
    
tag_clusters = get_tag_clusters()
len(tag_clusters)

In [None]:
sns.histplot([len(c) for c in tag_clusters.values()], discrete=True)

In [None]:
with open("tag_clusters.json", "w") as fout:
    print(json.dumps(tag_clusters, indent=2, sort_keys=True), file=fout)

In [None]:
LOOKUP_DF

In [None]:
def compute_doc_matrix(db: str) -> np.ndarray:
    df = LOOKUP_DF[LOOKUP_DF["db"] == db]
    res = np.zeros((df.shape[0], df.shape[0]))
    with tqdm(desc=db, total=df.shape[0] * df.shape[0]) as progress_bar:
        for rix in range(df.shape[0]):
            row = df.iloc[rix]
            for cix in range(df.shape[0]):
                progress_bar.update(1)
                if cix > rix:
                    continue
                col = df.iloc[cix]
                dist = get_distances(
                    np.array(row["embedding"]).ravel().reshape((1, -1)),
                    np.array(col["embedding"]))[0]
                res[rix, cix] = dist
                res[cix, rix] = dist
    return res, df["text"].astype(str)

def compute_doc_linkage(doc_matrix):
    return hc.linkage(sp.distance.squareform(doc_matrix.round(10)), method="average")

def get_doc_clusters(doc_linkage, prompts, *, th):
    cids = hc.fcluster(doc_linkage, th, criterion="distance")
    res = collections.defaultdict(list)
    for prompt, cid in zip(prompts, cids):
        res[int(cid)].append(f"{prompt}")
    return dict(res)

def get_doc_clusters_from_db(db: str, th: float) -> dict[int, str]:
    doc_matrix, doc_prompts = compute_doc_matrix(db)
    doc_linkage = compute_doc_linkage(doc_matrix)
    return get_doc_clusters(doc_linkage, doc_prompts, th=th)

In [None]:
sm_doc_matrix, sm_doc_prompts = compute_doc_matrix("sm")
sm_doc_matrix

In [None]:
sm_doc_matrix.shape, sm_doc_matrix.round(10)

In [None]:
sm_doc_linkage = compute_doc_linkage(sm_doc_matrix)

In [None]:
sm_doc_clusters = get_doc_clusters(sm_doc_linkage, sm_doc_prompts, th=0.03)
len(sm_doc_clusters)

In [None]:
sns.histplot([len(c) for c in sm_doc_clusters.values()], discrete=True)

In [None]:
with open("sm_doc_clusters.json", "w") as fout:
    print(json.dumps(sm_doc_clusters, indent=2, sort_keys=True), file=fout)

In [None]:
ap_doc_clusters = get_doc_clusters_from_db("ap", 0.03)
sns.histplot([len(c) for c in ap_doc_clusters.values()], discrete=True)

In [None]:
len(ap_doc_clusters)

In [None]:
with open("ap_doc_clusters.json", "w") as fout:
    print(json.dumps(ap_doc_clusters, indent=2, sort_keys=True), file=fout)

In [None]:
sns.histplot([len(clean(txt)) for txt in LOOKUP_DF["text"]], discrete=True)

In [None]:
ix = np.argmax([len(clean(txt)) for txt in LOOKUP_DF["text"]])
ix, DF.iloc[ix], len(LOOKUP_DF.iloc[ix]["text"])