# S-BERT


In [1]:
import math
import re
from   random import *
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os
import pickle

In [2]:
# Set GPU device
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

os.environ['http_proxy']  = 'http://192.41.170.23:3128'
os.environ['https_proxy'] = 'http://192.41.170.23:3128'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
print(device)

#make our work comparable if restarted the kernel
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

# torch.cuda.get_device_name(0)

cuda


In [3]:
torch.cuda.device(1)

<torch.cuda.device at 0x75054bf75a60>

## 1. Data

I will use SNLI data for e Natural Language Inference (NLI) task. I will use only 1 percent of the whole data due to limited computational resources

In [4]:
from datasets import load_dataset

# Load SNLI dataset
dataset = load_dataset('snli',split='train[:1%]')
dataset

Dataset({
    features: ['premise', 'hypothesis', 'label'],
    num_rows: 5502
})

In [5]:
dataset = dataset.train_test_split(test_size=0.1)

In [6]:
dataset

DatasetDict({
    train: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 4951
    })
    test: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 551
    })
})

In [7]:
dataset['train'][0]

{'premise': 'A person on skis on a rail at night.',
 'hypothesis': 'They are fantastic skiiers',
 'label': 1}

In [8]:
next(iter(dataset['train']))

{'premise': 'A person on skis on a rail at night.',
 'hypothesis': 'They are fantastic skiiers',
 'label': 1}

In [9]:
len(dataset['train']['premise'])

4951

In [10]:
dataset['train'].features.keys()

dict_keys(['premise', 'hypothesis', 'label'])

In [11]:
dataset['train']['premise'][0]

'A person on skis on a rail at night.'

### Clean sentences

In [3]:
def clean_sentences(example):
    example["premise"] = example["premise"].lower()
    example["hypothesis"] = example["hypothesis"].lower()
    example["premise"] = re.sub("[.,!?\\-]", '',example["premise"])
    example["hypothesis"] = re.sub("[.,!?\\-]", '',example["hypothesis"])
    return example

In [13]:
dataset = dataset.map(clean_sentences)

Map:   0%|          | 0/4951 [00:00<?, ? examples/s]

Map:   0%|          | 0/551 [00:00<?, ? examples/s]

In [14]:
dataset['train'][0]

{'premise': 'a person on skis on a rail at night',
 'hypothesis': 'they are fantastic skiiers',
 'label': 1}

In [15]:
set(dataset['train']['label'])

{-1, 0, 1, 2}

In [16]:
dataset = dataset.filter(lambda x: 0 if x['label'] == -1 else 1)

Filter:   0%|          | 0/4951 [00:00<?, ? examples/s]

Filter:   0%|          | 0/551 [00:00<?, ? examples/s]

In [17]:
set(dataset['train']['label'])

{0, 1, 2}

### Tokenizing and Numericalizing

Load data saved from the pretrained BERT to tokenize and numericalize.

In [13]:
data = pickle.load(open('./h_data.pkl', 'rb'))
word2id = data['word2id']
max_len = data['max_len']
max_mask = data['max_mask']
vocab_size = data['vocab_size']

In [19]:
word2id.get('test',word2id['[UNK]'])

27065

In [18]:
def tokenize(example):
        output = {}
        output['input_ids'] = []
        output['att_mask'] = []
        input_ids = [word2id.get(word, word2id['[UNK]']) for word in example.split()]
        n_pad = max_len - len(input_ids)
        input_ids.extend([0] * n_pad)
        att_mask = [1 if idx != 0 else 0 for idx in input_ids]  # Create attention mask
        output['input_ids'].append(torch.tensor(input_ids))  # Convert to tensor
        output['att_mask'].append(torch.tensor(att_mask))  # Convert to tensor
        return output

In [21]:
def preprocess(example):
    # Tokenize the premise
    premise_result = tokenize(example['premise'])
    # Tokenize the hypothesis
    hypothesis_result = tokenize(example['hypothesis'])

    # Extract labels
    labels = example["label"]
    #num_rows
    return {
        "premise_input_ids": premise_result["input_ids"],
        "premise_att_mask": premise_result["att_mask"],
        "hypothesis_input_ids": hypothesis_result["input_ids"],
        "hypothesis_att_mask": hypothesis_result["att_mask"],
        "labels" : labels
    }

