In [None]:
# import packages
import numpy as np
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import datasets
from sklearn.metrics import classification_report
import time
import os
import re

In [None]:
SEED = 42
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

# Data preprocessing

In [None]:
class CustomTokenizer:
    def __init__(self, word2id):
        self.word2id = word2id
        # Define special token IDs
        self.pad_id = word2id.get("[PAD]", 0)
        self.cls_id = word2id.get("[CLS]", 1)
        self.sep_id = word2id.get("[SEP]", 2)
        self.mask_id = word2id.get("[MASK]", 3)
        # If '[UNK]' exists in vocabulary, use it; otherwise fallback to pad_id
        self.unk_id = word2id.get("[UNK]", self.pad_id)

    def tokenize(self, text):
        sent = text.lower()
        sent = re.sub(r"[.,!?\\-]", " ", sent)  # clean sentence
        tks = sent.split()  # split the sentence to tokens
        return tks

    def encode(self, text, max_length=128, padding=True, truncation=True):
        """Convert text to a list of token IDs with [CLS] and [SEP]."""
        tokens = self.tokenize(text)
        if truncation:
            tokens = tokens[: max_length - 2]  # leave room for [CLS] and [SEP]
        ids = (
            [self.cls_id]
            + [self.word2id.get(tok, self.unk_id) for tok in tokens]
            + [self.sep_id]
        )
        if padding:
            pad_len = max_length - len(ids)
            if pad_len > 0:
                ids += [self.pad_id] * pad_len
        return ids

    def __call__(
        self, texts, max_length=128, padding=True, truncation=True, return_tensors="pt"
    ):
        """Batch encode texts. Returns a dict with 'input_ids' and 'attention_mask'."""
        batch_ids = []
        for text in texts:
            ids = self.encode(
                text, max_length=max_length, padding=padding, truncation=truncation
            )
            batch_ids.append(ids)

        input_ids = torch.tensor(batch_ids, dtype=torch.long)
        attention_mask = (input_ids != self.pad_id).long()
        return {"input_ids": input_ids, "attention_mask": attention_mask}

In [None]:
def load_nli_dataset(dataset_name="snli", max_samples=100000):
    """Load SNLI or MNLI dataset"""
    if dataset_name.lower() == "snli":
        dataset = datasets.load_dataset("snli")
    else:
        dataset = datasets.load_dataset("glue", "mnli")

    # Filter out invalid labels (-1)
    dataset = dataset.filter(lambda x: x["label"] != -1)

    # Take subset if needed
    if max_samples:
        dataset["train"] = dataset["train"].select(
            range(min(max_samples, len(dataset["train"])))
        )

    return dataset

def preprocess_nli_data(dataset, tokenizer, max_length=128):
    """Tokenize NLI dataset using custom tokenizer."""
    def tokenize_function(examples):
        # Tokenize premise and hypothesis
        premise = tokenizer(examples['premise'], max_length=max_length, padding=True, truncation=True)
        hypothesis = tokenizer(examples['hypothesis'], max_length=max_length, padding=True, truncation=True)

        return {
            'premise_input_ids': premise['input_ids'],
            'premise_attention_mask': premise['attention_mask'],
            'hypothesis_input_ids': hypothesis['input_ids'],
            'hypothesis_attention_mask': hypothesis['attention_mask'],
            'labels': examples['label']
        }

    tokenized_dataset = dataset.map(tokenize_function, batched=True)
    tokenized_dataset.set_format(
        type='torch',
        columns=[
            'premise_input_ids',
            'premise_attention_mask',
            'hypothesis_input_ids',
            'hypothesis_attention_mask',
            'labels'
        ]
    )
    return tokenized_dataset

# Bert model

