# Import dependencies

In [1]:
import os
import sys

sys.path.insert(0, os.path.dirname(os.getcwd())) 

In [2]:
import time
import gc
import json

import numpy as np
import pandas as pd

from transformers import BertTokenizer, BertModel
from transformers import logging

from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score

from InputDataset import InputDataset

import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.nn.utils.rnn import pad_sequence

from torch import cuda

from absa_models.ABSA_BERT_Dropout_CNN_BiLSTM_Linear import ABSA_BERT_Dropout_CNN_BiLSTM_Linear

In [3]:
device = 'cuda' if cuda.is_available() else 'cpu'
logging.set_verbosity_error() 

In [4]:
print(torch.cuda.get_device_name(0))
print(f"Memory: {torch.cuda.get_device_properties(0).total_memory // 1024 ** 3} GB")

NVIDIA GeForce RTX 2060 SUPER
Memory: 8 GB


In [5]:
def clear_memory():
    torch.cuda.empty_cache()

    with torch.no_grad():
        torch.cuda.empty_cache()

    gc.collect()

# Load Data

In [6]:
DATASET = 'ABSA_MAMS_train.json'

In [7]:
df = pd.json_normalize(json.load(open(DATASET)))

In [8]:
df.head()

Unnamed: 0,text,tokens,absa_tags
0,The decor is not special at all but their food...,"[The, decor, is, not, special, at, all, but, t...","[0, 1, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 3, 0, 0, ..."
1,The decor is not special at all but their food...,"[The, decor, is, not, special, at, all, but, t...","[0, 1, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 3, 0, 0, ..."
2,The decor is not special at all but their food...,"[The, decor, is, not, special, at, all, but, t...","[0, 1, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 3, 0, 0, ..."
3,"when tables opened up, the manager sat another...","[when, tables, opened, up, ,, the, manager, sa...","[0, 2, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]"
4,"when tables opened up, the manager sat another...","[when, tables, opened, up, ,, the, manager, sa...","[0, 2, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]"


# Train & Validate

In [9]:
TRAIN_BATCH_SIZE = 4
VALID_BATCH_SIZE = 4

EPOCHS = 2

LEARNING_RATE = 1e-5

TRAIN_SPLIT = 0.8

NO_RUNS = 10

In [10]:
SEQ_LEN = 512

In [11]:
BERT_FINE_TUNED_PATH = '../../../results/ABSA/MAMS/models/bert_fine_tuned.pth'

In [12]:
MODEL_OUTPUT = '../../../results/ABSA/MAMS/models/bert_fine_tuned_dropout_cnn_bilstm_linear.pth'
STATS_OUTPUT = '../../../results/ABSA/MAMS/stats/bert_fine_tuned_dropout_cnn_bilstm_linear.csv'

In [13]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [14]:
best_accuracy = 0.0

In [15]:
def create_mini_batch(samples):
    ids_tensors = [s[1] for s in samples]
    ids_tensors[0] = torch.nn.ConstantPad1d((0, SEQ_LEN - len(ids_tensors[0])), 0)(ids_tensors[0])
    ids_tensors = pad_sequence(ids_tensors, batch_first=True).to(device)

    tags_tensors = [s[2] for s in samples]
    tags_tensors[0] = torch.nn.ConstantPad1d((0, SEQ_LEN - len(tags_tensors[0])), 0)(tags_tensors[0])
    tags_tensors = pad_sequence(tags_tensors, batch_first=True).to(device)
    
    masks_tensors = torch.zeros(ids_tensors.shape, dtype=torch.long).to(device)
    masks_tensors = masks_tensors.masked_fill(ids_tensors != 0, 1).to(device)
    
    return ids_tensors, tags_tensors, masks_tensors

