In [1]:
from typing import Literal

import os
import torch
import tqdm
import numpy as np

from torch import optim, nn, Tensor
from torch.utils.data import Dataset, DataLoader

from bert import BertForClassification

In [2]:
DATABASE_PATH = "./aclImdb"
EPOCH = 10
MAX_POSITION_EMBEDDINGS = 512

device = torch.device("cuda" if torch.cuda.is_available() else "mps")

In [3]:
from datasets import load_dataset

dataset = load_dataset("imdb")
dataset["train"].features

{'text': Value(dtype='string', id=None),
 'label': ClassLabel(names=['neg', 'pos'], id=None)}

In [4]:
from transformers import BertTokenizerFast

tokenizer: BertTokenizerFast = BertTokenizerFast.from_pretrained("bert-base-uncased")


def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        padding="max_length",
        truncation=True,
    )


tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(["text"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")

In [5]:
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))

In [6]:
bert = BertForClassification(
    vocab_size=tokenizer.vocab_size,
    d_model=768,
    intermediate_size=4 * 768,
    max_position_embeddings=512,
    num_attention_heads=8,
    hidden_dropout_prob=0.1,
    num_hidden_layers=12,
    num_labels=2,
).to(device)

In [7]:
from transformers import get_scheduler

train_dataloader = DataLoader(small_train_dataset, batch_size=16, shuffle=True)
optimizer = optim.Adam(bert.parameters(), lr=1e-5)
# lr_scheduler = get_scheduler(
#     name="linear",
#     optimizer=optimizer,
#     num_warmup_steps=0,
#     num_training_steps=len(train_dataloader) * EPOCH,
# )

bert.train()
losses = []

with tqdm.tqdm(total=len(train_dataloader) * EPOCH) as tqdm_bar:
    for epoch in range(EPOCH):
        training_loss = 0.0
        for batch in train_dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            batch["input_ids"].transpose_(0, 1)
            batch["attention_mask"].transpose_(0, 1)
            batch["token_type_ids"].transpose_(0, 1)
            # labels: Tensor
            # input_ids = torch.stack(input_ids).to(device)
            # attention_mask = torch.stack(attention_mask).to(device)
            # labels = labels.to(device)
            loss, logits = bert(**batch)
            loss: Tensor
            loss.backward()
            training_loss += loss.item()
            optimizer.step()
            optimizer.zero_grad()
            tqdm_bar.update(1)
        print("Epoch:", epoch, "Loss:", training_loss / len(train_dataloader))
        losses.append(training_loss / len(train_dataloader))

import matplotlib.pyplot as plt

plt.plot(losses)

torch.save(bert.state_dict(), "bert_from_scratch.pt")

  0%|          | 0/630 [00:00<?, ?it/s]

 10%|█         | 63/630 [00:32<04:04,  2.32it/s]

Epoch: 0 Loss: 0.7238034653285194


 20%|██        | 126/630 [01:03<03:41,  2.28it/s]

Epoch: 1 Loss: 0.7009218032397921


 27%|██▋       | 170/630 [01:26<03:55,  1.95it/s]

In [7]:
bert.load_state_dict(
    torch.load("bert_from_scratch.pt.bk", map_location=torch.device(device))
)

bert.to(device)

import evaluate

metric = evaluate.load("accuracy")

bert.eval()
eval_dataloader = DataLoader(small_eval_dataset, batch_size=8, shuffle=True)
with tqdm.tqdm(eval_dataloader) as tqdm_bar:
    for batch in eval_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        batch["input_ids"].transpose_(0, 1)
        batch["attention_mask"].transpose_(0, 1)
        batch["token_type_ids"].transpose_(0, 1)
        with torch.no_grad():
            _, logits = bert(**batch)
        predictions = torch.argmax(logits, dim=-1)
        metric.add_batch(predictions=predictions, references=batch["labels"])
        tqdm_bar.update(1)

metric.compute()

100%|██████████| 125/125 [00:11<00:00, 10.75it/s]


{'accuracy': 0.745}