In [None]:
EPOCH = 3

In [None]:
from datasets import load_dataset

dataset = load_dataset("imdb")
dataset["train"].features

In [None]:
from transformers import BertTokenizerFast

tokenizer: BertTokenizerFast = BertTokenizerFast.from_pretrained("bert-base-uncased")


def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        padding="max_length",
        truncation=True,
    )


tokenized_datasets = dataset.map(tokenize_function, batched=True)

In [None]:
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(10000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))

In [8]:
import torch
import tqdm

from torch import optim, nn, Tensor
from torch.utils.data import DataLoader

from lstm import LstmForClassification

lstm = LstmForClassification(
    num_labels=2,
    hidden_size=512,
    num_layers=2,
    vocab_size=tokenizer.vocab_size,
    pad_token_id=0,
).to("mps")

train_dataloader = DataLoader(small_train_dataset, batch_size=200, shuffle=True)
optimizer = optim.Adam(lstm.parameters(), lr=1e-2)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)
lstm.train()
with tqdm.tqdm(total=len(train_dataloader) * EPOCH) as tqdm_bar:
    for epoch in range(EPOCH):
        for batch in train_dataloader:
            labels: Tensor = batch["label"].to("mps")
            input_ids = torch.stack(batch["input_ids"]).to("mps")
            loss, logits = lstm(
                x=input_ids,
                labels=labels,
            )
            loss: Tensor
            loss.backward()
            optimizer.step()
            lr_scheduler.step(loss.item())
            optimizer.zero_grad()
            tqdm_bar.update(1)
        print(epoch, loss.item())

  1%|▏         | 2/150 [04:33<5:37:39, 136.89s/it] 


KeyboardInterrupt: 

In [None]:
import evaluate

metric = evaluate.load("accuracy")
lstm.eval()
eval_dataloader = DataLoader(small_eval_dataset, batch_size=256, shuffle=True)
tqdm_bar = tqdm.tqdm(eval_dataloader)

with tqdm.tqdm(eval_dataloader) as tqdm_bar:
    for batch in eval_dataloader:
        labels: Tensor = batch["label"].to("mps")
        input_ids = torch.stack(batch["input_ids"]).to("mps")
        with torch.no_grad():
            _, logits = lstm(
                x=input_ids,
                labels=labels,
            )
        predictions = torch.argmax(logits, dim=-1)
        metric.add_batch(predictions=predictions, references=labels)
        tqdm_bar.update(1)

metric.compute()