In [None]:
from transformers import RobertaForSequenceClassification, RobertaConfig, RobertaTokenizer

from src.model_compression.training_utils.datasets import (processors,output_modes, 
                                                           load_and_cache_examples)
from  src.model_compression.training_utils.modules import RobertaForSpanClassification

In [0]:
task_name = ""
model_type = ""
model_checkpoint = ""
tokenizer_name = ""
do_lower_case = False
data_dir = ""
output_dir = ""

In [0]:
MODEL_CLASSES = {
    "roberta": (
        RobertaConfig,
        RobertaTokenizer,
        {"classification": RobertaForSequenceClassification, "span_classification": RobertaForSpanClassification},
    ),
}
tasks_num_spans = {
    "wic": 2,
    "wsc": 2,
}

In [0]:
# Prepare task
task_name = task_name.lower()
assert task_name in processors, f"Task {task_name} not found!"
processor = processors[task_name]()
output_mode = output_modes[task_name]
label_list = processor.get_labels() 
num_labels = len(label_list)

In [0]:
model_type = model_type.lower()
config_class, tokenizer_class, model_classes = MODEL_CLASSES[model_type]
model_class = model_classes[output_mode]
config = config_class.from_pretrained(
    model_checkpoint,
    num_labels=num_labels,
    finetuning_task=task_name,
)
if output_mode == "span_classification":
    config.num_spans = tasks_num_spans[task_name]
tokenizer = tokenizer_class.from_pretrained(
    tokenizer_name,
    do_lower_case=do_lower_case,
)
model = model_class.from_pretrained(
            model_checkpoint,
            config=config,
        )
model.cuda()

In [0]:
train_dataset = load_and_cache_examples(task_name, tokenizer, data_dir) 

In [ ]:
global_step, tr_loss = train(train_dataset, model, tokenizer)

In [ ]:
result, preds, ex_ids = evaluate(task_name, model, tokenizer, prefix="", use_tqdm=False)
result = dict((f"{k}", v) for k, v in result.items())

eval_task_names = (task_name,) # ("rte", "ax-b", "ax-g") if args.task_name == "rte" else (args.task_name,)
for eval_task_name in eval_task_names:
    result, preds, ex_ids = evaluate(
        eval_task_name, model, tokenizer, split="test", prefix="", use_tqdm=False
    )
    processor = processors[eval_task_name]()
    if task_name == "record":
        answers = processor.get_answers(data_dir, "test")
        processor.write_preds(preds, ex_ids, output_dir, answers=answers)
    else:
        processor.write_preds(preds, ex_ids, output_dir)