In [24]:
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import torch.optim as optim
import torch
from tqdm import tqdm
from sklearn.model_selection import train_test_split

# Data

In [25]:
# create a dataset for feeding bert by concating two paragraphs

class ParagraphDataset(Dataset):
    def __init__(self, data, tokenizer, max_len):
        self.data = data
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.label_dict = {'reverse':0, 'correct':1}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        paragraph1 = row['paragraph 1']
        paragraph2 = row['paragraph 2']
        label = self.label_dict[row['label']]

        # Half the maximum length for each paragraph
        half_max_len = self.max_len // 2

        encoding_paragraph1 = self.tokenizer.encode_plus(
            paragraph1,
            add_special_tokens=True,
            max_length=half_max_len,
            return_token_type_ids=False,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
            truncation=True
        )

        encoding_paragraph2 = self.tokenizer.encode_plus(
            paragraph2,
            add_special_tokens=True,
            max_length=half_max_len,
            return_token_type_ids=False,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
            truncation=True
        )

        # Concatenate the encodings
        input_ids = torch.cat([encoding_paragraph1['input_ids'], encoding_paragraph2['input_ids']], dim=-1)
        attention_mask = torch.cat([encoding_paragraph1['attention_mask'], encoding_paragraph2['attention_mask']], dim=-1)

        return {
            'input_ids': input_ids.flatten(),
            'attention_mask': attention_mask.flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

In [None]:
# bert model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

In [None]:
max_len = 128
batch_size = 16
train_data = ParagraphDataset(pd.read_csv('train.csv'), tokenizer, max_len)
val_data = ParagraphDataset(pd.read_csv('val.csv'), tokenizer, max_len)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)

# Model and training

In [None]:
optimizer = optim.AdamW(model.parameters(), lr=1e-5)
loss_fn = torch.nn.CrossEntropyLoss()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

In [None]:
# Training loop
num_epochs = 2
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    total_correct = 0
    total = 0

    for batch in tqdm(train_loader):
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        total_loss += loss.item()

        _, preds = torch.max(outputs.logits, dim=1)
        total_correct += (preds == labels).sum().item()
        total += labels.size(0)

        loss.backward()
        optimizer.step()

    avg_loss = total_loss / len(train_loader)
    accuracy = total_correct / total
    print(f"Train Epoch: {epoch+1}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")

# Prediction

In [None]:
test_data = ParagraphDataset(pd.read_csv('test.csv'), tokenizer, max_len)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

In [None]:
test_output = []
for batch in test_loader:
    with torch.no_grad():
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        val_loss = outputs.loss
        total_val_loss += val_loss.item()

        _, val_preds = torch.max(outputs.logits, dim=1)
        test_output = test_output +  list(torch.unbind(val_preds))
        total_val_correct += (val_preds == labels).sum().item()
        total_val += labels.size(0)

avg_val_loss = total_val_loss / len(val_loader)
val_accuracy = total_val_correct / total_val
print(f"Val Epoch: {epoch+1}, Loss: {avg_val_loss:.4f}, Accuracy: {val_accuracy:.4f}")