In [7]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from datasets import load_dataset
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import numpy as np
from tqdm import tqdm

In [8]:
def load_paws_dataset():
    return load_dataset("google-research-datasets/paws", "labeled_final")

In [9]:
def create_dataloader(data, tokenizer, batch_size=16, shuffle=False, max_length=128):
    def collate_batch(batch):
        encoded = tokenizer(
            [item['sentence1'] for item in batch],
            [item['sentence2'] for item in batch],
            max_length=max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        labels = torch.tensor([item['label'] for item in batch], dtype=torch.float)
        return {
            'input_ids': encoded['input_ids'],
            'attention_mask': encoded['attention_mask'],
            'label': labels
        }
    
    return DataLoader(
        data,
        batch_size=batch_size,
        shuffle=shuffle,
        collate_fn=collate_batch
    )

In [10]:
class BertSimilarityClassifier(nn.Module):
    def __init__(self, model_name='bert-base-uncased'):
        super().__init__()
        self.bert = BertForSequenceClassification.from_pretrained(model_name, num_labels=1)
    
    def forward(self, input_ids, attention_mask):
        return self.bert(input_ids=input_ids, attention_mask=attention_mask).logits

def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    
    for batch in tqdm(dataloader, desc="Training"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device).unsqueeze(1)
        
        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask)
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    return total_loss / len(dataloader)

In [11]:
def evaluate(model, dataloader, device):
    model.eval()
    true_labels = []
    predictions = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            outputs = model(input_ids, attention_mask)
            predictions.extend(outputs.squeeze().cpu().numpy())
            true_labels.extend(batch['label'].numpy())
    
    true_labels = np.array(true_labels)
    predictions = np.array(predictions)
    
    predicted_labels = (predictions >= 0.5).astype(int)
    true_labels_binary = (true_labels >= 0.5).astype(int)
    
    accuracy = accuracy_score(true_labels_binary, predicted_labels)
    precision, recall, f1, _ = precision_recall_fscore_support(
        true_labels_binary, predicted_labels, average='binary'
    )
    correlation = np.corrcoef(true_labels, predictions)[0, 1]
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1_score': f1,
        'correlation': correlation
    }

In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
dataset = load_paws_dataset()

train_loader = create_dataloader(dataset['train'], tokenizer, shuffle=True)
test_loader = create_dataloader(dataset['test'], tokenizer)

model = BertSimilarityClassifier().to(device)
optimizer = AdamW(model.parameters(), lr=2e-5)
criterion = nn.MSELoss()

num_epochs = 2
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch + 1}/{num_epochs}")
    train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
    print(f"Training Loss: {train_loss:.4f}")
    
    metrics = evaluate(model, test_loader, device)
    for metric_name, value in metrics.items():
        print(f"{metric_name}: {value:.4f}")

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Epoch 1/2


Training:   4%|▍         | 134/3088 [00:47<17:21,  2.84it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:   6%|▋         | 193/3088 [01:08<17:02,  2.83it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:  18%|█▊        | 564/3088 [03:15<14:54,  2.82it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:  36%|███▌      | 1118/3088 [06:24<11:29,  2.86it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.

Training Loss: 0.1593


Evaluating: 100%|██████████| 500/500 [00:55<00:00,  9.01it/s]


accuracy: 0.8919
precision: 0.8512
recall: 0.9154
f1_score: 0.8821
correlation: 0.8210

Epoch 2/2


Training:   7%|▋         | 202/3088 [01:10<16:54,  2.85it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:  24%|██▍       | 742/3088 [04:18<13:43,  2.85it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:  28%|██▊       | 862/3088 [05:00<12:56,  2.87it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Training:  48%|████▊     | 1472/3088 [08:28<09:16,  2.90it/s]Be aware, overflowing tokens are not returned for the setting you have chosen, i.

Training Loss: 0.0570


Evaluating: 100%|██████████| 500/500 [00:57<00:00,  8.74it/s]

accuracy: 0.9019
precision: 0.8585
recall: 0.9316
f1_score: 0.8935
correlation: 0.8374



