## Load models

In [1]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer, XLMRobertaTokenizerFast
import torch

In [2]:
model_checkpoint = "FacebookAI/xlm-roberta-base"
num_classes = 15

In [3]:
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_classes)

Some weights of XLMRobertaForSequenceClassification were not initialized from the model checkpoint at FacebookAI/xlm-roberta-base and are newly initialized: ['classifier.out_proj.bias', 'classifier.dense.bias', 'classifier.out_proj.weight', 'classifier.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

## Load dataset 

In [5]:
from utils.loaders import DatasetUA, convert_to_sequences

In [6]:
dataset_ua_full = DatasetUA("DatasetUA").load().shuffle()
dataset_ua = convert_to_sequences(dataset_ua_full).shuffle()

Generating train split: 0 examples [00:00, ? examples/s]

In [7]:
dataset_ua_splitted = dataset_ua.train_test_split(test_size=0.15)

## Create collator and loader

In [8]:
from torch.utils.data import DataLoader
from utils.collators import XlmRobertaCollator
from utils.meta import label_to_int

In [9]:
collator = XlmRobertaCollator(tokenizer=tokenizer, label_to_int=label_to_int)

In [10]:
train_loader = DataLoader(dataset_ua_splitted["train"], batch_size=16, collate_fn=collator)
val_loader = DataLoader(dataset_ua_splitted["test"], batch_size=4, collate_fn=collator)

## Train

In [11]:
from utils.trainers import Trainer, TrainConfig

In [12]:
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=torch.optim.Adam,
    config=TrainConfig(),
)

In [13]:
trainer.train()

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2412 [00:00<?, ?it/s]

  0%|          | 0/1703 [00:00<?, ?it/s]

  0%|          | 0/2412 [00:00<?, ?it/s]

  0%|          | 0/1703 [00:00<?, ?it/s]

model.pt:   0%|          | 0.00/1.11G [00:00<?, ?B/s]

events.out.tfevents.1708863156.zarawindows.11276.0:   0%|          | 0.00/472 [00:00<?, ?B/s]

## Predict

In [16]:
from utils.predictors import predict_class_accuracy

In [38]:
predict_class_accuracy(dataset_ua_full, model, tokenizer, label_to_int)

  0%|          | 0/1919 [00:00<?, ?it/s]

{'business': 0.3492063492063492,
 'economy': 0.7777777777777778,
 'education': 0.9375,
 'fashion': 0.875,
 'financy': 0.9375,
 'fun': 0.10784313725490197,
 'health': 0.8785714285714286,
 'kino': 0.9285714285714286,
 'porady': 0.9365079365079365,
 'realestate': 0.9017857142857143,
 'show': 0.6492537313432836,
 'smachnonews': 0.8303571428571429,
 'sport': 0.9698492462311558,
 'tech': 0.8125,
 'zakordon': 0.8690476190476191}