In [22]:
tokenized_datasets = dataset.map(preprocess)

Map:   0%|          | 0/4946 [00:00<?, ? examples/s]

Map:   0%|          | 0/550 [00:00<?, ? examples/s]

In [23]:
tokenized_datasets = tokenized_datasets.remove_columns(['premise','hypothesis','label'])

In [24]:
tokenized_datasets.set_format("torch")

In [25]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['premise_input_ids', 'premise_att_mask', 'hypothesis_input_ids', 'hypothesis_att_mask', 'labels'],
        num_rows: 4946
    })
    test: Dataset({
        features: ['premise_input_ids', 'premise_att_mask', 'hypothesis_input_ids', 'hypothesis_att_mask', 'labels'],
        num_rows: 550
    })
})

## 2. Data loader

In [26]:
batch_size = 6

In [27]:
from torch.utils.data import DataLoader

# initialize the dataloader
batch_size = 2
train_dataloader = DataLoader(
    tokenized_datasets['train'], 
    batch_size=batch_size, 
    shuffle=True
)
eval_dataloader = DataLoader(
    tokenized_datasets['test'], 
    batch_size=batch_size
)
# test_dataloader = DataLoader(
#     tokenized_datasets['validation'], 
#     batch_size=batch_size
# )

In [28]:
for batch in train_dataloader:
    print(batch['premise_input_ids'].shape)
    print(batch['premise_att_mask'].shape)
    print(batch['hypothesis_input_ids'].shape)
    print(batch['hypothesis_att_mask'].shape)
    print(batch['labels'].shape)
    break

torch.Size([2, 1, 1000])
torch.Size([2, 1, 1000])
torch.Size([2, 1, 1000])
torch.Size([2, 1, 1000])
torch.Size([2])


## 3. Model

Recall that BERT only uses the encoder.

BERT has the following components:

- Embedding layers
- Attention Mask
- Encoder layer
- Multi-head attention
- Scaled dot product attention
- Position-wise feed-forward network
- BERT (assembling all the components)

## 3.1 Embedding

Here we simply generate the positional embedding, and sum the token embedding, positional embedding, and segment embedding together.

<img src = "./figures/BERT_embed.png" width=500>

In [4]:
class Embedding(nn.Module):
    def __init__(self, vocab_size, max_len, n_segments, d_model, device):
        super(Embedding, self).__init__()
        self.tok_embed = nn.Embedding(vocab_size, d_model)  # token embedding
        self.pos_embed = nn.Embedding(max_len, d_model)      # position embedding
        self.seg_embed = nn.Embedding(n_segments, d_model)  # segment(token type) embedding
        self.norm = nn.LayerNorm(d_model)
        self.device = device

    def forward(self, x, seg):
        #x, seg: (bs, len)
        seq_len = x.size(1)
        pos = torch.arange(seq_len, dtype=torch.long).to(self.device)
        pos = pos.unsqueeze(0).expand_as(x)  # (len,) -> (bs, len)
        embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)
        return self.norm(embedding)

## 3.2 Attention mask

In [5]:
def get_attn_pad_mask(seq_q, seq_k, device):
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    # eq(zero) is PAD token
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1).to(device)  # batch_size x 1 x len_k(=len_q), one is masking
    return pad_attn_mask.expand(batch_size, len_q, len_k)  # batch_size x len_q x len_k

### Testing the attention mask

In [31]:
# print(get_attn_pad_mask(batch['premise_input_ids'], batch['premise_input_ids'], device).shape)

## 3.3 Encoder

The encoder has two main components: 

- Multi-head Attention
- Position-wise feed-forward network

First let's make the wrapper called `EncoderLayer`

