In [1]:
from terms.model.data import TermsDataModule
from terms.model.metrics import get_metrics
from terms.model.train import TermsTrainer
from terms.model.module import TermsModule

from terms.config import BaseLoraConfig
from terms.preprocess import preprocess, subsample
from terms.schemas import TermsDataModel
from terms.constants import COL_CLASSES, COL_LABELS, COL_EXAMINER_DECISION

import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
PRETRAINED_MODEL ="BAAI/bge-small-en-v1.5"
MODEL_DIR = f"./model_{PRETRAINED_MODEL.replace("/", "_")}"

In [3]:
FILENAME_NICE = "../data/alphabetical_list.csv"
FILENAME_DB = "../data/en_all_list.parquet"

df_nice = pd.read_csv(FILENAME_NICE)
df_db = pd.read_parquet(FILENAME_DB)

In [4]:
df_nice = df_nice[df_nice["language"]=="en"]
df_nice = df_nice.rename(columns={"class_number" : "NiceClass", "term": "Terms"})

list_cols = [col for col in df_nice.columns if col in TermsDataModel.__annotations__.keys()]
df_nice = df_nice[list_cols]

In [5]:
map_cols = {COL_CLASSES : TermsDataModel.NiceClass, COL_EXAMINER_DECISION : TermsDataModel.Terms, "ID" : TermsDataModel.Id}
list_cols = list(TermsDataModel.__annotations__.keys())
df_db = df_db.rename(columns=map_cols)
df_db = df_db[list_cols]

In [6]:
df_nice_pre = preprocess(data=df_nice, remove_duplicate_terms=True)
df_db_pre = preprocess(data=df_db, remove_duplicate_terms=True)
df = subsample(data_base=df_db_pre, data_complementary=df_db_pre, threshold_per_class=80)
df.to_parquet("data_small.parquet")


map_id_to_nice = {id: nice for id, nice in enumerate(df[TermsDataModel.NiceClass].unique().tolist())}
map_nice_to_id = {nice : id for id, nice in map_id_to_nice.items()} 

df[TermsDataModel.NiceClass] = df[TermsDataModel.NiceClass].map(map_nice_to_id)

In [7]:
from sklearn.model_selection import train_test_split

df_train, df_temp = train_test_split(
    df, 
    test_size=0.3, 
    shuffle=True, 
    stratify=df[TermsDataModel.NiceClass],
    random_state=42
)

df_test, df_val = train_test_split(
    df_temp, 
    test_size=0.5, 
    shuffle=True, 
    stratify=df_temp[TermsDataModel.NiceClass], 
    random_state=42
)

num_classes = len(df_train.NiceClass.unique())

In [8]:
tokenizer_kwargs=dict(                 
    padding="max_length",
    truncation=True,
    max_length=100,
)

pl_datamodule = TermsDataModule(
    df_train = df_train,
    df_val = df_val,
    df_test = df_test,
    pretrained_model_name=PRETRAINED_MODEL,
    tokenizer_kwargs=tokenizer_kwargs,
    model_dir=MODEL_DIR
)

