In [1]:
from typing import Literal

import os
import torch
import tqdm
import numpy as np

from torch import optim, nn, Tensor
from torch.utils.data import Dataset, DataLoader

from bert import BertForClassification

In [2]:
DATABASE_PATH = "./aclImdb"
EPOCH = 5
MAX_POSITION_EMBEDDINGS = 512

device = torch.device("cuda" if torch.cuda.is_available() else "mps")

In [3]:
from transformers import BertTokenizerFast

tokenizer: BertTokenizerFast = BertTokenizerFast.from_pretrained(
    "bert-base-uncased",
)


def tokenize(examples):
    return tokenizer(examples, padding="max_length", truncation=True)


class ImdbDataset(Dataset):
    def __init__(self, mode: Literal["train", "test"]):
        super().__init__()
        self.data = []
        self.label = []
        for k, v in {"{mode}/neg": 0, "{mode}/pos": 1}.items():
            class_root = os.path.join(DATABASE_PATH, k.format(mode=mode))
            for j in os.listdir(class_root):
                self.data.append(open(os.path.join(class_root, j)).read().strip())
                self.label.append(v)
        self.tokenized = tokenize(self.data).data

    def __getitem__(self, idx):
        input_ids = self.tokenized["input_ids"][idx]
        attention_mask = self.tokenized["attention_mask"][idx]
        labels = self.label[idx]
        return input_ids, attention_mask, labels

    def __len__(self):
        return len(self.data)
    
train_dataset = ImdbDataset("train")
test_dataset = ImdbDataset("test")

In [4]:
bert = BertForClassification(
    vocab_size=tokenizer.vocab_size,
    d_model=768,
    intermediate_size=4 * 768,
    max_position_embeddings=512,
    num_attention_heads=8,
    hidden_dropout_prob=0.1,
    num_hidden_layers=12,
    num_labels=2,
).to(device)

In [5]:
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
optimizer = optim.Adam(bert.parameters(), lr=1e-5)
# lr_scheduler

bert.train()
for epoch in range(EPOCH):
    tqdm_bar = tqdm.tqdm(train_dataloader)
    for i, (input_ids, attention_mask, labels) in enumerate(train_dataloader):
        labels: Tensor
        input_ids = torch.stack(input_ids).to(device)
        attention_mask = torch.stack(attention_mask).to(device)
        labels = labels.to(device)
        loss, logits = bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        loss: Tensor
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        tqdm_bar.update(1)
        if (i == 300):
            break
    print("Epoch:", epoch, "Loss:", loss)
torch.save(bert.state_dict(), "bert_from_scratch.pt")

 38%|███▊      | 301/782 [05:12<08:20,  1.04s/it]

Epoch: 0 Loss: tensor(0.6555, device='cuda:0', grad_fn=<NllLossBackward0>)


 38%|███▊      | 301/782 [05:12<08:19,  1.04s/it]


Epoch: 1 Loss: tensor(0.6977, device='cuda:0', grad_fn=<NllLossBackward0>)


 38%|███▊      | 301/782 [05:12<08:19,  1.04s/it]
  6%|▌         | 44/782 [00:45<12:43,  1.03s/it]

KeyboardInterrupt: 

In [6]:
torch.save(bert.state_dict(), "bert_from_scratch.pt")

In [None]:
bert.load_state_dict(
    torch.load("bert_from_scratch.pt", map_location=torch.device(device))
)

import evaluate

metric = evaluate.load("accuracy")

bert.eval()
eval_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=True)
for input_ids, attention_mask, labels in eval_dataloader:
    labels: Tensor
    input_ids = torch.stack(input_ids).to(device)
    attention_mask = torch.stack(attention_mask).to(device)
    labels = labels.to(device)
    with torch.no_grad():
        _, logits = bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
    predictions = torch.argmax(logits, dim=-1)
    metric.add_batch(predictions=predictions, references=labels)

metric.compute()