In [None]:
class Embedding(nn.Module):
    def __init__(self, vocab_size, d_model=768, max_len=512, n_segments=2):
        super(Embedding, self).__init__()
        # embedding matrix: maps token ids -> vectors
        self.tok_embed = nn.Embedding(
            vocab_size, d_model
        )  # (V, D) ; lookup on input (B, L) -> (B, L, D)
        # positional embedding matrix: maps position idx -> vectors
        self.pos_embed = nn.Embedding(
            max_len, d_model
        )  # (M, D) ; pos embedding for positions 0..L-1 -> (B, L, D)

        # segment (token type) embedding: maps segment id -> vectors
        self.seg_embed = nn.Embedding(
            n_segments, d_model
        )  # (S, D) ; seg lookup on seg ids (B, L) -> (B, L, D)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, seg):
        seq_len = x.size(1)
        pos = torch.arange(seq_len, dtype=torch.long, device=x.device)
        pos = pos.unsqueeze(0).expand_as(x)
        embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)
        return self.norm(embedding)


def get_attn_pad_mask(seq_q, seq_k):
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)
    return pad_attn_mask.expand(batch_size, len_q, len_k)


class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask):
        """
        x: input ids shape (B, L)
        seg: segment ids shape (B, L)
        """
        d_k = Q.size(-1)
        scores = torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(d_k)
        # K.transpose(-1, -2): (B, H, d_k, L_k)
        # Q:                   (B, H, L_q, d_k)
        # matmul result:       (B, H, L_q, L_k)
        # scores:              (B, H, L_q, L_k)
        scores.masked_fill_(attn_mask, -1e9)
        # attn_mask:           (B, H, L_q, L_k)
        # scores unchanged shape: (B, H, L_q, L_k)
        attn = nn.Softmax(dim=-1)(scores)  # (B, H, L_q, L_k)
        context = torch.matmul(attn, V)  # (B, H, L_q, d_v)
        return context, attn


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=768, n_heads=12, d_k=64, d_v=64):
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads  # H
        self.d_k = d_k  # d_k
        self.d_v = d_v  # d_v

        # linear projections for Q, K, V
        # W_Q: projects (B, L, D) -> (B, L, H * d_k)
        # weight shape: (D, H*d_k) ; bias shape: (H*d_k,)
        self.W_Q = nn.Linear(d_model, d_k * n_heads)
        self.W_K = nn.Linear(d_model, d_k * n_heads)
        self.W_V = nn.Linear(d_model, d_v * n_heads)

        self.linear = nn.Linear(n_heads * d_v, d_model)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, Q, K, V, attn_mask):
        # Q, K, V: (B, L, D)
        residual = Q
        batch_size = Q.size(0)  # B

        q_s = (
            self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        )  # q_s: (B, H, L, d_k)
        k_s = (
            self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        )  # k_s: (B, H, L, d_k)
        v_s = (
            self.W_V(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2)
        )  # v_s: (B, H, L, d_v)

        attn_mask = attn_mask.unsqueeze(1).repeat(
            1, self.n_heads, 1, 1
        )  # (B, H, L_q, L_k)

        context, attn = ScaledDotProductAttention()(
            q_s, k_s, v_s, attn_mask
        )  # context: (B, H, L, d_v) ; attn: (B, H, L, L)

        context = (
            context.transpose(1, 2)
            .contiguous()
            .view(batch_size, -1, self.n_heads * self.d_v)
        )  # (B, L, H * d_v)

        output = self.linear(context)

        return (
            self.norm(output + residual),
            attn,
        )  # output: (B, L, D), attn: (B, H, L, L)


class PoswiseFeedForwardNet(nn.Module):
    def __init__(self, d_model=768, d_ff=3072):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.fc2(torch.nn.functional.gelu(self.fc1(x)))


class EncoderLayer(nn.Module):
    def __init__(self, d_model=768, n_heads=12, d_ff=3072):
        super(EncoderLayer, self).__init__()
        # self-attention sublayer
        self.enc_self_attn = MultiHeadAttention(d_model, n_heads)
        # feed-forward sublayer
        self.pos_ffn = PoswiseFeedForwardNet(d_model, d_ff)

    def forward(self, enc_inputs, enc_self_attn_mask):
        enc_outputs, attn = self.enc_self_attn(
            enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask
        )
        enc_outputs = self.pos_ffn(enc_outputs)
        return enc_outputs, attn


