In [1]:
import math
import re
from   random import *
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#make our work comparable if restarted the kernel
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

batch_size = 6
max_mask   = 5  # max masked tokens when 15% exceed, it will only be max_pred
max_len    = 1000 # maximum of length to be padded; 

class Embedding(nn.Module):
    def __init__(self, vocab_size, max_len, n_segments, d_model, device):
        super(Embedding, self).__init__()
        self.tok_embed = nn.Embedding(vocab_size, d_model)  # token embedding
        self.pos_embed = nn.Embedding(max_len, d_model)      # position embedding
        self.seg_embed = nn.Embedding(n_segments, d_model)  # segment(token type) embedding
        self.norm = nn.LayerNorm(d_model)
        self.device = device

    def forward(self, x, seg):
        #x, seg: (bs, len)
        seq_len = x.size(1)
        pos = torch.arange(seq_len, dtype=torch.long).to(self.device)
        pos = pos.unsqueeze(0).expand_as(x)  # (len,) -> (bs, len)
        embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)
        return self.norm(embedding)
    
def get_attn_pad_mask(seq_q, seq_k, device):
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    # eq(zero) is PAD token
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1).to(device)  # batch_size x 1 x len_k(=len_q), one is masking
    return pad_attn_mask.expand(batch_size, len_q, len_k)  # batch_size x len_q x len_k

class EncoderLayer(nn.Module):
    def __init__(self, n_heads, d_model, d_ff, d_k, device):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention(n_heads, d_model, d_k, device)
        self.pos_ffn       = PoswiseFeedForwardNet(d_model, d_ff)

    def forward(self, enc_inputs, enc_self_attn_mask):
        enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V
        enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size x len_q x d_model]
        return enc_outputs, attn
    
class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k, device):
        super(ScaledDotProductAttention, self).__init__()
        self.scale = torch.sqrt(torch.FloatTensor([d_k])).to(device)

    def forward(self, Q, K, V, attn_mask):
        scores = torch.matmul(Q, K.transpose(-1, -2)) / self.scale # scores : [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
        scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is one.
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V)
        return context, attn 
    
n_layers = 6    # number of Encoder of Encoder Layer
n_heads  = 8    # number of heads in Multi-Head Attention
d_model  = 768  # Embedding Size
d_ff = 768 * 4  # 4*d_model, FeedForward dimension
d_k = d_v = 64  # dimension of K(=Q), V
n_segments = 2

class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, d_model, d_k, device):
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads
        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_k
        self.W_Q = nn.Linear(d_model, d_k * n_heads)
        self.W_K = nn.Linear(d_model, d_k * n_heads)
        self.W_V = nn.Linear(d_model, self.d_v * n_heads)
        self.device = device
    def forward(self, Q, K, V, attn_mask):
        # q: [batch_size x len_q x d_model], k: [batch_size x len_k x d_model], v: [batch_size x len_k x d_model]
        residual, batch_size = Q, Q.size(0)
        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        q_s = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)  # q_s: [batch_size x n_heads x len_q x d_k]
        k_s = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1,2)  # k_s: [batch_size x n_heads x len_k x d_k]
        v_s = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1,2)  # v_s: [batch_size x n_heads x len_k x d_v]

        attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1) # attn_mask : [batch_size x n_heads x len_q x len_k]

        # context: [batch_size x n_heads x len_q x d_v], attn: [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
        context, attn = ScaledDotProductAttention(self.d_k, self.device)(q_s, k_s, v_s, attn_mask)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v) # context: [batch_size x len_q x n_heads * d_v]
        output = nn.Linear(self.n_heads * self.d_v, self.d_model, device=self.device)(context)
        return nn.LayerNorm(self.d_model, device=self.device)(output + residual), attn # output: [batch_size x len_q x d_model]
    
