# Results

In [30]:
import warnings

import torch
from torch import nn
from torch.nn.utils.rnn import pad_sequence

from src.data.dataset import SNLIVocabulary, SNLIDataset
from src.models.net import NLIModel
from src.models.classifiers import Classifier
from src.models.encoders import BiLSTMEncoder

warnings.filterwarnings("ignore")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Setup

In [3]:
vocab = SNLIDataset("data/", split="train",).vocab

In [6]:
embedding_dim = 300
hidden_size= 300

encoder = BiLSTMEncoder(
    word_embeddings=vocab.word_embedding,
    input_size=embedding_dim,
    hidden_size=hidden_size,
    max_pooling=True,
)
classifier = Classifier(2 * hidden_size)
model = NLIModel(encoder, classifier).to(device)

model.load_state_dict(torch.load("models/bilstm-max/best_model.pt", map_location=device))
model = model.to(device)

In [37]:
def predict(
    model: nn.Module,
    vocab: SNLIVocabulary,
    device: torch.device,
    premise: str,
    hypothesis: str,
) -> str:
    """Predict the entailment label of the given premise and hypothesis.

    Args:
        model (nn.Module): The model (encoder + classifier).
        vocab (SNLIVocabulary): The vocabulary.
        device (torch.device): The device to use.
        premise (str): The premise.
        hypothesis (str): The hypothesis.

    Returns:
        str: The predicted entailment label.
    """
    # Set the model to evaluation mode
    model.eval()
    
    id_to_label = {
            0: "entailment",
            1: "neutral",
            2: "contradiction",
        }

    # Disable gradient computation
    with torch.no_grad():
        # Tokenize and index the premise and hypothesis
        premise_indices = vocab.tokenize_and_index(premise)
        hypothesis_indices = vocab.tokenize_and_index(hypothesis)

        # Convert indices to tensors and wrap them in a list
        premise_indices = [torch.tensor(premise_indices, dtype=torch.long)]
        hypothesis_indices = [torch.tensor(hypothesis_indices, dtype=torch.long)]

        # Pad sequences and compute lengths
        padded_premises = pad_sequence(premise_indices, batch_first=True, padding_value=1)
        premise_lengths = torch.tensor([len(premise_indices[0])], dtype=torch.long)
        padded_hypotheses = pad_sequence(hypothesis_indices, batch_first=True, padding_value=1)
        hypothesis_lengths = torch.tensor([len(hypothesis_indices[0])], dtype=torch.long)

        # Move the batch to the device
        padded_premises = padded_premises.to(device)
        premise_lengths = premise_lengths.to(device)
        padded_hypotheses = padded_hypotheses.to(device)
        hypothesis_lengths = hypothesis_lengths.to(device)

        # Compute the logits
        logits = model(padded_premises, premise_lengths, padded_hypotheses, hypothesis_lengths)

        # Get the predictions
        predictions = torch.argmax(logits, dim=-1)
        
    return id_to_label[int(predictions.item())]


## Predictions

Predictions using the BiLSTM model with max pooling.

In [38]:
premise_1 = "Two men sitting in the sun"
hypothesis_1 = "Nobody is sitting in the shade"
label_1 = predict(model, vocab, device, premise_1, hypothesis_1)
print(f"Label: {label_1}")

Label: contradiction


In [39]:
premise_1 = "A man is walking a dog"
hypothesis_2 = "No cat is outside"
label_2 = predict(model, vocab, device, premise_1, hypothesis_2)
print(f"Label: {label_2}")

Label: contradiction


A possible reason for the failure is the presence of negations in the hypotheses, which might lead the model to focus on the opposite aspect between the premise and the hypothesis. The models may be more sensitive to negation words like "nobody" and "no" in the hypothesis, causing it to perceive a stronger contradiction than exists. Additionally, the model might struggle with understanding the relationships between different entities in the sentences, such as "men" and "nobody," or "dog" and "cat." This difficulty in capturing the semantic relationships between entities could lead the model to assess the relationship between the premise and the hypothesis incorrectly. 

## Results

The following table shows the results of the models on the SNLI dev and test sets, and the micro and macro averaged results on the SentEval tasks.

| **Model** | **SNLI Dev** | **SNLI Test** | **Micro** | **Macro** |
|---|---|---|---|---|
| Baseline | 0.657 | 0.654 | 80.498 | 79.155 |
| LSTM | 0.815 | 0.814 | 76.088 | 75.567 |
| BiLSTM | 0.807 | 0.816 | 79.128 | 78.634 |
| BiLSTM (max) | 0.847 | 0.841 | 79.583 | 79.006 |