# Import dependencies

In [None]:
import os
import sys

sys.path.insert(0, os.path.dirname(os.getcwd())) 

In [2]:
import time
import gc
import json

import numpy as np
import pandas as pd

from transformers import BertTokenizer, BertModel
from transformers import logging

from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score

from InputDataset import InputDataset

import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.nn.utils.rnn import pad_sequence

import matplotlib.pyplot as plt

from torch import cuda

from absa_models.ABSA_BERT_Dropout_CNN_BiLSTM_Linear import ABSA_BERT_Dropout_CNN_BiLSTM_Linear

In [3]:
device = 'cuda' if cuda.is_available() else 'cpu'
logging.set_verbosity_error() 

In [4]:
print(torch.cuda.get_device_name(0))
print(f"Memory: {torch.cuda.get_device_properties(0).total_memory // 1024 ** 3} GB")

NVIDIA GeForce RTX 2060 SUPER
Memory: 8 GB


In [5]:
def clear_memory():
    torch.cuda.empty_cache()

    with torch.no_grad():
        torch.cuda.empty_cache()

    gc.collect()

# Load Data

In [6]:
DATASET = 'ABSA_MAMS_train.json'

In [7]:
df = pd.json_normalize(json.load(open(DATASET)))

In [8]:
df.head()

Unnamed: 0,text,tokens,absa_tags
0,The decor is not special at all but their food...,"[The, decor, is, not, special, at, all, but, t...","[0, 1, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 3, 0, 0, ..."
1,The decor is not special at all but their food...,"[The, decor, is, not, special, at, all, but, t...","[0, 1, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 3, 0, 0, ..."
2,The decor is not special at all but their food...,"[The, decor, is, not, special, at, all, but, t...","[0, 1, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 3, 0, 0, ..."
3,"when tables opened up, the manager sat another...","[when, tables, opened, up, ,, the, manager, sa...","[0, 2, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]"
4,"when tables opened up, the manager sat another...","[when, tables, opened, up, ,, the, manager, sa...","[0, 2, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]"


# Train & Validate

In [9]:
TRAIN_BATCH_SIZE = 4
VALID_BATCH_SIZE = 4

EPOCHS = 2

LEARNING_RATE = 1e-5

TRAIN_SPLIT = 0.8

NO_RUNS = 10

In [10]:
SEQ_LEN = 512

In [11]:
BERT_FINE_TUNED_PATH = '../../../results/ABSA/MAMS/models/bert_fine_tuned_512.pth'

In [12]:
MODEL_OUTPUT = '../../../results/ABSA/MAMS/models/bert_fine_tuned_dropout_cnn_bilstm_linear_512.pth'
STATS_OUTPUT = '../../../results/ABSA/MAMS/stats/bert_fine_tuned_dropout_cnn_bilstm_linear_512.csv'

In [13]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [14]:
best_accuracy = 0.0

In [15]:
def create_mini_batch(samples):
    ids_tensors = [s[1] for s in samples]
    ids_tensors[0] = torch.nn.ConstantPad1d((0, SEQ_LEN - len(ids_tensors[0])), 0)(ids_tensors[0])
    ids_tensors = pad_sequence(ids_tensors, batch_first=True).to(device)

    tags_tensors = [s[2] for s in samples]
    tags_tensors[0] = torch.nn.ConstantPad1d((0, SEQ_LEN - len(tags_tensors[0])), 0)(tags_tensors[0])
    tags_tensors = pad_sequence(tags_tensors, batch_first=True).to(device)
    
    masks_tensors = torch.zeros(ids_tensors.shape, dtype=torch.long).to(device)
    masks_tensors = masks_tensors.masked_fill(ids_tensors != 0, 1).to(device)
    
    return ids_tensors, tags_tensors, masks_tensors