In [6]:
class EncoderLayer(nn.Module):
    def __init__(self, n_heads, d_model, d_ff, d_k, device):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention(n_heads, d_model, d_k, device)
        self.pos_ffn       = PoswiseFeedForwardNet(d_model, d_ff)

    def forward(self, enc_inputs, enc_self_attn_mask):
        enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V
        enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size x len_q x d_model]
        return enc_outputs, attn

Let's define the scaled dot attention, to be used inside the multihead attention

In [7]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k, device):
        super(ScaledDotProductAttention, self).__init__()
        self.scale = torch.sqrt(torch.FloatTensor([d_k])).to(device)

    def forward(self, Q, K, V, attn_mask):
        scores = torch.matmul(Q, K.transpose(-1, -2)) / self.scale # scores : [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
        scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is one.
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V)
        return context, attn 

Let's define the parameters first

In [8]:
n_layers = 6    # number of Encoder of Encoder Layer
n_heads  = 8    # number of heads in Multi-Head Attention
d_model  = 768  # Embedding Size
d_ff = 768 * 4  # 4*d_model, FeedForward dimension
d_k = d_v = 64  # dimension of K(=Q), V
n_segments = 2

Here is the Multiheadattention.

In [9]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, d_model, d_k, device):
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads
        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_k
        self.W_Q = nn.Linear(d_model, d_k * n_heads)
        self.W_K = nn.Linear(d_model, d_k * n_heads)
        self.W_V = nn.Linear(d_model, self.d_v * n_heads)
        self.device = device
    def forward(self, Q, K, V, attn_mask):
        # q: [batch_size x len_q x d_model], k: [batch_size x len_k x d_model], v: [batch_size x len_k x d_model]
        residual, batch_size = Q, Q.size(0)
        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        q_s = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)  # q_s: [batch_size x n_heads x len_q x d_k]
        k_s = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)  # k_s: [batch_size x n_heads x len_k x d_k]
        v_s = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1,2)  # v_s: [batch_size x n_heads x len_k x d_v]

        attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1) # attn_mask : [batch_size x n_heads x len_q x len_k]

        # context: [batch_size x n_heads x len_q x d_v], attn: [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
        context, attn = ScaledDotProductAttention(self.d_k, self.device)(q_s, k_s, v_s, attn_mask)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v) # context: [batch_size x len_q x n_heads * d_v]
        output = nn.Linear(self.n_heads * self.d_v, self.d_model, device=self.device)(context)
        return nn.LayerNorm(self.d_model, device=self.device)(output + residual), attn # output: [batch_size x len_q x d_model]

Here is the PoswiseFeedForwardNet.

In [10]:
class PoswiseFeedForwardNet(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        # (batch_size, len_seq, d_model) -> (batch_size, len_seq, d_ff) -> (batch_size, len_seq, d_model)
        return self.fc2(F.gelu(self.fc1(x)))

## 3.4 Putting them together

In [11]:
class BERT(nn.Module):
    def __init__(self, n_layers, n_heads, d_model, d_ff, d_k, n_segments, vocab_size, max_len, device):
        super(BERT, self).__init__()
        self.params = {'n_layers': n_layers, 'n_heads': n_heads, 'd_model': d_model,
                       'd_ff': d_ff, 'd_k': d_k, 'n_segments': n_segments,
                       'vocab_size': vocab_size, 'max_len': max_len}
        self.embedding = Embedding(vocab_size, max_len, n_segments, d_model, device)
        self.layers = nn.ModuleList([EncoderLayer(n_heads, d_model, d_ff, d_k, device) for _ in range(n_layers)])
        self.fc = nn.Linear(d_model, d_model)
        self.activ = nn.Tanh()
        self.linear = nn.Linear(d_model, d_model)
        self.norm = nn.LayerNorm(d_model)
        self.classifier = nn.Linear(d_model, 2)
        # decoder is shared with embedding layer
        embed_weight = self.embedding.tok_embed.weight
        n_vocab, n_dim = embed_weight.size()
        self.decoder = nn.Linear(n_dim, n_vocab, bias=False)
        self.decoder.weight = embed_weight
        self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))
        self.device = device

    def forward(self, input_ids, segment_ids, masked_pos):
        output = self.embedding(input_ids, segment_ids)
        enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids, self.device)
        for layer in self.layers:
            output, enc_self_attn = layer(output, enc_self_attn_mask)
        # output : [batch_size, len, d_model], attn : [batch_size, n_heads, d_mode, d_model]
        
        # 1. predict next sentence
        # it will be decided by first token(CLS)
        h_pooled   = self.activ(self.fc(output[:, 0])) # [batch_size, d_model]
        logits_nsp = self.classifier(h_pooled) # [batch_size, 2]

        # 2. predict the masked token
        masked_pos = masked_pos[:, :, None].expand(-1, -1, output.size(-1)) # [batch_size, max_pred, d_model]
        h_masked = torch.gather(output, 1, masked_pos) # masking position [batch_size, max_pred, d_model]
        h_masked  = self.norm(F.gelu(self.linear(h_masked)))
        logits_lm = self.decoder(h_masked) + self.decoder_bias # [batch_size, max_pred, n_vocab]

        return logits_lm, logits_nsp
    
    def get_last_hidden_state(self, input_ids, segment_ids):
        output = self.embedding(input_ids, segment_ids)
        enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids, self.device)
        for layer in self.layers:
            output, enc_self_attn = layer(output, enc_self_attn_mask)

        return output

