In [5]:
import torch
from torch.utils.data import DataLoader
from transformers import BertForSequenceClassification, AdamW, BertTokenizerFast, Trainer, TrainingArguments, DataCollatorWithPadding, AutoTokenizer, AutoModelForSequenceClassification
from abc import ABC, abstractmethod
from sklearn.metrics import accuracy_score
import pandas as pd
import numpy as np
import os
from datasets import Dataset

from tensorboard import notebook



  from .autonotebook import tqdm as notebook_tqdm


In [6]:
class DataReader:
    def __init__(self, data_folder, output_file):
        self.data_folder = data_folder
        self.output_file = output_file

    def process(self):
        files = [pd.read_xml(os.path.join(self.data_folder, file)) for file in os.listdir(self.data_folder) if file.endswith(".xml")]
        data = pd.concat(files, ignore_index=True)
        data.rename(columns={'t1': 'premise', 't2': 'hypothesis'}, inplace=True)

        # Convert unique string labels to integers
        unique_labels = data['label'].unique()
        data['labels'] = data['label']
        label_to_int = {label: idx for idx, label in enumerate(unique_labels)}
        data['label'] = data['label'].map(label_to_int)

        data.dropna(inplace=True)
        data.reset_index(drop=True, inplace=True)
        data.to_csv(self.output_file, index=False)

In [7]:
class CustomDataset(Dataset):
    def __init__(self, custom_data, device=None):
        self.custom_data = custom_data
        # if device:
        #     self.custom_data = [item.to(device) for item in self.custom_data]

    def __len__(self):
        return len(self.custom_data)

    def __getitem__(self, idx):
        return self.custom_data[idx]

class DataProcessor(ABC):
    def __init__(self, tokenizer, config):
        self.tokenizer = tokenizer
        self.config = config

    @abstractmethod
    def tokenize_and_cut(self, sentence):
        pass

    @abstractmethod
    def preprocess(self, premise, hypothesis, label):
        pass

    def split_dataset(self, dataset, train_val_ratio=0.9):
        train_val_split_idx = int(len(dataset) * train_val_ratio)
        train_val_dataset, test_dataset = dataset[:train_val_split_idx], dataset[train_val_split_idx:]

        train_split_idx = int(len(train_val_dataset) * train_val_ratio)
        train_dataset, val_dataset = train_val_dataset[:train_split_idx], train_val_dataset[train_split_idx:]

        return train_dataset, val_dataset, test_dataset

    @abstractmethod
    def get_data_loaders(self, csv_file):
        pass

class BERTDataProcessor(DataProcessor):
    def __init__(self, tokenizer, config):
        super().__init__(tokenizer, config)
        self.data_collator = DataCollatorWithPadding(tokenizer=self.tokenizer)

    def tokenize_and_cut(self, premise, hypothesis):
        tokens = self.tokenizer(premise, hypothesis,
                                max_length=self.config.max_length,
                                truncation=True)
        return tokens

    def preprocess(self, premise, hypothesis, label):
        tokens = self.tokenize_and_cut(premise, hypothesis)
        tokens["labels"] = label
        return tokens

    def get_data_loaders(self, csv_file):
        df = pd.read_csv(csv_file)
        dataset = [{"premise": row["premise"], "hypothesis": row["hypothesis"], "label": row["label"]} for _, row in df.iterrows()]
        dataset = [self.preprocess(data["premise"], data["hypothesis"], data["label"]) for data in dataset]

        # split the dataset into training, validation and test sets

        train_dataset, val_dataset, test_dataset =  self.split_dataset(dataset)

        # train_dataset = Dataset.from_dict(train_dataset)
        # val_dataset = Dataset.from_dict(val_dataset)
        # test_dataset = Dataset.from_dict(test_dataset)

        # train_loader = DataLoader(train_dataset, batch_size=self.config.batch_size, shuffle=True, collate_fn=self.data_collator)
        # val_loader = DataLoader(val_dataset, batch_size=self.config.batch_size, shuffle=False, collate_fn=self.data_collator)
        # test_loader = DataLoader(test_dataset, batch_size=self.config.batch_size, shuffle=False, collate_fn=self.data_collator)

        return train_dataset, val_dataset, test_dataset

In [13]:
class BertConfig:
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    max_length = 128
    batch_size = 8
    learning_rate = 1e-6
    num_epochs = 3
    num_labels = 2
    warmup_steps = 300
    weight_decay = 0.01
    log_steps = 10

In [9]:
# Load pretrained model and tokenizer
config  = BertConfig()
model = AutoModelForSequenceClassification.from_pretrained('bert-base-cased', num_labels =config.num_labels)
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly i

In [14]:
train_folder = "data/COLIEE2021statute_data-English/train"
train_out_file = 'data/coliee_train/coliee_2021.csv'
data_reader = DataReader(train_folder, train_out_file)
data_reader.process()


