# Import dependencies

In [None]:
import os
import sys

sys.path.insert(0, os.path.dirname(os.getcwd())) 

In [1]:
import time
import gc

import numpy as np
import pandas as pd

from transformers import BertTokenizer, BertModel
from transformers import logging

from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score

from semeval_reader import SemevalReader

from InputDataset import InputDataset

import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

from torch import cuda

from models.BERT_Dropout_Linear import BERT_Dropout_Linear

In [2]:
device = 'cuda' if cuda.is_available() else 'cpu'
logging.set_verbosity_error() 

In [3]:
print(torch.cuda.get_device_name(0))
print(f"Memory: {torch.cuda.get_device_properties(0).total_memory // 1024 ** 3} GB")

NVIDIA GeForce RTX 2060 SUPER
Memory: 8 GB


In [4]:
def clear_memory():
    torch.cuda.empty_cache()

    with torch.no_grad():
        torch.cuda.empty_cache()

    gc.collect()

# Load Data

In [5]:
def get_target_list_for_polarity(polarity):
    if polarity == 'positive':
        return [0, 0, 1]
    if polarity == 'negative':
        return [1, 0, 0]
    return [0, 1, 0]

In [6]:
semeval_reader = SemevalReader('../../../data/semeval16_restaurants_train.xml')

reviews = semeval_reader.read_reviews()
absolute_polarity_sentences = semeval_reader.get_absolute_polarity_sentences()

df = pd.DataFrame(map(lambda x: (x.text, x.opinions[0].polarity), absolute_polarity_sentences))
df.rename(columns={0: 'text'}, inplace=True)
df['target_list'] = df.apply(lambda row: get_target_list_for_polarity(row[1]), axis=1)

absolute_polarity_df = df.drop(columns=[1])

# Train & Validate

In [7]:
TRAIN_BATCH_SIZE = 4
VALID_BATCH_SIZE = 4

EPOCHS = 2

LEARNING_RATE = 1e-5

TRAIN_SPLIT = 0.8

NO_RUNS = 10

In [8]:
MODEL_OUTPUT = '../../../results/SA/SemEval16 - Task 5 - Restaurants/models/bert_fine_tuned.pth'
STATS_OUTPUT = '../../../results/SA/SemEval16 - Task 5 - Restaurants/stats/bert_pre_trained_dropout_linear.csv'

In [9]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [10]:
best_accuracy = 0.0

In [11]:
def train(epoch, model, loss_fn, optimizer, dataloader):
    model.train()

    dataloader_len = len(dataloader)

    for _,data in enumerate(dataloader, 0):
        optimizer.zero_grad()

        ids = data['ids'].to(device, dtype = torch.long)
        mask = data['mask'].to(device, dtype = torch.long)
        token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)
        targets = data['targets'].to(device, dtype = torch.float)

        outputs = model(ids, mask, token_type_ids)

        loss = loss_fn(outputs, targets)
        
        if _ % (dataloader_len // 10) == 0:
            print(f"Epoch: {epoch}/{EPOCHS}, Batch: {_}/{dataloader_len}, Loss: {loss.item()}")
        
        loss.backward()
        
        optimizer.step()

In [12]:
def validation(model, dataloader):
    model.eval()
    
    fin_targets=[]
    fin_outputs=[]

    with torch.no_grad():
        for _, data in enumerate(dataloader, 0):
            ids = data['ids'].to(device, dtype = torch.long)
            mask = data['mask'].to(device, dtype = torch.long)
            token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)
            targets = data['targets'].to(device, dtype = torch.float)

            outputs = model(ids, mask, token_type_ids)
            
            fin_targets.extend(targets.cpu().detach().numpy().tolist())
            fin_outputs.extend(torch.sigmoid(outputs).cpu().detach().numpy().tolist())

    return fin_outputs, fin_targets

In [13]:
results = pd.DataFrame(columns=['accuracy','precision_score_micro','precision_score_macro','recall_score_micro','recall_score_macro','f1_score_micro','f1_score_macro', 'execution_time'])

