# Install Libraries

# Import Libraries

In [60]:
from torch.nn import CrossEntropyLoss
from tqdm.auto import tqdm
from abc import ABC, abstractmethod
from transformers import BertTokenizerFast, DataCollatorWithPadding, PreTrainedModel, AdamW, get_linear_schedule_with_warmup
from datasets import load_metric
import pandas as pd
import os
from torchtext.data import get_tokenizer
from torchtext.vocab import GloVe
from torch.utils.data import DataLoader
import torch
from torch.nn.utils.rnn import pad_sequence
import random
import numpy as np



# Data Reader

In [46]:
class DataReader:
    def __init__(self, data_folder, output_file):
        self.data_folder = data_folder
        self.output_file = output_file

    def process(self):
        files = [pd.read_xml(os.path.join(self.data_folder, file)) for file in os.listdir(self.data_folder) if file.endswith(".xml")]
        data = pd.concat(files, ignore_index=True)
        data.rename(columns={'t1': 'premise', 't2': 'hypothesis'}, inplace=True)
        
        # Convert unique string labels to integers
        unique_labels = data['label'].unique()
        data['labels'] = data['label']
        label_to_int = {label: idx for idx, label in enumerate(unique_labels)}
        data['label'] = data['label'].map(label_to_int)
        
        data.dropna(inplace=True)
        data.reset_index(drop=True, inplace=True)
        data.to_csv(self.output_file, index=False)


# Data Processing

In [54]:
class DataProcessor(ABC):
    def __init__(self, tokenizer, config):
        self.tokenizer = tokenizer
        self.config = config

    @abstractmethod
    def tokenize_and_cut(self, sentence):
        pass

    @abstractmethod
    def preprocess(self, premise, hypothesis, label):
        pass

    @abstractmethod
    def collate_fn(self, batch):
        pass

    def split_dataset(self, dataset, train_val_ratio=0.9):
        train_val_split_idx = int(len(dataset) * train_val_ratio)
        train_val_dataset, test_dataset = dataset[:train_val_split_idx], dataset[train_val_split_idx:]
        
        train_split_idx = int(len(train_val_dataset) * train_val_ratio)
        train_dataset, val_dataset = train_val_dataset[:train_split_idx], train_val_dataset[train_split_idx:]
        
        return train_dataset, val_dataset, test_dataset

    def get_data_loaders(self, csv_file):
        df = pd.read_csv(csv_file)
        dataset = [{"premise": row["premise"], "hypothesis": row["hypothesis"], "label": row["label"]} for _, row in df.iterrows()]
        dataset = [self.preprocess(data["premise"], data["hypothesis"], data["label"]) for data in dataset]

        
        train_dataset, val_dataset, test_dataset = self.split_dataset(dataset)

        train_loader = DataLoader(train_dataset, batch_size=self.config.batch_size, shuffle=True, collate_fn=self.collate_fn)
        val_loader = DataLoader(val_dataset, batch_size=self.config.batch_size, shuffle=False, collate_fn=self.collate_fn)
        test_loader = DataLoader(test_dataset, batch_size=self.config.batch_size, shuffle=False, collate_fn=self.collate_fn)

        return train_loader, val_loader, test_loader


class BiLSTMDataProcessor(DataProcessor):
    def __init__(self, tokenizer, embedding, config):
        super().__init__(tokenizer, config)
        self.embedding = embedding

    def tokenize_and_cut(self, sentence):
        tokens = self.tokenizer(sentence)
        tokens = tokens[:self.config.max_length-2]
        return tokens

    def preprocess(self, premise, hypothesis, label):
        # Convert tokens to their respective indices
        premise = [self.embedding.stoi.get(token, self.embedding.stoi.get("<unk>", 0)) for token in self.tokenize_and_cut(premise)]
        hypothesis = [self.embedding.stoi.get(token, self.embedding.stoi.get("<unk>", 0)) for token in self.tokenize_and_cut(hypothesis)]
        
        return {
            "premise": torch.LongTensor(premise),
            "hypothesis": torch.LongTensor(hypothesis),
            "labels": label
        }

    def collate_fn(self, batch):
        premise = pad_sequence([data['premise'] for data in batch], batch_first=True)
        hypothesis = pad_sequence([data['hypothesis'] for data in batch], batch_first=True)
        labels = torch.tensor([data['labels'] for data in batch])
        return {"premise": premise, "hypothesis": hypothesis, "labels": labels}


