In [1]:
import wandb
wandb.init(mode="disabled")

from transformers import AutoTokenizer, DataCollatorWithPadding, Trainer
from src.config import NUM_TYPES, NUM_MANIFESTATIONS, MODEL_NAMES
from src.data import load_data, prepare_datasets
from src.model import SharedMTLModel
from src.metrics import compute_metrics
from src.training import compute_pos_weights, get_training_args, get_early_stopping_callback
from src.predict import predict_dev_set
from src.logging_utils import log_experiment_results

lang = "eng"
trial_id = "MTL_10epochs_full_soft_gating"
model_name = MODEL_NAMES[-1]

In [None]:
train_1, train_2, train_3 = load_data(lang)
tokenizer = AutoTokenizer.from_pretrained(model_name, force_download=True)
train_dataset, val_dataset = prepare_datasets(train_1, train_2, train_3, tokenizer)

In [None]:
pos_weight_2 = compute_pos_weights(train_2, train_2.columns[2:])
pos_weight_3 = compute_pos_weights(train_3, train_3.columns[2:])
model = SharedMTLModel(model_name, NUM_TYPES, NUM_MANIFESTATIONS, pos_weight_2, pos_weight_3)

In [None]:
training_args = get_training_args(trial_id)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
    data_collator=DataCollatorWithPadding(tokenizer),
    callbacks=[get_early_stopping_callback()]
)

trainer.train()
eval_results = trainer.evaluate()
print(
    "Validation Results:",
    f"\nsubtask_1 f1_macro: {eval_results['eval_subtask_1/f1_macro']:.4f}",
    f"\nsubtask_2 f1_macro: {eval_results['eval_subtask_2/f1_macro']:.4f}",
    f"\nsubtask_3 f1_macro: {eval_results['eval_subtask_3/f1_macro']:.4f}",
)

In [None]:
log_experiment_results(
    eval_results,
    trial_id,
    lang,
    model_name,
    training_args,
    NUM_TYPES,
    NUM_MANIFESTATIONS
)

In [None]:
output_1, output_2, output_3 = predict_dev_set(trainer, tokenizer, lang, trial_id)
print("Predictions saved for all 3 dev sets with Logical Gating applied.")