data_processor = BERTDataProcessor(tokenizer, config)
train_loader, val_loader, test_loader = data_processor.get_data_loaders(train_out_file)

In [15]:
%tensorboard --logdir logs

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return {'accuracy': accuracy_score(labels, predictions)}

# Training the model
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=config.num_epochs,
    per_device_train_batch_size=config.batch_size,
    per_device_eval_batch_size=config.batch_size,
    warmup_steps=config.warmup_steps,
    weight_decay=config.weight_decay,
    logging_dir='./logs',
    logging_steps=config.log_steps,  # Log every 10 steps
    # save_steps=100,  # Save the model every 100 steps
    evaluation_strategy='epoch',  # Evaluate the model every 'logging_steps'
    save_strategy='epoch',  # Save the model every epoch
    load_best_model_at_end=True,
    use_mps_device=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_loader,
    eval_dataset=val_loader,
    compute_metrics=compute_metrics,
    data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
)

# Train the model
trainer.train()

# Save the model
trainer.save_model("./results")

# # You can then load the model with
# loaded_model = BertForSequenceClassification.from_pretrained("./results")






huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


  0%|          | 0/2460 [00:00<?, ?it/s]You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  0%|          | 10/2460 [00:05<18:28,  2.21it/s]

{'loss': 0.65, 'learning_rate': 1.6666666666666667e-06, 'epoch': 0.12}


  1%|          | 20/2460 [00:09<17:44,  2.29it/s]

{'loss': 0.776, 'learning_rate': 3.3333333333333333e-06, 'epoch': 0.24}


  1%|          | 30/2460 [00:14<18:21,  2.21it/s]

{'loss': 0.7766, 'learning_rate': 5e-06, 'epoch': 0.37}


  2%|▏         | 40/2460 [00:18<18:14,  2.21it/s]

{'loss': 0.8391, 'learning_rate': 6.666666666666667e-06, 'epoch': 0.49}


  2%|▏         | 50/2460 [00:23<18:17,  2.20it/s]

{'loss': 0.8093, 'learning_rate': 8.333333333333334e-06, 'epoch': 0.61}


  2%|▏         | 60/2460 [00:27<18:25,  2.17it/s]

{'loss': 0.745, 'learning_rate': 1e-05, 'epoch': 0.73}


  3%|▎         | 70/2460 [00:32<18:22,  2.17it/s]

{'loss': 0.7942, 'learning_rate': 1.1666666666666668e-05, 'epoch': 0.85}


  3%|▎         | 80/2460 [00:37<18:04,  2.19it/s]

{'loss': 0.7258, 'learning_rate': 1.3333333333333333e-05, 'epoch': 0.98}


                                                 
  3%|▎         | 82/2460 [00:39<16:56,  2.34it/s]

{'eval_loss': 0.7958950996398926, 'eval_accuracy': 0.4657534246575342, 'eval_runtime': 1.3631, 'eval_samples_per_second': 53.553, 'eval_steps_per_second': 7.336, 'epoch': 1.0}


  4%|▎         | 90/2460 [00:44<20:45,  1.90it/s]

{'loss': 0.7281, 'learning_rate': 1.5e-05, 'epoch': 1.1}


  4%|▍         | 100/2460 [00:48<17:53,  2.20it/s]

{'loss': 0.7453, 'learning_rate': 1.6666666666666667e-05, 'epoch': 1.22}


  4%|▍         | 110/2460 [00:53<18:02,  2.17it/s]

{'loss': 0.7807, 'learning_rate': 1.8333333333333333e-05, 'epoch': 1.34}


  5%|▍         | 120/2460 [00:58<17:49,  2.19it/s]

{'loss': 0.7256, 'learning_rate': 2e-05, 'epoch': 1.46}


  5%|▌         | 130/2460 [01:02<17:49,  2.18it/s]

{'loss': 0.721, 'learning_rate': 2.1666666666666667e-05, 'epoch': 1.59}


  6%|▌         | 140/2460 [01:07<17:39,  2.19it/s]

{'loss': 0.7629, 'learning_rate': 2.3333333333333336e-05, 'epoch': 1.71}


  6%|▌         | 150/2460 [01:11<17:45,  2.17it/s]

{'loss': 0.8141, 'learning_rate': 2.5e-05, 'epoch': 1.83}


  7%|▋         | 160/2460 [01:16<17:18,  2.22it/s]

{'loss': 0.7616, 'learning_rate': 2.6666666666666667e-05, 'epoch': 1.95}


                                                  
  7%|▋         | 164/2460 [01:19<14:35,  2.62it/s]

{'eval_loss': 0.7914258241653442, 'eval_accuracy': 0.4657534246575342, 'eval_runtime': 1.2149, 'eval_samples_per_second': 60.088, 'eval_steps_per_second': 8.231, 'epoch': 2.0}


  7%|▋         | 170/2460 [01:23<22:32,  1.69it/s]

