# Import dependencies

In [1]:
import os
import sys

sys.path.insert(0, os.path.dirname(os.getcwd())) 

In [2]:
import time
import gc
import json

import numpy as np
import pandas as pd

from transformers import BertTokenizer, BertModel
from transformers import logging

from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score

from InputDataset import InputDataset

import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.nn.utils.rnn import pad_sequence

from torch import cuda

from absa_models.ABSA_BERT_Dropout_Linear import ABSA_BERT_Dropout_Linear

In [3]:
device = 'cuda' if cuda.is_available() else 'cpu'
logging.set_verbosity_error() 

In [4]:
print(torch.cuda.get_device_name(0))
print(f"Memory: {torch.cuda.get_device_properties(0).total_memory // 1024 ** 3} GB")

NVIDIA GeForce RTX 2060 SUPER
Memory: 8 GB


In [5]:
def clear_memory():
    torch.cuda.empty_cache()

    with torch.no_grad():
        torch.cuda.empty_cache()

    gc.collect()

# Load Data

In [6]:
DATASET = 'ABSA_SemEval16_Restaurants_train.json'

In [7]:
df = pd.json_normalize(json.load(open(DATASET)))

In [8]:
df.head()

Unnamed: 0,text,tokens,absa_tags
0,Judging from previous posts this used to be a ...,"[Judging, from, previous, posts, this, used, t...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, ..."
1,"We, there were four of us, arrived at noon - t...","[We, ,, there, were, four, of, us, ,, arrived,...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
2,"They never brought us complimentary noodles, i...","[They, never, brought, us, complimentary, nood...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,The food was lousy - too sweet or too salty an...,"[The, food, was, lousy, -, too, sweet, or, too...","[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]"
4,The food was lousy - too sweet or too salty an...,"[The, food, was, lousy, -, too, sweet, or, too...","[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]"


# Train & Validate

In [9]:
TRAIN_BATCH_SIZE = 4
VALID_BATCH_SIZE = 4

EPOCHS = 2

LEARNING_RATE = 1e-5

TRAIN_SPLIT = 0.8

NO_RUNS = 10

In [10]:
BERT_FINE_TUNED_OUTPUT = '../../../results/ABSA/SemEval16 - Task 5 - Restaurants/models/bert_fine_tuned.pth'

In [11]:
MODEL_OUTPUT = '../../../results/ABSA/SemEval16 - Task 5 - Restaurants/models/bert_pre_trained_dropout_linear.pth'
STATS_OUTPUT = '../../../results/ABSA/SemEval16 - Task 5 - Restaurants/stats/bert_pre_trained_dropout_linear.csv'

In [12]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [13]:
best_accuracy = 0.0

In [14]:
def create_mini_batch(samples):
    ids_tensors = [s[1] for s in samples]
    ids_tensors = pad_sequence(ids_tensors, batch_first=True).to(device)

    tags_tensors = [s[2] for s in samples]
    tags_tensors = pad_sequence(tags_tensors, batch_first=True).to(device)
    
    masks_tensors = torch.zeros(ids_tensors.shape, dtype=torch.long).to(device)
    masks_tensors = masks_tensors.masked_fill(ids_tensors != 0, 1).to(device)
    
    return ids_tensors, tags_tensors, masks_tensors

In [15]:
def train(epoch, model, loss_fn, optimizer, dataloader):
    model.train()

    dataloader_len = len(dataloader)

    for _,data in enumerate(dataloader, 0):
        optimizer.zero_grad()

        ids_tensors, tags_tensors, masks_tensors = data

        outputs = model(ids_tensors, masks_tensors)

        loss = loss_fn(outputs.view(-1, 4), tags_tensors.view(-1))
        
        if _ % (dataloader_len // 10) == 0:
            print(f"Epoch: {epoch}/{EPOCHS}, Batch: {_}/{dataloader_len}, Loss: {loss.item()}")
        
        loss.backward()
        
        optimizer.step()

In [16]:
def validation(model, dataloader):
    model.eval()
    
    fin_targets=[]
    fin_outputs=[]

    with torch.no_grad():
        for _, data in enumerate(dataloader, 0):
            ids_tensors, tags_tensors, masks_tensors = data
            ids_tensors = ids_tensors.to(device)
            tags_tensors = tags_tensors.to(device)
            masks_tensors = masks_tensors.to(device)

            outputs = model(ids_tensors, masks_tensors)
            
            _, predictions = torch.max(outputs, dim=2)

            fin_outputs += list([int(p) for pred in predictions for p in pred])
            fin_targets += list([int(tag) for tags_tensor in tags_tensors for tag in tags_tensor])

    return fin_outputs, fin_targets

In [17]:
results = pd.DataFrame(columns=['accuracy','precision_score_micro','precision_score_macro','recall_score_micro','recall_score_macro','f1_score_micro','f1_score_macro', 'execution_time'])

In [18]:
for i in range(NO_RUNS):
    # clear cache cuda
    torch.cuda.empty_cache()
    with torch.no_grad():
        torch.cuda.empty_cache()
    gc.collect()

    start_time = time.time()

    print(f"Run {i + 1}/{NO_RUNS}")

    train_dataset = df.sample(frac=TRAIN_SPLIT)
    test_dataset = df.drop(train_dataset.index).reset_index(drop=True)
    train_dataset = train_dataset.reset_index(drop=True)

    training_set = InputDataset(train_dataset, tokenizer)
    testing_set = InputDataset(test_dataset, tokenizer)

    train_dataloader = DataLoader(
        training_set,
        sampler = RandomSampler(train_dataset),
        batch_size = TRAIN_BATCH_SIZE,
        drop_last = True,
        collate_fn=create_mini_batch
    )

    validation_dataloader = DataLoader(
        testing_set,
        sampler = SequentialSampler(testing_set),
        batch_size = VALID_BATCH_SIZE,
        drop_last = True,
        collate_fn=create_mini_batch
    )

    model = ABSA_BERT_Dropout_Linear(BertModel.from_pretrained('bert-base-uncased'), dropout=0.3, no_out_labels=4, device=device).to(device)

    optimizer = torch.optim.Adam(params = model.parameters(), lr=LEARNING_RATE)
    loss_fn = torch.nn.CrossEntropyLoss()

    for epoch in range(EPOCHS):
        train(epoch, model, loss_fn, optimizer, train_dataloader)

    outputs, targets = validation(model, validation_dataloader)
    
    accuracy = accuracy_score(targets, outputs)
    precision_score_micro = precision_score(targets, outputs, average='micro')
    precision_score_macro = precision_score(targets, outputs, average='macro')
    recall_score_micro = recall_score(targets, outputs, average='micro')
    recall_score_macro = recall_score(targets, outputs, average='macro')
    f1_score_micro = f1_score(targets, outputs, average='micro')
    f1_score_macro = f1_score(targets, outputs, average='macro')

    execution_time = time.time() - start_time

    results.loc[i] = [accuracy,precision_score_micro,precision_score_macro,recall_score_micro,recall_score_macro,f1_score_micro,f1_score_macro, execution_time]

    if accuracy > best_accuracy:
        best_accuracy = accuracy
        torch.save(model.bert, BERT_FINE_TUNED_OUTPUT)
        torch.save(model, MODEL_OUTPUT)

    del train_dataset
    del test_dataset
    del training_set
    del testing_set
    del model
    del loss_fn
    del optimizer
    del outputs
    del targets

Run 1/10
Epoch: 0/2, Batch: 0/501, Loss: 1.425561547279358
Epoch: 0/2, Batch: 50/501, Loss: 0.4436885118484497
Epoch: 0/2, Batch: 100/501, Loss: 0.20202913880348206
Epoch: 0/2, Batch: 150/501, Loss: 0.29910609126091003
Epoch: 0/2, Batch: 200/501, Loss: 0.329155296087265
Epoch: 0/2, Batch: 250/501, Loss: 0.10395553708076477
Epoch: 0/2, Batch: 300/501, Loss: 0.10924636572599411
Epoch: 0/2, Batch: 350/501, Loss: 0.21530422568321228
Epoch: 0/2, Batch: 400/501, Loss: 0.06656711548566818
Epoch: 0/2, Batch: 450/501, Loss: 0.06816624104976654
Epoch: 0/2, Batch: 500/501, Loss: 0.11064442992210388
Epoch: 1/2, Batch: 0/501, Loss: 0.15480688214302063
Epoch: 1/2, Batch: 50/501, Loss: 0.034236203879117966
Epoch: 1/2, Batch: 100/501, Loss: 0.12474062293767929
Epoch: 1/2, Batch: 150/501, Loss: 0.09590430557727814
Epoch: 1/2, Batch: 200/501, Loss: 0.06961847096681595
Epoch: 1/2, Batch: 250/501, Loss: 0.17892388999462128
Epoch: 1/2, Batch: 300/501, Loss: 0.12749789655208588
Epoch: 1/2, Batch: 350/501, L

  _warn_prf(average, modifier, msg_start, len(result))


Run 2/10
Epoch: 0/2, Batch: 0/501, Loss: 1.6699203252792358
Epoch: 0/2, Batch: 50/501, Loss: 0.3681097626686096
Epoch: 0/2, Batch: 100/501, Loss: 0.4129892587661743
Epoch: 0/2, Batch: 150/501, Loss: 0.1622176170349121
Epoch: 0/2, Batch: 200/501, Loss: 0.1882394254207611
Epoch: 0/2, Batch: 250/501, Loss: 0.07583526521921158
Epoch: 0/2, Batch: 300/501, Loss: 0.16277767717838287
Epoch: 0/2, Batch: 350/501, Loss: 0.08312597870826721
Epoch: 0/2, Batch: 400/501, Loss: 0.029837941750884056
Epoch: 0/2, Batch: 450/501, Loss: 0.33205267786979675
Epoch: 0/2, Batch: 500/501, Loss: 0.08451253920793533
Epoch: 1/2, Batch: 0/501, Loss: 0.07834344357252121
Epoch: 1/2, Batch: 50/501, Loss: 0.09820043295621872
Epoch: 1/2, Batch: 100/501, Loss: 0.07585856318473816
Epoch: 1/2, Batch: 150/501, Loss: 0.18355832993984222
Epoch: 1/2, Batch: 200/501, Loss: 0.0540386326611042
Epoch: 1/2, Batch: 250/501, Loss: 0.0863758847117424
Epoch: 1/2, Batch: 300/501, Loss: 0.08957051485776901
Epoch: 1/2, Batch: 350/501, Los

  _warn_prf(average, modifier, msg_start, len(result))


Run 3/10
Epoch: 0/2, Batch: 0/501, Loss: 1.4225574731826782
Epoch: 0/2, Batch: 50/501, Loss: 0.35757768154144287
Epoch: 0/2, Batch: 100/501, Loss: 0.33134925365448
Epoch: 0/2, Batch: 150/501, Loss: 0.43298694491386414
Epoch: 0/2, Batch: 200/501, Loss: 0.23885276913642883
Epoch: 0/2, Batch: 250/501, Loss: 0.06208469346165657
Epoch: 0/2, Batch: 300/501, Loss: 0.034357696771621704
Epoch: 0/2, Batch: 350/501, Loss: 0.2663525342941284
Epoch: 0/2, Batch: 400/501, Loss: 0.13552376627922058
Epoch: 0/2, Batch: 450/501, Loss: 0.17041681706905365
Epoch: 0/2, Batch: 500/501, Loss: 0.09563403576612473
Epoch: 1/2, Batch: 0/501, Loss: 0.24745702743530273
Epoch: 1/2, Batch: 50/501, Loss: 0.060177527368068695
Epoch: 1/2, Batch: 100/501, Loss: 0.2502489387989044
Epoch: 1/2, Batch: 150/501, Loss: 0.08058343827724457
Epoch: 1/2, Batch: 200/501, Loss: 0.13486836850643158
Epoch: 1/2, Batch: 250/501, Loss: 0.0709654912352562
Epoch: 1/2, Batch: 300/501, Loss: 0.10875704884529114
Epoch: 1/2, Batch: 350/501, Lo

  _warn_prf(average, modifier, msg_start, len(result))


Run 5/10
Epoch: 0/2, Batch: 0/501, Loss: 1.2073097229003906
Epoch: 0/2, Batch: 50/501, Loss: 0.09565231204032898
Epoch: 0/2, Batch: 100/501, Loss: 0.204040065407753
Epoch: 0/2, Batch: 150/501, Loss: 0.12114951014518738
Epoch: 0/2, Batch: 200/501, Loss: 0.27892807126045227
Epoch: 0/2, Batch: 250/501, Loss: 0.12061385065317154
Epoch: 0/2, Batch: 300/501, Loss: 0.23567767441272736
Epoch: 0/2, Batch: 350/501, Loss: 0.0716610923409462
Epoch: 0/2, Batch: 400/501, Loss: 0.3098052740097046
Epoch: 0/2, Batch: 450/501, Loss: 0.1737252175807953
Epoch: 0/2, Batch: 500/501, Loss: 0.19318731129169464
Epoch: 1/2, Batch: 0/501, Loss: 0.034299056977033615
Epoch: 1/2, Batch: 50/501, Loss: 0.10275238007307053
Epoch: 1/2, Batch: 100/501, Loss: 0.04211714491248131
Epoch: 1/2, Batch: 150/501, Loss: 0.09699420630931854
Epoch: 1/2, Batch: 200/501, Loss: 0.07792932540178299
Epoch: 1/2, Batch: 250/501, Loss: 0.08042173832654953
Epoch: 1/2, Batch: 300/501, Loss: 0.0717870220541954
Epoch: 1/2, Batch: 350/501, Los

  _warn_prf(average, modifier, msg_start, len(result))


Run 6/10
Epoch: 0/2, Batch: 0/501, Loss: 1.4859004020690918
Epoch: 0/2, Batch: 50/501, Loss: 0.1725979447364807
Epoch: 0/2, Batch: 100/501, Loss: 0.3848600387573242
Epoch: 0/2, Batch: 150/501, Loss: 0.24990534782409668
Epoch: 0/2, Batch: 200/501, Loss: 0.1501498520374298
Epoch: 0/2, Batch: 250/501, Loss: 0.39123716950416565
Epoch: 0/2, Batch: 300/501, Loss: 0.21294443309307098
Epoch: 0/2, Batch: 350/501, Loss: 0.05886194109916687
Epoch: 0/2, Batch: 400/501, Loss: 0.10586221516132355
Epoch: 0/2, Batch: 450/501, Loss: 0.19457115232944489
Epoch: 0/2, Batch: 500/501, Loss: 0.0652860552072525
Epoch: 1/2, Batch: 0/501, Loss: 0.14239910244941711
Epoch: 1/2, Batch: 50/501, Loss: 0.10167951881885529
Epoch: 1/2, Batch: 100/501, Loss: 0.04623642936348915
Epoch: 1/2, Batch: 150/501, Loss: 0.07714298367500305
Epoch: 1/2, Batch: 200/501, Loss: 0.06798572093248367
Epoch: 1/2, Batch: 250/501, Loss: 0.17029336094856262
Epoch: 1/2, Batch: 300/501, Loss: 0.08767244219779968
Epoch: 1/2, Batch: 350/501, Lo

  _warn_prf(average, modifier, msg_start, len(result))


Run 7/10
Epoch: 0/2, Batch: 0/501, Loss: 1.239093542098999
Epoch: 0/2, Batch: 50/501, Loss: 0.40008509159088135
Epoch: 0/2, Batch: 100/501, Loss: 0.3836880922317505
Epoch: 0/2, Batch: 150/501, Loss: 0.24102827906608582
Epoch: 0/2, Batch: 200/501, Loss: 0.21157719194889069
Epoch: 0/2, Batch: 250/501, Loss: 0.4001758396625519
Epoch: 0/2, Batch: 300/501, Loss: 0.2492435723543167
Epoch: 0/2, Batch: 350/501, Loss: 0.06420967727899551
Epoch: 0/2, Batch: 400/501, Loss: 0.10267505794763565
Epoch: 0/2, Batch: 450/501, Loss: 0.13099487125873566
Epoch: 0/2, Batch: 500/501, Loss: 0.15378797054290771
Epoch: 1/2, Batch: 0/501, Loss: 0.12170278280973434
Epoch: 1/2, Batch: 50/501, Loss: 0.016754547134041786
Epoch: 1/2, Batch: 100/501, Loss: 0.06999333202838898
Epoch: 1/2, Batch: 150/501, Loss: 0.09414666891098022
Epoch: 1/2, Batch: 200/501, Loss: 0.03757647052407265
Epoch: 1/2, Batch: 250/501, Loss: 0.05663006380200386
Epoch: 1/2, Batch: 300/501, Loss: 0.12362111359834671
Epoch: 1/2, Batch: 350/501, L

  _warn_prf(average, modifier, msg_start, len(result))


Run 8/10
Epoch: 0/2, Batch: 0/501, Loss: 1.3388924598693848
Epoch: 0/2, Batch: 50/501, Loss: 0.18256591260433197
Epoch: 0/2, Batch: 100/501, Loss: 0.31771034002304077
Epoch: 0/2, Batch: 150/501, Loss: 0.25696834921836853
Epoch: 0/2, Batch: 200/501, Loss: 0.1885557472705841
Epoch: 0/2, Batch: 250/501, Loss: 0.095547154545784
Epoch: 0/2, Batch: 300/501, Loss: 0.19578662514686584
Epoch: 0/2, Batch: 350/501, Loss: 0.09658917784690857
Epoch: 0/2, Batch: 400/501, Loss: 0.18980520963668823
Epoch: 0/2, Batch: 450/501, Loss: 0.09581949561834335
Epoch: 0/2, Batch: 500/501, Loss: 0.09768068790435791
Epoch: 1/2, Batch: 0/501, Loss: 0.18713942170143127
Epoch: 1/2, Batch: 50/501, Loss: 0.08589942008256912
Epoch: 1/2, Batch: 100/501, Loss: 0.624434769153595
Epoch: 1/2, Batch: 150/501, Loss: 0.10930667817592621
Epoch: 1/2, Batch: 200/501, Loss: 0.052737876772880554
Epoch: 1/2, Batch: 250/501, Loss: 0.038205765187740326
Epoch: 1/2, Batch: 300/501, Loss: 0.16841557621955872
Epoch: 1/2, Batch: 350/501, L

  _warn_prf(average, modifier, msg_start, len(result))


Run 9/10
Epoch: 0/2, Batch: 0/501, Loss: 1.690863847732544
Epoch: 0/2, Batch: 50/501, Loss: 0.3572140634059906
Epoch: 0/2, Batch: 100/501, Loss: 0.40639567375183105
Epoch: 0/2, Batch: 150/501, Loss: 0.27099713683128357
Epoch: 0/2, Batch: 200/501, Loss: 0.1214594617486
Epoch: 0/2, Batch: 250/501, Loss: 0.11732137948274612
Epoch: 0/2, Batch: 300/501, Loss: 0.18091920018196106
Epoch: 0/2, Batch: 350/501, Loss: 0.13125145435333252
Epoch: 0/2, Batch: 400/501, Loss: 0.2712176442146301
Epoch: 0/2, Batch: 450/501, Loss: 0.09992582350969315
Epoch: 0/2, Batch: 500/501, Loss: 0.07293074578046799
Epoch: 1/2, Batch: 0/501, Loss: 0.08985713869333267
Epoch: 1/2, Batch: 50/501, Loss: 0.11777885258197784
Epoch: 1/2, Batch: 100/501, Loss: 0.11972327530384064
Epoch: 1/2, Batch: 150/501, Loss: 0.09321782737970352
Epoch: 1/2, Batch: 200/501, Loss: 0.11947367340326309
Epoch: 1/2, Batch: 250/501, Loss: 0.051883991807699203
Epoch: 1/2, Batch: 300/501, Loss: 0.07743975520133972
Epoch: 1/2, Batch: 350/501, Loss

  _warn_prf(average, modifier, msg_start, len(result))


Run 10/10
Epoch: 0/2, Batch: 0/501, Loss: 1.3803911209106445
Epoch: 0/2, Batch: 50/501, Loss: 0.48764005303382874
Epoch: 0/2, Batch: 100/501, Loss: 0.3612918257713318
Epoch: 0/2, Batch: 150/501, Loss: 0.1417476385831833
Epoch: 0/2, Batch: 200/501, Loss: 0.20340627431869507
Epoch: 0/2, Batch: 250/501, Loss: 0.19702062010765076
Epoch: 0/2, Batch: 300/501, Loss: 0.053664423525333405
Epoch: 0/2, Batch: 350/501, Loss: 0.08173860609531403
Epoch: 0/2, Batch: 400/501, Loss: 0.15476039052009583
Epoch: 0/2, Batch: 450/501, Loss: 0.19059699773788452
Epoch: 0/2, Batch: 500/501, Loss: 0.06809026747941971
Epoch: 1/2, Batch: 0/501, Loss: 0.0657120794057846
Epoch: 1/2, Batch: 50/501, Loss: 0.1011616662144661
Epoch: 1/2, Batch: 100/501, Loss: 0.07117016613483429
Epoch: 1/2, Batch: 150/501, Loss: 0.2628721594810486
Epoch: 1/2, Batch: 200/501, Loss: 0.08059419691562653
Epoch: 1/2, Batch: 250/501, Loss: 0.09309190511703491
Epoch: 1/2, Batch: 300/501, Loss: 0.046441178768873215
Epoch: 1/2, Batch: 350/501, 

  _warn_prf(average, modifier, msg_start, len(result))


In [19]:
results

Unnamed: 0,accuracy,precision_score_micro,precision_score_macro,recall_score_micro,recall_score_macro,f1_score_micro,f1_score_macro,execution_time
0,0.977123,0.977123,0.645671,0.977123,0.649908,0.977123,0.646427,108.376989
1,0.975359,0.975359,0.608283,0.975359,0.642118,0.975359,0.624018,107.031998
2,0.973911,0.973911,0.624107,0.973911,0.616845,0.973911,0.615034,109.782082
3,0.97362,0.97362,0.611122,0.97362,0.667176,0.97362,0.634214,110.095
4,0.974552,0.974552,0.642253,0.974552,0.665266,0.974552,0.65337,107.730155
5,0.969545,0.969545,0.605571,0.969545,0.650758,0.969545,0.614738,108.889863
6,0.975014,0.975014,0.639595,0.975014,0.642169,0.975014,0.640552,109.059339
7,0.974743,0.974743,0.63914,0.974743,0.650333,0.974743,0.644279,108.987041
8,0.974452,0.974452,0.632667,0.974452,0.659407,0.974452,0.644342,108.337
9,0.976919,0.976919,0.637421,0.976919,0.652782,0.976919,0.643509,106.13


In [20]:
results.to_csv(STATS_OUTPUT)