In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset, Subset
import pandas as pd
import numpy as np
from transformers import AutoTokenizer

from models.bert import BERTClassification
from models.attention import SimpleAttention

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class TextDataset(Dataset):
    def __init__(self, df_path):
        super().__init__()
        self.df = pd.read_csv(df_path, index_col=0)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index) -> tuple[str, int]:
        sentence = self.df["Sentence"][index]
        level = self.df["Level"][index]
        label = self.level2label(level)

        return sentence, label

    @classmethod
    def level2label(cls, level):
        return {"N1": 0, "N2": 1, "N3": 2, "N4": 3, "N5": 4}[level]

In [3]:
LEARNING_RATE = 3e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
NUM_EPOCHS = 3
NUM_WORKERS = 2
LOAD_FROM = None
DATA_ROOT = r"training_data/train.csv"
NUM_CLASS = 5
EXP_FOLDER = "exp4"

In [4]:
tokenizer = AutoTokenizer.from_pretrained("cl-tohoku/bert-base-japanese")
dataset = TextDataset(DATA_ROOT)
# sub_dataset = Subset(
#     dataset, np.linspace(0, len(dataset), num=50, endpoint=False, dtype=int)
# )
data_loader = DataLoader(
    dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS
)

In [5]:
model = SimpleAttention(num_class=NUM_CLASS, vocab_size=len(tokenizer)).to(DEVICE)
criterion = nn.CrossEntropyLoss(reduction="sum")  # to get average easily
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [6]:
for epoch in range(NUM_EPOCHS):
    print(f"Epoch [{epoch}/{NUM_EPOCHS}]")
    total_loss = 0
    all_truths = []
    all_outputs = []

    model.train()
    for batch_idx, (sentences, labels) in enumerate(data_loader):
        inputs = tokenizer(sentences, return_tensors="pt", padding=True).to(DEVICE)
        labels = torch.LongTensor(labels).to(DEVICE)
        
        outputs = model(**inputs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        if batch_idx % 50 == 0:
            print(
                f"[Batch {batch_idx+1:4d}/{len(data_loader)}]"
                f" Loss: {loss.item()/BATCH_SIZE:.4f}"
                f" Labels: {labels.tolist()}"
            )
    if epoch % 50 == 0:
        print(f"Total loss: {total_loss/len(dataset):.4f}")

Epoch [0/3]
[Batch    1/130] Loss: 1.6061 Labels: [2, 2, 4, 1, 4, 0, 4, 2, 3, 3, 2, 2, 3, 3, 3, 2]
[Batch   51/130] Loss: 1.3689 Labels: [3, 1, 0, 3, 2, 2, 2, 4, 2, 1, 1, 4, 2, 4, 1, 2]
[Batch  101/130] Loss: 1.2640 Labels: [2, 4, 2, 0, 2, 2, 2, 1, 3, 3, 2, 3, 3, 3, 3, 3]
Total loss: 1.3171
Epoch [1/3]
[Batch    1/130] Loss: 1.0874 Labels: [2, 2, 4, 2, 1, 3, 0, 0, 4, 3, 2, 1, 2, 1, 2, 0]
[Batch   51/130] Loss: 0.6830 Labels: [0, 4, 4, 3, 4, 4, 2, 1, 3, 1, 2, 4, 2, 2, 1, 2]
[Batch  101/130] Loss: 1.0881 Labels: [0, 2, 2, 4, 3, 2, 0, 3, 2, 4, 3, 0, 0, 3, 3, 1]
Epoch [2/3]
[Batch    1/130] Loss: 0.6145 Labels: [1, 1, 0, 2, 0, 3, 0, 4, 2, 3, 0, 2, 1, 4, 2, 4]
[Batch   51/130] Loss: 0.5474 Labels: [2, 3, 2, 2, 0, 1, 0, 3, 4, 0, 4, 3, 2, 4, 2, 1]
[Batch  101/130] Loss: 0.5072 Labels: [2, 1, 2, 0, 3, 2, 2, 2, 0, 0, 0, 3, 1, 1, 3, 4]


In [7]:
"""
Epoch    1| Loss: 1.6502
Epoch   51| Loss: 0.5745
Epoch  101| Loss: 0.2626
Epoch  151| Loss: 0.1694
Epoch  201| Loss: 0.1005
Epoch  251| Loss: 0.0651
Epoch  301| Loss: 0.0441
"""

'\nEpoch    1| Loss: 1.6502\nEpoch   51| Loss: 0.5745\nEpoch  101| Loss: 0.2626\nEpoch  151| Loss: 0.1694\nEpoch  201| Loss: 0.1005\nEpoch  251| Loss: 0.0651\nEpoch  301| Loss: 0.0441\n'

In [9]:
batch_idx, (sentences, labels) = next(enumerate(data_loader))
inputs = tokenizer(sentences, return_tensors="pt", padding=True).to(DEVICE)
tokens_list = [tokenizer.convert_ids_to_tokens(ids) for ids in inputs["input_ids"]]

model.eval()
_, attention = model.get_attention_output(**inputs)
cls_attn = attention[:, 0, :]
print(attention.shape)
print(cls_attn.sum())

for tokens, attn in zip(tokens_list, cls_attn):
    for t in tokens:
        print(f"{t:>15}", end="")
    print()
    for a in attn.tolist():
        print(f"{a:15.4f}", end="")
    print()

ValueError: too many values to unpack (expected 2)