## 4. Loss Function

## Classification Objective Function 
We concatenate the sentence embeddings $u$ and $v$ with the element-wise difference  $\lvert u - v \rvert $ and multiply the result with the trainable weight  $ W_t ∈  \mathbb{R}^{3n \times k}  $:

$ o = \text{softmax}\left(W^T \cdot \left(u, v, \lvert u - v \rvert\right)\right) $

where $n$ is the dimension of the sentence embeddings and k the number of labels. We optimize cross-entropy loss. This structure is depicted in Figure 1.

## Regression Objective Function. 
The cosine similarity between the two sentence embeddings $u$ and $v$ is computed (Figure 2). We use means quared-error loss as the objective function.

(Manhatten / Euclidean distance, semantically  similar sentences can be found.)

<img src="./figures/sbert-architecture.png" >

In [None]:
from tqdm.auto import tqdm

n_layers = 12    # number of Encoder of Encoder Layer
n_heads  = 12    # number of heads in Multi-Head Attention
d_model  = 768  # Embedding Size
d_ff = d_model * 4  # 4*d_model, FeedForward dimension
d_k = d_v = 64  # dimension of K(=Q), V
n_segments = 2

model = BERT(
    n_layers, 
    n_heads, 
    d_model, 
    d_ff, 
    d_k, 
    n_segments, 
    vocab_size, 
    max_len, 
    device
).to(device)  # Move model to GPU

In [39]:
model.load_state_dict(torch.load('./bert_model.pth'))

<All keys matched successfully>

In [23]:
def mean_pool(token_embeds, attention_mask):
    # reshape attention_mask to cover 768-dimension embeddings
    in_mask = attention_mask.unsqueeze(-1).expand(
        token_embeds.size()
    ).float()
    # perform mean-pooling but exclude padding tokens (specified by in_mask)
    pool = torch.sum(token_embeds * in_mask, 1) / torch.clamp(
        in_mask.sum(1), min=1e-9
    )
    return pool

In [24]:
def configurations(u,v):
    # build the |u-v| tensor
    uv = torch.sub(u, v)   # batch_size,hidden_dim
    uv_abs = torch.abs(uv) # batch_size,hidden_dim
    
    # concatenate u, v, |u-v|
    x = torch.cat([u, v, uv_abs], dim=-1) # batch_size, 3*hidden_dim
    return x

def cosine_similarity(u, v):
    dot_product = np.dot(u, v)
    norm_u = np.linalg.norm(u)
    norm_v = np.linalg.norm(v)
    similarity = dot_product / (norm_u * norm_v)
    return similarity

