In [None]:
import logging
logging.disable(logging.WARNING)

In [None]:
import pandas as pd
import torch
from transformers import BertTokenizer, BertModel, BertConfig, AdamW, get_linear_schedule_with_warmup
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.metrics import mean_squared_error
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
import os

# Load the training and validation datasets
train_df = pd.read_csv('./A3_task1_data_files/train.csv', sep='\t')
val_df = pd.read_csv('./A3_task1_data_files/dev.csv', sep='\t')

class TextSimilarityDataset(Dataset):
    def __init__(self, df, tokenizer, max_length=128):
        self.df = df.dropna()  
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        sentence1 = self.df.iloc[idx, self.df.columns.get_loc('sentence1')]
        sentence2 = self.df.iloc[idx, self.df.columns.get_loc('sentence2')]
        score = self.df.iloc[idx, self.df.columns.get_loc('score')]

        inputs = self.tokenizer.encode_plus(
            sentence1,
            sentence2,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'input_ids': inputs['input_ids'].squeeze(0),
            'attention_mask': inputs['attention_mask'].squeeze(0),
            'labels': torch.tensor(score, dtype=torch.float)
        }

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

for param in model.parameters():
    param.requires_grad = False

# Modify the model output layer for regression
model.config.num_labels = 1
model.classifier = torch.nn.Linear(model.config.hidden_size, model.config.num_labels)

# Define loss function and optimizer
criterion = torch.nn.MSELoss()
optimizer = AdamW(model.parameters(), lr=1e-5, no_deprecation_warning=True)
total_steps = len(train_df) * 10
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

train_dataset = TextSimilarityDataset(train_df, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_dataset = TextSimilarityDataset(val_df, tokenizer)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)


num_epochs = 10
train_losses = []
val_losses = []
for epoch in range(num_epochs):
    model.train()
    total_train_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']

        outputs = model(input_ids, attention_mask=attention_mask)
        predictions = outputs.last_hidden_state.mean(dim=1)
        predictions = model.classifier(predictions).squeeze(1)

        loss = criterion(predictions, labels)
        total_train_loss += loss.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

    train_loss = total_train_loss / len(train_loader)
    train_losses.append(train_loss)

    model.eval()
    total_val_loss = 0
    all_predictions = []
    all_labels = []
    for batch in val_loader:
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']
        with torch.no_grad():
            outputs = model(input_ids, attention_mask=attention_mask)
            predictions = outputs.last_hidden_state.mean(dim=1)
            predictions = model.classifier(predictions).squeeze(1)
        val_loss = criterion(predictions, labels)
        total_val_loss += val_loss.item()
        all_predictions.extend(predictions.tolist())
        all_labels.extend(labels.tolist())
    val_loss = total_val_loss / len(val_loader)
    val_losses.append(val_loss)
    pearson_corr, _ = pearsonr(all_labels, all_predictions)
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss}, Val Loss: {val_loss}, Pearson Correlation: {pearson_corr}")
    checkpoint_path = f'bert_model_epoch_{epoch+2}.pt'
    torch.save(model.state_dict(), checkpoint_path)
    print(f"Saved model checkpoint: {checkpoint_path}")

# Plot Losses vs Epochs
plt.plot(range(1, num_epochs+1), train_losses, label='Training Loss')
plt.plot(range(1, num_epochs+1), val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Loss Plot: Training Loss and Validation Loss vs Epochs')
plt.legend()
plt.show()
