# Sentiment Classification using Bert model 

We use hugging fase bert-base-uncased, you need to install it first (https://huggingface.co/transformers/model_doc/bert.html). 
The dataset here we use is the 2-class SST-2.

In [1]:
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
import torch.utils.data.dataloader as dataloader
from pytorch_transformers import BertTokenizer, BertModel, AdamW, WarmupLinearSchedule
import numpy as np
from tqdm import tqdm
import time

First we define the model architecture, which is simply the bert-base-uncased from huggingface, and a linear layer on top. The output dimension can either be 1 or 2 since it's a binary classification task, only with different  corresponding classification and loss function. 

In [2]:
class BertParams:
    def __init__(self):
        self.model_name = 'BERTModel.pt'
        self.pretrained = 'bert-base-uncased'
        self.max_sent_length = 254
        self.hidden = 768  # to be checked


class BertBase(nn.Module):
    def __init__(self, out_dim):   # output dimension depends on the task.
        super(BertBase, self).__init__()
        self.Params = BertParams()
        self.tokenizer = BertTokenizer.from_pretrained(self.Params.pretrained, cache_dir='/home/lyu/robustness/pretrained/bert/')   # download pretrained bert model to the cache dir.
        self.bert = BertModel.from_pretrained(self.Params.pretrained, cache_dir='/home/lyu/robustness/pretrained/bert/')
        self.bert.eval()
        self.fc = nn.Linear(self.Params.hidden, out_dim)

    def forward(self, input):
        if len(input.shape) < 3 : 
           bert_out = self.bert(input)
        else:   
            # when the input is already embedding, input.shape == (batch, sequence_length, embedding_dimension)
            # We seperate the embedding step only to simplify the gradient based feature methods, where you can directly feed embedding vectors as input.
            bert_out = self.get_output_from_embedding(input)

        encoded = bert_out[1]  # CLS  # use the pooled output from the last layer.
        return self.fc(encoded)

    def get_output_from_embedding(self, embedding, attention_mask=None, head_mask=None):
        if attention_mask is None:
            attention_mask = torch.ones(embedding.shape[0], embedding.shape[1]).to(embedding)
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
        extended_attention_mask = extended_attention_mask.to(
            dtype=next(self.bert.parameters()).dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        if head_mask is not None:
            if head_mask.dim() == 1:
                head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
                head_mask = head_mask.expand(self.bert.config.num_hidden_layers, -1, -1, -1, -1)
            elif head_mask.dim() == 2:
                head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(
                    -1)  # We can specify head_mask for each layer
            head_mask = head_mask.to(
                dtype=next(self.bert.parameters()).dtype)  # switch to fload if need + fp16 compatibility
        else:
            head_mask = [None] * self.bert.config.num_hidden_layers
        encoder_outputs = self.bert.encoder(embedding,
                                             extended_attention_mask,
                                             head_mask=head_mask)
        sequence_output = encoder_outputs[0]
        pooled_output = self.bert.pooler(sequence_output)
        outputs = (sequence_output, pooled_output,) + encoder_outputs[1:]
        return outputs

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = BertBase(1).to(device)

Now we need the dataset to train this bert model.

In [4]:
class DataBinary(Dataset):   # prepare bert style instances.
    def __init__(self, instances, tokenizer,  maxlen):
        self.maxlen = maxlen
        self.instances = [[self.pad(sent[0], tokenizer, maxlen), sent[1]] for sent in instances]

    def __len__(self):
        return len(self.instances)

    def __getitem__(self, idx):
        x = self.instances[idx][0]
        y = self.instances[idx][1]
        return x, y

    def pad(self, instance, tokenizer, maxlen):    # pad value is  0 in bert tokenizer
        padded = np.zeros((maxlen, ), dtype=np.int64)
        if len(instance) > maxlen - 2:
            instance = [tokenizer.cls_token] + instance[:maxlen -2] + [tokenizer.sep_token]
            padded[:] = tokenizer.convert_tokens_to_ids(instance)
        else:
            instance = [tokenizer.cls_token] + instance + [tokenizer.sep_token]
            padded[:len(instance)] = tokenizer.convert_tokens_to_ids(instance)

        return padded

    
label2id_binary = {"0": 0, "1": 1}   # for SST y label.
data_root = '/home/lyu/robustness/Datasets/'

def load_splits_json(which_data):
    def load_json(file):
        with open(file, 'r')as f:
            data = json.load(f)
        return data

    print('Loading {} dataset...'.format(which_data))
    directory = os.path.join(data_root, which_data+'data')
    json_file = os.path.join(directory, which_data+'_input.json')
    train_ids = os.path.join(directory, which_data+'_train_ids.json')
    dev_ids = os.path.join(directory, which_data+'_dev_ids.json')
    test_ids = os.path.join(directory, which_data+'_test_ids.json')
    All_samples = load_json(json_file)
    train_index = load_json(train_ids)
    dev_index = load_json(dev_ids)
    test_index = load_json(test_ids)
    train_samples = [(All_samples[i]['en_defs'][0], All_samples[i]['label']) for i in train_index]
    dev_samples = [(All_samples[i]['en_defs'][0], All_samples[i]['label']) for i in dev_index]
    test_samples = [(All_samples[i]['en_defs'][0], All_samples[i]['label']) for i in test_index]
    return train_samples, dev_samples, test_samples

    
def load_data_for_bert(which_data, tokenizer):   
    if which_data == 'SST':
        label2id = label2id_binary
        MAX_LENGTH = 20
    else:
        raise ValueError('Datasets:  SST.')
    print('Loading {} data for bert model...'.format(which_data))
    train_samples, dev_samples, test_samples = load_splits_json(which_data)    # load SST data, which is json object of list of (text, label) tuple.
    train_ids = [ [tokenizer.tokenize(sample[0]), label2id[sample[1]]] for sample in train_samples]
    dev_ids = [ [tokenizer.tokenize(sample[0]), label2id[sample[1]]] for sample in dev_samples]
    test_ids = [ [tokenizer.tokenize(sample[0]), label2id[sample[1]]] for sample in test_samples]

    train_data = DataBinary(train_ids, tokenizer, MAX_LENGTH)
    dev_data = DataBinary(dev_ids, tokenizer, MAX_LENGTH)
    test_data = DataBinary(test_ids, tokenizer, MAX_LENGTH)
    return train_data, dev_data, test_data


In [5]:
train_data, dev_data, test_data = load_data_for_bert('SST', model.tokenizer)

Loading SST data for bert model...
Loading SST dataset...


Let's have a look of the first training instance.

In [6]:
print(train_data[0])

(array([  101,  1037, 18385,  1010,  6057,  1998,  2633, 18276,  2128,
        1011, 16603,  1997,  5053,  1998,  1996,  6841,  1998,  5687,
        5469,   102]), 1)


Now we can train the model.

In [7]:
criterion = nn.BCEWithLogitsLoss().to(device)  # since we already set out_dim=1
# criterion = nn.CrossEntropyLoss().to(device)

In [8]:
def binary_accuracy(preds, y):
    """
    Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
    """
    # round predictions to the closest integer
    rounded_preds = torch.round(torch.sigmoid(preds))
    correct = (rounded_preds == y.unsqueeze(1).float()).float()  # convert into float for division
    acc = correct.sum() / len(correct)
    return acc


def train(train_data, optimizer, batch_size):
        epoch_loss = 0
        epoch_acc = 0
        model.train()
        iterator = iter(dataloader.DataLoader(train_data, batch_size=batch_size, shuffle=True, pin_memory=True))
        for batch in iterator:
            optimizer.zero_grad()
            text, label = batch[0].to(device), batch[1].to(device)
            predictions = model(text)
            loss = criterion(predictions, label.unsqueeze(1).float())
            acc = binary_accuracy(predictions, label)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_acc += acc.item()
        return epoch_loss / len(iterator), epoch_acc / len(iterator)

    
def evaluate(eval_data, batch_size):
    epoch_loss = 0
    epoch_acc = 0
    model.eval()
    iterator = iter(dataloader.DataLoader(eval_data, batch_size=batch_size, shuffle=True, pin_memory=True))
    with torch.no_grad():
        for batch in iterator:
            text, label = batch[0].to(device), batch[1].to(device)
            predictions = model(text)
            loss = criterion(predictions, label.unsqueeze(1).float())
            acc = binary_accuracy(predictions, label)
            epoch_loss += loss.item()
            epoch_acc += acc.item()
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [9]:
lrlast = .001
lrmain = .00001
epochs_num = 5  # let's only train for 5 epochs 
batch_size = 32
optimizer = optim.Adam([{"params": model.bert.parameters(), "lr": lrmain}, {"params": model.fc.parameters(), "lr": lrlast}])
best_valid_loss = float('inf')
model_dir = '/home/lyu/robustness/SST/model/bert_sst.pt'

for epoch in tqdm(range(epochs_num)):
    start_time = time.time()
    train_loss, train_acc = train(train_data, optimizer, batch_size)
    valid_loss, valid_acc = evaluate(dev_data, batch_size)
    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), model_dir)

    print(f'Epoch: {epoch + 1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc * 100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc * 100:.2f}%')

    


  0%|          | 0/5 [01:09<?, ?it/s]


KeyboardInterrupt: 