# Import dependencies

In [1]:
import os
import sys

sys.path.insert(0, os.path.dirname(os.getcwd())) 

In [2]:
import time
import gc
import json

import numpy as np
import pandas as pd

from transformers import BertTokenizer, BertModel
from transformers import logging

from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score

from InputDataset import InputDataset

import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.nn.utils.rnn import pad_sequence

from torch import cuda

from absa_models.ABSA_BERT_Dropout_BiLSTM_Linear import ABSA_BERT_Dropout_BiLSTM_Linear

In [3]:
device = 'cuda' if cuda.is_available() else 'cpu'
logging.set_verbosity_error() 

In [4]:
print(torch.cuda.get_device_name(0))
print(f"Memory: {torch.cuda.get_device_properties(0).total_memory // 1024 ** 3} GB")

NVIDIA GeForce RTX 2060 SUPER
Memory: 8 GB


In [5]:
def clear_memory():
    torch.cuda.empty_cache()

    with torch.no_grad():
        torch.cuda.empty_cache()

    gc.collect()

# Load Data

In [6]:
DATASET = 'ABSA_SemEval16_Restaurants_train.json'

In [7]:
df = pd.json_normalize(json.load(open(DATASET)))

In [8]:
df.head()

Unnamed: 0,text,tokens,absa_tags
0,Judging from previous posts this used to be a ...,"[Judging, from, previous, posts, this, used, t...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, ..."
1,"We, there were four of us, arrived at noon - t...","[We, ,, there, were, four, of, us, ,, arrived,...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
2,"They never brought us complimentary noodles, i...","[They, never, brought, us, complimentary, nood...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,The food was lousy - too sweet or too salty an...,"[The, food, was, lousy, -, too, sweet, or, too...","[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]"
4,The food was lousy - too sweet or too salty an...,"[The, food, was, lousy, -, too, sweet, or, too...","[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]"


# Train & Validate

In [9]:
TRAIN_BATCH_SIZE = 4
VALID_BATCH_SIZE = 4

EPOCHS = 2

LEARNING_RATE = 1e-5

TRAIN_SPLIT = 0.8

NO_RUNS = 10

In [10]:
MODEL_OUTPUT = '../../../results/ABSA/SemEval16 - Task 5 - Restaurants/models/bert_pre_trained_dropout_bilstm_linear.pth'
STATS_OUTPUT = '../../../results/ABSA/SemEval16 - Task 5 - Restaurants/stats/bert_pre_trained_dropout_bilstm_linear.csv'

In [11]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [12]:
best_accuracy = 0.0

In [13]:
def create_mini_batch(samples):
    ids_tensors = [s[1] for s in samples]
    ids_tensors = pad_sequence(ids_tensors, batch_first=True).to(device)

    tags_tensors = [s[2] for s in samples]
    tags_tensors = pad_sequence(tags_tensors, batch_first=True).to(device)
    
    masks_tensors = torch.zeros(ids_tensors.shape, dtype=torch.long).to(device)
    masks_tensors = masks_tensors.masked_fill(ids_tensors != 0, 1).to(device)
    
    return ids_tensors, tags_tensors, masks_tensors

In [14]:
def train(epoch, model, loss_fn, optimizer, dataloader):
    model.train()

    dataloader_len = len(dataloader)

    for _,data in enumerate(dataloader, 0):
        optimizer.zero_grad()

        ids_tensors, tags_tensors, masks_tensors = data

        outputs = model(ids_tensors, masks_tensors)

        loss = loss_fn(outputs.view(-1, 4), tags_tensors.view(-1))
        
        if _ % (dataloader_len // 10) == 0:
            print(f"Epoch: {epoch}/{EPOCHS}, Batch: {_}/{dataloader_len}, Loss: {loss.item()}")
        
        loss.backward()
        
        optimizer.step()

In [15]:
def validation(model, dataloader):
    model.eval()
    
    fin_targets=[]
    fin_outputs=[]

    with torch.no_grad():
        for _, data in enumerate(dataloader, 0):
            ids_tensors, tags_tensors, masks_tensors = data
            ids_tensors = ids_tensors.to(device)
            tags_tensors = tags_tensors.to(device)
            masks_tensors = masks_tensors.to(device)

            outputs = model(ids_tensors, masks_tensors)
            
            _, predictions = torch.max(outputs, dim=2)

            fin_outputs += list([int(p) for pred in predictions for p in pred])
            fin_targets += list([int(tag) for tags_tensor in tags_tensors for tag in tags_tensor])

    return fin_outputs, fin_targets

In [16]:
results = pd.DataFrame(columns=['accuracy','precision_score_micro','precision_score_macro','recall_score_micro','recall_score_macro','f1_score_micro','f1_score_macro', 'execution_time'])

In [17]:
for i in range(NO_RUNS):
    # clear cache cuda
    torch.cuda.empty_cache()
    with torch.no_grad():
        torch.cuda.empty_cache()
    gc.collect()

    start_time = time.time()

    print(f"Run {i + 1}/{NO_RUNS}")

    train_dataset = df.sample(frac=TRAIN_SPLIT)
    test_dataset = df.drop(train_dataset.index).reset_index(drop=True)
    train_dataset = train_dataset.reset_index(drop=True)

    training_set = InputDataset(train_dataset, tokenizer)
    testing_set = InputDataset(test_dataset, tokenizer)

    train_dataloader = DataLoader(
        training_set,
        sampler = RandomSampler(train_dataset),
        batch_size = TRAIN_BATCH_SIZE,
        drop_last = True,
        collate_fn=create_mini_batch
    )

    validation_dataloader = DataLoader(
        testing_set,
        sampler = SequentialSampler(testing_set),
        batch_size = VALID_BATCH_SIZE,
        drop_last = True,
        collate_fn=create_mini_batch
    )

    model = ABSA_BERT_Dropout_BiLSTM_Linear(BertModel.from_pretrained('bert-base-uncased'), dropout=0.3, bilstm_in_features=256, no_out_labels=4, device=device).to(device)

    optimizer = torch.optim.Adam(params = model.parameters(), lr=LEARNING_RATE)
    loss_fn = torch.nn.CrossEntropyLoss()

    for epoch in range(EPOCHS):
        train(epoch, model, loss_fn, optimizer, train_dataloader)

    outputs, targets = validation(model, validation_dataloader)
    
    accuracy = accuracy_score(targets, outputs)
    precision_score_micro = precision_score(targets, outputs, average='micro')
    precision_score_macro = precision_score(targets, outputs, average='macro')
    recall_score_micro = recall_score(targets, outputs, average='micro')
    recall_score_macro = recall_score(targets, outputs, average='macro')
    f1_score_micro = f1_score(targets, outputs, average='micro')
    f1_score_macro = f1_score(targets, outputs, average='macro')

    execution_time = time.time() - start_time

    results.loc[i] = [accuracy,precision_score_micro,precision_score_macro,recall_score_micro,recall_score_macro,f1_score_micro,f1_score_macro, execution_time]

    if accuracy > best_accuracy:
        best_accuracy = accuracy
        torch.save(model, MODEL_OUTPUT)

    del train_dataset
    del test_dataset
    del training_set
    del testing_set
    del model
    del loss_fn
    del optimizer
    del outputs
    del targets

Run 1/10
Epoch: 0/2, Batch: 0/501, Loss: 1.3372317552566528
Epoch: 0/2, Batch: 50/501, Loss: 0.26957133412361145
Epoch: 0/2, Batch: 100/501, Loss: 0.18267442286014557
Epoch: 0/2, Batch: 150/501, Loss: 0.5470376014709473
Epoch: 0/2, Batch: 200/501, Loss: 0.22714264690876007
Epoch: 0/2, Batch: 250/501, Loss: 0.18650703132152557
Epoch: 0/2, Batch: 300/501, Loss: 0.2192768007516861
Epoch: 0/2, Batch: 350/501, Loss: 0.16719886660575867
Epoch: 0/2, Batch: 400/501, Loss: 0.09343289583921432
Epoch: 0/2, Batch: 450/501, Loss: 0.05971622094511986
Epoch: 0/2, Batch: 500/501, Loss: 0.05397642403841019
Epoch: 1/2, Batch: 0/501, Loss: 0.22075191140174866
Epoch: 1/2, Batch: 50/501, Loss: 0.08641530573368073
Epoch: 1/2, Batch: 100/501, Loss: 0.0952942967414856
Epoch: 1/2, Batch: 150/501, Loss: 0.11962279677391052
Epoch: 1/2, Batch: 200/501, Loss: 0.08681264519691467
Epoch: 1/2, Batch: 250/501, Loss: 0.3792873024940491
Epoch: 1/2, Batch: 300/501, Loss: 0.07292864471673965
Epoch: 1/2, Batch: 350/501, Lo

  _warn_prf(average, modifier, msg_start, len(result))


Run 2/10
Epoch: 0/2, Batch: 0/501, Loss: 1.4734156131744385
Epoch: 0/2, Batch: 50/501, Loss: 0.31056833267211914
Epoch: 0/2, Batch: 100/501, Loss: 0.29673612117767334
Epoch: 0/2, Batch: 150/501, Loss: 0.6667531728744507
Epoch: 0/2, Batch: 200/501, Loss: 0.3433379530906677
Epoch: 0/2, Batch: 250/501, Loss: 0.4033827781677246
Epoch: 0/2, Batch: 300/501, Loss: 0.31040453910827637
Epoch: 0/2, Batch: 350/501, Loss: 0.1484968662261963
Epoch: 0/2, Batch: 400/501, Loss: 0.23793640732765198
Epoch: 0/2, Batch: 450/501, Loss: 0.14800694584846497
Epoch: 0/2, Batch: 500/501, Loss: 0.2618422210216522
Epoch: 1/2, Batch: 0/501, Loss: 0.10553143918514252
Epoch: 1/2, Batch: 50/501, Loss: 0.2981967329978943
Epoch: 1/2, Batch: 100/501, Loss: 0.22884051501750946
Epoch: 1/2, Batch: 150/501, Loss: 0.15063782036304474
Epoch: 1/2, Batch: 200/501, Loss: 0.1385904848575592
Epoch: 1/2, Batch: 250/501, Loss: 0.12228793650865555
Epoch: 1/2, Batch: 300/501, Loss: 0.19119806587696075
Epoch: 1/2, Batch: 350/501, Loss:

  _warn_prf(average, modifier, msg_start, len(result))


Run 3/10
Epoch: 0/2, Batch: 0/501, Loss: 1.5055649280548096
Epoch: 0/2, Batch: 50/501, Loss: 0.3573801517486572
Epoch: 0/2, Batch: 100/501, Loss: 0.2724635601043701
Epoch: 0/2, Batch: 150/501, Loss: 0.3449219763278961
Epoch: 0/2, Batch: 200/501, Loss: 0.25143200159072876
Epoch: 0/2, Batch: 250/501, Loss: 0.3694571554660797
Epoch: 0/2, Batch: 300/501, Loss: 0.14510761201381683
Epoch: 0/2, Batch: 350/501, Loss: 0.16649052500724792
Epoch: 0/2, Batch: 400/501, Loss: 0.13574853539466858
Epoch: 0/2, Batch: 450/501, Loss: 0.11637713760137558
Epoch: 0/2, Batch: 500/501, Loss: 0.15276692807674408
Epoch: 1/2, Batch: 0/501, Loss: 0.33111268281936646
Epoch: 1/2, Batch: 50/501, Loss: 0.07547174394130707
Epoch: 1/2, Batch: 100/501, Loss: 0.15486830472946167
Epoch: 1/2, Batch: 150/501, Loss: 0.11387661099433899
Epoch: 1/2, Batch: 200/501, Loss: 0.1502121090888977
Epoch: 1/2, Batch: 250/501, Loss: 0.10882952064275742
Epoch: 1/2, Batch: 300/501, Loss: 0.23007673025131226
Epoch: 1/2, Batch: 350/501, Los

  _warn_prf(average, modifier, msg_start, len(result))


Run 4/10
Epoch: 0/2, Batch: 0/501, Loss: 1.3243271112442017
Epoch: 0/2, Batch: 50/501, Loss: 0.3837937116622925
Epoch: 0/2, Batch: 100/501, Loss: 0.4946613907814026
Epoch: 0/2, Batch: 150/501, Loss: 0.25077465176582336
Epoch: 0/2, Batch: 200/501, Loss: 0.27100756764411926
Epoch: 0/2, Batch: 250/501, Loss: 0.22920796275138855
Epoch: 0/2, Batch: 300/501, Loss: 0.35936474800109863
Epoch: 0/2, Batch: 350/501, Loss: 0.17409077286720276
Epoch: 0/2, Batch: 400/501, Loss: 0.15364164113998413
Epoch: 0/2, Batch: 450/501, Loss: 0.11549869179725647
Epoch: 0/2, Batch: 500/501, Loss: 0.1338236927986145
Epoch: 1/2, Batch: 0/501, Loss: 0.1206924170255661
Epoch: 1/2, Batch: 50/501, Loss: 0.21621465682983398
Epoch: 1/2, Batch: 100/501, Loss: 0.04294122755527496
Epoch: 1/2, Batch: 150/501, Loss: 0.1029665544629097
Epoch: 1/2, Batch: 200/501, Loss: 0.13119420409202576
Epoch: 1/2, Batch: 250/501, Loss: 0.12965109944343567
Epoch: 1/2, Batch: 300/501, Loss: 0.0829215720295906
Epoch: 1/2, Batch: 350/501, Loss

  _warn_prf(average, modifier, msg_start, len(result))


Run 5/10
Epoch: 0/2, Batch: 0/501, Loss: 1.2764060497283936
Epoch: 0/2, Batch: 50/501, Loss: 0.7209184169769287
Epoch: 0/2, Batch: 100/501, Loss: 0.36739465594291687
Epoch: 0/2, Batch: 150/501, Loss: 0.38248735666275024
Epoch: 0/2, Batch: 200/501, Loss: 0.19883334636688232
Epoch: 0/2, Batch: 250/501, Loss: 0.20351478457450867
Epoch: 0/2, Batch: 300/501, Loss: 0.13055092096328735
Epoch: 0/2, Batch: 350/501, Loss: 0.10080907493829727
Epoch: 0/2, Batch: 400/501, Loss: 0.14913742244243622
Epoch: 0/2, Batch: 450/501, Loss: 0.20382894575595856
Epoch: 0/2, Batch: 500/501, Loss: 0.3269672393798828
Epoch: 1/2, Batch: 0/501, Loss: 0.295542448759079
Epoch: 1/2, Batch: 50/501, Loss: 0.09146575629711151
Epoch: 1/2, Batch: 100/501, Loss: 0.1619015336036682
Epoch: 1/2, Batch: 150/501, Loss: 0.05600066855549812
Epoch: 1/2, Batch: 200/501, Loss: 0.08621595054864883
Epoch: 1/2, Batch: 250/501, Loss: 0.14950478076934814
Epoch: 1/2, Batch: 300/501, Loss: 0.1776418536901474
Epoch: 1/2, Batch: 350/501, Loss

  _warn_prf(average, modifier, msg_start, len(result))


Run 6/10
Epoch: 0/2, Batch: 0/501, Loss: 1.3815734386444092
Epoch: 0/2, Batch: 50/501, Loss: 0.24794384837150574
Epoch: 0/2, Batch: 100/501, Loss: 0.5780226588249207
Epoch: 0/2, Batch: 150/501, Loss: 0.2936543822288513
Epoch: 0/2, Batch: 200/501, Loss: 0.21889759600162506
Epoch: 0/2, Batch: 250/501, Loss: 0.23975956439971924
Epoch: 0/2, Batch: 300/501, Loss: 0.223497211933136
Epoch: 0/2, Batch: 350/501, Loss: 0.1888204962015152
Epoch: 0/2, Batch: 400/501, Loss: 0.15995334088802338
Epoch: 0/2, Batch: 450/501, Loss: 0.03576294332742691
Epoch: 0/2, Batch: 500/501, Loss: 0.1768287718296051
Epoch: 1/2, Batch: 0/501, Loss: 0.28987035155296326
Epoch: 1/2, Batch: 50/501, Loss: 0.1357325315475464
Epoch: 1/2, Batch: 100/501, Loss: 0.16757024824619293
Epoch: 1/2, Batch: 150/501, Loss: 0.04457523301243782
Epoch: 1/2, Batch: 200/501, Loss: 0.25879761576652527
Epoch: 1/2, Batch: 250/501, Loss: 0.11780251562595367
Epoch: 1/2, Batch: 300/501, Loss: 0.13109520077705383
Epoch: 1/2, Batch: 350/501, Loss:

  _warn_prf(average, modifier, msg_start, len(result))


Run 7/10
Epoch: 0/2, Batch: 0/501, Loss: 1.3527722358703613
Epoch: 0/2, Batch: 50/501, Loss: 0.33541980385780334
Epoch: 0/2, Batch: 100/501, Loss: 0.4073351323604584
Epoch: 0/2, Batch: 150/501, Loss: 0.16363930702209473
Epoch: 0/2, Batch: 200/501, Loss: 0.31986451148986816
Epoch: 0/2, Batch: 250/501, Loss: 0.478752076625824
Epoch: 0/2, Batch: 300/501, Loss: 0.20075854659080505
Epoch: 0/2, Batch: 350/501, Loss: 0.18070539832115173
Epoch: 0/2, Batch: 400/501, Loss: 0.148182213306427
Epoch: 0/2, Batch: 450/501, Loss: 0.08441294729709625
Epoch: 0/2, Batch: 500/501, Loss: 0.09012862294912338
Epoch: 1/2, Batch: 0/501, Loss: 0.08977865427732468
Epoch: 1/2, Batch: 50/501, Loss: 0.23136943578720093
Epoch: 1/2, Batch: 100/501, Loss: 0.08375988900661469
Epoch: 1/2, Batch: 150/501, Loss: 0.16893774271011353
Epoch: 1/2, Batch: 200/501, Loss: 0.20095117390155792
Epoch: 1/2, Batch: 250/501, Loss: 0.1269429624080658
Epoch: 1/2, Batch: 300/501, Loss: 0.09208545088768005
Epoch: 1/2, Batch: 350/501, Loss

  _warn_prf(average, modifier, msg_start, len(result))


Run 8/10
Epoch: 0/2, Batch: 0/501, Loss: 1.3720585107803345
Epoch: 0/2, Batch: 50/501, Loss: 0.7799583673477173
Epoch: 0/2, Batch: 100/501, Loss: 0.2889355421066284
Epoch: 0/2, Batch: 150/501, Loss: 0.3383268713951111
Epoch: 0/2, Batch: 200/501, Loss: 0.4101695716381073
Epoch: 0/2, Batch: 250/501, Loss: 0.16904692351818085
Epoch: 0/2, Batch: 300/501, Loss: 0.13360919058322906
Epoch: 0/2, Batch: 350/501, Loss: 0.15092360973358154
Epoch: 0/2, Batch: 400/501, Loss: 0.27866610884666443
Epoch: 0/2, Batch: 450/501, Loss: 0.22742612659931183
Epoch: 0/2, Batch: 500/501, Loss: 0.13091197609901428
Epoch: 1/2, Batch: 0/501, Loss: 0.045281846076250076
Epoch: 1/2, Batch: 50/501, Loss: 0.1491203010082245
Epoch: 1/2, Batch: 100/501, Loss: 0.196173757314682
Epoch: 1/2, Batch: 150/501, Loss: 0.16922904551029205
Epoch: 1/2, Batch: 200/501, Loss: 0.1682540625333786
Epoch: 1/2, Batch: 250/501, Loss: 0.06089256331324577
Epoch: 1/2, Batch: 300/501, Loss: 0.11267917603254318
Epoch: 1/2, Batch: 350/501, Loss:

  _warn_prf(average, modifier, msg_start, len(result))


Run 9/10
Epoch: 0/2, Batch: 0/501, Loss: 1.4027537107467651
Epoch: 0/2, Batch: 50/501, Loss: 0.2765120565891266
Epoch: 0/2, Batch: 100/501, Loss: 0.3133908212184906
Epoch: 0/2, Batch: 150/501, Loss: 0.4022553563117981
Epoch: 0/2, Batch: 200/501, Loss: 0.25325843691825867
Epoch: 0/2, Batch: 250/501, Loss: 0.32814452052116394
Epoch: 0/2, Batch: 300/501, Loss: 0.24985751509666443
Epoch: 0/2, Batch: 350/501, Loss: 0.18277513980865479
Epoch: 0/2, Batch: 400/501, Loss: 0.13613088428974152
Epoch: 0/2, Batch: 450/501, Loss: 0.20344184339046478
Epoch: 0/2, Batch: 500/501, Loss: 0.1271086186170578
Epoch: 1/2, Batch: 0/501, Loss: 0.0940302312374115
Epoch: 1/2, Batch: 50/501, Loss: 0.1050601676106453
Epoch: 1/2, Batch: 100/501, Loss: 0.1495361179113388
Epoch: 1/2, Batch: 150/501, Loss: 0.04615339636802673
Epoch: 1/2, Batch: 200/501, Loss: 0.14855574071407318
Epoch: 1/2, Batch: 250/501, Loss: 0.20184896886348724
Epoch: 1/2, Batch: 300/501, Loss: 0.15082241594791412
Epoch: 1/2, Batch: 350/501, Loss:

  _warn_prf(average, modifier, msg_start, len(result))


Run 10/10
Epoch: 0/2, Batch: 0/501, Loss: 1.3867647647857666
Epoch: 0/2, Batch: 50/501, Loss: 0.6434510350227356
Epoch: 0/2, Batch: 100/501, Loss: 0.37690138816833496
Epoch: 0/2, Batch: 150/501, Loss: 0.19747866690158844
Epoch: 0/2, Batch: 200/501, Loss: 0.19600293040275574
Epoch: 0/2, Batch: 250/501, Loss: 0.17990685999393463
Epoch: 0/2, Batch: 300/501, Loss: 0.1686510294675827
Epoch: 0/2, Batch: 350/501, Loss: 0.19198372960090637
Epoch: 0/2, Batch: 400/501, Loss: 0.13128529489040375
Epoch: 0/2, Batch: 450/501, Loss: 0.1493339091539383
Epoch: 0/2, Batch: 500/501, Loss: 0.34613683819770813
Epoch: 1/2, Batch: 0/501, Loss: 0.09689463675022125
Epoch: 1/2, Batch: 50/501, Loss: 0.13778163492679596
Epoch: 1/2, Batch: 100/501, Loss: 0.055622611194849014
Epoch: 1/2, Batch: 150/501, Loss: 0.17002885043621063
Epoch: 1/2, Batch: 200/501, Loss: 0.20751795172691345
Epoch: 1/2, Batch: 250/501, Loss: 0.1467234194278717
Epoch: 1/2, Batch: 300/501, Loss: 0.11128871887922287
Epoch: 1/2, Batch: 350/501, 

  _warn_prf(average, modifier, msg_start, len(result))


In [18]:
results

Unnamed: 0,accuracy,precision_score_micro,precision_score_macro,recall_score_micro,recall_score_macro,f1_score_micro,f1_score_macro,execution_time
0,0.962472,0.962472,0.544023,0.962472,0.486541,0.962472,0.4502,112.819
1,0.967998,0.967998,0.586315,0.967998,0.595147,0.967998,0.590647,97.074863
2,0.969155,0.969155,0.590735,0.969155,0.621731,0.969155,0.604941,97.658057
3,0.955559,0.955559,0.598544,0.955559,0.533716,0.955559,0.518666,96.817216
4,0.965611,0.965611,0.582737,0.965611,0.598223,0.965611,0.587395,97.123147
5,0.958143,0.958143,0.413822,0.958143,0.466397,0.958143,0.436695,96.497204
6,0.970011,0.970011,0.603725,0.970011,0.572069,0.970011,0.577821,97.277734
7,0.966017,0.966017,0.572028,0.966017,0.539037,0.966017,0.54776,97.449983
8,0.964996,0.964996,0.622778,0.964996,0.508364,0.964996,0.50017,96.985469
9,0.958703,0.958703,0.654543,0.958703,0.499106,0.958703,0.48018,97.651367


In [19]:
results.to_csv(STATS_OUTPUT)