In [14]:
for i in range(10):
    # clear cache cuda
    torch.cuda.empty_cache()
    with torch.no_grad():
        torch.cuda.empty_cache()
    gc.collect()

    start_time = time.time()

    print(f"Run {i + 1}/10")

    train_dataset = df.sample(frac=TRAIN_SPLIT)
    test_dataset = df.drop(train_dataset.index).reset_index(drop=True)
    train_dataset = train_dataset.reset_index(drop=True)

    training_set = InputDataset(train_dataset, tokenizer)
    testing_set = InputDataset(test_dataset, tokenizer)

    train_dataloader = DataLoader(
        training_set,
        sampler = RandomSampler(train_dataset),
        batch_size = TRAIN_BATCH_SIZE,
        drop_last = True
    )

    validation_dataloader = DataLoader(
        testing_set,
        sampler = SequentialSampler(testing_set),
        batch_size = VALID_BATCH_SIZE,
        drop_last = True
    )

    model = BERT_Dropout_Linear(BertModel.from_pretrained('bert-base-uncased'), dropout=0.3, no_out_labels=3, device=device).to(device)

    optimizer = torch.optim.Adam(params = model.parameters(), lr=LEARNING_RATE)
    loss_fn = torch.nn.BCEWithLogitsLoss()

    for epoch in range(EPOCHS):
        train(epoch, model, loss_fn, optimizer, train_dataloader)

    outputs, targets = validation(model, validation_dataloader)
    outputs = np.argmax(outputs, axis=1)
    targets = np.argmax(targets, axis=1)
    
    accuracy = accuracy_score(targets, outputs)
    precision_score_micro = precision_score(targets, outputs, average='micro')
    precision_score_macro = precision_score(targets, outputs, average='macro')
    recall_score_micro = recall_score(targets, outputs, average='micro')
    recall_score_macro = recall_score(targets, outputs, average='macro')
    f1_score_micro = f1_score(targets, outputs, average='micro')
    f1_score_macro = f1_score(targets, outputs, average='macro')

    execution_time = time.time() - start_time

    results.loc[i] = [accuracy,precision_score_micro,precision_score_macro,recall_score_micro,recall_score_macro,f1_score_micro,f1_score_macro, execution_time]

    if accuracy > best_accuracy:
        best_accuracy = accuracy
        torch.save(model.bert, MODEL_OUTPUT)

    del train_dataset
    del test_dataset
    del training_set
    del testing_set
    del model
    del loss_fn
    del optimizer
    del outputs
    del targets

Run 1/10
Epoch: 0/2, Batch: 0/313, Loss: 0.8135584592819214
Epoch: 0/2, Batch: 31/313, Loss: 0.6470083594322205
Epoch: 0/2, Batch: 62/313, Loss: 0.499563068151474
Epoch: 0/2, Batch: 93/313, Loss: 0.4776678681373596
Epoch: 0/2, Batch: 124/313, Loss: 0.34995603561401367
Epoch: 0/2, Batch: 155/313, Loss: 0.1958431452512741
Epoch: 0/2, Batch: 186/313, Loss: 0.2779799699783325
Epoch: 0/2, Batch: 217/313, Loss: 0.20088832080364227
Epoch: 0/2, Batch: 248/313, Loss: 0.22771942615509033
Epoch: 0/2, Batch: 279/313, Loss: 0.1353072226047516
Epoch: 0/2, Batch: 310/313, Loss: 0.06596636027097702
Epoch: 1/2, Batch: 0/313, Loss: 0.06889650225639343
Epoch: 1/2, Batch: 31/313, Loss: 0.13402268290519714
Epoch: 1/2, Batch: 62/313, Loss: 0.05303209275007248
Epoch: 1/2, Batch: 93/313, Loss: 0.11233144253492355
Epoch: 1/2, Batch: 124/313, Loss: 0.09338028728961945
Epoch: 1/2, Batch: 155/313, Loss: 0.3873445391654968
Epoch: 1/2, Batch: 186/313, Loss: 0.06349997222423553
Epoch: 1/2, Batch: 217/313, Loss: 0.03

  _warn_prf(average, modifier, msg_start, len(result))