In [16]:
def train(epoch, model, loss_fn, optimizer, dataloader):
    model.train()

    dataloader_len = len(dataloader)

    for _,data in enumerate(dataloader, 0):
        optimizer.zero_grad()

        ids_tensors, tags_tensors, masks_tensors = data

        outputs = model(ids_tensors, masks_tensors)

        loss = loss_fn(outputs.view(-1, 4), tags_tensors.view(-1))
        
        if _ % (dataloader_len // 10) == 0:
            print(f"Epoch: {epoch}/{EPOCHS}, Batch: {_}/{dataloader_len}, Loss: {loss.item()}")
        
        loss.backward()
        
        optimizer.step()

In [17]:
def validation(model, dataloader):
    model.eval()
    
    fin_targets=[]
    fin_outputs=[]

    with torch.no_grad():
        for _, data in enumerate(dataloader, 0):
            ids_tensors, tags_tensors, masks_tensors = data
            ids_tensors = ids_tensors.to(device)
            tags_tensors = tags_tensors.to(device)
            masks_tensors = masks_tensors.to(device)

            outputs = model(ids_tensors, masks_tensors)
            
            _, predictions = torch.max(outputs, dim=2)

            fin_outputs += list([int(p) for pred in predictions for p in pred])
            fin_targets += list([int(tag) for tags_tensor in tags_tensors for tag in tags_tensor])

    return fin_outputs, fin_targets

In [18]:
results = pd.DataFrame(columns=['accuracy','precision_score_micro','precision_score_macro','recall_score_micro','recall_score_macro','f1_score_micro','f1_score_macro', 'execution_time'])

In [19]:
for i in range(NO_RUNS):
    # clear cache cuda
    torch.cuda.empty_cache()
    with torch.no_grad():
        torch.cuda.empty_cache()
    gc.collect()

    start_time = time.time()

    print(f"Run {i + 1}/{NO_RUNS}")

    train_dataset = df.sample(frac=TRAIN_SPLIT)
    test_dataset = df.drop(train_dataset.index).reset_index(drop=True)
    train_dataset = train_dataset.reset_index(drop=True)

    training_set = InputDataset(train_dataset, tokenizer)
    testing_set = InputDataset(test_dataset, tokenizer)

    train_dataloader = DataLoader(
        training_set,
        sampler = RandomSampler(train_dataset),
        batch_size = TRAIN_BATCH_SIZE,
        drop_last = True,
        collate_fn=create_mini_batch
    )

    validation_dataloader = DataLoader(
        testing_set,
        sampler = SequentialSampler(testing_set),
        batch_size = VALID_BATCH_SIZE,
        drop_last = True,
        collate_fn=create_mini_batch
    )

    model = ABSA_BERT_Dropout_CNN_BiLSTM_Linear(torch.load(BERT_FINE_TUNED_PATH), bert_seq_len=SEQ_LEN, dropout=0.3, bilstm_in_features=256, conv_out_channels=SEQ_LEN, conv_kernel_size=3, no_out_labels=4, device=device).to(device)

    optimizer = torch.optim.Adam(params = model.parameters(), lr=LEARNING_RATE)
    loss_fn = torch.nn.CrossEntropyLoss()

    for epoch in range(EPOCHS):
        train(epoch, model, loss_fn, optimizer, train_dataloader)

    outputs, targets = validation(model, validation_dataloader)
    
    accuracy = accuracy_score(targets, outputs)
    precision_score_micro = precision_score(targets, outputs, average='micro')
    precision_score_macro = precision_score(targets, outputs, average='macro')
    recall_score_micro = recall_score(targets, outputs, average='micro')
    recall_score_macro = recall_score(targets, outputs, average='macro')
    f1_score_micro = f1_score(targets, outputs, average='micro')
    f1_score_macro = f1_score(targets, outputs, average='macro')

    execution_time = time.time() - start_time

    results.loc[i] = [accuracy,precision_score_micro,precision_score_macro,recall_score_micro,recall_score_macro,f1_score_micro,f1_score_macro, execution_time]

    if accuracy > best_accuracy:
        best_accuracy = accuracy
        torch.save(model, MODEL_OUTPUT)

    del train_dataset
    del test_dataset
    del training_set
    del testing_set
    del model
    del loss_fn
    del optimizer
    del outputs
    del targets

Run 1/10
Epoch: 0/2, Batch: 0/2237, Loss: 1.337424874305725
Epoch: 0/2, Batch: 223/2237, Loss: 0.06588121503591537
Epoch: 0/2, Batch: 446/2237, Loss: 0.04964529350399971
Epoch: 0/2, Batch: 669/2237, Loss: 0.046402085572481155
Epoch: 0/2, Batch: 892/2237, Loss: 0.045496005564928055
Epoch: 0/2, Batch: 1115/2237, Loss: 0.03715670853853226
Epoch: 0/2, Batch: 1338/2237, Loss: 0.026098119094967842
Epoch: 0/2, Batch: 1561/2237, Loss: 0.04837885499000549
Epoch: 0/2, Batch: 1784/2237, Loss: 0.04134152829647064
Epoch: 0/2, Batch: 2007/2237, Loss: 0.03735505789518356
Epoch: 0/2, Batch: 2230/2237, Loss: 0.029250850901007652
Epoch: 1/2, Batch: 0/2237, Loss: 0.02179548144340515
Epoch: 1/2, Batch: 223/2237, Loss: 0.019059281796216965
Epoch: 1/2, Batch: 446/2237, Loss: 0.03356434404850006
Epoch: 1/2, Batch: 669/2237, Loss: 0.023434842005372047
Epoch: 1/2, Batch: 892/2237, Loss: 0.016522405669093132
Epoch: 1/2, Batch: 1115/2237, Loss: 0.027374884113669395
Epoch: 1/2, Batch: 1338/2237, Loss: 0.020623428

  _warn_prf(average, modifier, msg_start, len(result))


Run 2/10
Epoch: 0/2, Batch: 0/2237, Loss: 1.3903459310531616
Epoch: 0/2, Batch: 223/2237, Loss: 0.05805936083197594
Epoch: 0/2, Batch: 446/2237, Loss: 0.04446237161755562
Epoch: 0/2, Batch: 669/2237, Loss: 0.035281069576740265
Epoch: 0/2, Batch: 892/2237, Loss: 0.030495591461658478
Epoch: 0/2, Batch: 1115/2237, Loss: 0.02675103023648262
Epoch: 0/2, Batch: 1338/2237, Loss: 0.030400747433304787
Epoch: 0/2, Batch: 1561/2237, Loss: 0.024205323308706284
Epoch: 0/2, Batch: 1784/2237, Loss: 0.03527204692363739
Epoch: 0/2, Batch: 2007/2237, Loss: 0.024949975311756134
Epoch: 0/2, Batch: 2230/2237, Loss: 0.02465018257498741
Epoch: 1/2, Batch: 0/2237, Loss: 0.047750718891620636
Epoch: 1/2, Batch: 223/2237, Loss: 0.01827072724699974
Epoch: 1/2, Batch: 446/2237, Loss: 0.022552605718374252
Epoch: 1/2, Batch: 669/2237, Loss: 0.027882158756256104
Epoch: 1/2, Batch: 892/2237, Loss: 0.027007728815078735
Epoch: 1/2, Batch: 1115/2237, Loss: 0.021534988656640053
Epoch: 1/2, Batch: 1338/2237, Loss: 0.014743

  _warn_prf(average, modifier, msg_start, len(result))


Run 3/10
Epoch: 0/2, Batch: 0/2237, Loss: 1.4443761110305786
Epoch: 0/2, Batch: 223/2237, Loss: 0.07694996893405914
Epoch: 0/2, Batch: 446/2237, Loss: 0.0510706827044487
Epoch: 0/2, Batch: 669/2237, Loss: 0.04272520914673805
Epoch: 0/2, Batch: 892/2237, Loss: 0.03513428568840027
Epoch: 0/2, Batch: 1115/2237, Loss: 0.04658816382288933
Epoch: 0/2, Batch: 1338/2237, Loss: 0.05464542284607887
Epoch: 0/2, Batch: 1561/2237, Loss: 0.04031216353178024
Epoch: 0/2, Batch: 1784/2237, Loss: 0.028560126200318336
Epoch: 0/2, Batch: 2007/2237, Loss: 0.06860407441854477
Epoch: 0/2, Batch: 2230/2237, Loss: 0.02433713525533676
Epoch: 1/2, Batch: 0/2237, Loss: 0.033887751400470734
Epoch: 1/2, Batch: 223/2237, Loss: 0.039969105273485184
Epoch: 1/2, Batch: 446/2237, Loss: 0.02108311839401722
Epoch: 1/2, Batch: 669/2237, Loss: 0.030116135254502296
Epoch: 1/2, Batch: 892/2237, Loss: 0.036938607692718506
Epoch: 1/2, Batch: 1115/2237, Loss: 0.02099326252937317
Epoch: 1/2, Batch: 1338/2237, Loss: 0.030645368620

  _warn_prf(average, modifier, msg_start, len(result))


Run 4/10
Epoch: 0/2, Batch: 0/2237, Loss: 1.3898239135742188
Epoch: 0/2, Batch: 223/2237, Loss: 0.057664304971694946
Epoch: 0/2, Batch: 446/2237, Loss: 0.0434998981654644
Epoch: 0/2, Batch: 669/2237, Loss: 0.03203262388706207
Epoch: 0/2, Batch: 892/2237, Loss: 0.06295375525951385
Epoch: 0/2, Batch: 1115/2237, Loss: 0.024290993809700012
Epoch: 0/2, Batch: 1338/2237, Loss: 0.03382594510912895
Epoch: 0/2, Batch: 1561/2237, Loss: 0.030079802498221397
Epoch: 0/2, Batch: 1784/2237, Loss: 0.04680955410003662
Epoch: 0/2, Batch: 2007/2237, Loss: 0.02906038984656334
Epoch: 0/2, Batch: 2230/2237, Loss: 0.032683663070201874
Epoch: 1/2, Batch: 0/2237, Loss: 0.019945180043578148
Epoch: 1/2, Batch: 223/2237, Loss: 0.029746176674962044
Epoch: 1/2, Batch: 446/2237, Loss: 0.03460461646318436
Epoch: 1/2, Batch: 669/2237, Loss: 0.015263566747307777
Epoch: 1/2, Batch: 892/2237, Loss: 0.034105055034160614
Epoch: 1/2, Batch: 1115/2237, Loss: 0.02013171836733818
Epoch: 1/2, Batch: 1338/2237, Loss: 0.016732679

  _warn_prf(average, modifier, msg_start, len(result))


Run 5/10
Epoch: 0/2, Batch: 0/2237, Loss: 1.368764042854309
Epoch: 0/2, Batch: 223/2237, Loss: 0.10554052889347076
Epoch: 0/2, Batch: 446/2237, Loss: 0.10064486414194107
Epoch: 0/2, Batch: 669/2237, Loss: 0.0597253255546093
Epoch: 0/2, Batch: 892/2237, Loss: 0.04104592651128769
Epoch: 0/2, Batch: 1115/2237, Loss: 0.044972144067287445
Epoch: 0/2, Batch: 1338/2237, Loss: 0.04261232167482376
Epoch: 0/2, Batch: 1561/2237, Loss: 0.041253041476011276
Epoch: 0/2, Batch: 1784/2237, Loss: 0.036957018077373505
Epoch: 0/2, Batch: 2007/2237, Loss: 0.024608302861452103
Epoch: 0/2, Batch: 2230/2237, Loss: 0.029533429071307182
Epoch: 1/2, Batch: 0/2237, Loss: 0.03254769742488861
Epoch: 1/2, Batch: 223/2237, Loss: 0.028459109365940094
Epoch: 1/2, Batch: 446/2237, Loss: 0.030201751738786697
Epoch: 1/2, Batch: 669/2237, Loss: 0.028992995619773865
Epoch: 1/2, Batch: 892/2237, Loss: 0.017882371321320534
Epoch: 1/2, Batch: 1115/2237, Loss: 0.024692241102457047
Epoch: 1/2, Batch: 1338/2237, Loss: 0.03785777

  _warn_prf(average, modifier, msg_start, len(result))


Run 8/10
Epoch: 0/2, Batch: 0/2237, Loss: 1.3674522638320923
Epoch: 0/2, Batch: 223/2237, Loss: 0.06595898419618607
Epoch: 0/2, Batch: 446/2237, Loss: 0.06338416039943695
Epoch: 0/2, Batch: 669/2237, Loss: 0.042411599308252335
Epoch: 0/2, Batch: 892/2237, Loss: 0.03782764822244644
Epoch: 0/2, Batch: 1115/2237, Loss: 0.03883666172623634
Epoch: 0/2, Batch: 1338/2237, Loss: 0.022984547540545464
Epoch: 0/2, Batch: 1561/2237, Loss: 0.03253750503063202
Epoch: 0/2, Batch: 1784/2237, Loss: 0.03576252982020378
Epoch: 0/2, Batch: 2007/2237, Loss: 0.020682288333773613
Epoch: 0/2, Batch: 2230/2237, Loss: 0.03230959177017212
Epoch: 1/2, Batch: 0/2237, Loss: 0.036385778337717056
Epoch: 1/2, Batch: 223/2237, Loss: 0.030680935829877853
Epoch: 1/2, Batch: 446/2237, Loss: 0.024990489706397057
Epoch: 1/2, Batch: 669/2237, Loss: 0.022984640672802925
Epoch: 1/2, Batch: 892/2237, Loss: 0.015107674524188042
Epoch: 1/2, Batch: 1115/2237, Loss: 0.03476819023489952
Epoch: 1/2, Batch: 1338/2237, Loss: 0.01251899

  _warn_prf(average, modifier, msg_start, len(result))


Run 9/10
Epoch: 0/2, Batch: 0/2237, Loss: 1.3470470905303955
Epoch: 0/2, Batch: 223/2237, Loss: 0.06796323508024216
Epoch: 0/2, Batch: 446/2237, Loss: 0.04075315222144127
Epoch: 0/2, Batch: 669/2237, Loss: 0.04358958825469017
Epoch: 0/2, Batch: 892/2237, Loss: 0.03260678052902222
Epoch: 0/2, Batch: 1115/2237, Loss: 0.030927062034606934
Epoch: 0/2, Batch: 1338/2237, Loss: 0.04044903814792633
Epoch: 0/2, Batch: 1561/2237, Loss: 0.0471179224550724
Epoch: 0/2, Batch: 1784/2237, Loss: 0.035143181681632996
Epoch: 0/2, Batch: 2007/2237, Loss: 0.059345196932554245
Epoch: 0/2, Batch: 2230/2237, Loss: 0.03294423595070839
Epoch: 1/2, Batch: 0/2237, Loss: 0.020972566679120064
Epoch: 1/2, Batch: 223/2237, Loss: 0.012971581891179085
Epoch: 1/2, Batch: 446/2237, Loss: 0.02714782953262329
Epoch: 1/2, Batch: 669/2237, Loss: 0.02948751114308834
Epoch: 1/2, Batch: 892/2237, Loss: 0.025712875649333
Epoch: 1/2, Batch: 1115/2237, Loss: 0.01853383705019951
Epoch: 1/2, Batch: 1338/2237, Loss: 0.02030803635716

  _warn_prf(average, modifier, msg_start, len(result))


Run 10/10
Epoch: 0/2, Batch: 0/2237, Loss: 1.3848083019256592
Epoch: 0/2, Batch: 223/2237, Loss: 0.06975953280925751
Epoch: 0/2, Batch: 446/2237, Loss: 0.03455674275755882
Epoch: 0/2, Batch: 669/2237, Loss: 0.044937100261449814
Epoch: 0/2, Batch: 892/2237, Loss: 0.046358175575733185
Epoch: 0/2, Batch: 1115/2237, Loss: 0.03617854416370392
Epoch: 0/2, Batch: 1338/2237, Loss: 0.02188854292035103
Epoch: 0/2, Batch: 1561/2237, Loss: 0.04947226867079735
Epoch: 0/2, Batch: 1784/2237, Loss: 0.022064607590436935
Epoch: 0/2, Batch: 2007/2237, Loss: 0.04551277309656143
Epoch: 0/2, Batch: 2230/2237, Loss: 0.03140935301780701
Epoch: 1/2, Batch: 0/2237, Loss: 0.02447187528014183
Epoch: 1/2, Batch: 223/2237, Loss: 0.02643175609409809
Epoch: 1/2, Batch: 446/2237, Loss: 0.02746272087097168
Epoch: 1/2, Batch: 669/2237, Loss: 0.023923154920339584
Epoch: 1/2, Batch: 892/2237, Loss: 0.028636373579502106
Epoch: 1/2, Batch: 1115/2237, Loss: 0.02193189226090908
Epoch: 1/2, Batch: 1338/2237, Loss: 0.0199341643

  _warn_prf(average, modifier, msg_start, len(result))


In [20]:
results

Unnamed: 0,accuracy,precision_score_micro,precision_score_macro,recall_score_micro,recall_score_macro,f1_score_micro,f1_score_macro,execution_time
0,0.992356,0.992356,0.446461,0.992356,0.428879,0.992356,0.42607,1682.6967
1,0.992493,0.992493,0.45686,0.992493,0.429427,0.992493,0.435601,1587.31622
2,0.992098,0.992098,0.435759,0.992098,0.426446,0.992098,0.423515,1586.886021
3,0.992632,0.992632,0.454012,0.992632,0.455001,0.992632,0.442801,1589.593208
4,0.991207,0.991207,0.406379,0.991207,0.30043,0.991207,0.317384,1590.068523
5,0.991474,0.991474,0.465261,0.991474,0.312875,0.991474,0.346455,1588.66503
6,0.992092,0.992092,0.462284,0.992092,0.419814,0.992092,0.435875,1589.307721
7,0.992484,0.992484,0.462541,0.992484,0.426954,0.992484,0.438015,1589.917121
8,0.99209,0.99209,0.432031,0.99209,0.436874,0.99209,0.411031,1588.980866
9,0.992455,0.992455,0.447269,0.992455,0.450963,0.992455,0.442479,1588.134304


In [21]:
results.to_csv(STATS_OUTPUT)