class BERT(nn.Module):
    def __init__(self, vocab_size, d_model=768, n_layers=12, n_heads=12, max_len=512):
        super(BERT, self).__init__()
        self.d_model = d_model
        self.n_layers = n_layers
        self.embedding = Embedding(vocab_size, d_model, max_len)  #  (B, L, D)
        self.layers = nn.ModuleList(
            [EncoderLayer(d_model, n_heads) for _ in range(n_layers)]
        )

        # MLM head
        self.fc = nn.Linear(d_model, d_model)
        self.activ = nn.Tanh()
        self.linear = nn.Linear(d_model, d_model)
        self.norm = nn.LayerNorm(d_model)

        # NSP head
        self.nsp_classifier = nn.Linear(d_model, 2)

        # MLM decoder (tied to embedding)
        self.decoder = nn.Linear(d_model, vocab_size, bias=False)
        self.decoder.weight = self.embedding.tok_embed.weight
        self.decoder_bias = nn.Parameter(torch.zeros(vocab_size))

    def forward(self, input_ids, segment_ids, masked_pos=None):
        # Embedding
        output = self.embedding(input_ids, segment_ids)  # (B, L, D)

        # Attention mask
        attn_mask = get_attn_pad_mask(input_ids, input_ids)  # (B, L, L)

        # Encoder layers
        for layer in self.layers:
            output, _ = layer(output, attn_mask)

        # NSP prediction (using [CLS] token)
        cls_output = output[:, 0, :]  # (B, D)
        nsp_logits = self.nsp_classifier(cls_output)  # (B, 2)

        # MLM prediction (if masked_pos provided)
        if masked_pos is not None:
            masked_pos = masked_pos[:, :, None].expand(
                -1, -1, self.d_model
            )  # (B, N_mask, D)
            h_masked = torch.gather(output, 1, masked_pos)  # (B, N_mask, D)
            h_masked = self.norm(
                torch.nn.functional.gelu(self.linear(h_masked))
            )  # (B, N_mask, D)
            mlm_logits = self.decoder(h_masked) + self.decoder_bias  # (B, N_mask, V)
            return mlm_logits, nsp_logits

        return output, nsp_logits

# Sentence bert model

In [None]:
def mean_pooling(token_embeddings, attention_mask):
    """Mean pooling for sentence embeddings"""
    input_mask_expanded = (
        attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    )
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
        input_mask_expanded.sum(1), min=1e-9
    )


class SentenceBERT(nn.Module):
    """Siamese network structure for Sentence-BERT"""

    def __init__(self, bert_model, hidden_dim=768, num_classes=3):
        super(SentenceBERT, self).__init__()
        self.bert = bert_model
        self.classifier = nn.Linear(hidden_dim * 3, num_classes)

    def forward(self, input_ids_a, attention_mask_a, input_ids_b, attention_mask_b):
        # Get embeddings from both sentences
        emb_a = self.get_sentence_embedding(input_ids_a, attention_mask_a)
        emb_b = self.get_sentence_embedding(input_ids_b, attention_mask_b)

        # Configuration
        # Concatenate features: u, v, |u-v|
        diff = torch.abs(emb_a - emb_b)
        features = torch.cat([emb_a, emb_b, diff], dim=-1)

        # Classification
        logits = self.classifier(features)
        return logits

    def get_sentence_embedding(self, input_ids, attention_mask):
        """Extract sentence embedding from BERT"""
        # Get BERT outputs
        outputs, _ = self.bert(input_ids, torch.zeros_like(input_ids))

        # Mean pooling
        sentence_embedding = mean_pooling(outputs, attention_mask)
        return sentence_embedding

    def encode(self, input_ids, attention_mask):
        """Encode sentence for similarity computation"""
        with torch.no_grad():
            return self.get_sentence_embedding(input_ids, attention_mask)

# Training

In [None]:
# # For this experiment we are using snli dataset
# dataset1 = load_nli_dataset("snli")
# dataset1

In [None]:
# dataset1['train']

In [None]:
# dataset1['train']['premise']

In [None]:
# del dataset1