In [42]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [26]:
classifier_head = torch.nn.Linear(768*3, 3).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
optimizer_classifier = torch.optim.Adam(classifier_head.parameters(), lr=2e-5)

criterion = nn.CrossEntropyLoss()

## 5. Training

In [44]:
from transformers import get_linear_schedule_with_warmup

# and setup a warmup for the first ~10% steps
total_steps = int(len(tokenized_datasets) / batch_size)
warmup_steps = int(0.1 * total_steps)
scheduler = get_linear_schedule_with_warmup(
		optimizer, num_warmup_steps=warmup_steps,
  	num_training_steps=total_steps - warmup_steps
)

# then during the training loop we update the scheduler per step
scheduler.step()

scheduler_classifier = get_linear_schedule_with_warmup(
		optimizer_classifier, num_warmup_steps=warmup_steps,
  	num_training_steps=total_steps - warmup_steps
)

# then during the training loop we update the scheduler per step
scheduler_classifier.step()



In [46]:
from tqdm.auto import tqdm

num_epoch = 2

# need segment and masked for model input but not used in SBERT
# Create segment_ids tensor with shape (batch_size, max_len)
segment_ids = torch.tensor([0] * max_len).unsqueeze(0).repeat(batch_size, 1).to(device)

# Create masked_pos tensor with shape (batch_size, max_mask)
masked_pos = torch.tensor([0] * max_mask).unsqueeze(0).repeat(batch_size, 1).to(device)

accuracy = 0
count = 0
# 1 epoch should be enough, increase if wanted
for epoch in range(num_epoch):
    model.train()  
    classifier_head.train()
    # initialize the dataloader loop with tqdm (tqdm == progress bar)
    for step, batch in enumerate(tqdm(train_dataloader, leave=True)):
        # zero all gradients on each new step
        optimizer.zero_grad()
        optimizer_classifier.zero_grad()
        
        # prepare batches and more all to the active device
        inputs_ids_a = batch['premise_input_ids'].squeeze(1).to(device)
        inputs_ids_b = batch['hypothesis_input_ids'].squeeze(1).to(device)
        attention_a = batch['premise_att_mask'].squeeze(1).to(device)
        attention_b = batch['hypothesis_att_mask'].squeeze(1).to(device)
        label = batch['labels'].to(device)
        
        # extract token embeddings from BERT at last_hidden_state
        u = model.get_last_hidden_state(inputs_ids_a, segment_ids)
        v = model.get_last_hidden_state(inputs_ids_b, segment_ids)

        u_last_hidden_state = u # all token embeddings A = batch_size, seq_len, hidden_dim
        v_last_hidden_state = v # all token embeddings B = batch_size, seq_len, hidden_dim

         # get the mean pooled vectors
        u_mean_pool = mean_pool(u_last_hidden_state, attention_a) # batch_size, hidden_dim
        v_mean_pool = mean_pool(v_last_hidden_state, attention_b) # batch_size, hidden_dim
        
        # build the |u-v| tensor
        uv = torch.sub(u_mean_pool, v_mean_pool)   # batch_size,hidden_dim
        uv_abs = torch.abs(uv) # batch_size,hidden_dim
        
        # concatenate u, v, |u-v|
        x = torch.cat([u_mean_pool, v_mean_pool, uv_abs], dim=-1) # batch_size, 3*hidden_dim
        
        # process concatenated tensor through classifier_head
        x = classifier_head(x) #batch_size, classifer
        for out, lab in zip(x, label):
            count = count + 1
            if torch.argmax(out).item() == lab.item():
                accuracy = accuracy + 1
        # calculate the 'softmax-loss' between predicted and true label
        loss = criterion(x, label)
        
        # using loss, calculate gradients and then optimizerize
        loss.backward()
        optimizer.step()
        optimizer_classifier.step()

        scheduler.step() # update learning rate scheduler
        scheduler_classifier.step()
        
    print(f'Epoch: {epoch + 1} | loss = {loss.item():.6f} | Accuracy = {(accuracy / count) * 100}%')
    

  0%|          | 0/2473 [00:00<?, ?it/s]

