In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import AutoTokenizer, AutoModel, AdamW
import pandas as pd
import numpy as np

"""Local Imports"""
from src.preprocessing import remove_stop_words

In [25]:
data = pd.read_csv('data/train/total_data.csv').to_numpy()

In [3]:
class SiameseModel(nn.Module):
    def __init__(self, h_size, max_length):
        super(SiameseModel, self).__init__()
        # The hidden layer size for the classification token (CLS) in BERT
        # In this case, it is 768
        self.h_size = h_size
        
        # The max length a title could be for padding purposes
        self.max_length = max_length * 2
        
        # BERT tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        
        # BERT model
        self.bert = AutoModel.from_pretrained("bert-base-uncased")
        
        # We want to freeze all parameters except the last couple for training
        for idx, param in enumerate(self.bert.parameters()):
            if idx < 170:
                param.requires_grad = False
        
        # Fully-Connected layers
        self.fc1 = nn.Linear(self.h_size, 384)
        self.fc2 = nn.Linear(384, 2)
        
        # Dropout for overfitting
        self.dropout = nn.Dropout(p=0.2)
        
        # Softmax for prediction
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        """
        x is going to be a numpy array of [sentenceA, sentenceB].
        Model using BERT to make a prediction of whether the two titles represent 
        the same entity.
        """
        # BERT for title similarity works having the two sentences (sentence1, sentence2)
        # and ordering them in both combinations that they could be (sentence1 + sentence2)
        # and (sentence2 + sentence1). That is why we do np.flip() on x (the input sentences)
        input1 = self.tokenizer(x.tolist(),
                                return_tensors='pt',
                                padding='max_length',
                                truncation=True,
                                max_length=self.max_length)

        input2 = self.tokenizer(np.flip(x, 1).tolist(),
                                return_tensors='pt',
                                padding='max_length',
                                truncation=True,
                                max_length=self.max_length)
        
        # Send the inputs through BERT
        # We index at 1 because that gives us the classification token (CLS)
        # that BERT talks about in the paper (as opposed to each hidden layer for each)
        # token embedding
        output1 = self.bert(**input1)[1]
        output2 = self.bert(**input2)[1]
        
        # BERT calls for the addition of both 
        addition = output1 + output2
        
        # Fully-Connected Layer 1 (input of 768 units and output of 384)
        addition = self.fc1(addition)
        
        # ReLU Activation
        additionn = F.relu(addition)
        
        # Dropout
        addition = self.dropout(addition)
        
        # Fully-Connected Layer 2 (input of 384 units, out of 2 for Softmax)
        addition = self.fc2(addition)
        
        # Dropout
        addition = self.dropout(addition)
        
        # Softmax Activation to get predictions
        addition = self.softmax(addition)
        
        return addition

net = SiameseModel(768, 44)
net.train()
print()

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=433.0), HTML(value='')))




HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=231508.0), HTML(value='')))




HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=440473133.0), HTML(value='')))





In [4]:
# Using cross-entropy because we are making a classifier
criterion = nn.CrossEntropyLoss()

# Using Adam optimizer
opt = AdamW(net.parameters(), lr=1e-6)

In [33]:
# 5 epochs
for epoch in range(10):
    # Keep track of loss for each mini-batch
    # Each mini-batch is going to be 32 examples
    BATCH_SIZE = 32# 5 epochs
for epoch in range(10):
    # Keep track of loss for each mini-batch
    # Each mini-batch is going to be 32 examples
    BATCH_SIZE = 32
    
    losses = 0
    accuracies = 0
    for i, position in enumerate(range(0, len(train_data), BATCH_SIZE)):
        if (position + BATCH_SIZE > len(train_data)):
            batch_data = train_data[position:]
            batch_labels = train_labels[position:]
        else:
            batch_data = train_data[position:position + BATCH_SIZE]
            batch_labels = train_labels[position:position + BATCH_SIZE]
            
        # Zero the parameter gradients
        opt.zero_grad()
        
        # Forward propagation
        forward = net(batch_data)

        # Calculate loss
        loss = criterion(forward, torch.from_numpy(batch_labels).view(-1).long())
        
        # Backprop
        loss.backward()
        
        # Apply the gradients
        opt.step()
        
        losses += loss.item()
        if (i % 10 == 9):
            # Print statistics every batch
            print('Epoch: %d, Batch %5d, loss: %.3f' %
                    (epoch + 1, i + 1, losses/10))
            
            losses = 0

