In [1]:
import torch

device = torch.device('mps')

In [3]:
from datasets import load_dataset

trec = load_dataset('trec', split='train[:1000]')
trec

Using custom data configuration default
Reusing dataset trec (/Users/xinli/.cache/huggingface/datasets/trec/default/1.1.0/751da1ab101b8d297a3d6e9c79ee9b0173ff94c4497b75677b59b61d5467a9b9)


Dataset({
    features: ['label-coarse', 'label-fine', 'text'],
    num_rows: 1000
})

In [4]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

tokens = tokenizer(
    trec['text'], max_length=512,
    truncation=True, padding='max_length'
)

In [5]:
import numpy as np

labels = np.zeros(
    (len(trec), max(trec['label-coarse'])+1)
)

labels[np.arange(len(trec)), trec['label-coarse']] = 1
labels[:5]

array([[1., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0.]])

In [6]:
labels = torch.Tensor(labels)

In [18]:
class TrecDataset(torch.utils.data.Dataset):
    def __init__(self, tokens, labels):
        self.tokens = tokens
        self.labels = labels
        
    def __getitem__(self, idx):
        input_ids = self.tokens[idx].ids
        attention_mask = self.tokens[idx].attention_mask
        labels = self.labels[idx]
        
        return {
            'input_ids': torch.tensor(input_ids),
            'attention_mask': torch.tensor(attention_mask),
            'labels': torch.tensor(labels)
        }
    
    def __len__(self):
        return len(self.labels)
    
dataset = TrecDataset(tokens, labels)

In [19]:
loader = torch.utils.data.DataLoader(dataset, batch_size=1)

In [20]:
from transformers import BertForSequenceClassification, BertConfig

config = BertConfig.from_pretrained('bert-base-uncased')
config.num_labels = max(trec['label-coarse'])+1           # six outputs
model = BertForSequenceClassification(config).to(device)  # to mps

In [21]:
model.train()

optim = torch.optim.Adam(model.parameters(), lr=5e-5)

In [22]:
from time import time
from tqdm.auto import tqdm

loop_time = []

# setup loop (using tqdm for the progress bar)
loop = tqdm(loader, leave=True)
for batch in loop:
    batch_mps = {
        'input_ids': batch['input_ids'].to(device),
        'attention_mask': batch['attention_mask'].to(device),
        'labels': batch['labels'].to(device)
    }
    
    t0 = time()
    
    optim.zero_grad()
    outputs = model(**batch_mps)
    
    loss = outputs[0]
    loss.backward()
    
    optim.step()
    
    loop_time.append(time()-t0)
    loop.set_postfix(loss=loss.item())

  0%|          | 0/1000 [00:00<?, ?it/s]

  'labels': torch.tensor(labels)


In [None]:
loop_time