In [15]:
import torch
from torch import nn
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset
from sklearn.metrics import f1_score, accuracy_score, confusion_matrix
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

In [8]:
ds = load_dataset('emotion')
ds

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 16000
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 2000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 2000
    })
})

In [4]:
model_chckpt = 'distilbert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_chckpt)
model = AutoModelForSequenceClassification.from_pretrained(model_chckpt, num_labels=6).to('cuda')

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
def tokenize_batch(batch, tokenizer=tokenizer):
    return tokenizer(batch['text'], padding=True, truncation=True)

ds = ds.map(tokenize_batch, batched=True, batch_size=None)
ds

Map: 100%|██████████| 16000/16000 [00:00<00:00, 16209.57 examples/s]
Map: 100%|██████████| 2000/2000 [00:00<00:00, 30067.67 examples/s]
Map: 100%|██████████| 2000/2000 [00:00<00:00, 30579.31 examples/s]


DatasetDict({
    train: Dataset({
        features: ['text', 'label', 'input_ids', 'attention_mask'],
        num_rows: 16000
    })
    validation: Dataset({
        features: ['text', 'label', 'input_ids', 'attention_mask'],
        num_rows: 2000
    })
    test: Dataset({
        features: ['text', 'label', 'input_ids', 'attention_mask'],
        num_rows: 2000
    })
})

In [12]:
ds.set_format('pytorch')

DatasetDict({
    train: Dataset({
        features: ['text', 'label', 'input_ids', 'attention_mask'],
        num_rows: 16000
    })
    validation: Dataset({
        features: ['text', 'label', 'input_ids', 'attention_mask'],
        num_rows: 2000
    })
    test: Dataset({
        features: ['text', 'label', 'input_ids', 'attention_mask'],
        num_rows: 2000
    })
})

In [None]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = torch.argmax(pred.predictions, dim=1)
    f1_macro = f1_score(labels, preds, average='macro')
    f1_micro = f1_score(labels, preds, average='micro')
    acc = accuracy_score(labels, preds)
    return {
        'f1_macro': f1_macro,
        'f1_micro': f1_micro,
        'acc': acc
    }

In [None]:
train_ds = ds['train']
val_ds = ds['val']
test_ds = ds['test']

In [None]:
batch_size = 32
logging_steps = len(train_ds) // batch_size
num_train_epochs = 10
learning_rate = 2e-5
weight_decay = 1e-5

training_args = TrainingArguments(
    output_dir=f'{model_chckpt}_emotions',
    num_train_epochs=num_train_epochs,
    learning_rate=learning_rate,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=weight_decay,
    evaluation_strategy='epoch',
    logging_steps=logging_steps,
    disable_tqdm=False
)

In [None]:
trainer = Trainer(
    model=model,
    training_args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=tokenizer
)

trainer.train()

In [None]:
def plot_confusion_matrix(y_true, y_pred):
    cm = confusion_matrix(y_true, y_pred, normalize='true')
    cm = np.round(cm, 3)
    _, ax = plt.subplots(figsize=(6, 6))
    ax.set_title("Normalized confusion matrix")
    ax.set_xlabel("Predicted label")
    ax.set_ylabel("True label")
    sns.heatmap(cm, annot=True, ax=ax, cmap="YlGnBu")
    plt.show()

In [None]:
test_preds = trainer.predict(test_dataset=test_ds).logits

In [None]:
def error_analysis(batch, model, loss_func=nn.CrossEntropyLoss(), device='cuda'):
    with torch.no_grad():
        output = model(input_ids=batch['input_ids'], attantion_mask=batch['attention_mask'])
    loss_vals = loss_func(output.logits, batch['label'].to(device))
    return {'loss value': loss_vals}

In [None]:
val_ds = val_ds.map(error_analysis, batched=False)

In [None]:
test_ds = test_ds.map(error_analysis, batched=False)