print('Finished training')
    
    losses = 0
    accuracies = 0
    for i, position in enumerate(range(0, len(train_data), BATCH_SIZE)):
        if (position + BATCH_SIZE > len(train_data)):
            batch_data = train_data[position:]
            batch_labels = train_labels[position:]
        else:
            batch_data = train_data[position:position + BATCH_SIZE]
            batch_labels = train_labels[position:position + BATCH_SIZE]
            
        # Zero the parameter gradients
        opt.zero_grad()
        
        # Forward propagation
        forward = net(batch_data)

        # Calculate loss
        loss = criterion(forward, torch.from_numpy(batch_labels).view(-1).long())
        
        # Backprop
        loss.backward()
        
        # Apply the gradients
        opt.step()
        
        losses += loss.item()
        if (i % 10 == 9):
            # Print statistics every batch
            print('Epoch: %d, Batch %5d, loss: %.3f' %
                    (epoch + 1, i + 1, losses/10))
            
            losses = 0

print('Finished training')

Epoch: 1, Batch    10, loss: 0.747
Epoch: 1, Batch    20, loss: 0.763
Epoch: 1, Batch    30, loss: 0.766
Epoch: 1, Batch    40, loss: 0.696
Epoch: 1, Batch    50, loss: 0.745
Epoch: 1, Batch    60, loss: 0.716
Epoch: 1, Batch    70, loss: 0.696
Epoch: 1, Batch    80, loss: 0.699
Epoch: 1, Batch    90, loss: 0.712
Epoch: 1, Batch   100, loss: 0.704
Epoch: 1, Batch   110, loss: 0.697
Epoch: 1, Batch   120, loss: 0.694
Epoch: 1, Batch   130, loss: 0.710
Epoch: 1, Batch   140, loss: 0.694
Epoch: 1, Batch   150, loss: 0.699
Epoch: 1, Batch   160, loss: 0.694
Epoch: 1, Batch   170, loss: 0.688
Epoch: 1, Batch   180, loss: 0.695
Epoch: 1, Batch   190, loss: 0.683
Epoch: 1, Batch   200, loss: 0.690
Epoch: 1, Batch   210, loss: 0.693
Epoch: 1, Batch   220, loss: 0.707
Epoch: 1, Batch   230, loss: 0.691
Epoch: 1, Batch   240, loss: 0.703
Epoch: 1, Batch   250, loss: 0.697
Epoch: 1, Batch   260, loss: 0.696
Epoch: 1, Batch   270, loss: 0.696
Epoch: 1, Batch   280, loss: 0.702
Epoch: 1, Batch   29

Epoch: 2, Batch   490, loss: 0.431
Epoch: 2, Batch   500, loss: 0.433
Epoch: 2, Batch   510, loss: 0.435
Epoch: 2, Batch   520, loss: 0.459
Epoch: 2, Batch   530, loss: 0.418
Epoch: 2, Batch   540, loss: 0.444
Epoch: 2, Batch   550, loss: 0.471
Epoch: 2, Batch   560, loss: 0.469
Epoch: 2, Batch   570, loss: 0.430
Epoch: 2, Batch   580, loss: 0.402
Epoch: 2, Batch   590, loss: 0.419
Epoch: 2, Batch   600, loss: 0.458
Epoch: 2, Batch   610, loss: 0.401
Epoch: 2, Batch   620, loss: 0.450
Epoch: 2, Batch   630, loss: 0.435
Epoch: 2, Batch   640, loss: 0.430
Epoch: 2, Batch   650, loss: 0.439
Epoch: 2, Batch   660, loss: 0.405
Epoch: 2, Batch   670, loss: 0.381
Epoch: 2, Batch   680, loss: 0.442
Epoch: 2, Batch   690, loss: 0.410
Epoch: 2, Batch   700, loss: 0.383
Epoch: 2, Batch   710, loss: 0.426
Epoch: 2, Batch   720, loss: 0.459
Epoch: 2, Batch   730, loss: 0.432
Epoch: 2, Batch   740, loss: 0.413
Epoch: 2, Batch   750, loss: 0.421
Epoch: 2, Batch   760, loss: 0.418
Epoch: 2, Batch   77

