In [1]:
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import transformers
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm, trange

TRAIN_PATH = "absc_data/train.csv"
TEST_PATH = "absc_data/test.csv"

In [2]:
train = pd.read_csv(TRAIN_PATH)
train, val = train_test_split(train, test_size=0.25, random_state=42)
test = pd.read_csv(TEST_PATH)

print(len(train), len(val), len(test))

3630 1210 1211


In [3]:
# Constants

BATCH_SIZE = 64
CLASSES = train['Polarity'].unique()
CLASS2INDEX = {c: i for i, c in enumerate(CLASSES)}
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCHS = 30
LEARNING_RATE = 1e-3
MAX_LEN = 128
NUM_CLASSES = 4
NUM_CLASSES = len(CLASSES)

In [4]:
class BertDataset(Dataset):
    def __init__(self, data, tokenizer, *, max_len=128):
        super(BertDataset, self).__init__()
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len
    # END __init__

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        text = f"{row['Aspect Term']} [SEP] {row['Sentence']}"
        label = NUM_CLASSES * [0]
        label[CLASS2INDEX[row['Polarity']]] = 1

        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=True,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
            truncation=True,
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'token_type_ids': encoding['token_type_ids'].flatten(),
            'label': torch.tensor(label)
        }
    # END __getitem__
# END BertDataset

In [5]:
class BertModel(nn.Module):
    def __init__(self, model_name, num_classes):
        super(BertModel, self).__init__()
        self.bert = transformers.BertModel.from_pretrained(model_name)
        self.drop = nn.Dropout(p=0.3)
        self.out = nn.Linear(self.bert.config.hidden_size, num_classes)
    # END __init__

    def forward(self, input_ids, attention_mask, token_type_ids):
        _, pooled_output = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            return_dict=False
        )
        output = self.drop(pooled_output)
        return self.out(output)
    # END forward
# END BertModel

In [6]:
tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
train_dataset = BertDataset(train, tokenizer, max_len=MAX_LEN)
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
)
val_dataset = BertDataset(val, tokenizer, max_len=MAX_LEN)
val_loader = DataLoader(
    dataset=val_dataset,
    batch_size=BATCH_SIZE,
)

In [7]:
# from pprint import pp
# first = train_dataset[0]
# pp(first)

In [None]:
model = BertModel("bert-base-uncased", NUM_CLASSES)
model.load_state_dict(torch.load("model.pth", weights_only=True))

model = model.to(DEVICE, non_blocking=True)
for bert_param in model.bert.parameters():
    bert_param.requires_grad = False

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)

In [9]:
best_acc = 0
for epoch in trange(EPOCHS):
    tqdm.write(f"\nEpoch {epoch + 1}/{EPOCHS}")

    correct = 0
    train_loss = 0

    tqdm.write("Training...")
    model.train()
    for i, batch in tqdm(enumerate(train_loader),
                         leave=False,
                         total=len(train_loader),
                         colour='magenta'):
        input_ids = batch['input_ids'].to(DEVICE, non_blocking=True)
        attention_mask = batch['attention_mask'].to(DEVICE, non_blocking=True)
        token_type_ids = batch['token_type_ids'].to(DEVICE, non_blocking=True)
        label = batch['label'].to(DEVICE, non_blocking=True)

        optimizer.zero_grad()
        output = model(input_ids, attention_mask, token_type_ids)
        label = label.type_as(output)
        loss = loss_fn(output, label)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()

        _, pred = torch.max(output, dim=1)
        correct += torch.sum(pred == torch.argmax(label, dim=1)).item()
    # END for i, batch in enumerate(train_loader)

    tqdm.write(f"Training Loss:            {train_loss}")
    tqdm.write(f"Training Accuracy:        {correct / len(train)}")

    correct = 0
    val_loss = 0

    tqdm.write("Validating...")
    with torch.inference_mode():
        model.eval()
        for batch in tqdm(val_loader, leave=False, colour='green'):
            input_ids = batch['input_ids'].to(DEVICE, non_blocking=True)
            attention_mask = batch['attention_mask'].to(
                DEVICE, non_blocking=True)
            token_type_ids = batch['token_type_ids'].to(
                DEVICE, non_blocking=True)
            label = batch['label'].to(DEVICE, non_blocking=True)

            output = model(input_ids, attention_mask, token_type_ids)
            label = label.type_as(output)
            loss = loss_fn(output, label)
            val_loss += loss.item()

            _, pred = torch.max(output, dim=1)
            correct += torch.sum(pred == torch.argmax(label, dim=1)).item()
        # END for batch in val_loader
    # END with torch.inference_mode()

    val_acc = correct / len(val)

    tqdm.write(f"Validation Loss:          {val_loss}")
    tqdm.write(f"Validation Accuracy:      {val_acc}")
    tqdm.write(f"Best Validation Accuracy: {best_acc}")

    if val_acc > best_acc:
        tqdm.write("Saving model...\n")
        best_acc = val_acc
        torch.save(model.state_dict(), "model.pth")
    # END if val_loss < best_loss