Run 2/10
Epoch: 0/2, Batch: 0/313, Loss: 0.6951073408126831
Epoch: 0/2, Batch: 31/313, Loss: 0.41582173109054565
Epoch: 0/2, Batch: 62/313, Loss: 0.514755129814148
Epoch: 0/2, Batch: 93/313, Loss: 0.42834794521331787
Epoch: 0/2, Batch: 124/313, Loss: 0.19163423776626587
Epoch: 0/2, Batch: 155/313, Loss: 0.1639644205570221
Epoch: 0/2, Batch: 186/313, Loss: 0.32411250472068787
Epoch: 0/2, Batch: 217/313, Loss: 0.17019063234329224
Epoch: 0/2, Batch: 248/313, Loss: 0.5374182462692261
Epoch: 0/2, Batch: 279/313, Loss: 0.1017962321639061
Epoch: 0/2, Batch: 310/313, Loss: 0.2698870599269867
Epoch: 1/2, Batch: 0/313, Loss: 0.6076891422271729
Epoch: 1/2, Batch: 31/313, Loss: 0.3124878704547882
Epoch: 1/2, Batch: 62/313, Loss: 0.13232654333114624
Epoch: 1/2, Batch: 93/313, Loss: 0.051659248769283295
Epoch: 1/2, Batch: 124/313, Loss: 0.0628635585308075
Epoch: 1/2, Batch: 155/313, Loss: 0.035960763692855835
Epoch: 1/2, Batch: 186/313, Loss: 0.08989375829696655
Epoch: 1/2, Batch: 217/313, Loss: 0.0

  _warn_prf(average, modifier, msg_start, len(result))


Run 3/10
Epoch: 0/2, Batch: 0/313, Loss: 0.7720322608947754
Epoch: 0/2, Batch: 31/313, Loss: 0.556869626045227
Epoch: 0/2, Batch: 62/313, Loss: 0.5700375437736511
Epoch: 0/2, Batch: 93/313, Loss: 0.5994336009025574
Epoch: 0/2, Batch: 124/313, Loss: 0.5111494064331055
Epoch: 0/2, Batch: 155/313, Loss: 0.3409075438976288
Epoch: 0/2, Batch: 186/313, Loss: 0.2673484981060028
Epoch: 0/2, Batch: 217/313, Loss: 0.19411611557006836
Epoch: 0/2, Batch: 248/313, Loss: 0.26490139961242676
Epoch: 0/2, Batch: 279/313, Loss: 0.6527553796768188
Epoch: 0/2, Batch: 310/313, Loss: 0.12804780900478363
Epoch: 1/2, Batch: 0/313, Loss: 0.3928905427455902
Epoch: 1/2, Batch: 31/313, Loss: 0.14158225059509277
Epoch: 1/2, Batch: 62/313, Loss: 0.12707266211509705
Epoch: 1/2, Batch: 93/313, Loss: 0.06748398393392563
Epoch: 1/2, Batch: 124/313, Loss: 0.30842864513397217
Epoch: 1/2, Batch: 155/313, Loss: 0.20613479614257812
Epoch: 1/2, Batch: 186/313, Loss: 0.3039252758026123
Epoch: 1/2, Batch: 217/313, Loss: 0.0575

  _warn_prf(average, modifier, msg_start, len(result))