Epoch: 3, Batch   970, loss: 0.367
Epoch: 3, Batch   980, loss: 0.353
Epoch: 3, Batch   990, loss: 0.335
Epoch: 3, Batch  1000, loss: 0.332
Epoch: 3, Batch  1010, loss: 0.375
Epoch: 3, Batch  1020, loss: 0.349
Epoch: 3, Batch  1030, loss: 0.383
Epoch: 3, Batch  1040, loss: 0.412
Epoch: 3, Batch  1050, loss: 0.347
Epoch: 3, Batch  1060, loss: 0.337
Epoch: 3, Batch  1070, loss: 0.368
Epoch: 3, Batch  1080, loss: 0.347
Epoch: 3, Batch  1090, loss: 0.427
Epoch: 3, Batch  1100, loss: 0.349
Epoch: 3, Batch  1110, loss: 0.380
Epoch: 3, Batch  1120, loss: 0.360
Epoch: 3, Batch  1130, loss: 0.393
Epoch: 3, Batch  1140, loss: 0.347
Epoch: 3, Batch  1150, loss: 0.346
Epoch: 3, Batch  1160, loss: 0.413
Epoch: 3, Batch  1170, loss: 0.346
Epoch: 3, Batch  1180, loss: 0.308
Epoch: 3, Batch  1190, loss: 0.358
Epoch: 3, Batch  1200, loss: 0.405
Epoch: 3, Batch  1210, loss: 0.386
Epoch: 3, Batch  1220, loss: 0.361
Epoch: 3, Batch  1230, loss: 0.342
Epoch: 3, Batch  1240, loss: 0.388
Epoch: 3, Batch  125

Epoch: 4, Batch  1450, loss: 0.319
Epoch: 4, Batch  1460, loss: 0.330
Epoch: 4, Batch  1470, loss: 0.336
Epoch: 4, Batch  1480, loss: 0.363
Epoch: 4, Batch  1490, loss: 0.371
Epoch: 4, Batch  1500, loss: 0.284
Epoch: 4, Batch  1510, loss: 0.318
Epoch: 4, Batch  1520, loss: 0.301
Epoch: 4, Batch  1530, loss: 0.333
Epoch: 4, Batch  1540, loss: 0.308
Epoch: 4, Batch  1550, loss: 0.348
Epoch: 4, Batch  1560, loss: 0.304
Epoch: 4, Batch  1570, loss: 0.283
Epoch: 4, Batch  1580, loss: 0.305
Epoch: 4, Batch  1590, loss: 0.270
Epoch: 4, Batch  1600, loss: 0.312
Epoch: 4, Batch  1610, loss: 0.386
Epoch: 4, Batch  1620, loss: 0.311
Epoch: 4, Batch  1630, loss: 0.292
Epoch: 4, Batch  1640, loss: 0.356
Epoch: 4, Batch  1650, loss: 0.314
Epoch: 4, Batch  1660, loss: 0.320
Epoch: 4, Batch  1670, loss: 0.309
Epoch: 4, Batch  1680, loss: 0.356
Epoch: 4, Batch  1690, loss: 0.315
Epoch: 4, Batch  1700, loss: 0.348
Epoch: 4, Batch  1710, loss: 0.334
Epoch: 4, Batch  1720, loss: 0.341
Epoch: 4, Batch  173

Epoch: 6, Batch    60, loss: 0.297
Epoch: 6, Batch    70, loss: 0.281
Epoch: 6, Batch    80, loss: 0.308
Epoch: 6, Batch    90, loss: 0.289
Epoch: 6, Batch   100, loss: 0.294
Epoch: 6, Batch   110, loss: 0.277
Epoch: 6, Batch   120, loss: 0.350
Epoch: 6, Batch   130, loss: 0.326
Epoch: 6, Batch   140, loss: 0.290
Epoch: 6, Batch   150, loss: 0.301
Epoch: 6, Batch   160, loss: 0.374
Epoch: 6, Batch   170, loss: 0.297
Epoch: 6, Batch   180, loss: 0.350
Epoch: 6, Batch   190, loss: 0.319
Epoch: 6, Batch   200, loss: 0.360
Epoch: 6, Batch   210, loss: 0.300
Epoch: 6, Batch   220, loss: 0.318
Epoch: 6, Batch   230, loss: 0.272
Epoch: 6, Batch   240, loss: 0.354
Epoch: 6, Batch   250, loss: 0.312
Epoch: 6, Batch   260, loss: 0.312
Epoch: 6, Batch   270, loss: 0.330
Epoch: 6, Batch   280, loss: 0.370
Epoch: 6, Batch   290, loss: 0.286
Epoch: 6, Batch   300, loss: 0.320
Epoch: 6, Batch   310, loss: 0.348
Epoch: 6, Batch   320, loss: 0.311
Epoch: 6, Batch   330, loss: 0.299
Epoch: 6, Batch   34

