In [1]:
import torch

if torch.cuda.is_available():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("CUDA is available. Using GPU.")
else:
    device = torch.device("cpu")
    print("CUDA is not available. Using CPU.")

CUDA is available. Using GPU.


In [2]:
import pynvml

def get_memory_free_MiB(gpu_index):
    pynvml.nvmlInit()
    handle = pynvml.nvmlDeviceGetHandleByIndex(int(gpu_index))
    mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
    return mem_info.free // 1024 ** 2

In [3]:
# I write this myself to select the GPU with highest available VRAM on puffer

total_gpus = torch.cuda.device_count()
largest_vram = 0
gpu_index = 0

for i in range(total_gpus):
    new_vram = get_memory_free_MiB(i)
    if new_vram > largest_vram:
        largest_vram = new_vram
        gpu_index = i
    print(f'GPU {i}: {torch.cuda.get_device_name(i)}')
    print(f'available memory of GPU {i}: {new_vram} MiB \n')

print(f'GPU {gpu_index} has the largest available VRAM: {largest_vram} MiB')

GPU 0: NVIDIA GeForce RTX 2080 Ti
available memory of GPU 0: 11000 MiB 

GPU 1: NVIDIA GeForce RTX 2080 Ti
available memory of GPU 1: 738 MiB 

GPU 2: NVIDIA GeForce RTX 2080 Ti
available memory of GPU 2: 4158 MiB 

GPU 3: NVIDIA GeForce RTX 2080 Ti
available memory of GPU 3: 3598 MiB 

GPU 0 has the largest available VRAM: 11000 MiB


In [4]:
torch.cuda.set_device(gpu_index)
print(f'current cuda device is set to: {torch.cuda.current_device()}')

current cuda device is set to: 0


In [5]:
tensor = torch.randn(3, 3, device=device)

print(f"Tensor is on: {tensor.device}")

Tensor is on: cuda:0


In [6]:
SEED = 69
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

## Data

### Train, Test, Validation 

In [7]:
import datasets
mnli = datasets.load_dataset('glue', 'mnli')
mnli_matched = mnli.filter(lambda x: x['label'] >= 0)

In [8]:
print(mnli)

DatasetDict({
    train: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 392702
    })
    validation_matched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 9815
    })
    validation_mismatched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 9832
    })
    test_matched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 9796
    })
    test_mismatched: Dataset({
        features: ['premise', 'hypothesis', 'label', 'idx'],
        num_rows: 9847
    })
})


In [9]:
print(mnli['train'].features)

{'premise': Value(dtype='string', id=None), 'hypothesis': Value(dtype='string', id=None), 'label': ClassLabel(names=['entailment', 'neutral', 'contradiction'], id=None), 'idx': Value(dtype='int32', id=None)}


In [10]:
import numpy as np
np.unique(mnli['train']['label'])

array([0, 1, 2])

In [11]:
from datasets import DatasetDict

train_range = 1000
validation_range = 1000
test_range = 100

# The test set in both mnli_matched and mnli_mismatched does not have gold labels.
# In MNLI, the test set is unannotated, so label = -1 is used as a placeholder.

raw_dataset = DatasetDict({
    'train': mnli['train'].shuffle(seed=SEED).select(list(range(train_range))),
    'validation': mnli['validation_matched'].shuffle(seed=SEED).select(list(range(validation_range))),
    'test': mnli['validation_mismatched'].shuffle(seed=SEED).select(list(range(test_range))),
})

In [12]:
print("Unique Labels in Train Set:", np.unique(raw_dataset['train']['label']))
print("Unique Labels in Validation Set:", np.unique(raw_dataset['validation']['label']))
print("Unique Labels in Test Set:", np.unique(raw_dataset['test']['label']))

Unique Labels in Train Set: [0 1 2]
Unique Labels in Validation Set: [0 1 2]
Unique Labels in Test Set: [0 1 2]


## Preprocessing

In [13]:
import json

with open('models/my_tokenizer.json', 'r') as f:
    tokenizer = json.load(f)
word2id = tokenizer['word2id']
id2word = tokenizer['id2word']
vocab_size = len(word2id)

In [14]:
# from transformers import BertTokenizer

# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def custom_tokenizer(sentences, max_length, padding='max_length', truncation=True):
    tokenized_outputs = {"input_ids": [], "attention_mask": []}
    for sentence in sentences:
        tokens = sentence.lower().split()
        token_ids = [word2id.get(token, word2id['[UNK]']) for token in tokens]
        
        #add [CLS] at the start and [SEP] at the end for BERT compatibility
        token_ids = [word2id['[CLS]']] + token_ids + [word2id['[SEP]']]
        
        #truncate if longer than max_length
        if truncation and len(token_ids) > max_length:
            token_ids = token_ids[:max_length-1] + [word2id['[SEP]']]
        
        attention_mask = [1] * len(token_ids)
        
        #pad if shorter than max_length
        if padding == 'max_length':
            padding_length = max_length - len(token_ids)
            token_ids += [word2id['[PAD]']] * padding_length
            attention_mask += [0] * padding_length
        
        tokenized_outputs["input_ids"].append(token_ids)
        tokenized_outputs["attention_mask"].append(attention_mask)
    return tokenized_outputs

In [15]:
def preprocess_function(examples):
    max_seq_length = 1000
    padding = 'max_length'
    
    #tokenize premise
    premise_result = custom_tokenizer(
        examples['premise'], 
        max_length=max_seq_length, 
        padding=padding, 
        truncation=True
    )
    
    #tokenize hypothesis
    hypothesis_result = custom_tokenizer(
        examples['hypothesis'], 
        max_length=max_seq_length, 
        padding=padding, 
        truncation=True
    )
    
    #extract labels
    labels = examples["label"]
    
    return {
        "premise_input_ids": premise_result["input_ids"],
        "premise_attention_mask": premise_result["attention_mask"],
        "hypothesis_input_ids": hypothesis_result["input_ids"],
        "hypothesis_attention_mask": hypothesis_result["attention_mask"],
        "labels": labels
    }


tokenized_datasets = raw_dataset.map(
    preprocess_function,
    batched=True,
)

tokenized_datasets = tokenized_datasets.remove_columns(['premise', 'hypothesis', 'label'])
tokenized_datasets.set_format("torch")


In [16]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['idx', 'premise_input_ids', 'premise_attention_mask', 'hypothesis_input_ids', 'hypothesis_attention_mask', 'labels'],
        num_rows: 1000
    })
    validation: Dataset({
        features: ['idx', 'premise_input_ids', 'premise_attention_mask', 'hypothesis_input_ids', 'hypothesis_attention_mask', 'labels'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['idx', 'premise_input_ids', 'premise_attention_mask', 'hypothesis_input_ids', 'hypothesis_attention_mask', 'labels'],
        num_rows: 100
    })
})

## Data loader

In [17]:
from torch.utils.data import DataLoader

# initialize the dataloaders
batch_size = 2
train_dataloader = DataLoader(
    tokenized_datasets['train'], 
    batch_size=batch_size, 
    shuffle=True
)
eval_dataloader = DataLoader(
    tokenized_datasets['validation'], 
    batch_size=batch_size
)
test_dataloader = DataLoader(
    tokenized_datasets['test'], 
    batch_size=batch_size
)

In [18]:
# iterate through the training dataloader
for batch in train_dataloader:
    print(batch['premise_input_ids'].shape)           # premise input IDs
    print(batch['premise_attention_mask'].shape)      # premise attention mask
    print(batch['hypothesis_input_ids'].shape)        # hypothesis input IDs
    print(batch['hypothesis_attention_mask'].shape)   # hypothesis attention mask
    print(batch['labels'].shape)                      # labels for classification
    break

torch.Size([2, 1000])
torch.Size([2, 1000])
torch.Size([2, 1000])
torch.Size([2, 1000])
torch.Size([2])


## Model

In [19]:
from torch import nn


class Embedding(nn.Module):
    def __init__(self, vocab_size, max_len, n_segments, d_model, device):
        super(Embedding, self).__init__()
        self.tok_embed = nn.Embedding(vocab_size, d_model)  # token embedding
        self.pos_embed = nn.Embedding(max_len, d_model)  # positional embedding
        self.seg_embed = nn.Embedding(n_segments, d_model)  # segment embedding
        self.norm = nn.LayerNorm(d_model)  # layer normalization
        self.device = device

    def forward(self, x, seg):
        seq_len = x.size(1)
        pos = (
            torch.arange(seq_len, dtype=torch.long)
            .to(self.device)
            .unsqueeze(0)
            .expand_as(x)
        )  # create position indices
        embedding = (
            self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)
        )  # sum all embeddings
        return self.norm(embedding)  # apply layer normalization


