In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset, Subset
import pandas as pd
import numpy as np
from transformers import AutoTokenizer

from model import BERTClassification

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class TextDataset(Dataset):
    def __init__(self, df_path):
        self.df = pd.read_csv(df_path, index_col=0)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index) -> tuple[str, int]:
        sentence = self.df["Sentence"][index]
        level = self.df["Level"][index]
        label = self.level2label(level)

        return sentence, label

    @classmethod
    def level2label(cls, level):
        return {"N1": 0, "N2": 1, "N3": 2, "N4": 3, "N5": 4}[level]

In [3]:
LEARNING_RATE = 3e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 2
NUM_EPOCHS = 1000
NUM_WORKERS = 2
IMAGE_SCALE = 0.1
LOAD_FROM = None
DATA_ROOT = r"util/jlpt_sentences.csv"
NUM_CLASS = 5
EXP_FOLDER = "exp4"

In [4]:
tokenizer = AutoTokenizer.from_pretrained("cl-tohoku/bert-base-japanese")
dataset = TextDataset(DATA_ROOT)
sub_dataset = Subset(
    dataset, np.linspace(0, len(dataset), num=50, endpoint=False, dtype=int)
)
data_loader = DataLoader(
    sub_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS
)

In [5]:
model = BERTClassification(num_class=NUM_CLASS).to(DEVICE)
criterion = nn.CrossEntropyLoss(reduction="sum")  # to get average easily
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [6]:
for epoch in range(NUM_EPOCHS):
    total_loss = 0
    all_truths = []
    all_outputs = []

    for batch_idx, (sentences, labels) in enumerate(data_loader):
        inputs = tokenizer(sentences, return_tensors="pt", padding=True).to(DEVICE)
        labels = torch.LongTensor(labels).to(DEVICE)

        outputs = model(**inputs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        # if batch_idx % 200 == 0 or loss.item() > 0.01:
        #     print(
        #         f"[Batch {batch_idx+1:4d}/{len(data_loader)}]"
        #         f" Loss: {loss.item()/BATCH_SIZE:.4f}"
        #         f" Labels: {labels.tolist()}"
        #     )
    # print(f"[Epoch {batch_idx+1:4d}/{len(data_loader)}]")
    if epoch % 50 == 0:
        print(f"Epoch {epoch+1:4d}| Loss: {total_loss/len(sub_dataset):.4f}")

Epoch    1| Loss: 1.6111
Epoch   51| Loss: 0.5750
Epoch  101| Loss: 0.2899
Epoch  151| Loss: 0.1675
Epoch  201| Loss: 0.0954
Epoch  251| Loss: 0.0780
Epoch  301| Loss: 0.0387


KeyboardInterrupt: 

In [None]:
"""
Epoch    1| Loss: 1.6502
Epoch   51| Loss: 0.5745
Epoch  101| Loss: 0.2626
Epoch  151| Loss: 0.1694
Epoch  201| Loss: 0.1005
Epoch  251| Loss: 0.0651
Epoch  301| Loss: 0.0441
"""