class PoswiseFeedForwardNet(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        # (batch_size, len_seq, d_model) -> (batch_size, len_seq, d_ff) -> (batch_size, len_seq, d_model)
        return self.fc2(F.gelu(self.fc1(x)))
    
class BERT(nn.Module):
    def __init__(self, n_layers, n_heads, d_model, d_ff, d_k, n_segments, vocab_size, max_len, device):
        super(BERT, self).__init__()
        self.params = {'n_layers': n_layers, 'n_heads': n_heads, 'd_model': d_model,
                       'd_ff': d_ff, 'd_k': d_k, 'n_segments': n_segments,
                       'vocab_size': vocab_size, 'max_len': max_len}
        self.embedding = Embedding(vocab_size, max_len, n_segments, d_model, device)
        self.layers = nn.ModuleList([EncoderLayer(n_heads, d_model, d_ff, d_k, device) for _ in range(n_layers)])
        self.fc = nn.Linear(d_model, d_model)
        self.activ = nn.Tanh()
        self.linear = nn.Linear(d_model, d_model)
        self.norm = nn.LayerNorm(d_model)
        self.classifier = nn.Linear(d_model, 2)
        # decoder is shared with embedding layer
        embed_weight = self.embedding.tok_embed.weight
        n_vocab, n_dim = embed_weight.size()
        self.decoder = nn.Linear(n_dim, n_vocab, bias=False)
        self.decoder.weight = embed_weight
        self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))
        self.device = device

    def forward(self, input_ids, segment_ids, masked_pos):
        output = self.embedding(input_ids, segment_ids)
        enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids, self.device)
        for layer in self.layers:
            output, enc_self_attn = layer(output, enc_self_attn_mask)
        # output : [batch_size, len, d_model], attn : [batch_size, n_heads, d_mode, d_model]
        
        # 1. predict next sentence
        # it will be decided by first token(CLS)
        h_pooled   = self.activ(self.fc(output[:, 0])) # [batch_size, d_model]
        logits_nsp = self.classifier(h_pooled) # [batch_size, 2]

        # 2. predict the masked token
        masked_pos = masked_pos[:, :, None].expand(-1, -1, output.size(-1)) # [batch_size, max_pred, d_model]
        h_masked = torch.gather(output, 1, masked_pos) # masking position [batch_size, max_pred, d_model]
        h_masked  = self.norm(F.gelu(self.linear(h_masked)))
        logits_lm = self.decoder(h_masked) + self.decoder_bias # [batch_size, max_pred, n_vocab]

        return logits_lm, logits_nsp
    
    def get_last_hidden_state(self, input_ids, segment_ids):
        output = self.embedding(input_ids, segment_ids)
        enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids, self.device)
        for layer in self.layers:
            output, enc_self_attn = layer(output, enc_self_attn_mask)

        return output
    




In [2]:
# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
n_layers = 12
n_heads = 12
d_model = 768
d_ff = d_model * 4
d_k = d_v = 64
n_segments = 2
vocab_size = 60305
max_len = 1000

# Initialize the model
model = BERT(
    n_layers, 
    n_heads, 
    d_model, 
    d_ff, 
    d_k, 
    n_segments, 
    vocab_size, 
    max_len, 
    device
).to(device)

# Load the saved weights
model.load_state_dict(torch.load('bert_model.pth', map_location=device))
model.eval()  # Set the model to evaluation mode

