In [17]:
import torch

EPOCH = 20
device = torch.device("cuda" if torch.cuda.is_available() else "mps")

In [2]:
from datasets import load_dataset

dataset = load_dataset("imdb")
dataset["train"].features

{'text': Value(dtype='string', id=None),
 'label': ClassLabel(names=['neg', 'pos'], id=None)}

In [8]:
from transformers import BertTokenizerFast, WordpieceTokenizer

tokenizer: BertTokenizerFast = BertTokenizerFast.from_pretrained("bert-base-uncased")


def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        padding=True,
        truncation=True,
    )


tokenized_datasets = dataset.map(tokenize_function, batched=True)

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

In [9]:
maxlen = 0
for seq in tokenized_datasets["train"]["input_ids"]:
    if len(seq) > maxlen:
        maxlen = len(seq)


maxlen

512

In [20]:
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(5000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))

In [13]:
import torch
import tqdm

from torch import optim, nn, Tensor
from torch.utils.data import DataLoader

from lstm import LstmForClassification

lstm = LstmForClassification(
    num_labels=2,
    hidden_size=768,
    num_layers=1,
    vocab_size=tokenizer.vocab_size,
    pad_token_id=0,
)

lstm.to(device)

LstmForClassification(
  (embedding): Embedding(30522, 768, padding_idx=0)
  (lstm): LSTM(
    (cells): ModuleList(
      (0): LSTMCell(
        (i): Linear(in_features=1536, out_features=768, bias=True)
        (f): Linear(in_features=1536, out_features=768, bias=True)
        (c): Linear(in_features=1536, out_features=768, bias=True)
        (o): Linear(in_features=1536, out_features=768, bias=True)
      )
    )
  )
  (fc): Linear(in_features=768, out_features=2, bias=True)
)

In [21]:
train_dataloader = DataLoader(small_train_dataset, batch_size=256, shuffle=True)
optimizer = optim.Adam(lstm.parameters(), lr=1e-3)
# lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)
lstm.train()
with tqdm.tqdm(total=len(train_dataloader) * EPOCH) as tqdm_bar:
    for epoch in range(EPOCH):
        training_loss = 0.0
        for batch in train_dataloader:
            labels: Tensor = batch["label"].to(device)
            input_ids = torch.stack(batch["input_ids"]).to(device)
            loss, logits = lstm(
                x=input_ids,
                labels=labels,
            )
            loss: Tensor
            loss.backward()
            training_loss += loss.item()
            optimizer.step()
            optimizer.zero_grad()
            tqdm_bar.update(1)
        print(epoch, training_loss / len(train_dataloader))

  5%|▌         | 20/400 [00:40<11:32,  1.82s/it]

0 0.7393235981464386


 10%|█         | 40/400 [01:20<10:54,  1.82s/it]

1 0.6526909589767456


 15%|█▌        | 60/400 [02:00<10:15,  1.81s/it]

2 0.6230518490076065


 20%|██        | 80/400 [02:40<09:39,  1.81s/it]

3 0.6007402092218399


 25%|██▌       | 100/400 [03:21<09:03,  1.81s/it]

4 0.5943064004182815


 30%|███       | 120/400 [04:01<08:34,  1.84s/it]

5 0.6230233013629913


 35%|███▌      | 140/400 [04:41<07:53,  1.82s/it]

6 0.6064278662204743


 40%|████      | 160/400 [05:22<07:13,  1.80s/it]

7 0.5968321055173874


 45%|████▌     | 180/400 [06:02<06:42,  1.83s/it]

8 0.5911589056253433


 50%|█████     | 200/400 [06:42<06:01,  1.81s/it]

9 0.5927586197853089


 55%|█████▌    | 220/400 [07:22<05:27,  1.82s/it]

10 0.5919461816549301


 60%|██████    | 240/400 [08:03<04:50,  1.82s/it]

11 0.5923446923494339


 65%|██████▌   | 260/400 [08:43<04:13,  1.81s/it]

12 0.5896066635847091


 70%|███████   | 280/400 [09:23<03:39,  1.83s/it]

13 0.587950485944748


 75%|███████▌  | 300/400 [10:03<03:01,  1.81s/it]

14 0.5874542742967606


 80%|████████  | 320/400 [10:43<02:27,  1.84s/it]

15 0.5856389224529266


 85%|████████▌ | 340/400 [11:23<01:48,  1.81s/it]

16 0.5865317910909653


 90%|█████████ | 360/400 [12:04<01:12,  1.81s/it]

17 0.5853672564029694


 95%|█████████▌| 380/400 [12:44<00:36,  1.85s/it]

18 0.5858426541090012


100%|██████████| 400/400 [13:24<00:00,  2.01s/it]

19 0.5875193804502488





In [22]:
import evaluate

metric = evaluate.load("accuracy")
lstm.eval()
eval_dataloader = DataLoader(small_eval_dataset, batch_size=256, shuffle=True)
tqdm_bar = tqdm.tqdm(eval_dataloader)

with tqdm.tqdm(eval_dataloader) as tqdm_bar:
    for batch in eval_dataloader:
        labels: Tensor = batch["label"].to(device)
        input_ids = torch.stack(batch["input_ids"]).to(device)
        with torch.no_grad():
            _, logits = lstm(
                x=input_ids,
                labels=labels,
            )
        predictions = torch.argmax(logits, dim=-1)
        metric.add_batch(predictions=predictions, references=labels)
        tqdm_bar.update(1)

metric.compute()

  0%|          | 0/4 [00:00<?, ?it/s]
100%|██████████| 4/4 [00:01<00:00,  2.04it/s]


{'accuracy': 0.508}