In [2]:
%load_ext autoreload
%autoreload 2

from model import SiglipStyleModel, ColSentenceModel
from utils import get_train_and_test_data

# loss = "siglip"
loss = "clip"
batch_size = 128
# epochs = 2 * batch_size
epochs = 10
lr = 1e-6
eval_batch = 250
architecture = "ColSent"

models = {
    "ColSent": ColSentenceModel,
    "Siglip": SiglipStyleModel,
}
# model = RetrieverModel(loss_type=loss)
model = models[architecture](loss_type=loss)
model.load("../../clip/ColSent/bert-mini/b64_lr1E-06_microsoft/ms_marcov2.1/model.safetensors")
model.use_max_sim = False

data_paths = [
    ("microsoft/ms_marco", "v1.1"),
    ("microsoft/ms_marco", "v2.1"),
]
train_data, test_data = get_train_and_test_data(data_paths)

lr_n = "" if lr == 1e-7 else f"lr{lr:.0E}_"
b_n = "" if batch_size == 2 else f"b{batch_size}_"

model_name = model.model_name.split("/")[-1]
model_path = f"{loss}/{architecture}/{model_name}/{b_n}{lr_n}{data_paths[0][0]}{data_paths[0][1]}"
print(model_path)
print(train_data)
print(test_data)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


[nltk_data] Downloading package punkt_tab to /home/jan-
[nltk_data]     malte/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


clip/ColSent/bert-mini/b128_lr1E-06_microsoft/ms_marcov1.1
Dataset({
    features: ['answers', 'passages', 'query', 'query_id', 'query_type', 'wellFormedAnswers'],
    num_rows: 82326
})
{'microsoft/ms_marcov1.1': Dataset({
    features: ['answers', 'passages', 'query', 'query_id', 'query_type', 'wellFormedAnswers'],
    num_rows: 5000
}), 'microsoft/ms_marcov2.1': Dataset({
    features: ['answers', 'passages', 'query', 'query_id', 'query_type', 'wellFormedAnswers'],
    num_rows: 5000
})}


In [3]:
import os
from transformers import Trainer, TrainingArguments
from callbacks import NotebookProgressCallbackNoTable
from transformers.utils.notebook import NotebookProgressCallback
from evaluation import compute_metrics
import wandb
from utils import collate_fn
from transformers.training_args import OptimizerNames

os.environ["WANDB_PROJECT"] = "MSE"
os.environ["WANDB_LOG_MODEL"] = "false"
wandb.init(entity="mse-jan-simon", name=model_path)

training_args = TrainingArguments(
    output_dir="models/" + model_path,
    per_device_train_batch_size=batch_size,
    num_train_epochs=epochs,
    learning_rate=lr,
    save_steps=1000,
    save_total_limit=1,
    remove_unused_columns=False,
    bf16=True,
    optim=OptimizerNames.ADAMW_8BIT,
    logging_steps=100,
    eval_steps=200,
    eval_strategy="steps",
    eval_on_start=True,
    per_device_eval_batch_size=eval_batch,
    report_to='wandb',
    lr_scheduler_type='constant_with_warmup',
    warmup_steps=500,
    # max_steps=2000,
)

trainer = Trainer(
    model,
    training_args,
    train_dataset=train_data,
    eval_dataset=test_data,
    data_collator=collate_fn,
    compute_metrics=compute_metrics
)

trainer.remove_callback(NotebookProgressCallback)
trainer.add_callback(NotebookProgressCallbackNoTable)

trainer.evaluate()
# trainer.train()
# try: trainer.train(resume_from_checkpoint=True)
# except: trainer.train(resume_from_checkpoint=False)

[34m[1mwandb[0m: Currently logged in as: [33mjanmaltegiannikos[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




{'eval_microsoft/ms_marcov1.1_loss': 5.336982250213623,
 'eval_microsoft/ms_marcov1.1_model_preparation_time': 0.0016,
 'eval_microsoft/ms_marcov1.1_recall@1': 0.1682,
 'eval_microsoft/ms_marcov1.1_recall@2': 0.2708,
 'eval_microsoft/ms_marcov1.1_recall@4': 0.3898,
 'eval_microsoft/ms_marcov1.1_recall@8': 0.5294,
 'eval_microsoft/ms_marcov1.1_recall@16': 0.6676,
 'eval_microsoft/ms_marcov1.1_recall@32': 0.8108,
 'eval_microsoft/ms_marcov1.1_mean_rank': 18.5364,
 'eval_microsoft/ms_marcov1.1_median_rank': 6.0,
 'eval_microsoft/ms_marcov1.1_mean_rank_norm': 0.0741456,
 'eval_microsoft/ms_marcov1.1_median_rank_norm': 0.024,
 'eval_microsoft/ms_marcov1.1_min_rank': 213,
 'eval_microsoft/ms_marcov1.1_min_rank_norm': 0.852,
 'eval_microsoft/ms_marcov1.1_recall@1%': 0.3428,
 'eval_microsoft/ms_marcov1.1_recall@2%': 0.4322,
 'eval_microsoft/ms_marcov1.1_recall@5%': 0.6282,
 'eval_microsoft/ms_marcov1.1_recall@10%': 0.7646,
 'eval_microsoft/ms_marcov1.1_recall@25%': 0.924,
 'eval_microsoft/ms_m