In [1]:
# !pip install transformers

In [2]:
# !pip install torchtext

In [3]:
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import pickle
import time as timer



# torchtext libraries
from torchtext.legacy.data import Field, TabularDataset, BucketIterator, Iterator, Dataset

# huggingface libraries
from transformers import BertForSequenceClassification
from torch.utils.data import DataLoader



In [4]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
torch.cuda.empty_cache()

In [5]:
device

device(type='cuda', index=1)

In [7]:
y_trn = pickle.load(open("../preprocessed_embeddings/elmo_trn_text_labels.pkl", "rb"))
y_val = pickle.load(open("../preprocessed_embeddings/elmo_val_text_labels.pkl", "rb"))
y_tst = pickle.load(open("../preprocessed_embeddings/elmo_tst_text_labels.pkl", "rb"))

x_trn = pickle.load(open("../preprocessed_embeddings/elmo_trn_text.pkl", "rb")).tolist()
x_val = pickle.load(open("../preprocessed_embeddings/elmo_val_text.pkl", "rb")).tolist()
x_tst = pickle.load(open("../preprocessed_embeddings/elmo_tst_text.pkl", "rb")).tolist()


In [8]:
batch_size = 8

### Training set
trn_dataset = []
for i in range(len(x_trn)):
    trn_dataset.append((torch.tensor(x_trn[i]), y_trn[i]))

del x_trn
del y_trn
trn_dataloader = DataLoader(trn_dataset, batch_size)

### Validation set
val_dataset = []
for i in range(len(x_val)):
    val_dataset.append((torch.tensor(x_val[i]), y_val[i]))

del x_val
del y_val
val_dataloader = DataLoader(val_dataset, batch_size)

### Test set
tst_dataset = []
for i in range(len(x_tst)):
    tst_dataset.append((torch.tensor(x_tst[i]), y_tst[i]))

del x_tst
del y_tst
tst_dataloader = DataLoader(tst_dataset, batch_size)

## **BERT MODEL**

In [9]:
class BERT(nn.Module):

    def __init__(self):
        super(BERT, self).__init__()

        self.encoder = BertForSequenceClassification.from_pretrained('bert-large-uncased')

    def forward(self, pretrained_embeddings, labels):
        loss, logits = self.encoder(inputs_embeds=pretrained_embeddings, labels=labels)[:2]
        pred = torch.argmax(logits, dim=1)
        return loss, pred

In [10]:
def train(model,
          optimizer,
          criterion = nn.BCELoss(),
          train_loader = trn_dataloader,
          valid_loader = val_dataloader,
          num_epochs = 10,):
    
    # initialize values
    running_loss = 0.0
    total_acc = 0
    # training loop
    # model.train()
    total_acc = 0
    best_val_acc = 0
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        train_acc = 0
        val_acc = 0
        train_seen = 0
        val_seen = 0
        model.train()
        for batch, labels in train_loader: 
            optimizer.zero_grad()  
            batch = batch.to(device)
            labels = labels.to(device)
            output = model(batch, labels)
            loss, pred = output
            train_acc += torch.eq(pred, labels).sum().item()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            epoch_loss += loss.item()
            train_seen += len(batch)
        train_acc = train_acc/train_seen
        print('The training accuracy for epoch {epoch} is {train_acc}'.format(epoch=epoch+1, train_acc=train_acc))
        print('The cumulative loss for epoch {epoch} is {epoch_loss}'.format(epoch=epoch+1, epoch_loss=epoch_loss))
        # validation 
        model.eval()
        for batch, labels in valid_loader:   
            labels = labels.to(device)
            batch = batch.to(device)
            output = model(batch, labels)
            loss, pred = output
            val_acc += torch.eq(pred, labels).sum().item()
            val_seen += len(batch)
        val_acc = val_acc/val_seen
        print('The validation accuracy for epoch {epoch} is {val_acc}'.format(epoch=epoch+1, val_acc=val_acc))
        if(val_acc > best_val_acc):
          best_val_acc = val_acc
          torch.save(model.state_dict(), 'weights.pt')

In [11]:
model = BERT()
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

train(model, optimizer=optimizer, num_epochs=7, train_loader=trn_dataloader)

Some weights of the model checkpoint at bert-large-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint a

The training accuracy for epoch 1 is 0.550571225714839
The cumulative loss for epoch 1 is 2632.1412913061195
The validation accuracy for epoch 1 is 0.4837908053767085
The training accuracy for epoch 2 is 0.5503453172400439
The cumulative loss for epoch 2 is 2626.495094367303
The validation accuracy for epoch 2 is 0.4837908053767085
The training accuracy for epoch 3 is 0.5487639579164784
The cumulative loss for epoch 3 is 2623.978348644596
The validation accuracy for epoch 3 is 0.4837908053767085
The training accuracy for epoch 4 is 0.5490866843090428
The cumulative loss for epoch 4 is 2622.8651397521608
The validation accuracy for epoch 4 is 0.4837908053767085
The training accuracy for epoch 5 is 0.5510230426644291
The cumulative loss for epoch 5 is 2617.037493823096
The validation accuracy for epoch 5 is 0.4837908053767085
The training accuracy for epoch 6 is 0.5508616794681469
The cumulative loss for epoch 6 is 2612.4942061265465
The validation accuracy for epoch 6 is 0.4837908053767