class BERTDataProcessor(DataProcessor):
    def __init__(self, tokenizer, config):
        super().__init__(tokenizer, config)
        self.data_collator = DataCollatorWithPadding(tokenizer=self.tokenizer)

    def tokenize_and_cut(self, premise, hypothesis):
        tokens = self.tokenizer(premise, hypothesis,
                                max_length=self.config.max_length,
                                truncation=True)
        return tokens

    def preprocess(self, premise, hypothesis, label):
        tokens = self.tokenize_and_cut(premise, hypothesis)
        tokens["labels"] = label
        return tokens

    def get_data_loaders(self, csv_file):
        df = pd.read_csv(csv_file)
        dataset = [{"premise": row["premise"], "hypothesis": row["hypothesis"], "label": row["label"]} for _, row in df.iterrows()]
        dataset = [self.preprocess(data["premise"], data["hypothesis"], data["label"]) for data in dataset]

        # split the dataset into training, validation and test sets

        train_dataset, val_dataset, test_dataset =  self.split_dataset(dataset)

        train_loader = DataLoader(train_dataset, batch_size=self.config.batch_size, shuffle=True, collate_fn=self.data_collator)
        val_loader = DataLoader(val_dataset, batch_size=self.config.batch_size, shuffle=False, collate_fn=self.data_collator)
        test_loader = DataLoader(test_dataset, batch_size=self.config.batch_size, shuffle=False, collate_fn=self.data_collator)

        return train_loader, val_loader, test_loader


# Models

In [61]:
def set_seed(seed_value=42):
    """Set seed for reproducibility."""
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)

class BiLSTM(torch.nn.Module):
    def __init__(self, vocab_size, options):
        super(BiLSTM, self).__init__()
        set_seed(options['seed_value'])
        self.embed_dim = 300
        self.hidden_size = options['d_hidden']
        self.num_classes = options['out_dim']
        self.directions = 2
        self.num_layers = 2
        self.concat = 4
        self.device = options['device']
        # Embedding layer
        self.embedding =  torch.nn.Embedding(vocab_size, self.embed_dim)
        self.projection = torch.nn.Linear(self.embed_dim, self.hidden_size)
        self.lstm = torch.nn.LSTM(self.hidden_size, self.hidden_size, self.num_layers,
                            bidirectional=True, batch_first=True, dropout=options['dp_ratio'])
        self.relu = torch.nn.ReLU()
        self.dropout = torch.nn.Dropout(p=options['dp_ratio'])

        self.lin1 = torch.nn.Linear(self.hidden_size * self.directions * self.concat, self.hidden_size)
        self.lin2 = torch.nn.Linear(self.hidden_size, self.hidden_size)
        self.lin3 = torch.nn.Linear(self.hidden_size, options['out_dim'])

        for lin in [self.lin1, self.lin2, self.lin3]:
            torch.nn.init.xavier_uniform_(lin.weight)
            torch.nn.init.zeros_(lin.bias)

        self.out = torch.nn.Sequential(
            self.lin1,
            self.relu,
            self.dropout,
            self.lin2,
            self.relu,
            self.dropout,
            self.lin3
        )

        self.loss_fn = CrossEntropyLoss()

    def forward(self, premise, hypothesis, labels=None):
        premise_embed = self.embedding(premise)
        hypothesis_embed = self.embedding(hypothesis)

        premise_proj = self.relu(self.projection(premise_embed))
        hypothesis_proj = self.relu(self.projection(hypothesis_embed))

        h0 = c0 = torch.tensor([]).new_zeros((self.num_layers * self.directions, premise.size(0), self.hidden_size)).to(self.device)

        _, (premise_ht, _) = self.lstm(premise_proj, (h0, c0))
        _, (hypothesis_ht, _) = self.lstm(hypothesis_proj, (h0, c0))

        premise = premise_ht[-2:].transpose(0, 1).contiguous().view(premise.size(0), -1)
        hypothesis = hypothesis_ht[-2:].transpose(0, 1).contiguous().view(premise.size(0), -1)

        combined = torch.cat((premise, hypothesis, torch.abs(premise - hypothesis), premise * hypothesis), 1)
        logits = self.out(combined)

        if labels is not None:
            loss = self.loss_fn(logits, labels)
            return {"loss": loss, "logits": logits}
        else:
            return {"logits": logits}

# Training