In [None]:
def train_sentence_bert(
    bert_model,
    tokenizer,
    dataset_name="snli",
    num_epochs=5,
    batch_size=16,
    hidden_dim=768,
    max_train_samples=20000,
    print_every_epoch=True,
):
    # Load dataset
    dataset = load_nli_dataset(dataset_name, max_samples=max_train_samples)
    tokenized_dataset = preprocess_nli_data(dataset, tokenizer)

    train_loader = DataLoader(
        tokenized_dataset["train"], batch_size=batch_size, shuffle=True
    )
    val_loader = DataLoader(tokenized_dataset["validation"], batch_size=batch_size)

    model = SentenceBERT(bert_model, hidden_dim=hidden_dim).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
    criterion = nn.CrossEntropyLoss()

    total_steps = len(train_loader) * num_epochs
    warmup_steps = int(0.1 * total_steps)
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer, lr_lambda=lambda step: min(1.0, (step + 1) / warmup_steps)
    )

    print("Starting Sentence-BERT training...")
    start_time = time.perf_counter()

    epoch_train_losses = []
    epoch_val_losses = []
    epoch_val_accuracies = []

    for epoch in range(num_epochs):
        model.train()
        train_loss_sum = 0.0
        train_batches = 0

        for batch in train_loader:
            premise_ids = batch["premise_input_ids"].to(device)
            premise_mask = batch["premise_attention_mask"].to(device)
            hypothesis_ids = batch["hypothesis_input_ids"].to(device)
            hypothesis_mask = batch["hypothesis_attention_mask"].to(device)
            labels = batch["labels"].to(device)

            optimizer.zero_grad()
            logits = model(premise_ids, premise_mask, hypothesis_ids, hypothesis_mask)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()  # update learning rate

            train_loss_sum += loss.item()
            train_batches += 1

        avg_train_loss = train_loss_sum / max(1, train_batches)
        epoch_train_losses.append(avg_train_loss)

        # Validation (same as before)
        model.eval()
        val_loss_sum = 0.0
        val_examples = 0
        correct = 0
        with torch.no_grad():
            for batch in val_loader:
                premise_ids = batch["premise_input_ids"].to(device)
                premise_mask = batch["premise_attention_mask"].to(device)
                hypothesis_ids = batch["hypothesis_input_ids"].to(device)
                hypothesis_mask = batch["hypothesis_attention_mask"].to(device)
                labels = batch["labels"].to(device)

                logits = model(
                    premise_ids, premise_mask, hypothesis_ids, hypothesis_mask
                )
                loss = criterion(logits, labels)
                batch_size = labels.size(0)
                val_loss_sum += loss.item() * batch_size
                val_examples += batch_size

                preds = torch.argmax(logits, dim=1)
                correct += (preds == labels).sum().item()

        avg_val_loss = val_loss_sum / max(1, val_examples)
        val_accuracy = correct / max(1, val_examples)

        epoch_val_losses.append(avg_val_loss)
        epoch_val_accuracies.append(val_accuracy)

        if print_every_epoch:
            print(
                f"Epoch {epoch+1}/{num_epochs}  |  train_loss: {avg_train_loss:.4f}  |  val_loss: {avg_val_loss:.4f}  |  val_acc: {val_accuracy:.4f}"
            )

    total_ms = (time.perf_counter() - start_time) * 1000.0
    return model, epoch_train_losses, epoch_val_losses, epoch_val_accuracies, total_ms

#### Load BERT model checkpoint

In [None]:
bert_model_path = os.path.join(os.getcwd(), "app", "saved_models", "bert_trained.pth")
checkpoint_bert = torch.load(bert_model_path)

In [None]:
checkpoint_bert.keys()

In [None]:
vocab_size, d_model, n_layers, n_heads, max_len = (
    checkpoint_bert["vocab_size"],
    checkpoint_bert["d_model"],
    checkpoint_bert["n_layers"],
    checkpoint_bert["n_heads"],
    checkpoint_bert["max_len"],
)

# Recreate BERT model
bert_model = BERT(
    vocab_size=checkpoint_bert["vocab_size"],
    d_model=checkpoint_bert["d_model"],
    n_layers=checkpoint_bert["n_layers"],
    n_heads=checkpoint_bert["n_heads"],
    max_len=checkpoint_bert["max_len"],
)
bert_model.load_state_dict(checkpoint_bert["model_state_dict"])

# Create custom tokenizer using the saved word2id
tokenizer = CustomTokenizer(checkpoint_bert["word2id"])

