# Train and test BERT models using different parts of the case

In [2]:
from echr import *
from nb_tfidf import *
from bert import *
import os
import re
from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import random
import time
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from csv import DictWriter
from sklearn.model_selection import KFold

%load_ext autoreload
%autoreload 2

import warnings
warnings.filterwarnings("ignore")

from transformers import logging
logging.set_verbosity_error()


In [4]:
result_dir = 'results/parameter_optimization/'
articles = ['2', '3', '5', '6', '8', '10', '11', '13', '14', 'All']
path = 'datasets/Medvedeva/'
json_path = 'datasets/echrod/cases.json'
debug = False
num_runs = 5
cv = 10
n_jobs = -1
use_parts = 'facts'

In [5]:
if torch.cuda.is_available():       
    device = torch.device("cuda")
    print(f'There are {torch.cuda.device_count()} GPU(s) available.')
    print('Device name:', torch.cuda.get_device_name(0))

else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

There are 1 GPU(s) available.
Device name: NVIDIA GeForce RTX 3090


In [6]:
def run_experiment(part, article):
    
    def train(model, train_dataloader, val_dataloader=None, epochs=4, evaluation=False, debug=0):
        """Train the BertClassifier model.
        """

        # Specify loss function
        loss_fn = nn.CrossEntropyLoss()

        # Start training loop
        if debug: print("Start training...\n")
        for epoch_i in range(epochs):
            # =======================================
            #               Training
            # =======================================
            # Print the header of the result table
            if debug: print(f"{'Epoch':^7} | {'Batch':^7} | {'Train Loss':^12} | {'Val Loss':^10} | {'Val Acc':^9} | {'Elapsed':^9}")
            if debug: print("-"*70)

            # Measure the elapsed time of each epoch
            t0_epoch, t0_batch = time.time(), time.time()

            # Reset tracking variables at the beginning of each epoch
            total_loss, batch_loss, batch_counts = 0, 0, 0

            # Put the model into the training mode
            model.train()

            # For each batch of training data...
            for step, batch in enumerate(train_dataloader):
                batch_counts +=1
                # Load batch to GPU
                b_input_ids, b_attn_mask, b_labels = tuple(t.to(device) for t in batch)

                # Zero out any previously calculated gradients
                model.zero_grad()

                # Perform a forward pass. This will return logits.
                logits = model(b_input_ids, b_attn_mask)

                # Compute loss and accumulate the loss values
                loss = loss_fn(logits, b_labels)
                batch_loss += loss.item()
                total_loss += loss.item()

                # Perform a backward pass to calculate gradients
                loss.backward()

                # Clip the norm of the gradients to 1.0 to prevent "exploding gradients"
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

                # Update parameters and the learning rate
                optimizer.step()
                scheduler.step()

                # Print the loss values and time elapsed for every 20 batches
                if (step % 20 == 0 and step != 0) or (step == len(train_dataloader) - 1):
                    # Calculate time elapsed for 20 batches
                    time_elapsed = time.time() - t0_batch

                    # Print training results
                    if debug: print(f"{epoch_i + 1:^7} | {step:^7} | {batch_loss / batch_counts:^12.6f} | {'-':^10} | {'-':^9} | {time_elapsed:^9.2f}")

                    # Reset batch tracking variables
                    batch_loss, batch_counts = 0, 0
                    t0_batch = time.time()

            # Calculate the average loss over the entire training data
            avg_train_loss = total_loss / len(train_dataloader)

            if debug: print("-"*70)
            # =======================================
            #               Evaluation
            # =======================================
            if evaluation == True:
                # After the completion of each training epoch, measure the model's performance
                # on our validation set.
                val_loss, val_accuracy = evaluate(model, val_dataloader, device)

                # Print performance over the entire training data
                time_elapsed = time.time() - t0_epoch

                if debug: print(f"{epoch_i + 1:^7} | {'-':^7} | {avg_train_loss:^12.6f} | {val_loss:^10.6f} | {val_accuracy:^9.2f} | {time_elapsed:^9.2f}")
                if debug: print("-"*70)
            if debug: print("\n")
        if debug: print("Training complete!")

    # Prepare the data
    train_df = create_dataset(json_path, article, part) # echrod
    # train_df = create_dataset(path, article, part) #medvedeva 
    train_df = balance_dataset(train_df) 
    X = train_df['text'].to_numpy()
    y = train_df['violation'].to_numpy()
    
    print('Created data')
    
    accs, mccs, f1s = [], [], []
        
    skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=1)

    results = []
    
    print('Running 10 folds')
    for train_index, test_index in tqdm(skf.split(X, y)):
        
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]
        
        #MACHINE LEARNING
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
        train_inputs, train_masks = preprocessing_for_bert(X_train, tokenizer)
        
        set_seed(42)
     
        # Train
        train_labels = torch.tensor(y_train)
        train_data = TensorDataset(train_inputs, train_masks, train_labels)
        train_sampler = RandomSampler(train_data)
        train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)
        bert_classifier, optimizer, scheduler = initialize_model(device, train_dataloader, epochs=epochs)
        train(bert_classifier, train_dataloader, epochs=epochs)

        # Test
        test_inputs, test_masks = preprocessing_for_bert(X_test, tokenizer)
        test_dataset = TensorDataset(test_inputs, test_masks)
        test_sampler = SequentialSampler(test_dataset)
        test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=batch_size)
        probs = bert_predict(bert_classifier, test_dataloader)

        # Get predictions from the probabilities
        threshold = 0.5
        preds = np.where(probs[:, 1] >= 0.5, 1, 0)
        acc, mcc, f1 = return_metrics(preds, y_test, show=False)
        accs.append(acc)
        mccs.append(mcc)
        f1s.append(f1)
    
    acc = np.mean(accs)
    mcc = np.mean(mccs)
    f1 = np.mean(f1s)
    
    field_names = ['article', 'accuracy', 'MCC', 'F1','part', 'batch_size', 'epochs', 
                   'training_size', 'train_distribution']
    dct = {
        'article': article,
        'accuracy': acc,
        'MCC': mcc,
        'F1': f1,
        'part': part,
        'batch_size': batch_size,
        'epochs': epochs,
        'training_size': len(train_df),
        'train_distribution': round(train_df['violation'].mean()*100,2),
           }
    filename = 'results/BERT/parts/parts.csv'
    file_exists = os.path.isfile(filename)
    with open(filename, 'a') as f_object:
        dictwriter_object = DictWriter(f_object, fieldnames=field_names)
        if not file_exists:
            dictwriter_object.writeheader()  # file doesn't exist yet, write a header
        dictwriter_object.writerow(dct)
        f_object.close()
    return 0

In [7]:
# Run all combinations of articles and parts
for part in ['procedure+facts', 'procedure', 'facts']:
    print(part)
    for article in articles:
        print('\t', article)
        run_experiment(part, article)

procedure+facts
	 All
Created data
Running 10 folds


10it [5:24:47, 1948.74s/it]


procedure
	 All
Created data
Running 10 folds


10it [4:50:49, 1744.98s/it]


facts
	 All
Created data
Running 10 folds


10it [5:20:29, 1922.91s/it]