Run 4/10
Epoch: 0/2, Batch: 0/313, Loss: 0.6393111348152161
Epoch: 0/2, Batch: 31/313, Loss: 0.5059399008750916
Epoch: 0/2, Batch: 62/313, Loss: 0.5409404635429382
Epoch: 0/2, Batch: 93/313, Loss: 0.4267091155052185
Epoch: 0/2, Batch: 124/313, Loss: 0.3505723178386688
Epoch: 0/2, Batch: 155/313, Loss: 0.8488353490829468
Epoch: 0/2, Batch: 186/313, Loss: 0.31003719568252563
Epoch: 0/2, Batch: 217/313, Loss: 0.13435609638690948
Epoch: 0/2, Batch: 248/313, Loss: 0.19667381048202515
Epoch: 0/2, Batch: 279/313, Loss: 0.1810721606016159
Epoch: 0/2, Batch: 310/313, Loss: 0.47617995738983154
Epoch: 1/2, Batch: 0/313, Loss: 0.155410498380661
Epoch: 1/2, Batch: 31/313, Loss: 0.3115243911743164
Epoch: 1/2, Batch: 62/313, Loss: 0.12059950083494186
Epoch: 1/2, Batch: 93/313, Loss: 0.035972341895103455
Epoch: 1/2, Batch: 124/313, Loss: 0.07057766616344452
Epoch: 1/2, Batch: 155/313, Loss: 1.054293155670166
Epoch: 1/2, Batch: 186/313, Loss: 0.3555651903152466
Epoch: 1/2, Batch: 217/313, Loss: 0.05210

  _warn_prf(average, modifier, msg_start, len(result))


Run 5/10
Epoch: 0/2, Batch: 0/313, Loss: 0.7282848358154297
Epoch: 0/2, Batch: 31/313, Loss: 0.4992045760154724
Epoch: 0/2, Batch: 62/313, Loss: 0.768202543258667
Epoch: 0/2, Batch: 93/313, Loss: 0.37085962295532227
Epoch: 0/2, Batch: 124/313, Loss: 0.3804178833961487
Epoch: 0/2, Batch: 155/313, Loss: 0.45496630668640137
Epoch: 0/2, Batch: 186/313, Loss: 0.31966665387153625
Epoch: 0/2, Batch: 217/313, Loss: 0.13001202046871185
Epoch: 0/2, Batch: 248/313, Loss: 0.3262258768081665
Epoch: 0/2, Batch: 279/313, Loss: 0.20892584323883057
Epoch: 0/2, Batch: 310/313, Loss: 0.09545151889324188
Epoch: 1/2, Batch: 0/313, Loss: 0.10654284805059433
Epoch: 1/2, Batch: 31/313, Loss: 0.11697613447904587
Epoch: 1/2, Batch: 62/313, Loss: 0.044925980269908905
Epoch: 1/2, Batch: 93/313, Loss: 0.22500726580619812
Epoch: 1/2, Batch: 124/313, Loss: 0.05307605862617493
Epoch: 1/2, Batch: 155/313, Loss: 0.06678053736686707
Epoch: 1/2, Batch: 186/313, Loss: 0.061111778020858765
Epoch: 1/2, Batch: 217/313, Loss:

  _warn_prf(average, modifier, msg_start, len(result))


Run 6/10
Epoch: 0/2, Batch: 0/313, Loss: 0.6581122875213623
Epoch: 0/2, Batch: 31/313, Loss: 0.49423402547836304
Epoch: 0/2, Batch: 62/313, Loss: 0.6310539245605469
Epoch: 0/2, Batch: 93/313, Loss: 0.34446173906326294
Epoch: 0/2, Batch: 124/313, Loss: 0.32063934206962585
Epoch: 0/2, Batch: 155/313, Loss: 0.14304882287979126
Epoch: 0/2, Batch: 186/313, Loss: 0.17665791511535645
Epoch: 0/2, Batch: 217/313, Loss: 0.1837436556816101
Epoch: 0/2, Batch: 248/313, Loss: 0.4151671826839447
Epoch: 0/2, Batch: 279/313, Loss: 0.051669031381607056
Epoch: 0/2, Batch: 310/313, Loss: 0.07439880073070526
Epoch: 1/2, Batch: 0/313, Loss: 0.1109483391046524
Epoch: 1/2, Batch: 31/313, Loss: 0.20460782945156097
Epoch: 1/2, Batch: 62/313, Loss: 0.11655651032924652
Epoch: 1/2, Batch: 93/313, Loss: 0.0628238245844841
Epoch: 1/2, Batch: 124/313, Loss: 0.07661507278680801
Epoch: 1/2, Batch: 155/313, Loss: 0.035909608006477356
Epoch: 1/2, Batch: 186/313, Loss: 0.6083745956420898
Epoch: 1/2, Batch: 217/313, Loss: 

  _warn_prf(average, modifier, msg_start, len(result))


