In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments
from model.dataset import QAClassifierDataset
from misc.dataset_modifier import get_json
from model.weight import compute_class_weights
from model.trainer import WeightedTrainer

In [None]:
model_id = "microsoft/deberta-v3-base"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(
    model_id,
    num_labels=3
)


In [None]:
def train_epoch(path, epoch_id):
    samples = get_json(path)

    class_weight = compute_class_weights(samples)

    dataset = QAClassifierDataset(
        samples,
        tokenizer,
        class_weight
    )

    args = TrainingArguments(
        output_dir=f"./ckpts/epoch_{epoch_id}",
        num_train_epochs=1,
        per_device_train_batch_size=8,
        learning_rate=2e-5,
        weight_decay=0.01,
        logging_steps=50,
        save_strategy="no",
        report_to="none"
    )

    trainer = WeightedTrainer(
        model=model,
        args=args,
        train_dataset=dataset,
        tokenizer=tokenizer
    )

    trainer.train()

In [None]:
# Epoch 1: original only
train_epoch("./data/curated/train.json", 1)

# Epoch 2+
train_epoch("./data/updated/combined/train.json", 2)