Epoch: 1 | loss = 5.844354 | Accuracy = 33.178325919935304%


  0%|          | 0/2473 [00:00<?, ?it/s]

Epoch: 2 | loss = 5.601725 | Accuracy = 33.178325919935304%


In [None]:
# Save the model after training
torch.save(model.state_dict(), 'sbert_model.pth')
print("Model saved to S-bert_model.pth")

## 6. Evaluation

In [49]:
model.eval()
classifier_head.eval()
total_similarity = 0
with torch.no_grad():
    for step, batch in enumerate(eval_dataloader):
        # prepare batches and more all to the active device
        inputs_ids_a = batch['premise_input_ids'].squeeze(1).to(device)
        inputs_ids_b = batch['hypothesis_input_ids'].squeeze(1).to(device)
        attention_a = batch['premise_att_mask'].squeeze(1).to(device)
        attention_b = batch['hypothesis_att_mask'].squeeze(1).to(device)
        label = batch['labels'].to(device)
        
        # extract token embeddings from BERT at last_hidden_state
        # u = model(inputs_ids_a, attention_mask=attention_a)[0]  # all token embeddings A = batch_size, seq_len, hidden_dim
        # v = model(inputs_ids_b, attention_mask=attention_b)[0]  # all token embeddings B = batch_size, seq_len, hidden_dim
        u = model.get_last_hidden_state(inputs_ids_a, segment_ids)
        v = model.get_last_hidden_state(inputs_ids_b, segment_ids)
        # get the mean pooled vectors
        u_mean_pool = mean_pool(u, attention_a).detach().cpu().numpy().reshape(-1) # batch_size, hidden_dim
        v_mean_pool = mean_pool(v, attention_b).detach().cpu().numpy().reshape(-1) # batch_size, hidden_dim

        similarity_score = cosine_similarity(u_mean_pool, v_mean_pool)
        total_similarity += similarity_score
    
average_similarity = total_similarity / len(eval_dataloader)
print(f"Average Cosine Similarity: {average_similarity:.4f}")

Average Cosine Similarity: 0.9992


In [None]:
model.eval()
classifier_head.eval()
total_similarity = 0
with torch.no_grad():
    for step, batch in enumerate(eval_dataloader):
        # prepare batches and more all to the active device
        inputs_ids_a = batch['premise_input_ids'].squeeze(1).to(device)
        inputs_ids_b = batch['hypothesis_input_ids'].squeeze(1).to(device)
        attention_a = batch['premise_att_mask'].squeeze(1).to(device)
        attention_b = batch['hypothesis_att_mask'].squeeze(1).to(device)
        label = batch['labels'].to(device)
        
        # extract token embeddings from BERT at last_hidden_state
        u = model.get_last_hidden_state(inputs_ids_a, segment_ids)
        v = model.get_last_hidden_state(inputs_ids_b, segment_ids)

        u_last_hidden_state = u # all token embeddings A = batch_size, seq_len, hidden_dim
        v_last_hidden_state = v # all token embeddings B = batch_size, seq_len, hidden_dim

         # get the mean pooled vectors
        u_mean_pool = mean_pool(u_last_hidden_state, attention_a) # batch_size, hidden_dim
        v_mean_pool = mean_pool(v_last_hidden_state, attention_b) # batch_size, hidden_dim
        
        # build the |u-v| tensor
        uv = torch.sub(u_mean_pool, v_mean_pool)   # batch_size,hidden_dim
        uv_abs = torch.abs(uv) # batch_size,hidden_dim
        
        # concatenate u, v, |u-v|
        x = torch.cat([u_mean_pool, v_mean_pool, uv_abs], dim=-1) # batch_size, 3*hidden_dim
        
        # process concatenated tensor through classifier_head
        x = classifier_head(x) #batch_size, classifer
        for out, lab in zip(x, label):
            count = count + 1
            if torch.argmax(out).item() == lab.item():
                accuracy = accuracy + 1

