### preprocess_function

- preprocess_function: This function takes examples as input. Each example typically contains a premise, a hypothesis, and a label. The function tokenizes the premise and hypothesis using a tokenizer. It pads or truncates the tokenized sequences to a maximum sequence length of 128. Then, it extracts the labels from the examples. Finally, it returns a dictionary containing the tokenized premise and hypothesis along with their attention masks and the labels.

- tokenized_datasets.map: This method applies the preprocess_function to each example in the raw_dataset. It tokenizes the premises and hypotheses, extracts labels, and returns a new dataset (tokenized_datasets) with the processed data.

- tokenized_datasets.remove_columns: This method removes the columns 'premise', 'hypothesis', and 'label' from the dataset, as they are no longer needed after tokenization and preprocessing.

- tokenized_datasets.set_format("torch"): This method sets the format of the dataset to PyTorch tensors, making it compatible with PyTorch.

In [10]:
import torch
a = torch.LongTensor([[3,42,4,2,34,5]])
a.shape

torch.Size([1, 6])

In [11]:
a.unsqueeze(-1)

tensor([[[ 3],
         [42],
         [ 4],
         [ 2],
         [34],
         [ 5]]])

In [3]:
import torch
import numpy as np
import torch.nn as nn

class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask):
        d_k = Q.size(-1)
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)
        scores.masked_fill_(attn_mask, -1e9)
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attn, V)
        return context, attn

# Example tensors
Q = torch.randn(2, 3, 4)  # Batch size 2, 3 queries, each with 4 dimensions
K = torch.randn(2, 3, 4)  # Batch size 2, 3 keys, each with 4 dimensions
V = torch.randn(2, 3, 5)  # Batch size 2, 3 values, each with 5 dimensions

# Assuming an attention mask where certain positions are masked out
attn_mask = torch.zeros(2, 3, 3).bool()  # Shape [batch_size, num_heads, len_q(=len_k)]
attn_mask[:, :, -1] = True  # Mask out last position for all batches and heads

# Instantiate attention module
attention = ScaledDotProductAttention()

# Apply attention
context, attn = attention(Q, K, V, attn_mask)

print("Context shape:", context.shape)
print("Attention weights shape:", attn.shape)

Context shape: torch.Size([2, 3, 5])
Attention weights shape: torch.Size([2, 3, 3])