Epoch: 7, Batch   540, loss: 0.308
Epoch: 7, Batch   550, loss: 0.281
Epoch: 7, Batch   560, loss: 0.284
Epoch: 7, Batch   570, loss: 0.254
Epoch: 7, Batch   580, loss: 0.308
Epoch: 7, Batch   590, loss: 0.269
Epoch: 7, Batch   600, loss: 0.287
Epoch: 7, Batch   610, loss: 0.306
Epoch: 7, Batch   620, loss: 0.322
Epoch: 7, Batch   630, loss: 0.337
Epoch: 7, Batch   640, loss: 0.287
Epoch: 7, Batch   650, loss: 0.321
Epoch: 7, Batch   660, loss: 0.262
Epoch: 7, Batch   670, loss: 0.204
Epoch: 7, Batch   680, loss: 0.280
Epoch: 7, Batch   690, loss: 0.293
Epoch: 7, Batch   700, loss: 0.320
Epoch: 7, Batch   710, loss: 0.305
Epoch: 7, Batch   720, loss: 0.299
Epoch: 7, Batch   730, loss: 0.259
Epoch: 7, Batch   740, loss: 0.301
Epoch: 7, Batch   750, loss: 0.289
Epoch: 7, Batch   760, loss: 0.260
Epoch: 7, Batch   770, loss: 0.269
Epoch: 7, Batch   780, loss: 0.286
Epoch: 7, Batch   790, loss: 0.287
Epoch: 7, Batch   800, loss: 0.288
Epoch: 7, Batch   810, loss: 0.312
Epoch: 7, Batch   82

Epoch: 8, Batch  1020, loss: 0.281
Epoch: 8, Batch  1030, loss: 0.278
Epoch: 8, Batch  1040, loss: 0.295
Epoch: 8, Batch  1050, loss: 0.239
Epoch: 8, Batch  1060, loss: 0.275
Epoch: 8, Batch  1070, loss: 0.283
Epoch: 8, Batch  1080, loss: 0.272
Epoch: 8, Batch  1090, loss: 0.326
Epoch: 8, Batch  1100, loss: 0.299
Epoch: 8, Batch  1110, loss: 0.266
Epoch: 8, Batch  1120, loss: 0.267
Epoch: 8, Batch  1130, loss: 0.277
Epoch: 8, Batch  1140, loss: 0.305
Epoch: 8, Batch  1150, loss: 0.247
Epoch: 8, Batch  1160, loss: 0.336
Epoch: 8, Batch  1170, loss: 0.218
Epoch: 8, Batch  1180, loss: 0.235
Epoch: 8, Batch  1190, loss: 0.262
Epoch: 8, Batch  1200, loss: 0.304
Epoch: 8, Batch  1210, loss: 0.256
Epoch: 8, Batch  1220, loss: 0.280
Epoch: 8, Batch  1230, loss: 0.262
Epoch: 8, Batch  1240, loss: 0.279
Epoch: 8, Batch  1250, loss: 0.276
Epoch: 8, Batch  1260, loss: 0.266
Epoch: 8, Batch  1270, loss: 0.245
Epoch: 8, Batch  1280, loss: 0.269
Epoch: 8, Batch  1290, loss: 0.272
Epoch: 8, Batch  130

Epoch: 9, Batch  1500, loss: 0.253
Epoch: 9, Batch  1510, loss: 0.291
Epoch: 9, Batch  1520, loss: 0.230
Epoch: 9, Batch  1530, loss: 0.284
Epoch: 9, Batch  1540, loss: 0.270
Epoch: 9, Batch  1550, loss: 0.299
Epoch: 9, Batch  1560, loss: 0.258
Epoch: 9, Batch  1570, loss: 0.277
Epoch: 9, Batch  1580, loss: 0.258
Epoch: 9, Batch  1590, loss: 0.244
Epoch: 9, Batch  1600, loss: 0.272
Epoch: 9, Batch  1610, loss: 0.297
Epoch: 9, Batch  1620, loss: 0.269
Epoch: 9, Batch  1630, loss: 0.275
Epoch: 9, Batch  1640, loss: 0.289
Epoch: 9, Batch  1650, loss: 0.235
Epoch: 9, Batch  1660, loss: 0.222
Epoch: 9, Batch  1670, loss: 0.221
Epoch: 9, Batch  1680, loss: 0.243
Epoch: 9, Batch  1690, loss: 0.253
Epoch: 9, Batch  1700, loss: 0.274
Epoch: 9, Batch  1710, loss: 0.275
Epoch: 9, Batch  1720, loss: 0.296
Epoch: 9, Batch  1730, loss: 0.233
Epoch: 9, Batch  1740, loss: 0.219
Epoch: 9, Batch  1750, loss: 0.239
Epoch: 9, Batch  1760, loss: 0.230
Epoch: 9, Batch  1770, loss: 0.266
Epoch: 9, Batch  178