class BERT(nn.Module):
    def __init__(
        self,
        n_layers,
        n_heads,
        d_model,
        d_ff,
        d_k,
        n_segments,
        vocab_size,
        max_len,
        device,
    ):
        super(BERT, self).__init__()
        self.embedding = Embedding(
            vocab_size, max_len, n_segments, d_model, device
        )  # embedding layer
        self.layers = nn.ModuleList(
            [EncoderLayer(n_heads, d_model, d_ff, d_k, device) for _ in range(n_layers)]
        )  # transformer encoder layers
        self.fc = nn.Linear(d_model, d_model)  # fully connected layer for hidden states
        self.activ = nn.Tanh()  # activation function
        self.linear = nn.Linear(d_model, d_model)  # another linear layer
        self.norm = nn.LayerNorm(d_model)  # layer normalization
        self.classifier = nn.Linear(d_model, 2)  # classifier head for predictions
        self.decoder = nn.Linear(
            d_model, vocab_size, bias=False
        )  # decoder for language modeling
        self.decoder_bias = nn.Parameter(torch.zeros(vocab_size))  # bias for decoder
        self.device = device

    def forward(self, input_ids, segment_ids):
        output = self.embedding(input_ids, segment_ids)  # get embeddings
        enc_self_attn_mask = get_attn_pad_mask(
            input_ids, input_ids, self.device
        )  # attention mask
        for layer in self.layers:
            output, _ = layer(
                output, enc_self_attn_mask
            )  # pass through transformer layers
        return output  # return hidden states

    def get_last_hidden_state(self, input_ids):
        segment_ids = torch.zeros_like(input_ids).to(
            self.device
        )  # default segment ids as zeros
        output = self.embedding(input_ids, segment_ids)  # get embeddings
        enc_self_attn_mask = get_attn_pad_mask(
            input_ids, input_ids, self.device
        )  # attention mask
        for layer in self.layers:
            output, _ = layer(
                output, enc_self_attn_mask
            )  # pass through transformer layers
        return output  # return last hidden state


def get_attn_pad_mask(seq_q, seq_k, device):
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1).to(device)  # mask padding tokens
    return pad_attn_mask.expand(batch_size, len_q, len_k)  # expand mask to all heads


class EncoderLayer(nn.Module):
    def __init__(self, n_heads, d_model, d_ff, d_k, device):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention(
            n_heads, d_model, d_k, device
        )  # multi-head self-attention
        self.pos_ffn = PoswiseFeedForwardNet(
            d_model, d_ff
        )  # position-wise feed-forward network

    def forward(self, enc_inputs, enc_self_attn_mask):
        enc_outputs, attn = self.enc_self_attn(
            enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask
        )  # self-attention mechanism
        enc_outputs = self.pos_ffn(enc_outputs)  # position-wise feed-forward
        return enc_outputs, attn  # return outputs and attention weights


class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, d_model, d_k, device):
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads
        self.d_model = d_model
        self.d_k = d_k
        self.W_Q = nn.Linear(d_model, d_k * n_heads)  # query projection
        self.W_K = nn.Linear(d_model, d_k * n_heads)  # key projection
        self.W_V = nn.Linear(d_model, d_k * n_heads)  # value projection
        self.device = device

    def forward(self, Q, K, V, attn_mask):
        residual, batch_size = Q, Q.size(0)  # residual connection
        q_s = (
            self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        )  # project queries
        k_s = (
            self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        )  # project keys
        v_s = (
            self.W_V(V).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        )  # project values
        attn_mask = attn_mask.unsqueeze(1).repeat(
            1, self.n_heads, 1, 1
        )  # repeat mask for all heads
        context, attn = ScaledDotProductAttention(self.d_k, self.device)(
            q_s, k_s, v_s, attn_mask
        )  # scaled dot-product attention
        context = (
            context.transpose(1, 2)
            .contiguous()
            .view(batch_size, -1, self.n_heads * self.d_k)
        )  # concatenate attention heads
        output = nn.Linear(self.n_heads * self.d_k, self.d_model).to(self.device)(
            context
        )  # final linear layer
        return (
            nn.LayerNorm(self.d_model).to(self.device)(output + residual),
            attn,
        )  # add & normalize