# END for epoch in trange(EPOCHS)

  0%|          | 0/30 [00:00<?, ?it/s]


Epoch 1/30
Training...


  0%|          | 0/57 [00:00<?, ?it/s]

Training Loss:            62.90908396244049
Training Accuracy:        0.5179063360881543
Validating...


  0%|          | 0/19 [00:00<?, ?it/s]

Validation Loss:          20.457569122314453
Validation Accuracy:      0.5198347107438016
Best Validation Accuracy: 0
Saving model...


Epoch 2/30
Training...


  0%|          | 0/57 [00:00<?, ?it/s]

Training Loss:            59.84575939178467
Training Accuracy:        0.5432506887052342
Validating...


  0%|          | 0/19 [00:00<?, ?it/s]

Validation Loss:          18.894426345825195
Validation Accuracy:      0.556198347107438
Best Validation Accuracy: 0.5198347107438016
Saving model...


Epoch 3/30
Training...


  0%|          | 0/57 [00:00<?, ?it/s]

Training Loss:            57.814793825149536
Training Accuracy:        0.5702479338842975
Validating...


  0%|          | 0/19 [00:00<?, ?it/s]

Validation Loss:          18.497022807598114
Validation Accuracy:      0.5685950413223141
Best Validation Accuracy: 0.556198347107438
Saving model...


Epoch 4/30
Training...


  0%|          | 0/57 [00:00<?, ?it/s]

Training Loss:            56.88080567121506
Training Accuracy:        0.5782369146005509
Validating...


  0%|          | 0/19 [00:00<?, ?it/s]

Validation Loss:          18.513960242271423
Validation Accuracy:      0.5628099173553719
Best Validation Accuracy: 0.5685950413223141

Epoch 5/30
Training...


  0%|          | 0/57 [00:00<?, ?it/s]

Training Loss:            56.765028297901154
Training Accuracy:        0.5826446280991735
Validating...


  0%|          | 0/19 [00:00<?, ?it/s]

Validation Loss:          18.896151781082153
Validation Accuracy:      0.5396694214876033
Best Validation Accuracy: 0.5685950413223141

Epoch 6/30
Training...


  0%|          | 0/57 [00:00<?, ?it/s]

Training Loss:            56.2214418053627
Training Accuracy:        0.5931129476584022
Validating...


  0%|          | 0/19 [00:00<?, ?it/s]

Validation Loss:          17.589060187339783
Validation Accuracy:      0.6239669421487604
Best Validation Accuracy: 0.5685950413223141
Saving model...


Epoch 7/30
Training...


  0%|          | 0/57 [00:00<?, ?it/s]

Training Loss:            55.44809937477112
Training Accuracy:        0.5988980716253444
Validating...


  0%|          | 0/19 [00:00<?, ?it/s]

Validation Loss:          17.83389723300934
Validation Accuracy:      0.5867768595041323
Best Validation Accuracy: 0.6239669421487604

Epoch 8/30
Training...


  0%|          | 0/57 [00:00<?, ?it/s]

Training Loss:            54.97909379005432
Training Accuracy:        0.5964187327823691
Validating...


  0%|          | 0/19 [00:00<?, ?it/s]

Validation Loss:          17.185012459754944
Validation Accuracy:      0.6338842975206611
Best Validation Accuracy: 0.6239669421487604
Saving model...


Epoch 9/30
Training...


  0%|          | 0/57 [00:00<?, ?it/s]

Training Loss:            54.200892329216
Training Accuracy:        0.61267217630854
Validating...


  0%|          | 0/19 [00:00<?, ?it/s]

Validation Loss:          17.04141479730606
Validation Accuracy:      0.6429752066115703
Best Validation Accuracy: 0.6338842975206611
Saving model...


Epoch 10/30
Training...


  0%|          | 0/57 [00:00<?, ?it/s]

Training Loss:            54.040095806121826
Training Accuracy:        0.6123966942148761
Validating...


  0%|          | 0/19 [00:00<?, ?it/s]

Validation Loss:          17.17557603120804
Validation Accuracy:      0.6462809917355372
Best Validation Accuracy: 0.6429752066115703
Saving model...