In [None]:
# train sentence bert
model, epoch_train_losses, epoch_val_losses, epoch_val_accuracies, total_ms = (
    train_sentence_bert(
        bert_model,
        dataset_name="snli",
        batch_size=4,
        num_epochs=20,
        max_train_samples=20000,
        hidden_dim=d_model,
        tokenizer=tokenizer,
    )
)

In [None]:
# show results
avg_train_loss = sum(epoch_train_losses) / len(epoch_train_losses)
avg_val_loss = sum(epoch_val_losses) / len(epoch_val_losses)
avg_val_acc = sum(epoch_val_accuracies) / len(epoch_val_accuracies)

print(f"{'Avg Train Loss':<18}{'Avg Val Loss':<18}{'Avg Val Accuracy'}")
print(f"{avg_train_loss:<18.4f}{avg_val_loss:<18.4f}{avg_val_acc:.4f}")

<h3>Train and Validation Loss</h3>
<img src="SBertLoss.png" alt="SBert" width="500" />


# Evaluation and Analysis

In [None]:
path_ = os.path.join(os.getcwd(), "app", "saved_models")
SBert_checkpoint = torch.load(os.path.join(path_, "sbert_full_model_new1.pth"), map_location='cuda')

In [None]:
SBert_checkpoint.keys()

In [None]:
hidden_dim = checkpoint_bert['d_model']         
num_classes = SBert_checkpoint['num_classes']    
model = SentenceBERT(bert_model, hidden_dim=hidden_dim, num_classes=num_classes)

model.classifier.load_state_dict(SBert_checkpoint['sbert_classifier_state_dict'])

model.to(device)
model.eval()

In [None]:
# Create custom tokenizer using saved word2id
tokenizer = CustomTokenizer(SBert_checkpoint["word2id"])

test_dataset = load_nli_dataset("snli")
test_data = preprocess_nli_data(test_dataset["test"], tokenizer, max_length=128)

# Create test loader
test_loader = DataLoader(test_data, batch_size=4)

all_preds = []
all_labels = []