class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k, device):
        super(ScaledDotProductAttention, self).__init__()
        self.scale = torch.sqrt(torch.FloatTensor([d_k])).to(device)  # scaling factor

    def forward(self, Q, K, V, attn_mask):
        scores = torch.matmul(Q, K.transpose(-1, -2)) / self.scale  # scaled dot-product
        scores.masked_fill_(attn_mask, -1e9)  # apply attention mask
        attn = nn.Softmax(dim=-1)(scores)  # softmax to get attention weights
        context = torch.matmul(attn, V)  # weighted sum of values
        return context, attn  # return context and attention weights


class PoswiseFeedForwardNet(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)  # first feed-forward layer
        self.fc2 = nn.Linear(d_ff, d_model)  # second feed-forward layer

    def forward(self, x):
        return self.fc2(nn.functional.gelu(self.fc1(x)))  # apply gelu activation

In [20]:
max_len = 1000
n_layers = 12
n_heads = 12
d_model = 768
d_ff = d_model * 4
d_k = d_v = 64
n_segments = 2

In [21]:
# initialize and load BERT model
model = BERT(
    n_layers=n_layers,
    n_heads=n_heads,
    d_model=d_model,
    d_ff=d_ff,
    d_k=d_k,
    n_segments=n_segments,
    vocab_size=vocab_size,
    max_len=max_len,
    device=device
).to(device)

model.load_state_dict(torch.load('models/bert_model_1.pth', map_location=device))
model.eval()

