In [None]:
from terms.model.data import TermsDataModule
from terms.model.metrics import get_metrics
from terms.model.train import TermsTrainer
from terms.model.module import TermsModule

from terms.config import BaseLoraConfig
from terms.preprocess import preprocess, subsample
from terms.schemas import TermsDataModel
from terms.constants import COL_CLASSES, COL_LABELS, COL_EXAMINER_DECISION

import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
PRETRAINED_MODEL ="BAAI/bge-small-en-v1.5"
MODEL_DIR = f"./model_{PRETRAINED_MODEL.replace("/", "_")}"

In [3]:
FILENAME_NICE = "../data/alphabetical_list.csv"
FILENAME_DB = "../data/en_all_list.parquet"

df_nice = pd.read_csv(FILENAME_NICE)
df_db = pd.read_parquet(FILENAME_DB)

In [None]:
df_nice = df_nice[df_nice["language"]=="en"]
df_nice = df_nice.rename(columns={"class_number" : "NiceClass", "term": "Terms"})

list_cols = [col for col in df_nice.columns if col in TermsDataModel.__annotations__.keys()]
df_nice = df_nice[list_cols]

In [None]:
map_cols = {COL_CLASSES : TermsDataModel.NiceClass, COL_EXAMINER_DECISION : TermsDataModel.Terms, "ID" : TermsDataModel.Id}
list_cols = list(TermsDataModel.__annotations__.keys())
df_db = df_db.rename(columns=map_cols)
df_db = df_db[list_cols]

In [None]:
df_nice_pre = preprocess(dataframe=df_nice, remove_duplicate_terms=True)
df_db_pre = preprocess(dataframe=df_db, remove_duplicate_terms=True)
df = subsample(data_base=df_db_pre, data_complementary=df_db_pre, threshold_per_class=80)
df.to_parquet("data_small.parquet")


map_id_to_nice = {id: nice for id, nice in enumerate(df[TermsDataModel.NiceClass].unique().tolist())}
map_nice_to_id = {nice : id for id, nice in map_id_to_nice.items()} 

df[TermsDataModel.NiceClass] = df[TermsDataModel.NiceClass].map(map_nice_to_id)

In [None]:
from sklearn.model_selection import train_test_split

df_train, df_temp = train_test_split(
    df, 
    test_size=0.3, 
    shuffle=True, 
    stratify=df[TermsDataModel.NiceClass],
    random_state=42
)

df_test, df_val = train_test_split(
    df_temp, 
    test_size=0.5, 
    shuffle=True, 
    stratify=df_temp[TermsDataModel.NiceClass], 
    random_state=42
)

num_classes = len(df_train.NiceClass.unique())

In [None]:
tokenizer_kwargs=dict(                 
    padding="max_length",
    truncation=True,
    max_length=100,
)

pl_datamodule = TermsDataModule(
    df_train = df_train,
    df_val = df_val,
    df_test = df_test,
    pretrained_model_name=PRETRAINED_MODEL,
    tokenizer_kwargs=tokenizer_kwargs,
    model_dir=MODEL_DIR
)

In [None]:
metrics = get_metrics(num_classes=num_classes, top_k=[1,3])

In [None]:
import torch
from peft import LoraConfig, TaskType
from transformers import BitsAndBytesConfig

lora_cfg = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules="all-linear",
    bias="none",
    task_type=TaskType.SEQ_CLS,        
    modules_to_save=["classifier"],    
)

# quantisation_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_quant_type="nf4",
#     bnb_4bit_use_double_quant=True,
#     bnb_4bit_compute_dtype=torch.bfloat16,
# )


pl_model = TermsModule.from_peft_config(
    pretrained_model_name=PRETRAINED_MODEL, 
    num_classes=num_classes,
    metrics = metrics,
    lora_config=lora_cfg,
    # quantisation_config = quantisation_config
    )

print("pad_token_id:", pl_model.model.base_model.config.pad_token_id)

In [None]:
trainer = TermsTrainer(
    pl_datamodule=pl_datamodule,
    pl_model=pl_model,
    max_epochs=100,
    model_dir = f"./model_{PRETRAINED_MODEL.replace("/", "_")}",
    precision = "bf16-mixed"
)

In [None]:
trainer.train()

In [None]:
trainer.test()

In [None]:
import torch
from transformers import AutoConfig, AutoModelForSequenceClassification
from peft import PeftModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

cfg          = AutoConfig.from_pretrained("BAAI/bge-small-en-v1.5")
cfg.num_labels = 45

base = AutoModelForSequenceClassification.from_pretrained(
          "BAAI/bge-small-en-v1.5", config=cfg)

model = PeftModel.from_pretrained(
          model     = base,
          model_id  ="./model_BAAI_bge-small-en-v1.5/checkpoints/peft_adapter",
)
model.eval().to(device)


In [None]:
import torch
from torchmetrics import MetricCollection, Accuracy, Precision, Recall

metrics = MetricCollection({
    "acc_macro"    : Accuracy(task="multiclass", num_classes=45, average="macro", top_k=1),
    "acc_micro"    : Accuracy(task="multiclass", num_classes=45, average="micro",top_k=1),
    "acc_weighted" : Accuracy(task="multiclass", num_classes=45, average="weighted",top_k=1),
    "prec_macro"   : Precision(task="multiclass", num_classes=45, average="macro",top_k=1),
    "prec_micro"   : Precision(task="multiclass", num_classes=45, average="micro",top_k=1),
    "prec_weighted": Precision(task="multiclass", num_classes=45, average="weighted",top_k=1),
    "rec_macro"    : Recall(task="multiclass", num_classes=45, average="macro",top_k=1),
    "rec_micro"    : Recall(task="multiclass", num_classes=45, average="micro",top_k=1),
    "rec_weighted" : Recall(task="multiclass", num_classes=45, average="weighted",top_k=1),
}).to(device)

model.eval()
with torch.no_grad():
    for batch in pl_datamodule.test_dataloader():
        labels = batch["labels"].to(device)
        inputs = {k: v.to(device) for k, v in batch.items() if k != "labels"}
        probs  = model(**inputs).logits.softmax(dim=-1)
        metrics.update(probs, labels)

print(metrics.compute())


In [None]:
from peft import PeftModel
from transformers import AutoModelForSequenceClassification

# (a) rebuild the PEFT model that belongs to this checkpoint
base_model = AutoModelForSequenceClassification.from_pretrained(
    "BAAI/bge-small-en-v1.5",
    num_labels=45,
)
adapter_dir = "./model_BAAI_bge-small-en-v1.5/checkpoints/peft_adapter"
peft_model = PeftModel.from_pretrained(base_model, adapter_dir)

# (b) metrics collection (must match what you used in training)
metrics = get_metrics(num_classes=45, top_k=[1, 3])

# (c) now load the LightningModule
best_ckpt = (
    "./model_BAAI_bge-small-en-v1.5/checkpoints/"
    "BAAI_bge-small-en-v1.5_epoch=15_val_loss=1.27.ckpt"
)

pl_module = TermsModule.load_from_checkpoint(
    best_ckpt,
    model=peft_model,          # <-- supply the args Lightning doesn't have
    metrics=metrics,           #     (anything you put in `ignore=…`)
    strict=True,               # optional: keep default
)