Run 7/10
Epoch: 0/2, Batch: 0/313, Loss: 0.7746721506118774
Epoch: 0/2, Batch: 31/313, Loss: 0.5889847278594971
Epoch: 0/2, Batch: 62/313, Loss: 0.42400598526000977
Epoch: 0/2, Batch: 93/313, Loss: 0.2658620774745941
Epoch: 0/2, Batch: 124/313, Loss: 0.2650744915008545
Epoch: 0/2, Batch: 155/313, Loss: 0.46192190051078796
Epoch: 0/2, Batch: 186/313, Loss: 0.8510639071464539
Epoch: 0/2, Batch: 217/313, Loss: 0.6271886229515076
Epoch: 0/2, Batch: 248/313, Loss: 0.15678346157073975
Epoch: 0/2, Batch: 279/313, Loss: 0.3224421739578247
Epoch: 0/2, Batch: 310/313, Loss: 0.4628303050994873
Epoch: 1/2, Batch: 0/313, Loss: 0.2079296112060547
Epoch: 1/2, Batch: 31/313, Loss: 0.25567352771759033
Epoch: 1/2, Batch: 62/313, Loss: 0.44476351141929626
Epoch: 1/2, Batch: 93/313, Loss: 0.07715418189764023
Epoch: 1/2, Batch: 124/313, Loss: 0.08695300668478012
Epoch: 1/2, Batch: 155/313, Loss: 0.061430901288986206
Epoch: 1/2, Batch: 186/313, Loss: 0.09561901539564133
Epoch: 1/2, Batch: 217/313, Loss: 0.0

  _warn_prf(average, modifier, msg_start, len(result))


Run 8/10
Epoch: 0/2, Batch: 0/313, Loss: 0.6540887355804443
Epoch: 0/2, Batch: 31/313, Loss: 0.6353632807731628
Epoch: 0/2, Batch: 62/313, Loss: 0.324989378452301
Epoch: 0/2, Batch: 93/313, Loss: 0.522436261177063
Epoch: 0/2, Batch: 124/313, Loss: 0.442356675863266
Epoch: 0/2, Batch: 155/313, Loss: 0.24884924292564392
Epoch: 0/2, Batch: 186/313, Loss: 0.157729372382164
Epoch: 0/2, Batch: 217/313, Loss: 0.15961574018001556
Epoch: 0/2, Batch: 248/313, Loss: 0.33668774366378784
Epoch: 0/2, Batch: 279/313, Loss: 0.4965752959251404
Epoch: 0/2, Batch: 310/313, Loss: 0.4747634530067444
Epoch: 1/2, Batch: 0/313, Loss: 0.06401368230581284
Epoch: 1/2, Batch: 31/313, Loss: 0.37374186515808105
Epoch: 1/2, Batch: 62/313, Loss: 0.4173576235771179
Epoch: 1/2, Batch: 93/313, Loss: 0.34101957082748413
Epoch: 1/2, Batch: 124/313, Loss: 0.5942520499229431
Epoch: 1/2, Batch: 155/313, Loss: 0.0969112366437912
Epoch: 1/2, Batch: 186/313, Loss: 0.08985277265310287
Epoch: 1/2, Batch: 217/313, Loss: 0.09269888

  _warn_prf(average, modifier, msg_start, len(result))


