# Import dependencies

In [1]:
import os
import sys

sys.path.insert(0, os.path.dirname(os.getcwd())) 

In [2]:
import time
import gc
import json

import numpy as np
import pandas as pd

from transformers import BertTokenizer, BertModel
from transformers import logging

from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score

import matplotlib.pyplot as plt

from InputDataset import InputDataset

import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.nn.utils.rnn import pad_sequence

from torch import cuda

from ate_models.ATE_BERT_Dropout_BiLSTM_Linear import ATE_BERT_Dropout_BiLSTM_Linear

In [3]:
device = 'cuda' if cuda.is_available() else 'cpu'
logging.set_verbosity_error() 

In [4]:
print(torch.cuda.get_device_name(0))
print(f"Memory: {torch.cuda.get_device_properties(0).total_memory // 1024 ** 3} GB")

NVIDIA GeForce RTX 2060 SUPER
Memory: 8 GB


In [5]:
def clear_memory():
    torch.cuda.empty_cache()

    with torch.no_grad():
        torch.cuda.empty_cache()

    gc.collect()

# Load Data

In [6]:
DATASET = 'ATE_SemEval16_Restaurants_train.json'

In [7]:
df = pd.json_normalize(json.load(open(DATASET)))

In [8]:
df.head()

Unnamed: 0,text,tokens,iob_aspect_tags
0,Judging from previous posts this used to be a ...,"[Judging, from, previous, posts, this, used, t...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, ..."
1,"We, there were four of us, arrived at noon - t...","[We, ,, there, were, four, of, us, ,, arrived,...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
2,"They never brought us complimentary noodles, i...","[They, never, brought, us, complimentary, nood...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,The food was lousy - too sweet or too salty an...,"[The, food, was, lousy, -, too, sweet, or, too...","[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]"
4,The food was lousy - too sweet or too salty an...,"[The, food, was, lousy, -, too, sweet, or, too...","[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]"


# Train & Validate

In [9]:
TRAIN_BATCH_SIZE = 4
VALID_BATCH_SIZE = 4

EPOCHS = 2

LEARNING_RATE = 1e-5

TRAIN_SPLIT = 0.8

SEQ_LEN = 512

NO_RUNS = 10

In [10]:
MODEL_OUTPUT = '../../../results/ATE/SemEval16 - Task 5 - Restaurants/models/bert_pre_trained_dropout_bilstm_linear_512.pth'
STATS_OUTPUT = '../../../results/ATE/SemEval16 - Task 5 - Restaurants/stats/bert_pre_trained_dropout_bilstm_linear_512.csv'

In [11]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [12]:
best_accuracy = 0.0

In [13]:
def create_mini_batch(samples):
    ids_tensors = [s[1] for s in samples]
    ids_tensors[0] = torch.nn.ConstantPad1d((0, SEQ_LEN - len(ids_tensors[0])), 0)(ids_tensors[0])
    ids_tensors = pad_sequence(ids_tensors, batch_first=True).to(device)

    tags_tensors = [s[2] for s in samples]
    tags_tensors[0] = torch.nn.ConstantPad1d((0, SEQ_LEN - len(tags_tensors[0])), 0)(tags_tensors[0])
    tags_tensors = pad_sequence(tags_tensors, batch_first=True).to(device)
    
    masks_tensors = torch.zeros(ids_tensors.shape, dtype=torch.long).to(device)
    masks_tensors = masks_tensors.masked_fill(ids_tensors != 0, 1).to(device)
    
    return ids_tensors, tags_tensors, masks_tensors

In [14]:
def train(epoch, model, loss_fn, optimizer, dataloader):
    model.train()

    dataloader_len = len(dataloader)
    
    losses = []

    for _,data in enumerate(dataloader, 0):
        optimizer.zero_grad()

        ids_tensors, tags_tensors, masks_tensors = data

        outputs = model(ids_tensors, masks_tensors)

        loss = loss_fn(outputs.view(-1, 3), tags_tensors.view(-1))

        losses.append(loss_fn(outputs.view(-1, 3), tags_tensors.view(-1)).item())
        
        if _ % (dataloader_len // 10) == 0:
            print(f"Epoch: {epoch}/{EPOCHS}, Batch: {_}/{dataloader_len}, Loss: {loss.item()}")
        
        loss.backward()
        
        optimizer.step()
    
    return losses

In [15]:
def validation(model, dataloader):
    model.eval()
    
    fin_targets=[]
    fin_outputs=[]

    with torch.no_grad():
        for _, data in enumerate(dataloader, 0):
            ids_tensors, tags_tensors, masks_tensors = data
            ids_tensors = ids_tensors.to(device)
            tags_tensors = tags_tensors.to(device)
            masks_tensors = masks_tensors.to(device)

            outputs = model(ids_tensors, masks_tensors)
            
            _, predictions = torch.max(outputs, dim=2)

            fin_outputs += list([int(p) for pred in predictions for p in pred])
            fin_targets += list([int(tag) for tags_tensor in tags_tensors for tag in tags_tensor])

    return fin_outputs, fin_targets

In [16]:
results = pd.DataFrame(columns=['accuracy','precision_score_micro','precision_score_macro','recall_score_micro','recall_score_macro','f1_score_micro','f1_score_macro', 'execution_time'])

In [17]:
for i in range(NO_RUNS):
    # clear cache cuda
    torch.cuda.empty_cache()
    with torch.no_grad():
        torch.cuda.empty_cache()
    gc.collect()

    start_time = time.time()

    print(f"Run {i + 1}/{NO_RUNS}")

    train_dataset = df.sample(frac=TRAIN_SPLIT)
    test_dataset = df.drop(train_dataset.index).reset_index(drop=True)
    train_dataset = train_dataset.reset_index(drop=True)

    training_set = InputDataset(train_dataset, tokenizer)
    testing_set = InputDataset(test_dataset, tokenizer)

    train_dataloader = DataLoader(
        training_set,
        sampler = RandomSampler(train_dataset),
        batch_size = TRAIN_BATCH_SIZE,
        drop_last = True,
        collate_fn=create_mini_batch
    )

    validation_dataloader = DataLoader(
        testing_set,
        sampler = SequentialSampler(testing_set),
        batch_size = VALID_BATCH_SIZE,
        drop_last = True,
        collate_fn=create_mini_batch
    )

    model = ATE_BERT_Dropout_BiLSTM_Linear(BertModel.from_pretrained('bert-base-uncased'), dropout=0.3, bilstm_in_features=256, no_out_labels=3, device=device).to(device)

    optimizer = torch.optim.Adam(params = model.parameters(), lr=LEARNING_RATE)
    loss_fn = torch.nn.CrossEntropyLoss()

    train_losses = []
    for epoch in range(EPOCHS):
        losses = train(epoch, model, loss_fn, optimizer, train_dataloader)

        train_losses += losses

    outputs, targets = validation(model, validation_dataloader)

    plt.title(f'Train Loss for run {i + 1}/{NO_RUNS}')
    plt.plot(train_losses)
    plt.savefig(f'../../../results/ATE/SemEval16 - Task 5 - Restaurants/plots/bert_pt_do_bilstm_lin/train_loss_run_{i + 1}.png')

    plt.clf()
    
    accuracy = accuracy_score(targets, outputs)
    precision_score_micro = precision_score(targets, outputs, average='micro')
    precision_score_macro = precision_score(targets, outputs, average='macro')
    recall_score_micro = recall_score(targets, outputs, average='micro')
    recall_score_macro = recall_score(targets, outputs, average='macro')
    f1_score_micro = f1_score(targets, outputs, average='micro')
    f1_score_macro = f1_score(targets, outputs, average='macro')

    execution_time = time.time() - start_time

    results.loc[i] = [accuracy,precision_score_micro,precision_score_macro,recall_score_micro,recall_score_macro,f1_score_micro,f1_score_macro, execution_time]

    if accuracy > best_accuracy:
        best_accuracy = accuracy
        torch.save(model, MODEL_OUTPUT)

    del train_dataset
    del test_dataset
    del training_set
    del testing_set
    del model
    del loss_fn
    del optimizer
    del outputs
    del targets

Run 1/10
Epoch: 0/2, Batch: 0/501, Loss: 1.2063688039779663
Epoch: 0/2, Batch: 50/501, Loss: 0.0737944096326828
Epoch: 0/2, Batch: 100/501, Loss: 0.04147915914654732
Epoch: 0/2, Batch: 150/501, Loss: 0.026891347020864487
Epoch: 0/2, Batch: 200/501, Loss: 0.03880473971366882
Epoch: 0/2, Batch: 250/501, Loss: 0.05755380541086197
Epoch: 0/2, Batch: 300/501, Loss: 0.03910692036151886
Epoch: 0/2, Batch: 350/501, Loss: 0.06783652305603027
Epoch: 0/2, Batch: 400/501, Loss: 0.02082475833594799
Epoch: 0/2, Batch: 450/501, Loss: 0.014848523773252964
Epoch: 0/2, Batch: 500/501, Loss: 0.01815630868077278
Epoch: 1/2, Batch: 0/501, Loss: 0.00851296354085207
Epoch: 1/2, Batch: 50/501, Loss: 0.023495078086853027
Epoch: 1/2, Batch: 100/501, Loss: 0.029160290956497192
Epoch: 1/2, Batch: 150/501, Loss: 0.00688580609858036
Epoch: 1/2, Batch: 200/501, Loss: 0.00995971355587244
Epoch: 1/2, Batch: 250/501, Loss: 0.017245374619960785
Epoch: 1/2, Batch: 300/501, Loss: 0.01528118085116148
Epoch: 1/2, Batch: 350

  _warn_prf(average, modifier, msg_start, len(result))


Run 2/10
Epoch: 0/2, Batch: 0/501, Loss: 1.1922657489776611
Epoch: 0/2, Batch: 50/501, Loss: 0.0787394642829895
Epoch: 0/2, Batch: 100/501, Loss: 0.06290058046579361
Epoch: 0/2, Batch: 150/501, Loss: 0.0365154966711998
Epoch: 0/2, Batch: 200/501, Loss: 0.033398084342479706
Epoch: 0/2, Batch: 250/501, Loss: 0.023028777912259102
Epoch: 0/2, Batch: 300/501, Loss: 0.01619075983762741
Epoch: 0/2, Batch: 350/501, Loss: 0.022594241425395012
Epoch: 0/2, Batch: 400/501, Loss: 0.04971852898597717
Epoch: 0/2, Batch: 450/501, Loss: 0.028270291164517403
Epoch: 0/2, Batch: 500/501, Loss: 0.02184840478003025
Epoch: 1/2, Batch: 0/501, Loss: 0.016185887157917023
Epoch: 1/2, Batch: 50/501, Loss: 0.029658226296305656
Epoch: 1/2, Batch: 100/501, Loss: 0.02600334770977497
Epoch: 1/2, Batch: 150/501, Loss: 0.01872141659259796
Epoch: 1/2, Batch: 200/501, Loss: 0.014292692765593529
Epoch: 1/2, Batch: 250/501, Loss: 0.023432036861777306
Epoch: 1/2, Batch: 300/501, Loss: 0.021326348185539246
Epoch: 1/2, Batch: 

  _warn_prf(average, modifier, msg_start, len(result))


Run 3/10
Epoch: 0/2, Batch: 0/501, Loss: 1.1573076248168945
Epoch: 0/2, Batch: 50/501, Loss: 0.07576719671487808
Epoch: 0/2, Batch: 100/501, Loss: 0.03587526082992554
Epoch: 0/2, Batch: 150/501, Loss: 0.03740232437849045
Epoch: 0/2, Batch: 200/501, Loss: 0.028516853228211403
Epoch: 0/2, Batch: 250/501, Loss: 0.028219707310199738
Epoch: 0/2, Batch: 300/501, Loss: 0.02535332925617695
Epoch: 0/2, Batch: 350/501, Loss: 0.023552529513835907
Epoch: 0/2, Batch: 400/501, Loss: 0.026185275986790657
Epoch: 0/2, Batch: 450/501, Loss: 0.015478670597076416
Epoch: 0/2, Batch: 500/501, Loss: 0.02337920479476452
Epoch: 1/2, Batch: 0/501, Loss: 0.012239694595336914
Epoch: 1/2, Batch: 50/501, Loss: 0.0245584174990654
Epoch: 1/2, Batch: 100/501, Loss: 0.007301578298211098
Epoch: 1/2, Batch: 150/501, Loss: 0.018940480425953865
Epoch: 1/2, Batch: 200/501, Loss: 0.017207473516464233
Epoch: 1/2, Batch: 250/501, Loss: 0.015594132244586945
Epoch: 1/2, Batch: 300/501, Loss: 0.012094859965145588
Epoch: 1/2, Batc

  _warn_prf(average, modifier, msg_start, len(result))


Run 4/10
Epoch: 0/2, Batch: 0/501, Loss: 1.0058449506759644
Epoch: 0/2, Batch: 50/501, Loss: 0.08923798054456711
Epoch: 0/2, Batch: 100/501, Loss: 0.04624124616384506
Epoch: 0/2, Batch: 150/501, Loss: 0.048654649406671524
Epoch: 0/2, Batch: 200/501, Loss: 0.041631318628787994
Epoch: 0/2, Batch: 250/501, Loss: 0.026344450190663338
Epoch: 0/2, Batch: 300/501, Loss: 0.06966008245944977
Epoch: 0/2, Batch: 350/501, Loss: 0.030354542657732964
Epoch: 0/2, Batch: 400/501, Loss: 0.02425297163426876
Epoch: 0/2, Batch: 450/501, Loss: 0.024195801466703415
Epoch: 0/2, Batch: 500/501, Loss: 0.02524140104651451
Epoch: 1/2, Batch: 0/501, Loss: 0.02003200724720955
Epoch: 1/2, Batch: 50/501, Loss: 0.015851007774472237
Epoch: 1/2, Batch: 100/501, Loss: 0.027127645909786224
Epoch: 1/2, Batch: 150/501, Loss: 0.015916693955659866
Epoch: 1/2, Batch: 200/501, Loss: 0.023409340530633926
Epoch: 1/2, Batch: 250/501, Loss: 0.00825424212962389
Epoch: 1/2, Batch: 300/501, Loss: 0.017187902703881264
Epoch: 1/2, Batc

  _warn_prf(average, modifier, msg_start, len(result))


Run 5/10
Epoch: 0/2, Batch: 0/501, Loss: 1.1548683643341064
Epoch: 0/2, Batch: 50/501, Loss: 0.0788545310497284
Epoch: 0/2, Batch: 100/501, Loss: 0.052866119891405106
Epoch: 0/2, Batch: 150/501, Loss: 0.0435960479080677
Epoch: 0/2, Batch: 200/501, Loss: 0.038500603288412094
Epoch: 0/2, Batch: 250/501, Loss: 0.023997534066438675
Epoch: 0/2, Batch: 300/501, Loss: 0.022731035947799683
Epoch: 0/2, Batch: 350/501, Loss: 0.049260396510362625
Epoch: 0/2, Batch: 400/501, Loss: 0.016222096979618073
Epoch: 0/2, Batch: 450/501, Loss: 0.009482599794864655
Epoch: 0/2, Batch: 500/501, Loss: 0.01934005320072174
Epoch: 1/2, Batch: 0/501, Loss: 0.0299686286598444
Epoch: 1/2, Batch: 50/501, Loss: 0.027496760711073875
Epoch: 1/2, Batch: 100/501, Loss: 0.019040068611502647
Epoch: 1/2, Batch: 150/501, Loss: 0.037503499537706375
Epoch: 1/2, Batch: 200/501, Loss: 0.015362799167633057
Epoch: 1/2, Batch: 250/501, Loss: 0.007967686280608177
Epoch: 1/2, Batch: 300/501, Loss: 0.01667335443198681
Epoch: 1/2, Batch

  _warn_prf(average, modifier, msg_start, len(result))


Run 8/10
Epoch: 0/2, Batch: 0/501, Loss: 1.1045013666152954
Epoch: 0/2, Batch: 50/501, Loss: 0.0670560747385025
Epoch: 0/2, Batch: 100/501, Loss: 0.03397452086210251
Epoch: 0/2, Batch: 150/501, Loss: 0.02515069954097271
Epoch: 0/2, Batch: 200/501, Loss: 0.025508342310786247
Epoch: 0/2, Batch: 250/501, Loss: 0.016570085659623146
Epoch: 0/2, Batch: 300/501, Loss: 0.04859165474772453
Epoch: 0/2, Batch: 350/501, Loss: 0.022208470851182938
Epoch: 0/2, Batch: 400/501, Loss: 0.013768485747277737
Epoch: 0/2, Batch: 450/501, Loss: 0.021118639037013054
Epoch: 0/2, Batch: 500/501, Loss: 0.019485848024487495
Epoch: 1/2, Batch: 0/501, Loss: 0.031694862991571426
Epoch: 1/2, Batch: 50/501, Loss: 0.02423218823969364
Epoch: 1/2, Batch: 100/501, Loss: 0.011454087682068348
Epoch: 1/2, Batch: 150/501, Loss: 0.011909719556570053
Epoch: 1/2, Batch: 200/501, Loss: 0.019987722858786583
Epoch: 1/2, Batch: 250/501, Loss: 0.01491188071668148
Epoch: 1/2, Batch: 300/501, Loss: 0.016313910484313965
Epoch: 1/2, Batc

  _warn_prf(average, modifier, msg_start, len(result))


<Figure size 432x288 with 0 Axes>

In [18]:
results

Unnamed: 0,accuracy,precision_score_micro,precision_score_macro,recall_score_micro,recall_score_macro,f1_score_micro,f1_score_macro,execution_time
0,0.995379,0.995379,0.473221,0.995379,0.439346,0.995379,0.453845,376.056191
1,0.995535,0.995535,0.511154,0.995535,0.356073,0.995535,0.372995,369.115474
2,0.995477,0.995477,0.480656,0.995477,0.464795,0.995477,0.47224,381.904849
3,0.996051,0.996051,0.496419,0.996051,0.500367,0.996051,0.498375,380.993049
4,0.995367,0.995367,0.466067,0.995367,0.452757,0.995367,0.459024,372.416976
5,0.996555,0.996555,0.826181,0.996555,0.53825,0.996555,0.578448,369.571975
6,0.995363,0.995363,0.470954,0.995363,0.427095,0.995363,0.444711,368.585389
7,0.996148,0.996148,0.824313,0.996148,0.540726,0.996148,0.537783,370.620469
8,0.996223,0.996223,0.779856,0.996223,0.63389,0.996223,0.606805,369.080026
9,0.99518,0.99518,0.43821,0.99518,0.354422,0.99518,0.367819,368.632803


In [19]:
results.to_csv(STATS_OUTPUT)