In [1]:
EPOCH = 3

In [2]:
from datasets import load_dataset

dataset = load_dataset("imdb")
dataset["train"].features

Found cached dataset imdb (/Users/louiechou/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)


  0%|          | 0/3 [00:00<?, ?it/s]

{'text': Value(dtype='string', id=None),
 'label': ClassLabel(names=['neg', 'pos'], id=None)}

In [3]:
from transformers import BertTokenizerFast

tokenizer: BertTokenizerFast = BertTokenizerFast.from_pretrained("bert-base-uncased")


def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        padding="max_length",
        truncation=True,
    )


tokenized_datasets = dataset.map(tokenize_function, batched=True)

Loading cached processed dataset at /Users/louiechou/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-e0a0342ae289143d.arrow
Loading cached processed dataset at /Users/louiechou/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-fd8ddff947474c37.arrow
Loading cached processed dataset at /Users/louiechou/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-2a73e6194285aadb.arrow


In [4]:
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(10000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))

Loading cached shuffled indices for dataset at /Users/louiechou/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-95a7ea67f59766e0.arrow
Loading cached shuffled indices for dataset at /Users/louiechou/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-db88639656d75d2e.arrow


In [5]:
import torch
import tqdm

from torch import optim, nn, Tensor
from torch.utils.data import DataLoader

from lstm import LstmForClassification

lstm = LstmForClassification(
    num_labels=2,
    hidden_size=512,
    num_layers=2,
    vocab_size=tokenizer.vocab_size,
    pad_token_id=0,
).to("mps")

train_dataloader = DataLoader(small_train_dataset, batch_size=64, shuffle=True)
optimizer = optim.Adam(lstm.parameters(), lr=1e-2)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)
lstm.train()
with tqdm.tqdm(total=len(train_dataloader) * EPOCH) as tqdm_bar:
    for epoch in range(EPOCH):
        for batch in train_dataloader:
            labels: Tensor = batch["label"].to("mps")
            input_ids = torch.stack(batch["input_ids"]).to("mps")
            loss, logits = lstm(
                x=input_ids,
                labels=labels,
            )
            loss: Tensor
            loss.backward()
            optimizer.step()
            lr_scheduler.step(loss.item())
            optimizer.zero_grad()
            tqdm_bar.update(1)
        print(epoch, loss.item())

  1%|          | 5/471 [00:19<30:29,  3.93s/it]

Epoch 00005: reducing learning rate of group 0 to 1.0000e-03.


  2%|▏         | 9/471 [00:35<30:10,  3.92s/it]

Epoch 00009: reducing learning rate of group 0 to 1.0000e-04.


  3%|▎         | 13/471 [00:51<30:21,  3.98s/it]

Epoch 00013: reducing learning rate of group 0 to 1.0000e-05.


  4%|▎         | 17/471 [01:08<31:14,  4.13s/it]

Epoch 00017: reducing learning rate of group 0 to 1.0000e-06.


  4%|▍         | 21/471 [01:24<30:05,  4.01s/it]

Epoch 00021: reducing learning rate of group 0 to 1.0000e-07.


  5%|▌         | 25/471 [01:40<29:41,  3.99s/it]

Epoch 00025: reducing learning rate of group 0 to 1.0000e-08.


 19%|█▊        | 88/471 [05:48<25:15,  3.96s/it]


KeyboardInterrupt: 

In [None]:
import evaluate

metric = evaluate.load("accuracy")
lstm.eval()
eval_dataloader = DataLoader(small_eval_dataset, batch_size=64, shuffle=True)
tqdm_bar = tqdm.tqdm(eval_dataloader)
for batch in eval_dataloader:
    labels: Tensor = batch["label"].to("mps")
    input_ids = torch.stack(batch["input_ids"]).to("mps")
    with torch.no_grad():
        _, logits = lstm(
            x=input_ids,
            labels=labels,
        )
    predictions = torch.argmax(logits, dim=-1)
    metric.add_batch(predictions=predictions, references=labels)
    tqdm_bar.update(1)

metric.compute()