In [16]:
def train(epoch, model, loss_fn, optimizer, dataloader):
    model.train()

    dataloader_len = len(dataloader)

    losses = []

    for _,data in enumerate(dataloader, 0):
        optimizer.zero_grad()

        ids_tensors, tags_tensors, masks_tensors = data

        outputs = model(ids_tensors, masks_tensors)

        loss = loss_fn(outputs.view(-1, 4), tags_tensors.view(-1))

        losses.append(loss.item())
        
        if _ % (dataloader_len // 10) == 0:
            print(f"Epoch: {epoch}/{EPOCHS}, Batch: {_}/{dataloader_len}, Loss: {loss.item()}")
        
        loss.backward()
        
        optimizer.step()
    
    return losses

In [17]:
def validation(model, dataloader):
    model.eval()
    
    fin_targets=[]
    fin_outputs=[]

    with torch.no_grad():
        for _, data in enumerate(dataloader, 0):
            ids_tensors, tags_tensors, masks_tensors = data
            ids_tensors = ids_tensors.to(device)
            tags_tensors = tags_tensors.to(device)
            masks_tensors = masks_tensors.to(device)

            outputs = model(ids_tensors, masks_tensors)
            
            _, predictions = torch.max(outputs, dim=2)

            fin_outputs += list([int(p) for pred in predictions for p in pred])
            fin_targets += list([int(tag) for tags_tensor in tags_tensors for tag in tags_tensor])

    return fin_outputs, fin_targets

In [18]:
results = pd.DataFrame(columns=['accuracy','precision_score_micro','precision_score_macro','recall_score_micro','recall_score_macro','f1_score_micro','f1_score_macro', 'execution_time'])

In [19]:
for i in range(NO_RUNS):
    # clear cache cuda
    torch.cuda.empty_cache()
    with torch.no_grad():
        torch.cuda.empty_cache()
    gc.collect()

    start_time = time.time()

    print(f"Run {i + 1}/{NO_RUNS}")

    train_dataset = df.sample(frac=TRAIN_SPLIT)
    test_dataset = df.drop(train_dataset.index).reset_index(drop=True)
    train_dataset = train_dataset.reset_index(drop=True)

    training_set = InputDataset(train_dataset, tokenizer)
    testing_set = InputDataset(test_dataset, tokenizer)

    train_dataloader = DataLoader(
        training_set,
        sampler = RandomSampler(train_dataset),
        batch_size = TRAIN_BATCH_SIZE,
        drop_last = True,
        collate_fn=create_mini_batch
    )

    validation_dataloader = DataLoader(
        testing_set,
        sampler = SequentialSampler(testing_set),
        batch_size = VALID_BATCH_SIZE,
        drop_last = True,
        collate_fn=create_mini_batch
    )

    model = ABSA_BERT_Dropout_CNN_BiLSTM_Linear(torch.load(BERT_FINE_TUNED_PATH), bert_seq_len=SEQ_LEN, dropout=0.3, bilstm_in_features=256, conv_out_channels=SEQ_LEN, conv_kernel_size=3, no_out_labels=4, device=device).to(device)

    optimizer = torch.optim.Adam(params = model.parameters(), lr=LEARNING_RATE)
    loss_fn = torch.nn.CrossEntropyLoss()

    train_losses = []

    for epoch in range(EPOCHS):
        losses = train(epoch, model, loss_fn, optimizer, train_dataloader)

        train_losses += losses
    
    plt.title(f'Train Loss for run {i + 1}/{NO_RUNS}')
    plt.plot(train_losses)
    plt.savefig(f'../../../results/ABSA/MAMS/plots/bert_ft_do_cnn_bilstm_lin/train_loss_run_{i + 1}.png')

    plt.clf()

    outputs, targets = validation(model, validation_dataloader)
    
    accuracy = accuracy_score(targets, outputs)
    precision_score_micro = precision_score(targets, outputs, average='micro')
    precision_score_macro = precision_score(targets, outputs, average='macro')
    recall_score_micro = recall_score(targets, outputs, average='micro')
    recall_score_macro = recall_score(targets, outputs, average='macro')
    f1_score_micro = f1_score(targets, outputs, average='micro')
    f1_score_macro = f1_score(targets, outputs, average='macro')

    execution_time = time.time() - start_time

    results.loc[i] = [accuracy,precision_score_micro,precision_score_macro,recall_score_micro,recall_score_macro,f1_score_micro,f1_score_macro, execution_time]

    if accuracy > best_accuracy:
        best_accuracy = accuracy
        torch.save(model, MODEL_OUTPUT)

    del train_dataset
    del test_dataset
    del training_set
    del testing_set
    del model
    del loss_fn
    del optimizer
    del outputs
    del targets

Run 1/10
Epoch: 0/2, Batch: 0/2237, Loss: 1.3803099393844604
Epoch: 0/2, Batch: 223/2237, Loss: 0.0768231526017189
Epoch: 0/2, Batch: 446/2237, Loss: 0.03971938416361809
Epoch: 0/2, Batch: 669/2237, Loss: 0.04627726599574089
Epoch: 0/2, Batch: 892/2237, Loss: 0.06225328892469406
Epoch: 0/2, Batch: 1115/2237, Loss: 0.058220360428094864
Epoch: 0/2, Batch: 1338/2237, Loss: 0.04164021834731102
Epoch: 0/2, Batch: 1561/2237, Loss: 0.02869509346783161
Epoch: 0/2, Batch: 1784/2237, Loss: 0.025328686460852623
Epoch: 0/2, Batch: 2007/2237, Loss: 0.03221874311566353
Epoch: 0/2, Batch: 2230/2237, Loss: 0.01737811416387558
Epoch: 1/2, Batch: 0/2237, Loss: 0.04337086156010628
Epoch: 1/2, Batch: 223/2237, Loss: 0.022622648626565933
Epoch: 1/2, Batch: 446/2237, Loss: 0.024187779054045677
Epoch: 1/2, Batch: 669/2237, Loss: 0.0333479642868042
Epoch: 1/2, Batch: 892/2237, Loss: 0.02607640065252781
Epoch: 1/2, Batch: 1115/2237, Loss: 0.021782856434583664
Epoch: 1/2, Batch: 1338/2237, Loss: 0.0195704530924

  _warn_prf(average, modifier, msg_start, len(result))


Run 3/10
Epoch: 0/2, Batch: 0/2237, Loss: 1.4402340650558472
Epoch: 0/2, Batch: 223/2237, Loss: 0.1319902241230011
Epoch: 0/2, Batch: 446/2237, Loss: 0.07094556093215942
Epoch: 0/2, Batch: 669/2237, Loss: 0.042434707283973694
Epoch: 0/2, Batch: 892/2237, Loss: 0.03480217605829239
Epoch: 0/2, Batch: 1115/2237, Loss: 0.0378134623169899
Epoch: 0/2, Batch: 1338/2237, Loss: 0.045462146401405334
Epoch: 0/2, Batch: 1561/2237, Loss: 0.0371064729988575
Epoch: 0/2, Batch: 1784/2237, Loss: 0.036315858364105225
Epoch: 0/2, Batch: 2007/2237, Loss: 0.032301656901836395
Epoch: 0/2, Batch: 2230/2237, Loss: 0.02697656862437725
Epoch: 1/2, Batch: 0/2237, Loss: 0.036575715988874435
Epoch: 1/2, Batch: 223/2237, Loss: 0.033089734613895416
Epoch: 1/2, Batch: 446/2237, Loss: 0.02898130565881729
Epoch: 1/2, Batch: 669/2237, Loss: 0.027515701949596405
Epoch: 1/2, Batch: 892/2237, Loss: 0.020675258710980415
Epoch: 1/2, Batch: 1115/2237, Loss: 0.023575514554977417
Epoch: 1/2, Batch: 1338/2237, Loss: 0.0208939220

  _warn_prf(average, modifier, msg_start, len(result))


Run 4/10
Epoch: 0/2, Batch: 0/2237, Loss: 1.4142146110534668
Epoch: 0/2, Batch: 223/2237, Loss: 0.06044238805770874
Epoch: 0/2, Batch: 446/2237, Loss: 0.041091401129961014
Epoch: 0/2, Batch: 669/2237, Loss: 0.03636251762509346
Epoch: 0/2, Batch: 892/2237, Loss: 0.028049683198332787
Epoch: 0/2, Batch: 1115/2237, Loss: 0.03447793424129486
Epoch: 0/2, Batch: 1338/2237, Loss: 0.02237161062657833
Epoch: 0/2, Batch: 1561/2237, Loss: 0.02564460225403309
Epoch: 0/2, Batch: 1784/2237, Loss: 0.02102893404662609
Epoch: 0/2, Batch: 2007/2237, Loss: 0.01863938197493553
Epoch: 0/2, Batch: 2230/2237, Loss: 0.02164643630385399
Epoch: 1/2, Batch: 0/2237, Loss: 0.019479334354400635
Epoch: 1/2, Batch: 223/2237, Loss: 0.031192444264888763
Epoch: 1/2, Batch: 446/2237, Loss: 0.01534025464206934
Epoch: 1/2, Batch: 669/2237, Loss: 0.02649114839732647
Epoch: 1/2, Batch: 892/2237, Loss: 0.014657991006970406
Epoch: 1/2, Batch: 1115/2237, Loss: 0.01967111974954605
Epoch: 1/2, Batch: 1338/2237, Loss: 0.03261192142

  _warn_prf(average, modifier, msg_start, len(result))


Run 5/10
Epoch: 0/2, Batch: 0/2237, Loss: 1.340345859527588
Epoch: 0/2, Batch: 223/2237, Loss: 0.0587984062731266
Epoch: 0/2, Batch: 446/2237, Loss: 0.03605768829584122
Epoch: 0/2, Batch: 669/2237, Loss: 0.04394996166229248
Epoch: 0/2, Batch: 892/2237, Loss: 0.035256776958703995
Epoch: 0/2, Batch: 1115/2237, Loss: 0.04529440402984619
Epoch: 0/2, Batch: 1338/2237, Loss: 0.05555619299411774
Epoch: 0/2, Batch: 1561/2237, Loss: 0.029902007430791855
Epoch: 0/2, Batch: 1784/2237, Loss: 0.02429429441690445
Epoch: 0/2, Batch: 2007/2237, Loss: 0.0500834584236145
Epoch: 0/2, Batch: 2230/2237, Loss: 0.02486586570739746
Epoch: 1/2, Batch: 0/2237, Loss: 0.06501729786396027
Epoch: 1/2, Batch: 223/2237, Loss: 0.026982784271240234
Epoch: 1/2, Batch: 446/2237, Loss: 0.02136586420238018
Epoch: 1/2, Batch: 669/2237, Loss: 0.034475553780794144
Epoch: 1/2, Batch: 892/2237, Loss: 0.028757765889167786
Epoch: 1/2, Batch: 1115/2237, Loss: 0.027234189212322235
Epoch: 1/2, Batch: 1338/2237, Loss: 0.0173058286309

  _warn_prf(average, modifier, msg_start, len(result))


Run 6/10
Epoch: 0/2, Batch: 0/2237, Loss: 1.4504588842391968
Epoch: 0/2, Batch: 223/2237, Loss: 0.08682814985513687
Epoch: 0/2, Batch: 446/2237, Loss: 0.03825804963707924
Epoch: 0/2, Batch: 669/2237, Loss: 0.03280135244131088
Epoch: 0/2, Batch: 892/2237, Loss: 0.02680203504860401
Epoch: 0/2, Batch: 1115/2237, Loss: 0.04336237162351608
Epoch: 0/2, Batch: 1338/2237, Loss: 0.03448481112718582
Epoch: 0/2, Batch: 1561/2237, Loss: 0.07036522030830383
Epoch: 0/2, Batch: 1784/2237, Loss: 0.022621147334575653
Epoch: 0/2, Batch: 2007/2237, Loss: 0.01624879240989685
Epoch: 0/2, Batch: 2230/2237, Loss: 0.021233398467302322
Epoch: 1/2, Batch: 0/2237, Loss: 0.034653156995773315
Epoch: 1/2, Batch: 223/2237, Loss: 0.027386115863919258
Epoch: 1/2, Batch: 446/2237, Loss: 0.015081459656357765
Epoch: 1/2, Batch: 669/2237, Loss: 0.04100402817130089
Epoch: 1/2, Batch: 892/2237, Loss: 0.03774920105934143
Epoch: 1/2, Batch: 1115/2237, Loss: 0.03357301279902458
Epoch: 1/2, Batch: 1338/2237, Loss: 0.02277744375

  _warn_prf(average, modifier, msg_start, len(result))


Run 8/10
Epoch: 0/2, Batch: 0/2237, Loss: 1.3732680082321167
Epoch: 0/2, Batch: 223/2237, Loss: 0.06076937913894653
Epoch: 0/2, Batch: 446/2237, Loss: 0.04994494095444679
Epoch: 0/2, Batch: 669/2237, Loss: 0.03667077422142029
Epoch: 0/2, Batch: 892/2237, Loss: 0.04617001488804817
Epoch: 0/2, Batch: 1115/2237, Loss: 0.03946513310074806
Epoch: 0/2, Batch: 1338/2237, Loss: 0.02843860164284706
Epoch: 0/2, Batch: 1561/2237, Loss: 0.02297278307378292
Epoch: 0/2, Batch: 1784/2237, Loss: 0.027824850752949715
Epoch: 0/2, Batch: 2007/2237, Loss: 0.048174191266298294
Epoch: 0/2, Batch: 2230/2237, Loss: 0.02388974279165268
Epoch: 1/2, Batch: 0/2237, Loss: 0.027286406606435776
Epoch: 1/2, Batch: 223/2237, Loss: 0.0478542223572731
Epoch: 1/2, Batch: 446/2237, Loss: 0.0282632764428854
Epoch: 1/2, Batch: 669/2237, Loss: 0.04132716730237007
Epoch: 1/2, Batch: 892/2237, Loss: 0.016588812693953514
Epoch: 1/2, Batch: 1115/2237, Loss: 0.024623656645417213
Epoch: 1/2, Batch: 1338/2237, Loss: 0.0216680876910

  _warn_prf(average, modifier, msg_start, len(result))


Run 9/10
Epoch: 0/2, Batch: 0/2237, Loss: 1.3801532983779907
Epoch: 0/2, Batch: 223/2237, Loss: 0.09290043264627457
Epoch: 0/2, Batch: 446/2237, Loss: 0.039691053330898285
Epoch: 0/2, Batch: 669/2237, Loss: 0.05040424317121506
Epoch: 0/2, Batch: 892/2237, Loss: 0.03913619741797447
Epoch: 0/2, Batch: 1115/2237, Loss: 0.034096263349056244
Epoch: 0/2, Batch: 1338/2237, Loss: 0.032583002001047134
Epoch: 0/2, Batch: 1561/2237, Loss: 0.023557696491479874
Epoch: 0/2, Batch: 1784/2237, Loss: 0.031479932367801666
Epoch: 0/2, Batch: 2007/2237, Loss: 0.04015032947063446
Epoch: 0/2, Batch: 2230/2237, Loss: 0.02951274998486042
Epoch: 1/2, Batch: 0/2237, Loss: 0.02226678468286991
Epoch: 1/2, Batch: 223/2237, Loss: 0.022654790431261063
Epoch: 1/2, Batch: 446/2237, Loss: 0.03305863216519356
Epoch: 1/2, Batch: 669/2237, Loss: 0.024952178820967674
Epoch: 1/2, Batch: 892/2237, Loss: 0.027399688959121704
Epoch: 1/2, Batch: 1115/2237, Loss: 0.024063821882009506
Epoch: 1/2, Batch: 1338/2237, Loss: 0.0281659

<Figure size 432x288 with 0 Axes>

In [20]:
results

Unnamed: 0,accuracy,precision_score_micro,precision_score_macro,recall_score_micro,recall_score_macro,f1_score_micro,f1_score_macro,execution_time
0,0.992334,0.992334,0.588042,0.992334,0.439498,0.992334,0.44871,1616.284519
1,0.991831,0.991831,0.436373,0.991831,0.420946,0.991831,0.398167,1610.596591
2,0.9918,0.9918,0.472917,0.9918,0.376213,0.9918,0.39138,1610.118611
3,0.991696,0.991696,0.440245,0.991696,0.40754,0.991696,0.417682,1611.414271
4,0.992149,0.992149,0.43964,0.992149,0.409157,0.992149,0.400025,1611.455272
5,0.992199,0.992199,0.707548,0.992199,0.408342,0.992199,0.429811,1611.618326
6,0.991987,0.991987,0.435053,0.991987,0.396933,0.991987,0.406565,1611.704263
7,0.99188,0.99188,0.461722,0.99188,0.383317,0.99188,0.382939,1611.732044
8,0.991951,0.991951,0.470718,0.991951,0.398134,0.991951,0.409666,1614.435424
9,0.991604,0.991604,0.457174,0.991604,0.40153,0.991604,0.389872,1612.88677


In [21]:
results.to_csv(STATS_OUTPUT)