Epoch 11/30
Training...


  0%|          | 0/57 [00:00<?, ?it/s]

Training Loss:            53.50017488002777
Training Accuracy:        0.609366391184573
Validating...


  0%|          | 0/19 [00:00<?, ?it/s]

Validation Loss:          16.831595599651337
Validation Accuracy:      0.6553719008264463
Best Validation Accuracy: 0.6462809917355372
Saving model...


Epoch 12/30
Training...


  0%|          | 0/57 [00:00<?, ?it/s]

Training Loss:            53.685641050338745
Training Accuracy:        0.6214876033057851
Validating...


  0%|          | 0/19 [00:00<?, ?it/s]

Validation Loss:          17.07013750076294
Validation Accuracy:      0.6396694214876033
Best Validation Accuracy: 0.6553719008264463

Epoch 13/30
Training...


  0%|          | 0/57 [00:00<?, ?it/s]

Training Loss:            53.457115173339844
Training Accuracy:        0.6104683195592286
Validating...


  0%|          | 0/19 [00:00<?, ?it/s]

Validation Loss:          16.662864685058594
Validation Accuracy:      0.6537190082644628
Best Validation Accuracy: 0.6553719008264463

Epoch 14/30
Training...


  0%|          | 0/57 [00:00<?, ?it/s]

Training Loss:            53.23262292146683
Training Accuracy:        0.6143250688705234
Validating...


  0%|          | 0/19 [00:00<?, ?it/s]

Validation Loss:          16.63130784034729
Validation Accuracy:      0.6380165289256199
Best Validation Accuracy: 0.6553719008264463

Epoch 15/30
Training...


  0%|          | 0/57 [00:00<?, ?it/s]

Training Loss:            53.371110677719116
Training Accuracy:        0.615426997245179
Validating...


  0%|          | 0/19 [00:00<?, ?it/s]

Validation Loss:          16.706717669963837
Validation Accuracy:      0.6388429752066116
Best Validation Accuracy: 0.6553719008264463

Epoch 16/30
Training...


  0%|          | 0/57 [00:00<?, ?it/s]

Training Loss:            53.11155146360397
Training Accuracy:        0.615426997245179
Validating...


  0%|          | 0/19 [00:00<?, ?it/s]

Validation Loss:          16.46896457672119
Validation Accuracy:      0.643801652892562
Best Validation Accuracy: 0.6553719008264463

Epoch 17/30
Training...


  0%|          | 0/57 [00:00<?, ?it/s]

Training Loss:            52.76141905784607
Training Accuracy:        0.6212121212121212
Validating...


  0%|          | 0/19 [00:00<?, ?it/s]

Validation Loss:          16.334835648536682
Validation Accuracy:      0.6677685950413224
Best Validation Accuracy: 0.6553719008264463
Saving model...


Epoch 18/30
Training...


  0%|          | 0/57 [00:00<?, ?it/s]

Training Loss:            52.625972509384155
Training Accuracy:        0.615426997245179
Validating...


  0%|          | 0/19 [00:00<?, ?it/s]

Validation Loss:          16.334219217300415
Validation Accuracy:      0.6677685950413224
Best Validation Accuracy: 0.6677685950413224

Epoch 19/30
Training...


  0%|          | 0/57 [00:00<?, ?it/s]

Training Loss:            53.00153458118439
Training Accuracy:        0.6253443526170799
Validating...


  0%|          | 0/19 [00:00<?, ?it/s]

Validation Loss:          16.279368817806244
Validation Accuracy:      0.6743801652892562
Best Validation Accuracy: 0.6677685950413224
Saving model...


Epoch 20/30
Training...


  0%|          | 0/57 [00:00<?, ?it/s]

Training Loss:            52.40816056728363
Training Accuracy:        0.6258953168044077
Validating...


  0%|          | 0/19 [00:00<?, ?it/s]

Validation Loss:          16.28353500366211
Validation Accuracy:      0.6669421487603305
Best Validation Accuracy: 0.6743801652892562

Epoch 21/30
Training...


  0%|          | 0/57 [00:00<?, ?it/s]

Training Loss:            52.364370942115784
Training Accuracy:        0.6264462809917355
Validating...


  0%|          | 0/19 [00:00<?, ?it/s]

Validation Loss:          16.155362367630005
Validation Accuracy:      0.6677685950413224
Best Validation Accuracy: 0.6743801652892562

Epoch 22/30
Training...


  0%|          | 0/57 [00:00<?, ?it/s]

Training Loss:            52.491558372974396
Training Accuracy:        0.6214876033057851
Validating...


  0%|          | 0/19 [00:00<?, ?it/s]