BERT(
  (embedding): Embedding(
    (tok_embed): Embedding(23069, 768)
    (pos_embed): Embedding(1000, 768)
    (seg_embed): Embedding(2, 768)
    (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (layers): ModuleList(
    (0-11): 12 x EncoderLayer(
      (enc_self_attn): MultiHeadAttention(
        (W_Q): Linear(in_features=768, out_features=768, bias=True)
        (W_K): Linear(in_features=768, out_features=768, bias=True)
        (W_V): Linear(in_features=768, out_features=768, bias=True)
      )
      (pos_ffn): PoswiseFeedForwardNet(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
      )
    )
  )
  (fc): Linear(in_features=768, out_features=768, bias=True)
  (activ): Tanh()
  (linear): Linear(in_features=768, out_features=768, bias=True)
  (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (classifier): Linear(in_features=768, out_features=2, bias=True)
  (de

### Pooling
SBERT adds a pooling operation to the output of BERT / RoBERTa to derive a fixed sized sentence embedding

In [22]:
# define mean pooling function
def mean_pool(token_embeds, attention_mask):
    # reshape attention_mask to cover 768-dimension embeddings
    in_mask = attention_mask.unsqueeze(-1).expand(
        token_embeds.size()
    ).float()
    # perform mean-pooling but exclude padding tokens (specified by in_mask)
    pool = torch.sum(token_embeds * in_mask, 1) / torch.clamp(
        in_mask.sum(1), min=1e-9
    )
    return pool

## Loss Function

## Classification Objective Function 
We concatenate the sentence embeddings $u$ and $v$ with the element-wise difference  $\lvert u - v \rvert $ and multiply the result with the trainable weight  $ W_t ∈  \mathbb{R}^{3n \times k}  $:

$ o = \text{softmax}\left(W^T \cdot \left(u, v, \lvert u - v \rvert\right)\right) $

where $n$ is the dimension of the sentence embeddings and k the number of labels. We optimize cross-entropy loss. This structure is depicted in Figure 1.

## Regression Objective Function. 
The cosine similarity between the two sentence embeddings $u$ and $v$ is computed (Figure 2). We use means quared-error loss as the objective function.

(Manhatten / Euclidean distance, semantically  similar sentences can be found.)

<img src="./figures/sbert-architecture.png" >

In [23]:
def configurations(u,v):
    # build the |u-v| tensor
    uv = torch.sub(u, v)   # batch_size,hidden_dim
    uv_abs = torch.abs(uv) # batch_size,hidden_dim
    
    # concatenate u, v, |u-v|
    x = torch.cat([u, v, uv_abs], dim=-1) # batch_size, 3*hidden_dim
    return x

def cosine_similarity(u, v):
    dot_product = np.dot(u, v)
    norm_u = np.linalg.norm(u)
    norm_v = np.linalg.norm(v)
    similarity = dot_product / (norm_u * norm_v)
    return similarity

<img src="./figures/sbert-ablation.png" width="350" height="300">

In [24]:
classifier_head = torch.nn.Linear(768*3, 3).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
optimizer_classifier = torch.optim.Adam(classifier_head.parameters(), lr=2e-5)

criterion = nn.CrossEntropyLoss()

In [25]:
from transformers import get_linear_schedule_with_warmup

# and setup a warmup for the first ~10% steps
total_steps = int(len(raw_dataset) / batch_size)
warmup_steps = int(0.1 * total_steps)
scheduler = get_linear_schedule_with_warmup(
		optimizer, num_warmup_steps=warmup_steps,
  	num_training_steps=total_steps - warmup_steps
)

# then during the training loop we update the scheduler per step
scheduler.step()

scheduler_classifier = get_linear_schedule_with_warmup(
		optimizer_classifier, num_warmup_steps=warmup_steps,
  	num_training_steps=total_steps - warmup_steps
)

# then during the training loop we update the scheduler per step
scheduler_classifier.step()



In [26]:
premise_input_ids = batch['premise_input_ids']
print("Premise Max input_id:", torch.max(premise_input_ids))
print("Premise Min input_id:", torch.min(premise_input_ids))

hypothesis_input_ids = batch['hypothesis_input_ids']
print("Hypothesis Max input_id:", torch.max(hypothesis_input_ids))
print("Hypothesis Min input_id:", torch.min(hypothesis_input_ids))

print("Vocab size:", vocab_size)


Premise Max input_id: tensor(22619)
Premise Min input_id: tensor(0)
Hypothesis Max input_id: tensor(21349)
Hypothesis Min input_id: tensor(0)
Vocab size: 23069


## Training

In [27]:
from tqdm.auto import tqdm

num_epoch = 4  # 1 is enough, can increase if you want more epoches
for epoch in range(num_epoch):
    model.train()  
    classifier_head.train()

    # training loop with tqdm for progress bar
    for step, batch in enumerate(tqdm(train_dataloader, leave=True)):
        # zero gradients
        optimizer.zero_grad()
        optimizer_classifier.zero_grad()
        
        # move tensors to GPU (or CPU) device
        premise_input_ids = batch['premise_input_ids'].to(device)
        hypothesis_input_ids = batch['hypothesis_input_ids'].to(device)
        premise_attention_mask = batch['premise_attention_mask'].to(device)
        hypothesis_attention_mask = batch['hypothesis_attention_mask'].to(device)
        labels = batch['labels'].to(device)

        premise_segment_ids = torch.zeros_like(premise_input_ids).to(device)
        hypothesis_segment_ids = torch.ones_like(hypothesis_input_ids).to(device)
        
        u = model(premise_input_ids, premise_segment_ids)
        v = model(hypothesis_input_ids, hypothesis_segment_ids)
        
        u_mean_pool = mean_pool(u, premise_attention_mask)
        v_mean_pool = mean_pool(v, hypothesis_attention_mask)
        
        uv = torch.sub(u_mean_pool, v_mean_pool)
        uv_abs = torch.abs(uv)
        
        x = torch.cat([u_mean_pool, v_mean_pool, uv_abs], dim=-1)
        
        logits = classifier_head(x)
        loss = criterion(logits, labels)
        
        #backpropagation and optimization
        loss.backward()
        optimizer.step()
        optimizer_classifier.step()

        # Update learning rate scheduler
        scheduler.step()
        scheduler_classifier.step()
        
    print(f'Epoch: {epoch + 1} | loss = {loss.item():.6f}')


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 1 | loss = 3.666999


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 2 | loss = 3.084264


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 3 | loss = 3.534136


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 4 | loss = 2.575287


In [28]:
model.eval()
classifier_head.eval()
total_similarity = 0.0
correct_predictions = 0
total_samples = 0

with torch.no_grad():
    for step, batch in enumerate(eval_dataloader):
        # move tensors to the GPU (or CPU) device
        premise_input_ids = batch['premise_input_ids'].to(device)
        hypothesis_input_ids = batch['hypothesis_input_ids'].to(device)
        premise_attention_mask = batch['premise_attention_mask'].to(device)
        hypothesis_attention_mask = batch['hypothesis_attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        # create segment IDs for premise and hypothesis
        premise_segment_ids = torch.zeros_like(premise_input_ids).to(device)
        hypothesis_segment_ids = torch.ones_like(hypothesis_input_ids).to(device)
        
        # extract last hidden states using custom BERT for premise and hypothesis
        u = model.get_last_hidden_state(premise_input_ids)
        v = model.get_last_hidden_state(hypothesis_input_ids)
        
        # get the mean pooled vectors for premise and hypothesis
        u_mean_pool = mean_pool(u, premise_attention_mask)  # [B, H]
        v_mean_pool = mean_pool(v, hypothesis_attention_mask)  # [B, H]
        
        # compute cosine similarity for each sample in the batch
        # regression objective function
        cos_sim = (u_mean_pool * v_mean_pool).sum(dim=1) / (
            torch.norm(u_mean_pool, dim=1) * torch.norm(v_mean_pool, dim=1) + 1e-8
        )
        # average the similarity over the batch (a scalar)
        similarity_score = cos_sim.mean().item()
        total_similarity += similarity_score
        
        # build the |u-v| tensor
        uv = torch.sub(u_mean_pool, v_mean_pool)
        uv_abs = torch.abs(uv)
        
        # concatenate u, v, |u-v|
        # classfication objective function
        x = torch.cat([u_mean_pool, v_mean_pool, uv_abs], dim=-1)
        
        # pass concatenated tensor through classifier head
        logits = classifier_head(x)
        
        # get the predicted class (0, 1, or 2 for MNLI)
        predictions = torch.argmax(logits, dim=1)
        
    
average_similarity = total_similarity / len(eval_dataloader)

print(f"Average Cosine Similarity: {average_similarity:.4f}")

Average Cosine Similarity: 0.9982


## Inference

In [29]:
def calculate_similarity(model, sentence_a, sentence_b, device, max_length=128):
    # use the custom tokenizer to tokenize premise and hypothesis separately
    inputs_a = custom_tokenizer([sentence_a], max_length=max_length, padding='max_length', truncation=True)
    inputs_b = custom_tokenizer([sentence_b], max_length=max_length, padding='max_length', truncation=True)
    
    # convert lists to torch tensors and move to device
    input_ids_a = torch.tensor(inputs_a['input_ids']).to(device)
    attention_a = torch.tensor(inputs_a['attention_mask']).to(device)
    input_ids_b = torch.tensor(inputs_b['input_ids']).to(device)
    attention_b = torch.tensor(inputs_b['attention_mask']).to(device)
    
    # create segment IDs for premise and hypothesis
    premise_segment_ids = torch.zeros_like(input_ids_a).to(device)
    hypothesis_segment_ids = torch.ones_like(input_ids_b).to(device)
    
    # use the model's helper method to get token embeddings
    u = model.get_last_hidden_state(input_ids_a)
    v = model.get_last_hidden_state(input_ids_b)
    
    # get mean pooled sentence embeddings
    u_mean = mean_pool(u, attention_a).detach().cpu().numpy().squeeze()
    v_mean = mean_pool(v, attention_b).detach().cpu().numpy().squeeze()
    
    # calculate cosine similarity (result is a scalar)
    similarity_score = np.dot(u_mean, v_mean) / (np.linalg.norm(u_mean) * np.linalg.norm(v_mean) + 1e-8)
    return similarity_score

# sentences for similarity calculation
sentence_a = 'Your contribution helped make it possible for us to provide our students with a quality education.'
sentence_b = "Your contributions were of no help with our students' education."

similarity = calculate_similarity(model, sentence_a, sentence_b, device)
print(f"cosine Similarity: {similarity:.4f}")



cosine Similarity: 0.9950


In [30]:
torch.save(model.state_dict(), 'models/custom_bert_mnli.pth')
torch.save(classifier_head.state_dict(), 'models/classifier_head.pth')

with open('models/custom_tokenizer.json', 'w') as f:
    json.dump(word2id, f)

print("Model and tokenizer saved successfully.")

Model and tokenizer saved successfully.