In [62]:
class Trainer:
    def __init__(self, model, train_loader, val_loader, config):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.config = config
        self.device = config.device
        self.optimizer = AdamW(model.parameters(), lr=config.learning_rate)
        total_steps = len(train_loader) * config.num_epochs
        self.scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=0, num_training_steps=total_steps)

    def evaluate(self, data_loader):
        self.model.eval()
        with torch.no_grad():
            for batch in data_loader:
                batch = {k: v.to(self.device) for k, v in batch.items()}
                outputs = self.model(**batch)
                _, preds = torch.max(outputs["logits"], dim=1)
                self.metric.add_batch(predictions=preds, references=batch["labels"])
        return self.metric.compute(), outputs.get("loss", None)

    def train(self):
        best_val_accuracy = 0.0
        for epoch in range(self.config.num_epochs):
            print(f'Epoch {epoch+1}/{self.config.num_epochs}')

            self.metric = load_metric("accuracy")

            # Training
            self.model.train()
            for batch in tqdm(self.train_loader, desc="Training"):
                batch = {k: v.to(self.device) for k, v in batch.items()}
                outputs = self.model(**batch)
                loss = outputs["loss"]
                loss.backward()
                self.optimizer.step()
                self.scheduler.step()
                self.optimizer.zero_grad()

            # Validation
            val_accuracy, val_loss = self.evaluate(self.val_loader)
            if val_loss is not None:
                print(f'Validation Loss: {val_loss}')
            print(f'Validation Accuracy: {val_accuracy["accuracy"]}')

            # Save the best model separately
            if val_accuracy["accuracy"] > best_val_accuracy:
                best_val_accuracy = val_accuracy["accuracy"]
                print(f'New best validation accuracy: {best_val_accuracy}')
                print(f'Saving model to {self.config.best_model_path}')
                if isinstance(self.model, PreTrainedModel):
                    self.model.save_pretrained(self.config.best_model_path)
                else:
                    torch.save(self.model.state_dict(), self.config.best_model_path)

# Test

# Analysis

In [63]:
# Define the configuration
class BiLSTMConfig:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    max_length = 128
    batch_size = 16
    learning_rate = 1e-5
    num_epochs = 10
    checkpoint_path = "checkpoint.pt"
    best_model_path = "best_model.pt"
    d_hidden = 128
    dp_ratio = 0.1
    out_dim = 2

data_reader = DataReader("data/COLIEE2021statute_data-English/train", 'data/coliee_train/coliee_2021.csv')
data_reader.process()

tokenizer = get_tokenizer('basic_english')
glove = GloVe(name='6B', dim=300)
vocab_size = len(glove.itos)

csv_file = 'data/coliee_train/coliee_2021.csv'
data_processor = BiLSTMDataProcessor(tokenizer, glove, BiLSTMConfig)
train_loader, val_loader, test_loader = data_processor.get_data_loaders(csv_file)

options = {'d_hidden': BiLSTMConfig.d_hidden, 'dp_ratio': BiLSTMConfig.dp_ratio, 'out_dim': BiLSTMConfig.out_dim, 'device': BiLSTMConfig.device, 'seed_value': 42}
model = BiLSTM(vocab_size, options).to(BiLSTMConfig.device)
trainer = Trainer(model, train_loader, val_loader, BiLSTMConfig)
trainer.train()
trainer.evaluate(test_loader)



Epoch 1/10


Training: 100%|██████████| 41/41 [00:29<00:00,  1.41it/s]


Validation Loss: 0.6939980983734131
Validation Accuracy: 0.4931506849315068
New best validation accuracy: 0.4931506849315068
Saving model to best_model.pt
Epoch 2/10


Training: 100%|██████████| 41/41 [00:26<00:00,  1.54it/s]


Validation Loss: 0.6962475776672363
Validation Accuracy: 0.4520547945205479
Epoch 3/10


Training: 100%|██████████| 41/41 [00:30<00:00,  1.35it/s]


Validation Loss: 0.6970223784446716
Validation Accuracy: 0.4520547945205479
Epoch 4/10


Training: 100%|██████████| 41/41 [00:27<00:00,  1.50it/s]


Validation Loss: 0.6967896223068237
Validation Accuracy: 0.4520547945205479
Epoch 5/10


Training: 100%|██████████| 41/41 [00:28<00:00,  1.43it/s]


Validation Loss: 0.6970521807670593
Validation Accuracy: 0.4520547945205479
Epoch 6/10


Training: 100%|██████████| 41/41 [00:34<00:00,  1.21it/s]


Validation Loss: 0.6968870759010315
Validation Accuracy: 0.4520547945205479
Epoch 7/10


Training: 100%|██████████| 41/41 [00:39<00:00,  1.04it/s]


Validation Loss: 0.6974933743476868
Validation Accuracy: 0.4520547945205479
Epoch 8/10


Training: 100%|██████████| 41/41 [00:40<00:00,  1.01it/s]


Validation Loss: 0.6977031230926514
Validation Accuracy: 0.4520547945205479
Epoch 9/10


Training: 100%|██████████| 41/41 [00:39<00:00,  1.05it/s]


Validation Loss: 0.6969847083091736
Validation Accuracy: 0.4520547945205479
Epoch 10/10


Training: 100%|██████████| 41/41 [00:40<00:00,  1.00it/s]


Validation Loss: 0.6973340511322021
Validation Accuracy: 0.4520547945205479


({'accuracy': 0.4691358024691358}, tensor(0.7201))