In [20]:
from transformers import BertForSequenceClassification, AdamW

# Load pre-trained BERT model for sequence classification
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer
optimizer = AdamW(model.parameters(), lr=2e-5)

# Training loop
epochs = 3
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for batch in dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        total_loss += loss.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print(f"Epoch {epoch + 1}, Loss: {total_loss / len(dataloader)}")


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1, Loss: 0.6532773574193319
Epoch 2, Loss: 0.6364534497261047
Epoch 3, Loss: 0.5506133139133453


In [26]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer

class RTEDataset(Dataset):
    def __init__(self, premises, hypotheses, labels, tokenizer, max_length):
        assert len(premises) == len(hypotheses) == len(labels), "Dataset lengths are not equal."
        self.premises = premises
        self.hypotheses = hypotheses
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.premises)

    def __getitem__(self, idx):
        premise = self.premises[idx]
        hypothesis = self.hypotheses[idx]
        label = self.labels[idx]
        encoding = self.tokenizer.encode_plus(
            premise,
            hypothesis,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

# Example data for RTE task
premises = [
    "Your privacy is protected.",
    "We do not collect and share your data with third parties.",
    "Users have full control over their data.",
    "All transactions are encrypted.",
    "We may share data with third-party partners."
]
hypotheses = [
    "We do not protect your privacy.",
    "We collect and share your data with third parties.",
    "Users have no control over their data.",
    "All transactions are insecure.",
    "We do not share data with third-party partners."
]
labels = [0, 0, 0, 0, 0]  # 1 for entailment, 0 for non-entailment

 # Adding both entailment and non-entailment examples

# Initialize tokenizer and dataset
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
dataset = RTEDataset(premises, hypotheses, labels, tokenizer, max_length=128)

# Split into training and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Data loaders
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=2)

# Check a few batches
for batch in train_dataloader:
    print(batch)
    break


{'input_ids': tensor([[ 101, 5198, 2031, 2440, 2491, 2058, 2037, 2951, 1012,  102, 5198, 2031,
         2053, 2491, 2058, 2037, 2951, 1012,  102,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0],
        [ 101, 2057, 2089, 3745, 2951, 2007, 2353, 1011, 2283, 5826, 1012,  102,
         2057, 2079, 2025, 3745, 2951

In [27]:
def evaluate_model(model, dataloader):
    model.eval()
    predictions, true_labels = [], []
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            preds = torch.argmax(logits, dim=1).flatten()
            
            predictions.extend(preds.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
    
    return predictions, true_labels

# Evaluate the model
predictions, true_labels = evaluate_model(model, dataloader)
for i, (premise, hypothesis) in enumerate(zip(premises, hypotheses)):
    print(f"Premise: {premise}")
    print(f"Hypothesis: {hypothesis}")
    print(f"Prediction: {'Entailment' if predictions[i] == 1 else 'Non-entailment'}")
    print(f"True Label: {'Entailment' if true_labels[i] == 1 else 'Non-entailment'}\n")


Premise: Your privacy is protected.
Hypothesis: We do not protect your privacy.
Prediction: Non-entailment
True Label: Non-entailment

Premise: We do not collect and share your data with third parties.
Hypothesis: We collect and share your data with third parties.
Prediction: Non-entailment
True Label: Non-entailment

Premise: Users have full control over their data.
Hypothesis: Users have no control over their data.
Prediction: Non-entailment
True Label: Non-entailment

Premise: All transactions are encrypted.
Hypothesis: All transactions are insecure.
Prediction: Non-entailment
True Label: Non-entailment

Premise: We may share data with third-party partners.
Hypothesis: We do not share data with third-party partners.
Prediction: Non-entailment
True Label: Entailment