2025-07-13 16:26:37,468 - [94mInitialising TermsDataModule …[0m
2025-07-13 16:26:37,469 - [94mBatch size: 32, workers: 14, persistent: True[0m
2025-07-13 16:26:37,470 - [94mTokenizer kwargs: {'padding': 'max_length', 'truncation': True, 'max_length': 100}[0m
2025-07-13 16:26:37,470 - [94mLoading tokenizer: BAAI/bge-small-en-v1.5[0m


In [9]:
metrics = get_metrics(num_classes=num_classes, top_k=[1,3])

In [11]:
import torch
from peft import LoraConfig, TaskType
from transformers import BitsAndBytesConfig

lora_cfg = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules="all-linear",
    bias="none",
    task_type=TaskType.SEQ_CLS,        
    modules_to_save=["classifier"],    
)

# quantisation_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_quant_type="nf4",
#     bnb_4bit_use_double_quant=True,
#     bnb_4bit_compute_dtype=torch.bfloat16,
# )


pl_model = TermsModule.from_peft_config(
    pretrained_model_name=PRETRAINED_MODEL,
    num_classes=num_classes,
    metrics=metrics,
    lora_config=lora_cfg,
    quantization_config=None,
)

print("pad_token_id:", pl_model.model.base_model.config.pad_token_id)

2025-07-13 16:27:14,927 - [94mLoading base model 'BAAI/bge-small-en-v1.5' with 45 classes.[0m
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at BAAI/bge-small-en-v1.5 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
2025-07-13 16:27:15,218 - [94mPeft model created successfully from config.[0m
2025-07-13 16:27:15,225 - [94mInitialized TermsModule with model: BAAI/bge-small-en-v1.5[0m


pad_token_id: 0


In [12]:
trainer = TermsTrainer(
    pl_datamodule=pl_datamodule,
    pl_model=pl_model,
    max_epochs=100,
    model_dir = f"./model_{PRETRAINED_MODEL.replace("/", "_")}",
    precision = "bf16-mixed"
)

Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [13]:
trainer.train()

2025-07-13 16:27:21,948 - [94mSetting the seed value to {seed}[0m
Seed set to 42
2025-07-13 16:27:21,953 - [94mStart training the model...[0m
2025-07-13 16:27:21,954 - [94mLabel mapping: {np.int64(0): 0, np.int64(1): 1, np.int64(2): 2, np.int64(3): 3, np.int64(4): 4, np.int64(5): 5, np.int64(6): 6, np.int64(7): 7, np.int64(8): 8, np.int64(9): 9, np.int64(10): 10, np.int64(11): 11, np.int64(12): 12, np.int64(13): 13, np.int64(14): 14, np.int64(15): 15, np.int64(16): 16, np.int64(17): 17, np.int64(18): 18, np.int64(19): 19, np.int64(20): 20, np.int64(21): 21, np.int64(22): 22, np.int64(23): 23, np.int64(24): 24, np.int64(25): 25, np.int64(26): 26, np.int64(27): 27, np.int64(28): 28, np.int64(29): 29, np.int64(30): 30, np.int64(31): 31, np.int64(32): 32, np.int64(33): 33, np.int64(34): 34, np.int64(35): 35, np.int64(36): 36, np.int64(37): 37, np.int64(38): 38, np.int64(39): 39, np.int64(40): 40, np.int64(41): 41, np.int64(42): 42, np.int64(43): 43, np.int64(44): 44}[0m
2025-07-13 16

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

2025-07-13 16:27:22,180 - [94mCreating validation DataLoader …[0m


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]



                                                                           

2025-07-13 16:27:45,402 - [94mCreating training DataLoader …[0m


Training: |          | 0/? [00:00<?, ?it/s]

2025-07-13 16:28:08,817 - [94mTraining started.[0m


Epoch 14: 100%|██████████| 79/79 [00:08<00:00,  9.25it/s, v_num=d54d]



Detected KeyboardInterrupt, attempting graceful shutdown ...


RuntimeError: Please call `iter(combined_loader)` first.

In [None]:
trainer.test()

: 

In [None]:
import torch
from transformers import AutoConfig, AutoModelForSequenceClassification
from peft import PeftModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

cfg          = AutoConfig.from_pretrained("BAAI/bge-small-en-v1.5")
cfg.num_labels = 45

base = AutoModelForSequenceClassification.from_pretrained(
          "BAAI/bge-small-en-v1.5", config=cfg)

model = PeftModel.from_pretrained(
          model     = base,
          model_id  ="./model_BAAI_bge-small-en-v1.5/checkpoints/peft_adapter",
)
model.eval().to(device)


: 

In [None]:
import torch
from torchmetrics import MetricCollection, Accuracy, Precision, Recall

metrics = MetricCollection({
    "acc_macro"    : Accuracy(task="multiclass", num_classes=45, average="macro", top_k=1),
    "acc_micro"    : Accuracy(task="multiclass", num_classes=45, average="micro",top_k=1),
    "acc_weighted" : Accuracy(task="multiclass", num_classes=45, average="weighted",top_k=1),
    "prec_macro"   : Precision(task="multiclass", num_classes=45, average="macro",top_k=1),
    "prec_micro"   : Precision(task="multiclass", num_classes=45, average="micro",top_k=1),
    "prec_weighted": Precision(task="multiclass", num_classes=45, average="weighted",top_k=1),
    "rec_macro"    : Recall(task="multiclass", num_classes=45, average="macro",top_k=1),
    "rec_micro"    : Recall(task="multiclass", num_classes=45, average="micro",top_k=1),
    "rec_weighted" : Recall(task="multiclass", num_classes=45, average="weighted",top_k=1),
}).to(device)

model.eval()
with torch.no_grad():
    for batch in pl_datamodule.test_dataloader():
        labels = batch["labels"].to(device)
        inputs = {k: v.to(device) for k, v in batch.items() if k != "labels"}
        probs  = model(**inputs).logits.softmax(dim=-1)
        metrics.update(probs, labels)

print(metrics.compute())


: 

In [None]:
from peft import PeftModel
from transformers import AutoModelForSequenceClassification

# (a) rebuild the PEFT model that belongs to this checkpoint
base_model = AutoModelForSequenceClassification.from_pretrained(
    "BAAI/bge-small-en-v1.5",
    num_labels=45,
)
adapter_dir = "./model_BAAI_bge-small-en-v1.5/checkpoints/peft_adapter"
peft_model = PeftModel.from_pretrained(base_model, adapter_dir)

# (b) metrics collection (must match what you used in training)
metrics = get_metrics(num_classes=45, top_k=[1, 3])

# (c) now load the LightningModule
best_ckpt = (
    "./model_BAAI_bge-small-en-v1.5/checkpoints/"
    "BAAI_bge-small-en-v1.5_epoch=15_val_loss=1.27.ckpt"
)

pl_module = TermsModule.load_from_checkpoint(
    best_ckpt,
    model=peft_model,          # <-- supply the args Lightning doesn't have
    metrics=metrics,           #     (anything you put in `ignore=…`)
    strict=True,               # optional: keep default
)


: 

: 

: 