In [None]:
import pandas as pd
import numpy as np
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from datasets import load_dataset
import random
import tensorflow as tf
from tqdm import tqdm
from tensorflow.keras.utils import Sequence
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, LSTM, Lambda, Dense, Dropout, Concatenate, Embedding
from tensorflow.keras.utils import Sequence
import time
import string
import json
import re
import nltk
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
from torch.optim import AdamW
from tqdm import tqdm
import random
import torch.nn.functional as F


In [None]:
!gdown 1CUu7zxsFFd9CiWArZeTx7Xm2Ul6JVPwc

In [None]:
class SNLIDataset(Dataset):
    def __init__(self, premises, hypotheses, labels, tokenizer, max_length=128):
        self.premises = premises
        self.hypotheses = hypotheses
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        enc_a = self.tokenizer(
            self.premises[idx],
            padding='max_length', truncation=True,
            max_length=self.max_length, return_tensors='pt'
        )
        enc_b = self.tokenizer(
            self.hypotheses[idx],
            padding='max_length', truncation=True,
            max_length=self.max_length, return_tensors='pt'
        )
        return {
            'input_ids_a': enc_a['input_ids'].squeeze(0),
            'attention_mask_a': enc_a['attention_mask'].squeeze(0),
            'input_ids_b': enc_b['input_ids'].squeeze(0),
            'attention_mask_b': enc_b['attention_mask'].squeeze(0),
            'label': torch.tensor(self.labels[idx], dtype=torch.long)
        }

# 2. Siamese BERT with MLP
class SiameseBertClassifier(nn.Module):
    def __init__(self, pretrained_model='bert-base-uncased', num_labels=3):
        super(SiameseBertClassifier, self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_model)
        for param in self.bert.parameters():
          param.requires_grad = False
        for i in range(11,12):
          for param in self.bert.encoder.layer[i].parameters():
              param.requires_grad = True
        hidden_size = self.bert.config.hidden_size  # 768
        combined_dim = hidden_size * 4  # u, v, |u-v|, u*v

        self.classifier = nn.Sequential(
            nn.Linear(combined_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, num_labels)
        )

    def forward(self, input_ids_a, attention_mask_a, input_ids_b, attention_mask_b):
        u = self.bert(input_ids=input_ids_a, attention_mask=attention_mask_a).last_hidden_state[:, 0, :]
        v = self.bert(input_ids=input_ids_b, attention_mask=attention_mask_b).last_hidden_state[:, 0, :]

        abs_diff = torch.abs(u - v)
        elem_mult = u * v
        combined = torch.cat([u, v, abs_diff, elem_mult], dim=1)

        return self.classifier(combined)

# 3. Evaluation function
def evaluate(model, dataloader, device):
    model.eval()
    total, correct = 0, 0
    with torch.no_grad():
        for batch in dataloader:
            input_ids_a = batch['input_ids_a'].to(device)
            attention_mask_a = batch['attention_mask_a'].to(device)
            input_ids_b = batch['input_ids_b'].to(device)
            attention_mask_b = batch['attention_mask_b'].to(device)
            labels = batch['label'].to(device)

            logits = model(input_ids_a, attention_mask_a, input_ids_b, attention_mask_b)
            preds = torch.argmax(logits, dim=1)

            correct += (preds == labels).sum().item()
            total += labels.size(0)

    acc = correct / total
    print(f"Accuracy: {acc:.4f}")
    return acc

# 4. Training function
def train(model, train_loader, optimizer, device, epochs=3):
    loss_fn = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        model.train()
        print(f"Epoch {epoch + 1}")
        loop = tqdm(train_loader, leave=True)
        for batch in loop:
            input_ids_a = batch['input_ids_a'].to(device)
            attention_mask_a = batch['attention_mask_a'].to(device)
            input_ids_b = batch['input_ids_b'].to(device)
            attention_mask_b = batch['attention_mask_b'].to(device)
            labels = batch['label'].to(device)

            optimizer.zero_grad()
            logits = model(input_ids_a, attention_mask_a, input_ids_b, attention_mask_b)
            loss = loss_fn(logits, labels)
            loss.backward()
            optimizer.step()
        print("Test")
        #evaluate(model, test_loader, device)
        #print("Train")
        evaluate(model, train_loader, device)
    return model