Run 9/10
Epoch: 0/2, Batch: 0/313, Loss: 0.6632400751113892
Epoch: 0/2, Batch: 31/313, Loss: 0.6547996401786804
Epoch: 0/2, Batch: 62/313, Loss: 0.684213399887085
Epoch: 0/2, Batch: 93/313, Loss: 0.3147515058517456
Epoch: 0/2, Batch: 124/313, Loss: 0.7293826937675476
Epoch: 0/2, Batch: 155/313, Loss: 0.5695768594741821
Epoch: 0/2, Batch: 186/313, Loss: 0.1318027675151825
Epoch: 0/2, Batch: 217/313, Loss: 0.24810388684272766
Epoch: 0/2, Batch: 248/313, Loss: 0.16786710917949677
Epoch: 0/2, Batch: 279/313, Loss: 0.3092167377471924
Epoch: 0/2, Batch: 310/313, Loss: 0.09501457214355469
Epoch: 1/2, Batch: 0/313, Loss: 0.1528787612915039
Epoch: 1/2, Batch: 31/313, Loss: 0.4779481887817383
Epoch: 1/2, Batch: 62/313, Loss: 0.047969382256269455
Epoch: 1/2, Batch: 93/313, Loss: 0.13515228033065796
Epoch: 1/2, Batch: 124/313, Loss: 0.08887308835983276
Epoch: 1/2, Batch: 155/313, Loss: 0.07587829232215881
Epoch: 1/2, Batch: 186/313, Loss: 0.37368375062942505
Epoch: 1/2, Batch: 217/313, Loss: 0.062

  _warn_prf(average, modifier, msg_start, len(result))


Run 10/10
Epoch: 0/2, Batch: 0/313, Loss: 0.5847243666648865
Epoch: 0/2, Batch: 31/313, Loss: 0.6022611856460571
Epoch: 0/2, Batch: 62/313, Loss: 0.6459730863571167
Epoch: 0/2, Batch: 93/313, Loss: 0.4091607928276062
Epoch: 0/2, Batch: 124/313, Loss: 0.2826034426689148
Epoch: 0/2, Batch: 155/313, Loss: 0.4220062494277954
Epoch: 0/2, Batch: 186/313, Loss: 0.1523214876651764
Epoch: 0/2, Batch: 217/313, Loss: 0.6347696781158447
Epoch: 0/2, Batch: 248/313, Loss: 0.30869102478027344
Epoch: 0/2, Batch: 279/313, Loss: 0.3599624037742615
Epoch: 0/2, Batch: 310/313, Loss: 0.1985052525997162
Epoch: 1/2, Batch: 0/313, Loss: 0.04283856600522995
Epoch: 1/2, Batch: 31/313, Loss: 0.09483489394187927
Epoch: 1/2, Batch: 62/313, Loss: 0.48233968019485474
Epoch: 1/2, Batch: 93/313, Loss: 0.6010804176330566
Epoch: 1/2, Batch: 124/313, Loss: 0.046953123062849045
Epoch: 1/2, Batch: 155/313, Loss: 0.12951180338859558
Epoch: 1/2, Batch: 186/313, Loss: 0.06989099085330963
Epoch: 1/2, Batch: 217/313, Loss: 0.59

In [15]:
results

Unnamed: 0,accuracy,precision_score_micro,precision_score_macro,recall_score_micro,recall_score_macro,f1_score_micro,f1_score_macro,execution_time
0,0.86859,0.86859,0.570199,0.86859,0.611526,0.86859,0.588501,201.240559
1,0.884615,0.884615,0.584936,0.884615,0.612896,0.884615,0.598589,185.900803
2,0.871795,0.871795,0.563396,0.871795,0.607407,0.871795,0.582492,187.315874
3,0.884615,0.884615,0.579272,0.884615,0.600855,0.884615,0.589841,195.283949
4,0.88141,0.88141,0.578965,0.88141,0.608416,0.88141,0.593325,208.659036
5,0.919872,0.919872,0.608537,0.919872,0.611285,0.919872,0.609729,199.921068
6,0.894231,0.894231,0.587773,0.894231,0.6024,0.894231,0.594993,203.051747
7,0.88141,0.88141,0.565171,0.88141,0.623236,0.88141,0.587536,204.666527
8,0.929487,0.929487,0.616226,0.929487,0.623907,0.929487,0.61997,204.112702
9,0.910256,0.910256,0.928056,0.910256,0.680971,0.910256,0.709608,205.93971


In [16]:
results.to_csv(STATS_OUTPUT)