Validation Loss:          16.159138560295105
Validation Accuracy:      0.6661157024793388
Best Validation Accuracy: 0.6743801652892562

Epoch 23/30
Training...


  0%|          | 0/57 [00:00<?, ?it/s]

Training Loss:            52.66130083799362
Training Accuracy:        0.6267217630853994
Validating...


  0%|          | 0/19 [00:00<?, ?it/s]

Validation Loss:          16.222151935100555
Validation Accuracy:      0.6512396694214876
Best Validation Accuracy: 0.6743801652892562

Epoch 24/30
Training...


  0%|          | 0/57 [00:00<?, ?it/s]

Training Loss:            52.51752734184265
Training Accuracy:        0.6179063360881543
Validating...


  0%|          | 0/19 [00:00<?, ?it/s]

Validation Loss:          16.27602595090866
Validation Accuracy:      0.6454545454545455
Best Validation Accuracy: 0.6743801652892562

Epoch 25/30
Training...


  0%|          | 0/57 [00:00<?, ?it/s]

Training Loss:            52.393547773361206
Training Accuracy:        0.6201101928374656
Validating...


  0%|          | 0/19 [00:00<?, ?it/s]

Validation Loss:          16.048708260059357
Validation Accuracy:      0.6611570247933884
Best Validation Accuracy: 0.6743801652892562

Epoch 26/30
Training...


  0%|          | 0/57 [00:00<?, ?it/s]

Training Loss:            51.93385946750641
Training Accuracy:        0.6203856749311295
Validating...


  0%|          | 0/19 [00:00<?, ?it/s]

Validation Loss:          15.921619355678558
Validation Accuracy:      0.6727272727272727
Best Validation Accuracy: 0.6743801652892562

Epoch 27/30
Training...


  0%|          | 0/57 [00:00<?, ?it/s]

Training Loss:            52.36105525493622
Training Accuracy:        0.6198347107438017
Validating...


  0%|          | 0/19 [00:00<?, ?it/s]

Validation Loss:          15.979498744010925
Validation Accuracy:      0.6661157024793388
Best Validation Accuracy: 0.6743801652892562

Epoch 28/30
Training...


  0%|          | 0/57 [00:00<?, ?it/s]

Training Loss:            52.0375674366951
Training Accuracy:        0.6275482093663912
Validating...


  0%|          | 0/19 [00:00<?, ?it/s]

Validation Loss:          16.0925110578537
Validation Accuracy:      0.6727272727272727
Best Validation Accuracy: 0.6743801652892562

Epoch 29/30
Training...


  0%|          | 0/57 [00:00<?, ?it/s]

Training Loss:            52.100785315036774
Training Accuracy:        0.6212121212121212
Validating...


  0%|          | 0/19 [00:00<?, ?it/s]

Validation Loss:          16.024693489074707
Validation Accuracy:      0.6661157024793388
Best Validation Accuracy: 0.6743801652892562

Epoch 30/30
Training...


  0%|          | 0/57 [00:00<?, ?it/s]

Training Loss:            51.76095658540726
Training Accuracy:        0.6300275482093664
Validating...


  0%|          | 0/19 [00:00<?, ?it/s]

Validation Loss:          16.248477399349213
Validation Accuracy:      0.6471074380165289
Best Validation Accuracy: 0.6743801652892562


In [10]:
test_dataset = BertDataset(test, tokenizer, max_len=MAX_LEN)
test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=BATCH_SIZE,
)

In [11]:
model.load_state_dict(torch.load("model.pth", weights_only=True))
model.eval()

correct = 0
test_loss = 0

with torch.inference_mode():
    for batch in tqdm(test_loader, colour='cyan'):
        input_ids = batch['input_ids'].to(DEVICE, non_blocking=True)
        attention_mask = batch['attention_mask'].to(DEVICE, non_blocking=True)
        token_type_ids = batch['token_type_ids'].to(DEVICE, non_blocking=True)
        label = batch['label'].to(DEVICE, non_blocking=True)

        output = model(input_ids, attention_mask, token_type_ids)
        label = label.type_as(output)
        loss = loss_fn(output, label)
        test_loss += loss.item()

        _, pred = torch.max(output, dim=1)
        correct += torch.sum(pred == torch.argmax(label, dim=1)).item()
    # END for batch in test_loader
# END with torch.inference_mode()

tqdm.write(f"Test Loss:     {test_loss}")
tqdm.write(f"Test Accuracy: {correct / len(test)}")

  0%|          | 0/19 [00:00<?, ?it/s]

Test Loss:     16.089314579963684
Test Accuracy: 0.6416184971098265