{'loss': 0.795, 'learning_rate': 2.8333333333333335e-05, 'epoch': 2.07}


  7%|▋         | 180/2460 [01:28<17:16,  2.20it/s]

{'loss': 0.7528, 'learning_rate': 3e-05, 'epoch': 2.2}


  8%|▊         | 190/2460 [01:32<16:54,  2.24it/s]

{'loss': 0.7442, 'learning_rate': 3.1666666666666666e-05, 'epoch': 2.32}


  8%|▊         | 200/2460 [01:37<17:08,  2.20it/s]

{'loss': 0.7906, 'learning_rate': 3.3333333333333335e-05, 'epoch': 2.44}


  9%|▊         | 210/2460 [01:41<17:08,  2.19it/s]

{'loss': 0.8153, 'learning_rate': 3.5e-05, 'epoch': 2.56}


  9%|▉         | 220/2460 [01:46<16:48,  2.22it/s]

{'loss': 0.7734, 'learning_rate': 3.6666666666666666e-05, 'epoch': 2.68}


  9%|▉         | 230/2460 [01:50<16:44,  2.22it/s]

{'loss': 0.7274, 'learning_rate': 3.8333333333333334e-05, 'epoch': 2.8}


 10%|▉         | 240/2460 [01:55<16:54,  2.19it/s]

{'loss': 0.7324, 'learning_rate': 4e-05, 'epoch': 2.93}


                                                  
 10%|█         | 246/2460 [01:59<14:13,  2.59it/s]

{'eval_loss': 0.7824887037277222, 'eval_accuracy': 0.4657534246575342, 'eval_runtime': 1.2227, 'eval_samples_per_second': 59.704, 'eval_steps_per_second': 8.179, 'epoch': 3.0}


 10%|█         | 250/2460 [02:02<27:09,  1.36it/s]

{'loss': 0.7287, 'learning_rate': 4.166666666666667e-05, 'epoch': 3.05}


 11%|█         | 260/2460 [02:07<17:03,  2.15it/s]

{'loss': 0.6898, 'learning_rate': 4.3333333333333334e-05, 'epoch': 3.17}


 11%|█         | 270/2460 [02:11<16:51,  2.17it/s]

{'loss': 0.714, 'learning_rate': 4.5e-05, 'epoch': 3.29}


 11%|█▏        | 280/2460 [02:16<16:48,  2.16it/s]

{'loss': 0.7688, 'learning_rate': 4.666666666666667e-05, 'epoch': 3.41}


 12%|█▏        | 290/2460 [02:20<16:13,  2.23it/s]

{'loss': 0.7822, 'learning_rate': 4.8333333333333334e-05, 'epoch': 3.54}


 12%|█▏        | 300/2460 [02:25<16:25,  2.19it/s]

{'loss': 0.7417, 'learning_rate': 5e-05, 'epoch': 3.66}


 13%|█▎        | 310/2460 [02:30<16:22,  2.19it/s]

{'loss': 0.6748, 'learning_rate': 4.976851851851852e-05, 'epoch': 3.78}


 13%|█▎        | 320/2460 [02:34<16:23,  2.18it/s]

{'loss': 0.7741, 'learning_rate': 4.9537037037037035e-05, 'epoch': 3.9}


                                                  
 13%|█▎        | 328/2460 [02:39<13:41,  2.60it/s]

{'eval_loss': 0.7676177024841309, 'eval_accuracy': 0.4794520547945205, 'eval_runtime': 1.2099, 'eval_samples_per_second': 60.334, 'eval_steps_per_second': 8.265, 'epoch': 4.0}


 13%|█▎        | 330/2460 [02:42<37:03,  1.04s/it]

{'loss': 0.7503, 'learning_rate': 4.930555555555556e-05, 'epoch': 4.02}


 14%|█▍        | 340/2460 [02:46<16:33,  2.13it/s]

{'loss': 0.7205, 'learning_rate': 4.9074074074074075e-05, 'epoch': 4.15}


 14%|█▍        | 350/2460 [02:51<15:52,  2.22it/s]

{'loss': 0.709, 'learning_rate': 4.8842592592592595e-05, 'epoch': 4.27}


 15%|█▍        | 360/2460 [02:55<15:53,  2.20it/s]

{'loss': 0.6732, 'learning_rate': 4.8611111111111115e-05, 'epoch': 4.39}


KeyboardInterrupt: 

In [8]:
# Make predictions on the test set
predictions = trainer.predict(test_loader)

# Log predictions and metrics
# print(predictions.predictions)
# print(predictions.label_ids)
print(predictions.metrics)

100%|██████████| 11/11 [00:01<00:00,  8.85it/s]

{'test_loss': 0.7022614479064941, 'test_accuracy': 0.4691358024691358, 'test_runtime': 1.9434, 'test_samples_per_second': 41.68, 'test_steps_per_second': 5.66}



