In [1]:
!pip install datasets transformers



In [2]:
from datasets import load_dataset
import torch
from transformers import RobertaModel, RobertaTokenizer, Trainer, TrainingArguments

In [3]:
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaModel.from_pretrained("roberta-base")
loss_fn = torch.nn.BCEWithLogitsLoss()

def tokenize(batch):
    return tokenizer(batch['text'], padding=True, truncation=True)

In [4]:
dataset = load_dataset("civil_comments", split='train[:1000]')
dataset_train = dataset.map(tokenize, batched=True, batch_size=len(dataset))
labels = ['insult', 'obscene', 'severe_toxicity', 'sexual_explicit', 'threat', 'toxicity']
dataset_train.set_format('torch', columns=['input_ids', 'attention_mask']+labels)

Using custom data configuration default
Reusing dataset civil_comments (/home/henry/.cache/huggingface/datasets/civil_comments/default/0.9.0/98bdc73fc77a117cf5d17c9977e278c8023c64177a3ed9e0c49f7a5bdf10a47b)
Loading cached processed dataset at /home/henry/.cache/huggingface/datasets/civil_comments/default/0.9.0/98bdc73fc77a117cf5d17c9977e278c8023c64177a3ed9e0c49f7a5bdf10a47b/cache-f900a6a5583b5e79.arrow


In [5]:
class RobertaClass(torch.nn.Module):
    def __init__(self, num_labels: int):
        super(RobertaClass, self).__init__()
        self.l1 = RobertaModel.from_pretrained('roberta-base', return_dict=True)
        self.pre_classifier = torch.nn.Linear(768, 768)
        self.dropout = torch.nn.Dropout(0.3)
        self.classifier = torch.nn.Linear(768, num_labels)
    
    def forward(self, input_ids, attention_mask):
        output_1 = self.l1(input_ids, attention_mask=attention_mask)
        hidden_state = output_1[0]
        pooler = hidden_state[:, 0]
        pooler = self.pre_classifier(pooler)
        pooler = torch.nn.ReLU()(pooler)
        pooler = self.dropout(pooler)
        output = self.classifier(pooler)
        return output

model = RobertaClass(num_labels=len(labels))

In [6]:
dataset_train[0]

{'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0

In [7]:
class MyTrainer(Trainer):
    def compute_loss(self, model, inputs):
        labels = inputs.pop("labels")
        logits = model(**inputs)[0]
        return loss_fn(logits, labels)

In [8]:
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    warmup_steps=500,
    weight_decay=0.01,
    evaluate_during_training=True,
    logging_dir='./logs',
)

trainer = MyTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset_train,
)

In [9]:
trainer.train()

KeyError: 'labels'

In [35]:
train_dataloader = trainer.get_train_dataloader()
for thing in train_dataloader:
    print(thing)
    break

{'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'input_ids': tensor([[    0, 37562,     8,  ...,     1,     1,     1],
        [    0,   100,   524,  ...,     1,     1,     1],
        [    0,  4239,  6975,  ...,     1,     1,     1],
        ...,
        [    0, 14181,  7495,  ...,     1,     1,     1],
        [    0, 33757,    47,  ...,     1,     1,     1],
        [    0,   100,   216,  ...,     1,     1,     1]])}