BERT(
  (embedding): Embedding(
    (tok_embed): Embedding(60305, 768)
    (pos_embed): Embedding(1000, 768)
    (seg_embed): Embedding(2, 768)
    (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (layers): ModuleList(
    (0-11): 12 x EncoderLayer(
      (enc_self_attn): MultiHeadAttention(
        (W_Q): Linear(in_features=768, out_features=768, bias=True)
        (W_K): Linear(in_features=768, out_features=768, bias=True)
        (W_V): Linear(in_features=768, out_features=768, bias=True)
      )
      (pos_ffn): PoswiseFeedForwardNet(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
      )
    )
  )
  (fc): Linear(in_features=768, out_features=768, bias=True)
  (activ): Tanh()
  (linear): Linear(in_features=768, out_features=768, bias=True)
  (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (classifier): Linear(in_features=768, out_features=2, bias=True)
  (de

In [3]:
def mean_pool(token_embeds, attention_mask):
    # Reshape attention_mask to cover 768-dimension embeddings
    in_mask = attention_mask.unsqueeze(-1).expand(token_embeds.size()).float()
    # Perform mean-pooling but exclude padding tokens (specified by in_mask)
    pool = torch.sum(token_embeds * in_mask, 1) / torch.clamp(in_mask.sum(1), min=1e-9)
    return pool

class SentenceBERT(nn.Module):
    def __init__(self, bert_model):
        super(SentenceBERT, self).__init__()
        self.bert = bert_model  # Use the pre-trained BERT model from Task 1

    def forward(self, premise_input_ids, premise_attention_mask, hypothesis_input_ids, hypothesis_attention_mask):
        # Generate dummy segment IDs (tensor of zeros)
        premise_segment_ids = torch.zeros_like(premise_input_ids, device=self.bert.device)
        hypothesis_segment_ids = torch.zeros_like(hypothesis_input_ids, device=self.bert.device)

        # Encode the premise
        premise_embeds = self.bert.get_last_hidden_state(premise_input_ids, premise_segment_ids)
        u = mean_pool(premise_embeds, premise_attention_mask)  # Pooling to get sentence embedding

        # Encode the hypothesis
        hypothesis_embeds = self.bert.get_last_hidden_state(hypothesis_input_ids, hypothesis_segment_ids)
        v = mean_pool(hypothesis_embeds, hypothesis_attention_mask)  # Pooling to get sentence embedding

        return u, v
    
class SoftmaxLoss(nn.Module):
    def __init__(self, hidden_size, num_labels):
        super(SoftmaxLoss, self).__init__()
        self.fc = nn.Linear(hidden_size * 3, num_labels)  # Combine u, v, |u-v|

    def forward(self, u, v, labels):
        diff = torch.abs(u - v)  # Element-wise absolute difference
        features = torch.cat([u, v, diff], dim=1)  # Concatenate u, v, |u-v|
        logits = self.fc(features)  # Pass through a fully connected layer
        loss = F.cross_entropy(logits, labels)  # Compute cross-entropy loss
        return loss
    
    

In [4]:
from transformers import BertTokenizer
from datasets import load_dataset

# Load the SNLI dataset
raw_dataset = load_dataset('snli')

# Initialize the tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Define the preprocessing function
def preprocess_function(examples):
    max_seq_length = 128
    padding = 'max_length'
    # Tokenize the premise
    premise_result = tokenizer(
        examples['premise'], padding=padding, max_length=max_seq_length, truncation=True)
    # Tokenize the hypothesis
    hypothesis_result = tokenizer(
        examples['hypothesis'], padding=padding, max_length=max_seq_length, truncation=True)
    # Extract labels
    labels = examples["label"]
    return {
        "premise_input_ids": premise_result["input_ids"],
        "premise_attention_mask": premise_result["attention_mask"],
        "hypothesis_input_ids": hypothesis_result["input_ids"],
        "hypothesis_attention_mask": hypothesis_result["attention_mask"],
        "labels": labels
    }

# Apply the preprocessing function to the dataset
tokenized_datasets = raw_dataset.map(
    preprocess_function,
    batched=True,
)

# Remove unnecessary columns
tokenized_datasets = tokenized_datasets.remove_columns(['premise', 'hypothesis', 'label'])

# Set the format to PyTorch tensors
tokenized_datasets.set_format("torch")

In [5]:
from torch.utils.data import DataLoader

# Initialize the DataLoader
batch_size = 32
train_dataloader = DataLoader(
    tokenized_datasets['train'], 
    batch_size=batch_size, 
    shuffle=True
)
eval_dataloader = DataLoader(
    tokenized_datasets['validation'], 
    batch_size=batch_size
)
test_dataloader = DataLoader(
    tokenized_datasets['test'], 
    batch_size=batch_size
)


# Check the shapes of the batches
for batch in train_dataloader:
    print(batch['premise_input_ids'].shape)  # Should be (batch_size, max_seq_length)
    print(batch['premise_attention_mask'].shape)  # Should be (batch_size, max_seq_length)
    print(batch['hypothesis_input_ids'].shape)  # Should be (batch_size, max_seq_length)
    print(batch['hypothesis_attention_mask'].shape)  # Should be (batch_size, max_seq_length)
    print(batch['labels'].shape)  # Should be (batch_size,)
    break

torch.Size([32, 128])
torch.Size([32, 128])
torch.Size([32, 128])
torch.Size([32, 128])
torch.Size([32])


In [6]:
# Initialize SBERT and loss function
sbert = SentenceBERT(model).to(device)
criterion = SoftmaxLoss(hidden_size=d_model, num_labels=3).to(device)  # 3 labels for SNLI/MNLI
optimizer = optim.Adam(sbert.parameters(), lr=0.001)

from tqdm.auto import tqdm

# Training loop
num_epochs = 3  # Adjust as needed
for epoch in range(num_epochs):
    sbert.train()
    total_loss = 0
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}")
    for batch in progress_bar:
        # Move batch to device
        premise_input_ids = batch['premise_input_ids'].to(device)
        premise_attention_mask = batch['premise_attention_mask'].to(device)
        hypothesis_input_ids = batch['hypothesis_input_ids'].to(device)
        hypothesis_attention_mask = batch['hypothesis_attention_mask'].to(device)
        labels = batch['labels'].to(device)

        # Forward pass
        optimizer.zero_grad()
        u, v = sbert(premise_input_ids, premise_attention_mask, hypothesis_input_ids, hypothesis_attention_mask)
        loss = criterion(u, v, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        progress_bar.set_postfix({'loss': loss.item()})
    
    avg_train_loss = total_loss / len(train_dataloader)
    print(f"Epoch {epoch + 1}, Average Training Loss: {avg_train_loss}")

    # Validation loop
    sbert.eval()
    total_val_loss = 0
    with torch.no_grad():
        for batch in eval_dataloader:
            # Move batch to device
            premise_input_ids = batch['premise_input_ids'].to(device)
            premise_attention_mask = batch['premise_attention_mask'].to(device)
            hypothesis_input_ids = batch['hypothesis_input_ids'].to(device)
            hypothesis_attention_mask = batch['hypothesis_attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # Forward pass
            u, v = sbert(premise_input_ids, premise_attention_mask, hypothesis_input_ids, hypothesis_attention_mask)
            loss = criterion(u, v, labels)
            total_val_loss += loss.item()
    
    avg_val_loss = total_val_loss / len(eval_dataloader)
    print(f"Epoch {epoch + 1}, Average Validation Loss: {avg_val_loss}")

Epoch 1:   0%|          | 0/17193 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
# Save the SBERT model
torch.save(sbert.state_dict(), 'sbert_model.pth')
print("SBERT model saved to sbert_model.pth")

In [None]:
from sklearn.metrics.pairwise import cosine_similarity

# Example sentences
sentence1 = "I love programming."
sentence2 = "Coding is my passion."

# Tokenize and encode the sentences
inputs_1 = tokenizer(sentence1, return_tensors='pt', padding=True, truncation=True)
inputs_2 = tokenizer(sentence2, return_tensors='pt', padding=True, truncation=True)

premise_input_ids = inputs_1['input_ids'].to(device)
premise_attention_mask = inputs_1['attention_mask'].to(device)
hypothesis_input_ids = inputs_2['input_ids'].to(device)
hypothesis_attention_mask = inputs_2['attention_mask'].to(device)

# Get sentence embeddings
u, v = sbert(premise_input_ids, premise_attention_mask, hypothesis_input_ids, hypothesis_attention_mask)

# Compute cosine similarity
cos_sim = cosine_similarity(u.cpu().detach().numpy(), v.cpu().detach().numpy())
print(f"Cosine Similarity: {cos_sim[0][0]}")

In [None]:
test_dataloader = DataLoader(
    tokenized_datasets['test'], 
    batch_size=batch_size, 
    shuffle=False
)

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report

# Initialize lists to store predictions and labels
all_preds = []
all_labels = []

# Set the model to evaluation mode
sbert.eval()

# Evaluate on the test set
with torch.no_grad():
    for batch in test_dataloader:
        # Move batch to device
        premise_input_ids = batch['premise_input_ids'].to(device)
        premise_attention_mask = batch['premise_attention_mask'].to(device)
        hypothesis_input_ids = batch['hypothesis_input_ids'].to(device)
        hypothesis_attention_mask = batch['hypothesis_attention_mask'].to(device)
        labels = batch['labels'].to(device)

        # Forward pass
        u, v = sbert(premise_input_ids, premise_attention_mask, hypothesis_input_ids, hypothesis_attention_mask)
        logits = criterion.fc(torch.cat([u, v, torch.abs(u - v)], dim=1))  # Combine u, v, |u-v|
        preds = torch.argmax(logits, dim=1)

        # Store predictions and labels
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Compute metrics
accuracy = accuracy_score(all_labels, all_preds)
precision = precision_score(all_labels, all_preds, average='weighted')
recall = recall_score(all_labels, all_preds, average='weighted')
f1 = f1_score(all_labels, all_preds, average='weighted')

# Print classification report
print("Classification Report:")
print(classification_report(all_labels, all_preds, target_names=['entailment', 'neutral', 'contradiction']))

# Print overall metrics
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1-Score: {f1:.4f}")