In [64]:
net.eval()
BATCH_SIZE = 64
accuracies = 0
for i, position in enumerate(range(0, len(test_data), BATCH_SIZE)):
    if (position + BATCH_SIZE > len(test_data)):
        batch_data = test_data[position:]
        batch_labels = test_labels[position:]
    else:
        batch_data = test_data[position:position + BATCH_SIZE]
        batch_labels = test_labels[position:position + BATCH_SIZE]
    
    batch_labels = torch.from_numpy(batch_labels).view(-1).long()
    
    # Forward propagation
    forward = net(batch_data)
    
    accuracy = torch.sum(torch.argmax(forward, dim=1) == batch_labels) / float(BATCH_SIZE)
    accuracies += accuracy
    print("Batch: {}".format(i + 1), "Accuracy: {}".format(accuracy))

print("Total Accuracy: {}".format(accuracies / (len(test_data) // BATCH_SIZE + 1)))


Batch: 1 Accuracy: 0.875
Batch: 2 Accuracy: 0.90625
Batch: 3 Accuracy: 0.890625
Batch: 4 Accuracy: 0.890625
Batch: 5 Accuracy: 0.984375
Batch: 6 Accuracy: 0.9375
Batch: 7 Accuracy: 0.921875
Batch: 8 Accuracy: 0.890625
Batch: 9 Accuracy: 0.90625
Batch: 10 Accuracy: 0.90625
Batch: 11 Accuracy: 0.9375
Batch: 12 Accuracy: 0.96875
Batch: 13 Accuracy: 0.875
Batch: 14 Accuracy: 0.859375
Batch: 15 Accuracy: 0.96875
Batch: 16 Accuracy: 0.921875
Batch: 17 Accuracy: 0.890625
Batch: 18 Accuracy: 0.953125
Batch: 19 Accuracy: 0.96875
Batch: 20 Accuracy: 0.921875
Batch: 21 Accuracy: 0.890625
Batch: 22 Accuracy: 0.859375
Batch: 23 Accuracy: 0.921875
Batch: 24 Accuracy: 0.90625
Batch: 25 Accuracy: 0.9375
Batch: 26 Accuracy: 0.9375
Batch: 27 Accuracy: 0.9375
Batch: 28 Accuracy: 0.84375
Batch: 29 Accuracy: 0.953125
Batch: 30 Accuracy: 0.90625
Batch: 31 Accuracy: 0.90625
Batch: 32 Accuracy: 0.90625
Batch: 33 Accuracy: 0.953125
Batch: 34 Accuracy: 0.9375
Batch: 35 Accuracy: 0.9375
Batch: 36 Accuracy: 0.906

In [65]:
torch.save(net.state_dict(), './bert_model.pt')

In [3]:
net.load_state_dict(torch.load('./models/bert_model.pt'))
net.eval()
print()




In [4]:
def inference():
    title1 = input('First title: ')
    title2 = input('Second title: ')
    
    title1 = remove_stop_words(title1)
    title2 = remove_stop_words(title2)
    
    data = np.array([title1, title2]).reshape(1, 2)
    forward = net(data)
    
    print('Output: {}'.format(torch.argmax(forward)))
    print('Softmax: {}'.format(forward))
    

In [8]:
inference()

First title: 128 gb ssd
Second title: 256 gb ssd
Output: 0
Softmax: tensor([[-0.0200, -3.9226]], grad_fn=<LogSoftmaxBackward>)


In [57]:
inference()

First title: ASUS F512DA-EB51 VivoBook 15 Thin And Light Laptop, 15.6” Full HD, AMD Quad Core R5-3500U CPU, 8GB DDR4 RAM
Second title: ASUS F512DA-EB51 VivoBook 15 Thin And Light Laptop, 15.6” Full HD, intel core i7 7700k cpu, 8GB DDR4 RAM
Output: 1
Softmax: tensor([[-2.4919, -0.0864]], grad_fn=<LogSoftmaxBackward>)