In [None]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SiameseBertClassifier(num_labels=3).to(device)
checkpoint = torch.load("/kaggle/working/checkpoint.pt", map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
for param in model.bert.parameters():
    param.requires_grad = False
for i in range(10, 12):
    for param in model.bert.encoder.layer[i].parameters():
        param.requires_grad = True
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-6)

In [None]:
def get_train_data(round_number):
    if round_number == 0:
        data = pd.read_csv("/kaggle/input/snowball0/Labeled_data_by_LLM.csv")
    else:
        data = pd.read_csv(f"/kaggle/working/snowball_{round_number}.csv")
    contradict_data = data[(data["gpt_label"] == "contradiction")&(data["predicted_label"] == "contradiction")]
    not_contradict_data = data[(data["gpt_label"] == "not contradiction") & (data["predicted_label"] != "contradiction")]
    contradict_len = contradict_data.shape[0]
    not_contradict_len = not_contradict_data.shape[0]
    selected_length = min(contradict_len, not_contradict_len)
    contradict_data = contradict_data.sample(n=selected_length)
    not_contradict_data = not_contradict_data.sample(n=selected_length)
    print(f"number of both contra {len(list(contradict_data['sentence_i']))}")
    print(f"number of both no contra {len(list(not_contradict_data['sentence_i']))}")
    all_i = list(contradict_data["sentence_i"]) + list(not_contradict_data["sentence_i"])
    all_j = list(contradict_data["sentence_j"]) + list(not_contradict_data["sentence_j"])
    labels = [2 for _ in range(len(list(contradict_data["sentence_i"])))] + [1 for _ in range(len( list(not_contradict_data["sentence_i"])))]
    return all_i, all_j, labels

def predict(round_number):
    data = pd.read_csv("/kaggle/input/snowball0/Labeled_data_by_LLM.csv")
    all_i = data["sentence_i"].to_list()
    all_j = data["sentence_j"].to_list()
    all_i_embedding = []
    all_j_embedding = []
    batch_size = 64
    for i in tqdm(range(0,len(all_i), batch_size)):
      selected_sentences_i = all_i[i:i+batch_size]
      selected_sentences_j = all_j[i:i+batch_size]
      selected_tokens_i = tokenizer(selected_sentences_i, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device)
      selected_tokens_j = tokenizer(selected_sentences_j, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device)
      with torch.no_grad():
        outputs_i = model.bert(**selected_tokens_i)
        embeddings = outputs_i.last_hidden_state[:, 0, :]
        all_i_embedding.append(embeddings.cpu())
      with torch.no_grad():
        outputs_j = model.bert(**selected_tokens_j)
        embeddings = outputs_j.last_hidden_state[:, 0, :]
        all_j_embedding.append(embeddings.cpu())
    all_i_embedding = torch.cat(all_i_embedding, dim=0)
    all_j_embedding = torch.cat(all_j_embedding, dim=0)
    all_j_embedding = all_j_embedding.to(device)
    all_i_embedding = all_i_embedding.to(device)
    abs_diff = torch.abs(all_i_embedding - all_j_embedding)
    elem_mult = (all_i_embedding * all_j_embedding).to(device)
    combined = torch.cat([all_i_embedding,all_j_embedding, abs_diff, elem_mult], dim=1).to(device)
    logits = model.classifier(combined)
    probs = F.softmax(logits, dim=-1)
    all_probs = [probs]
    logits = probs.argmax(dim=1).cpu()
    to_label = {0:"entailment", 1:"neutral", 2:"contradiction"}
    labels = []
    for logit in logits:
      labels.append(to_label[logit.item()])
    data["predicted_label"] = labels
    data.to_csv(f"/kaggle/working/snowball_{round_number+1}.csv", index=False)


for i in range(10):
    print(f"in round {i} snowball")
    all_i, all_j, labels = get_train_data(i)
    
    train_premises = all_i
    train_hypotheses = all_j
    train_labels = labels # entailment=0, neutral=1, contradiction=2
    train_dataset = SNLIDataset(train_premises, train_hypotheses, train_labels, tokenizer)
    train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
    optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-5)
    model = train(model, train_loader, optimizer, device, epochs=2)
    predict(i)

In [None]:
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict()
}, "checkpoint_finetuned_round2.pt")