In [63]:
test_acc = accuracy / len(dataset['test']) * 100
print(test_acc)

35.090909090909086


In [67]:
import pandas as pd

eval_result = {'Model Type' : ['S-BERT(scratch)'],
               'SNLI Performance': [test_acc]}


pd.DataFrame(eval_result)

Unnamed: 0,Model Type,SNLI Performance
0,S-BERT(scratch),35.090909


## 7. Inference

Creating a inference function to be used in web app

In [63]:
def infer(model, tokenizer, sentence_a, sentence_b, device):
    model.eval()
    sentence_a = sentence_a.lower()
    sentence_b = sentence_b.lower()
    sentence_a = re.sub("[.,!?\\-]", '',sentence_a)
    sentence_b = re.sub("[.,!?\\-]", '',sentence_b)
    inputs_a = tokenizer(sentence_a)
    # Tokenize the hypothesis
    inputs_b = tokenizer(sentence_b)
    
     # Move input IDs and attention masks to the active device
    inputs_ids_a = inputs_a['input_ids'][0].unsqueeze(0).to(device)
    attention_a = inputs_a['att_mask'][0].unsqueeze(0).to(device)
    inputs_ids_b = inputs_b['input_ids'][0].unsqueeze(0).to(device)
    attention_b = inputs_b['att_mask'][0].unsqueeze(0).to(device)
    
    with torch.no_grad():
        # extract token embeddings from BERT at last_hidden_state
        u = model.get_last_hidden_state(inputs_ids_a, segment_ids)
        v = model.get_last_hidden_state(inputs_ids_b, segment_ids)

    u_last_hidden_state = u # all token embeddings A = batch_size, seq_len, hidden_dim
    v_last_hidden_state = v # all token embeddings B = batch_size, seq_len, hidden_dim

     # get the mean pooled vectors
    u_mean_pool = mean_pool(u_last_hidden_state, attention_a) # batch_size, hidden_dim
    v_mean_pool = mean_pool(v_last_hidden_state, attention_b) # batch_size, hidden_dim
    
    # build the |u-v| tensor
    uv = torch.sub(u_mean_pool, v_mean_pool)   # batch_size,hidden_dim
    uv_abs = torch.abs(uv) # batch_size,hidden_dim
    
    # concatenate u, v, |u-v|
    x = torch.cat([u_mean_pool, v_mean_pool, uv_abs], dim=-1) # batch_size, 3*hidden_dim
    
    # process concatenated tensor through classifier_head
    y = classifier_head(x) #batch_size, classifer

    return torch.argmax(nn.functional.softmax(y,1)).item()
    

In [3]:
s_model.load_state_dict(torch.load('./sbert_model.pth'))

<All keys matched successfully>

In [4]:
# Create segment_ids tensor with shape (batch_size, max_len)
segment_ids = torch.tensor([0] * max_len).unsqueeze(0).repeat(1, 1).to(device)

# Create masked_pos tensor with shape (batch_size, max_mask)
masked_pos = torch.tensor([0] * max_mask).unsqueeze(0).repeat(1, 1).to(device)

In [68]:
# sentence_a = 'It is wrong'
# sentence_b = 'I will not do it'
sentence_a = 'A man is playing a guitar on stage'
sentence_b = 'The man is performing music'
infer(s_model, tokenize, sentence_a, sentence_b, device)

0

## 8. Discussion

In training BERT, only one percent of BookCorpus dataset is trained with 100 epochs due to the limitation of computational resources. Then, S-BERT is retrained on one percent of SNLI data. Hence, the performance of the model is not good enough to make distinction various inputs. This can be seen in the average similarity scores which is 0.99 as most of the sentences look similar to the model. However, the accuracy of prediction the labels is acceptable with 35 percent. Here, the obvious limitation is computational resource, causing the training data size reduced.

To improve the performance, we need to scale up the training data and the model size by adding more layers. Moreover, the epochs is just two as it is expensive. This is also another factor of hindrance of model performance. I believe scaling up with more epochs will improve the model performance significantly.