with torch.no_grad():
    for batch in test_loader:
        premise_ids = batch["premise_input_ids"].to(device)
        premise_mask = batch["premise_attention_mask"].to(device)
        hypothesis_ids = batch["hypothesis_input_ids"].to(device)
        hypothesis_mask = batch["hypothesis_attention_mask"].to(device)
        labels = batch["labels"].to(device)

        logits = model(premise_ids, premise_mask, hypothesis_ids, hypothesis_mask)
        preds = torch.argmax(logits, dim=1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Generate classification report
from sklearn.metrics import classification_report

print("\nClassification Report on Test Set:")
print(
    classification_report(
        all_labels,
        all_preds,
        target_names=["entailment", "neutral", "contradiction"],
        digits=3,
    )
)

<h3>Classification Report on Test Set</h3>
<img src="SBertLoss.png" alt="SBert" width="500" />

# Inference

In [None]:
import torch
import torch.nn as nn
import ipywidgets as widgets
from IPython.display import display, clear_output
import re


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint_path = "app/saved_models/sbert_full_model_new1.pth"

checkpoint = torch.load(checkpoint_path, map_location="cpu")

# Retrieve word2id and create tokenizer
word2id = checkpoint["word2id"]
tokenizer = CustomTokenizer(word2id)
print(f"Vocabulary size: {len(word2id)}")

# Reconstruct BERT from saved config
bert_config = checkpoint["bert_config"]
bert_model = BERT(
    vocab_size=bert_config["vocab_size"],
    d_model=bert_config["d_model"],
    n_layers=bert_config["n_layers"],
    n_heads=bert_config["n_heads"],
    max_len=bert_config["max_len"],
)
bert_model.load_state_dict(checkpoint["bert_state_dict"])


first_weight = bert_model.embedding.tok_embed.weight[0, :5]
print("First few BERT weights (embedding layer):", first_weight)

# Create SentenceBERT
hidden_dim = bert_config["d_model"] 
num_classes = checkpoint["num_classes"] 
model = SentenceBERT(
    bert_model, hidden_dim=hidden_dim, num_classes=num_classes
)  
model.classifier.load_state_dict(checkpoint["sbert_classifier_state_dict"])
model.to(device)
model.eval()

print("Model and tokenizer loaded successfully!\n")


def predict_nli(premise, hypothesis):
    # Tokenize
    premise_enc = tokenizer(
        [premise], max_length=128, padding=True, truncation=True, return_tensors="pt"
    )
    hypothesis_enc = tokenizer(
        [hypothesis], max_length=128, padding=True, truncation=True, return_tensors="pt"
    )

    premise_ids = premise_enc["input_ids"].to(device)
    premise_mask = premise_enc["attention_mask"].to(device)
    hypothesis_ids = hypothesis_enc["input_ids"].to(device)
    hypothesis_mask = hypothesis_enc["attention_mask"].to(device)

    print("\n--- Input IDs (first 10) ---")
    print("Premise IDs:", premise_ids[0, :10].cpu().tolist())
    print("Hypothesis IDs:", hypothesis_ids[0, :10].cpu().tolist())

    with torch.no_grad():
        # Get embeddings before classifier
        u = model.get_sentence_embedding(premise_ids, premise_mask)
        v = model.get_sentence_embedding(hypothesis_ids, hypothesis_mask)
        print("\n--- Embedding stats ---")
        print(f"u mean: {u.mean().item():.6f}, u std: {u.std().item():.6f}")
        print(f"v mean: {v.mean().item():.6f}, v std: {v.std().item():.6f}")

        # Compute logits
        logits = model(premise_ids, premise_mask, hypothesis_ids, hypothesis_mask)
        print("--- Logits ---")
        print(logits.cpu().tolist())

        probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
        pred_idx = torch.argmax(logits, dim=1).item()

    labels = ["entailment", "neutral", "contradiction"]
    return labels[pred_idx], probs


# WIDGETS
premise_input = widgets.Textarea(
    value="",
    placeholder="Enter premise here...",
    description="Premise:",
    disabled=False,
    layout=widgets.Layout(width="100%", height="100px"),
)

hypothesis_input = widgets.Textarea(
    value="",
    placeholder="Enter hypothesis here...",
    description="Hypothesis:",
    disabled=False,
    layout=widgets.Layout(width="100%", height="100px"),
)

predict_button = widgets.Button(
    description="Predict NLI",
    button_style="primary",
    layout=widgets.Layout(width="200px"),
)

output = widgets.Output()

example_button = widgets.Button(
    description="Load Example",
    button_style="info",
    layout=widgets.Layout(width="200px"),
)


def on_predict_clicked(b):
    with output:
        clear_output(wait=True)  # clear previous output
        premise = premise_input.value.strip()
        hypothesis = hypothesis_input.value.strip()
        if not premise or not hypothesis:
            print("Please enter both premise and hypothesis.")
            return
        label, probs = predict_nli(premise, hypothesis)
        print(f"\nPrediction: {label.upper()}")
        print("Probabilities:")
        for i, lbl in enumerate(["entailment", "neutral", "contradiction"]):
            print(f"  {lbl}: {probs[i]:.3f}")


def on_example_clicked(b):
    premise_input.value = "A man is playing a guitar on stage."
    hypothesis_input.value = "The man is performing music."


predict_button.on_click(on_predict_clicked)
example_button.on_click(on_example_clicked)

ui = widgets.VBox(
    [
        premise_input,
        hypothesis_input,
        widgets.HBox([predict_button, example_button]),
        output,
    ]
)

display(ui)

# Save model

In [None]:
# for sen bert
hidden_dim = 512
num_classes = 3

save_dict = {
    # BERT weights and config
    "bert_state_dict": bert_model.state_dict(),
    "word2id": checkpoint_bert["word2id"],
    "id2word": checkpoint_bert["id2word"],
    "bert_config": {
        "vocab_size": checkpoint_bert["vocab_size"],
        "d_model": checkpoint_bert["d_model"],
        "n_layers": checkpoint_bert["n_layers"],
        "n_heads": checkpoint_bert["n_heads"],
        "max_len": checkpoint_bert["max_len"],
    },
    # SentenceBERT classifier weights
    "sbert_classifier_state_dict": model.classifier.state_dict(),
    "hidden_dim": hidden_dim,
    "num_classes": num_classes,
}

torch.save(save_dict, "sbert_full_model_new.pth")
print("Model saved to sbert_full_model.pth")