In [1]:
from json import load
import numpy as np
from neuralNetwork import NeuralNetwork, In_between_epochs
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from random import shuffle, randint
from transformers import RobertaModel, RobertaTokenizer, AutoTokenizer, AutoModel
from helper import dict_lists_to_list_of_dicts, get_train_valdiation_test_split, Dataset
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score

In [2]:
f = open("ECHR_Corpus.json",encoding= 'utf-8')
dataset = load(f)
f.close()

# Dataset
ogni caso nel dataset è organizzato così:
```json
{
    "text":"il testo completo della sentenza"
    "clauses":[
        {
            "_id": "id della clause",
            "start": "index di inizio della clause",
            "end": "index di fine della cluause"
        },
        {
            "...":"..."
        }
    ]
    "argument":[
        {
            "premises":[
                "id premise1", "id premise2", "..."
            ],
            "conclusion": "id conclusion"
        },
        {
            "...":"..."
        }
    ]
}
```
In questo passaggio riorganizzo i dati per avere comunque la stessa struttura ma trasformo le clauses in un dizionario che ha come id la chiave della clause e come valore il testo (senza quindi dover usare start e end per cercarlo nel testo). Non tutte le clause sono parte di una premessa o di una conclusione.

In [3]:
refactored_dataset = []
for datapoint in dataset:
    text = datapoint["text"]
    dict_clauses = {}
    for clause in datapoint["clauses"]:
        start = clause["start"]
        end = clause["end"]
        id = clause["_id"]
        dict_clauses[id] = text[start:end]
    refactored_dataset.append({
        "text": text,
        "arguments": datapoint["arguments"],
        "n_clauses": len(datapoint["clauses"]),
        "all_clauses": dict_clauses
    })

Secondo me ha poco senso salvare il dataset come dataframe dato che è praticamente solo testo. Comunque non dovrebbe essere difficilissimo tirarci fuori qualche statistica. Ho fatto degli esempi scemi qua:

In [4]:
n_arguments = []
n_premises = []
n_clauses = []
for case in refactored_dataset:
    n_arguments.append(len(case["arguments"]))
    n_clauses.append(case["n_clauses"])
    for argument in case["arguments"]:
        n_premises.append(len(argument["premises"]))
print(f"""
On average, a case has {np.mean(n_clauses):.2f} clauses with a median of {np.median(n_clauses):.0f} clauses per case.
On average, a case has {np.mean(n_arguments):.2f} arguments with a median of {np.median(n_arguments):.0f} arguments per case.
Each argument, on average, has: {np.mean(n_premises):.2f} premises with a median of {np.median(n_premises):.0f} premises per argument.
""")


On average, a case has 248.95 clauses with a median of 226 clauses per case.
On average, a case has 17.69 arguments with a median of 14 arguments per case.
Each argument, on average, has: 2.63 premises with a median of 2 premises per argument.



Qua puoi vedere come tirare fuori il testo di una conclusion o di una premise: ti basta usare la conclusion/premise come indice nel dizionario delle clause

In [5]:
premise = "\n\t".join([refactored_dataset[0]["all_clauses"][premise] for premise in refactored_dataset[0]["arguments"][0]["premises"]])
conclusion = refactored_dataset[0]["all_clauses"][refactored_dataset[0]["arguments"][0]["conclusion"]]
print(f"""
Here an example of an argument:
    - Premises:
        {premise}
    - Conclusion:
        {conclusion}
""")


Here an example of an argument:
    - Premises:
        The Commission notes that the applicant was detained after having been sentenced by the first instance court to 18 months' imprisonment.
	He was released after the Court of Appeal reviewed this sentence, reducing it to 15 months' imprisonment, convertible to a fine.
    - Conclusion:
        The Commission finds that the applicant was deprived of his liberty "after conviction by a competent court" within the meaning of Article 5 para. 1 (a) (Art. 5-1-a) of the Convention.



## Prepare for a task

### 6.1 Argument Clause Recognition

In [6]:
cases = get_train_valdiation_test_split(refactored_dataset, [4])#, seed = 42) #42, 4, 2

In [7]:
for i in range(len(cases)):
    train_cases, validation_cases, test_cases = cases[i]
    print(f"""there are {len(train_cases)} cases in the training set (fold {i}), 
    {len(validation_cases)} cases in the validation set (fold {i}) and
    {len(test_cases)} cases in the test set (fold {i})
    """)

there are 28 cases in the training set (fold 0), 
    6 cases in the validation set (fold 0) and
    8 cases in the test set (fold 0)
    
there are 28 cases in the training set (fold 1), 
    6 cases in the validation set (fold 1) and
    8 cases in the test set (fold 1)
    
there are 28 cases in the training set (fold 2), 
    6 cases in the validation set (fold 2) and
    8 cases in the test set (fold 2)
    
there are 28 cases in the training set (fold 3), 
    6 cases in the validation set (fold 3) and
    8 cases in the test set (fold 3)
    
there are 28 cases in the training set (fold 4), 
    6 cases in the validation set (fold 4) and
    8 cases in the test set (fold 4)
    


In [8]:
tokenizer = AutoTokenizer.from_pretrained("nlpaueb/legal-bert-base-uncased")

dato che non tutte le clause sono parte di una premise/conclusion, le dividiamo per cercare di tirare fuori un classificatore di clauses

In [9]:
def get_acr_dataloader(cases, tokenizer, batch_size=16, shuffle=False, verbose=True):
    ACR_x = []
    ACR_y = []
    for case in cases:
        n_clauses = case["n_clauses"]
        clauses = case["all_clauses"]
        args_set = set()
        splitter = "AS TO THE LAW" if "AS TO THE LAW" in case["text"] else "THE LAW"
        law_section = case["text"].split(splitter)[1]
        for argument in case["arguments"]:
            if not argument["conclusion"] in args_set and clauses[argument["conclusion"]] in law_section:
                ACR_x.append(clauses[argument["conclusion"]])
                ACR_y.append(torch.tensor([1.,0.]))
                args_set.add(argument["conclusion"])
            for premise in argument["premises"]:
                if not premise in args_set and clauses[premise] in law_section:
                    ACR_x.append(clauses[premise])
                    ACR_y.append(torch.tensor([1.,0.]))
                    args_set.add(premise)
        for clause_id in clauses.keys():
            if not clause_id in args_set and clauses[clause_id] in law_section:
                ACR_x.append(clauses[clause_id])
                ACR_y.append(torch.tensor([0.,1.]))
    if verbose:
        print(f"""
    There are:
        - {len(ACR_x)} clause
        - {len([a for a in ACR_y if a[0] == 1])} true clause
        - {len([a for a in ACR_y if a[0] == 0])} fake clause
    """)
    ACR_x_tokenized = dict_lists_to_list_of_dicts(tokenizer(ACR_x, padding=True, truncation=True, return_tensors='pt'))
    dataset = Dataset(ACR_x_tokenized, ACR_y)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
train_cases, validation_cases, test_cases = cases[-1]
print("train set:")
train_dataloader = get_acr_dataloader(train_cases, tokenizer, shuffle=True, batch_size=8)
print("validation set:")
validation_dataloder = get_acr_dataloader(validation_cases, tokenizer, batch_size=8)
print("test set:")
test_dataloader = get_acr_dataloader(test_cases, tokenizer, batch_size=8)

train set:

    There are:
        - 3355 clause
        - 1846 true clause
        - 1509 fake clause
    
validation set:

    There are:
        - 288 clause
        - 223 true clause
        - 65 fake clause
    
test set:

    There are:
        - 1046 clause
        - 624 true clause
        - 422 fake clause
    


# Pre-train

# Train

Da qui in avanti ti dovrebbe essere tutto abbastanza familiare dato che è praticamente lo stesso codice dell'assignment 2. Qua ho fatto solo 5 epoche e con un learning rate a caso ma è giusto per far vedere come funziona. Da ora si può iniziare a giocare...

In [10]:
class Model(NeuralNetwork):
    def __init__(self, out_features:int, dropout:float = 0.2) -> None:
        super().__init__()
        self.encoder = AutoModel.from_pretrained("nlpaueb/legal-bert-base-uncased")
        self.output_layer = nn.Linear(self.encoder.config.hidden_size, out_features)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, input):
        encoded_input, _ = self.encoder(**input, return_dict = False)
        encoded_input = self.dropout(encoded_input)
        encoded_input = torch.nn.functional.avg_pool1d(
            encoded_input.permute(0, 2, 1), 
            kernel_size=encoded_input.size(1)
        ).squeeze(2)
        if len(encoded_input.size()) != 2:
            print(encoded_input.size())
        return self.output_layer(encoded_input).float()

    def freeze_bert(self):
        for param in self.encoder.parameters():
                param.requires_grad = False
    def unfreeze_bert(self):
        for param in self.encoder.parameters():
                param.requires_grad = True
    
model = Model(2)

In [16]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("operating on device:", device)

operating on device: cuda:0


In [17]:
class EarlyStopping(In_between_epochs):
    def __init__(self, delta, patience):
        self.delta = delta
        self.patience = patience
        self.current_patience = 0
        self.best_valid_loss = 100000000000000000000
        self.best_model = Model(2)
        self.epochs = 0
    
    def __call__(self, model:torch.nn.Module, loaders:dict[str,torch.utils.data.DataLoader], device:'torch.device|str', output_extraction_function, losses:dict[str, float]) -> bool:
        self.epochs += 1
        if losses["validation"] < self.best_valid_loss - self.delta:
            self.best_valid_loss = losses["validation"]
            self.best_model.load_state_dict(model.state_dict())
            self.current_patience = 0
        else:
            self.current_patience += 1
            if self.current_patience >= self.patience:
                return True
        return False
    def reset(self):
        self.current_patience = 0
        self.epochs = 0

In [22]:
def train(model, train, validation, min_lr, start_lr, early_stopping, frac):
    tot_train_data, tot_val_data = [], []
    lr = start_lr
    while lr > min_lr:
        train_data, validation_data = model.train_network(train, 
                        validation, 
                        torch.optim.Adam, 
                        loss_function=nn.CrossEntropyLoss(),
                        device=device, 
                        batch_size=32,
                        verbose=True, 
                        output_extraction_function= lambda x: torch.max(x, -1)[1].view(-1).cpu(), 
                        metrics={
                        "accuracy": accuracy_score, 
                        "f1_score": lambda y_true, y_pred: f1_score(y_true, y_pred, average="macro")},
                        in_between_epochs = {"early_stopping": early_stopping},
                        learning_rate=lr,
                        epochs=2)
        train_data["epochs"] = early_stopping.epochs
        train_data["lr"] = lr
        validation_data["epochs"] = early_stopping.epochs
        validation_data["lr"] = lr
        tot_train_data.append(train_data)
        tot_val_data.append(validation_data)
        model.load_state_dict(early_stopping.best_model.state_dict())
        lr = lr * frac
        early_stopping.reset()

In [23]:
def greedy_train(lrs:list[float], train, validation):
    trainings = []
    models = []
    for lr in lrs:
        model = Model(2)
        train_data, validation_data = model.train_network(train, 
                        validation, 
                        torch.optim.Adam, 
                        loss_function=nn.CrossEntropyLoss(),
                        device=device, 
                        batch_size=32,
                        verbose=True, 
                        output_extraction_function= lambda x: torch.max(x, -1)[1].view(-1).cpu(), 
                        metrics={
                        "accuracy": accuracy_score, 
                        "f1_score": lambda y_true, y_pred: f1_score(y_true, y_pred, average="macro")},
                        learning_rate=lr,
                        epochs=15)
        for key in train_data:
                train_data[key] = [float(v) for v in train_data[key]]
                validation_data[key] = [float(v) for v in validation_data[key]]
        trainings.append({"train": train_data, "validation":validation_data, "lr":lr})
        models.append(model)
    return trainings, models

In [24]:
def fold_training(cases):
    best_models = []
    for cases_split in cases:
        train_cases, validation_cases, _ = cases_split
        train_dataloader = get_acr_dataloader(train_cases, tokenizer, shuffle=True, batch_size=8, verbose=False)
        validation_dataloder = get_acr_dataloader(validation_cases, tokenizer, batch_size=8, verbose=False)
        model = Model(2)
        early_stopping = EarlyStopping(.001, 3)
        train(model, train_dataloader, validation_dataloder, 5e-7, 1e-6, early_stopping, .6)
        model.load_state_dict(early_stopping.best_model.state_dict())
        best_models.append({"loss": early_stopping.best_valid_loss, "model": model})
    
    return best_models

In [25]:
models = fold_training(cases)

EPOCH 1 training loss:      0.589 - validation loss:      0.656                                                                                                                                                                                                                                                      
EPOCH 1 training accuracy:      0.701 - validation accuracy:      0.651
EPOCH 1 training f1_score:      0.618 - validation f1_score:      0.489
----------------------------------------------------------------------------------------------------

EPOCH 2 training loss:      0.499 - validation loss:      0.700                                                                                                                                                                                                                                                      
EPOCH 2 training accuracy:      0.761 - validation accuracy:      0.660
EPOCH 2 training f1_score:      0.716 - validation f1_score:  

In [28]:
best_model_fold = min(models, key= lambda x: x["loss"]) 
print(f"best model loss: {best_model_fold['loss']}")

best model loss: 0.47013330459594727


In [32]:
best_model_fold['model']

Model(
  (encoder): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=T

In [34]:
torch.save(best_model_fold['model'].state_dict(), 'best model')

In [36]:
best_model = Model(2)
best_model.load_state_dict(best_model_fold["model"].state_dict())
best_model = best_model.to(device)
best_model.eval()

Model(
  (encoder): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=T

In [63]:
def full_test(model, cases):
    y_true = []
    y_pred = []
    with torch.no_grad():
        for case in cases:
            clauses = case["all_clauses"]
            args_set = set()
            splitter = "AS TO THE LAW" if "AS TO THE LAW" in case["text"] else "THE LAW"
            law_section = case["text"].split(splitter)[1]
            for argument in case["arguments"]:
                if (argument["conclusion"] not in args_set) and (clauses[argument["conclusion"]] in law_section):
                    tokenized_x = tokenizer(clauses[argument["conclusion"]],padding = True, truncation=True, return_tensors='pt')
                    tokenized_x = {key: tokenized_x[key].to(device) for key in tokenized_x.keys()}
                    pred = model(tokenized_x)
                    y_pred.append(torch.max(pred, -1)[1].view(-1).cpu().tolist()[0])
                    y_true.append(0)
                    keys = list(tokenized_x.keys())
                    for k in keys:
                        del tokenized_x[k]
                    del tokenized_x
                    args_set.add(argument["conclusion"])
                elif argument["conclusion"] not in args_set and clauses[argument["conclusion"]] not in law_section:
                    args_set.add(argument["conclusion"])
                    y_pred.append(1)
                    y_true.append(0)
                for premise in argument["premises"]:
                    if premise not in args_set and clauses[premise] in law_section:
                        tokenized_x = tokenizer(clauses[premise], padding = True, truncation=True, return_tensors='pt')
                        tokenized_x = {key: tokenized_x[key].to(device) for key in tokenized_x.keys()}
                        pred = model(tokenized_x)
                        y_pred.append(torch.max(pred, -1)[1].view(-1).cpu().tolist()[0])
                        y_true.append(0)
                        keys = list(tokenized_x.keys())
                        for k in keys:
                            del tokenized_x[k]
                        del tokenized_x
                        args_set.add(premise)
                        
                    elif premise not in args_set and clauses[premise] not in law_section:
                        args_set.add(premise)
                        y_pred.append(1)
                        y_true.append(0)
            for clause_id in clauses.keys():
                if (not clause_id in args_set) and (clauses[clause_id] in law_section):
                    tokenized_x = tokenizer(clauses[clause_id], padding = True, truncation=True, return_tensors='pt')
                    tokenized_x = {key: tokenized_x[key].to(device) for key in tokenized_x.keys()}
                    pred = model(tokenized_x)
                    y_pred.append(torch.max(pred, -1)[1].view(-1).cpu().tolist()[0])
                    y_true.append(1)
                    keys = list(tokenized_x.keys())
                    for k in keys:
                        del tokenized_x[k]
                    del tokenized_x
                    args_set.add(clause_id)
                    
                #else:
                if (not clause_id in args_set) and (clauses[clause_id] not in law_section):
                    y_pred.append(1)
                    y_true.append(1)
    print(y_true)
    print(y_pred)
    print(len(y_true))
    print(f"""accuracy: {accuracy_score(y_true, y_pred):.2f}
    f1: {f1_score(y_true, y_pred):.2f}
    precision: {precision_score(y_true, y_pred):.2f}
    recall: {recall_score(y_true, y_pred):.2f}""")
full_test(best_model, test_cases)

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

In [58]:
clauses = test_cases[0]["all_clauses"]
clauses

{'5d4b80113e582511aa1cd850': " The\r\nCommission, taking into account the applicant's position in these\r\nproceedings and the impact of their outcome on her own financial\r\nsituation,",
 '5d4b802b3e582511aa1cd851': 'finds that the applicant is entitled to complain about the\r\nlength of the proceedings.\r',
 '5d4b804a3e582511aa1cd852': 'They submit that the applicant\r\nfailed to complain about the length in the course of the domestic court\r\nproceedings and did not lodge a request under S. 91 of the Court\r\nOrganisation Act 1990 (Gerichtsverfassungsgesetz) that, in view of the\r\nalleged delay, the superior court should fix an appropriate time-limit\r\nfor the conduct of the court proceedings.\r',
 '5d4b80613e582511aa1cd853': 'According to Article 26 (Art. 26), the "Commission may only deal\r\nwith the matter after all domestic remedies have been exhausted,\r\naccording to the generally recognised rules of international law".\r\n\r',
 '5d4b80763e582511aa1cd854': 'The Government ma

In [70]:
tokenized_x = tokenizer(test_cases[0]["all_clauses"][list(test_cases[0]["all_clauses"].keys())[2]], truncation=True, return_tensors='pt')
model(tokenized_x)

torch.max(model(tokenized_x), -1)[1].view(-1).cpu().tolist()[0]

0

## 6.2 Argument Relation Mining

In [18]:
def add_two_premises(idx_1, argument):
    idx_2 = randint(0, len(argument["premises"]) - 1)
    while idx_1 == idx_2:
       idx_2 = randint(0, len(argument["premises"]) - 1)
    premise_1 = argument["premises"][idx_1]
    premise_2 = argument["premises"][idx_2]
    return {
        "e1": clauses[premise_1],
        "e2": clauses[premise_2]
    }, (premise_1, premise_2)

def add_premise_conclusion(idx_1, argument):
    premise = argument["premises"][idx_1]
    conclusion = argument["conclusion"]
    if randint(1, 10) > 5: # with 50% chance we make the premise the first element
        return {
            "e1": clauses[premise],
            "e2": clauses[conclusion]
        }, (premise, conclusion)
    else: # with 50% chance we make the conclusion the first element
        return {
            "e2": clauses[conclusion],
            "e1": clauses[premise]
        }, (premise, conclusion)

In [62]:
args_set = set()    #Prepare ids of clauses that are parts of arguments
for case in refactored_dataset: # for each case
    for argument in case["arguments"]: 
        args_set.add(argument["conclusion"])
        for premise in argument["premises"]:
            args_set.add(premise)

ARM_x = []
ARM_y = []
for caseidx in range(len(dataset)):
    sorted_clauses = sorted(dataset[caseidx]['clauses'], key = lambda x: x['start'])
    for i in range(len(sorted_clauses)-1):
        if sorted_clauses[i]['_id'] in args_set:
            for el in sorted_clauses[i+1:i+6]:
                if el['_id'] in args_set:
                    ARM_x.append({'e1': refactored_dataset[caseidx]['all_clauses'][sorted_clauses[i]['_id']], 'e2': refactored_dataset[caseidx]['all_clauses'][el['_id']]})
                    y = torch.tensor([0., 1.])
                    for arg in refactored_dataset[caseidx]['arguments']:
                        if sorted_clauses[i]['_id'] in arg['premises'] or sorted_clauses[i]['_id'] == arg['conclusion']:
                            if el['_id'] in arg['premises'] or sorted_clauses[i]['_id'] == arg['conclusion']:
                                y = torch.tensor([1., 0.])
                    ARM_y.append(y)
    if caseidx == 33:
        print(caseidx, len(ARM_x), len(ARM_y))

print(len(ARM_x))
print(len([ARM_y[i] for i in range(len(ARM_y)) if ARM_y[i][0]]), len([ARM_y[i] for i in range(len(ARM_y)) if not ARM_y[i][0]]))

ARM_x_train = ARM_x[:6313] #Visto a mano, fino a indice 6312 sono nel train, fino a 7433 nel val, e il resto nel test
ARM_y_train = ARM_y[:6313]
ARM_x_val = ARM_x[6313:7434]
ARM_y_val = ARM_y[6313:7434]
ARM_x_test=ARM_x[7434:]
ARM_y_test=ARM_y[7434:]

ARM_x_train_tokenized = dict_lists_to_list_of_dicts({
    key : dict_lists_to_list_of_dicts(tokenizer([v[key] for v in ARM_x_train], padding=True, truncation=True, return_tensors='pt'))
for key in ARM_x_train[0].keys()})

ARM_x_val_tokenized = dict_lists_to_list_of_dicts({
    key : dict_lists_to_list_of_dicts(tokenizer([v[key] for v in ARM_x_val], padding=True, truncation=True, return_tensors='pt'))
for key in ARM_x_val[0].keys()})

ARM_x_test_tokenized = dict_lists_to_list_of_dicts({
    key : dict_lists_to_list_of_dicts(tokenizer([v[key] for v in ARM_x_test], padding=True, truncation=True, return_tensors='pt'))
for key in ARM_x_test[0].keys()})

33 7434 7434
10430
5161 5269


### Train

Model1 prende le due componenti e le analizza indipendentemente per poi dare un output

Molel2 concatena componente 1 e componente 2 per analizzarle assieme

In [63]:
class Model1(NeuralNetwork):
    def __init__(self, out_features:int, dropout:float = .3) -> None:
        super().__init__()
        self.encoder = AutoModel.from_pretrained("nlpaueb/legal-bert-base-uncased")#"FacebookAI/roberta-base")
        self.output_layer = nn.Linear(self.encoder.config.hidden_size * 2, out_features)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, input):
        _, encoded_input_1 = self.encoder(**input["e1"], return_dict = False)
        _, encoded_input_2 = self.encoder(**input["e2"], return_dict = False)
        encoded_input = self.dropout(torch.cat((encoded_input_1, encoded_input_2), dim=1))
    
        return self.output_layer(encoded_input)
    
class Model2(NeuralNetwork):
    def __init__(self, out_features:int, dropout:float = .3) -> None:
        super().__init__()
        self.encoder = AutoModel.from_pretrained("nlpaueb/legal-bert-base-uncased")
        self.output_layer = nn.Linear(self.encoder.config.hidden_size, out_features)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, input):
        input = {key: torch.cat((input["e1"][key], input["e2"][key]), dim=1) for key in input["e1"].keys()}
        _, encoded_input = self.encoder(**input, return_dict = False)
        encoded_input = self.dropout(encoded_input)
    
        return self.output_layer(encoded_input)

model1 = Model1(2)
model2 = Model2(2)

In [67]:
#ARM_x_tokenized = dict_lists_to_list_of_dicts({
#    key : dict_lists_to_list_of_dicts(tokenizer([v[key] for v in ARM_x], padding=True, truncation=True, return_tensors='pt'))
#for key in ARM_x[0].keys()})
#train_dataloader, validation_dataloader, test_dataloader = get_dataloader(ARM_x_tokenized, ARM_y, 8, [9])
batch_size = 8
train_dataloader, validation_dataloader, test_dataloader = (DataLoader(Dataset(ARM_x_train_tokenized, ARM_y_train), batch_size = batch_size, shuffle = True),
                                                            DataLoader(Dataset(ARM_x_val_tokenized, ARM_y_val),batch_size = batch_size, shuffle = False), 
                                                            DataLoader(Dataset(ARM_x_test_tokenized, ARM_y_test),batch_size = batch_size, shuffle = False))

In [70]:
train_data, validation_data =   model1.train_network(train_dataloader, 
                    validation_dataloader, 
                    torch.optim.Adam, 
                    loss_function=nn.CrossEntropyLoss(),
                    device=device, 
                    batch_size=32,
                    verbose=True, 
                    output_extraction_function= lambda x: torch.max(x, -1)[1].view(-1).cpu(), 
                    metrics={
                     "accuracy": accuracy_score, 
                     "f1_score": lambda y_true, y_pred: f1_score(y_true, y_pred, average="macro")},
                    learning_rate=1e-6,
                    epochs=5)

EPOCH 1 training loss:      0.442 - validation loss:      0.534                                                                                                                                                                                                                                                      
EPOCH 1 training accuracy:      0.810 - validation accuracy:      0.726
EPOCH 1 training f1_score:      0.786 - validation f1_score:      0.634
----------------------------------------------------------------------------------------------------

EPOCH 2 training loss:      0.408 - validation loss:      0.546                                                                                                                                                                                                                                                      
EPOCH 2 training accuracy:      0.825 - validation accuracy:      0.723
EPOCH 2 training f1_score:      0.803 - validation f1_score:  

In [72]:
train_data, validation_data =   model2.train_network(train_dataloader, #rispetto a model1 sembra avere risultati simili, ma ci mette più tempo
                    validation_dataloader, 
                    torch.optim.Adam, 
                    loss_function=nn.CrossEntropyLoss(),
                    device=device, 
                    batch_size=32,
                    verbose=True, 
                    output_extraction_function= lambda x: torch.max(x, -1)[1].view(-1).cpu(), 
                    metrics={
                     "accuracy": accuracy_score, 
                     "f1_score": lambda y_true, y_pred: f1_score(y_true, y_pred, average="macro")},
                    learning_rate=1e-6,
                    epochs=5)

EPOCH 1 training loss:      0.565 - validation loss:      0.659                                                                                                                                                                                                                                                      
EPOCH 1 training accuracy:      0.730 - validation accuracy:      0.640
EPOCH 1 training f1_score:      0.702 - validation f1_score:      0.527
----------------------------------------------------------------------------------------------------

EPOCH 2 training loss:      0.526 - validation loss:      0.624                                                                                                                                                                                                                                                      
EPOCH 2 training accuracy:      0.762 - validation accuracy:      0.670
EPOCH 2 training f1_score:      0.732 - validation f1_score:  

## 6.3 Premise/Conclusion Recognition

In [12]:
PCR_x = []
PCR_y = []
all_premises = set()
all_conclusions = set()
all_clause = set()
for case in refactored_dataset:
    n_clauses = case["n_clauses"]
    clauses = case["all_clauses"]
    for argument in case["arguments"]:
        all_conclusions.add(clauses[argument["conclusion"]])
        all_clause.add(clauses[argument["conclusion"]])
        for premise in argument["premises"]:
            all_premises.add(clauses[premise])
            all_clause.add(clauses[premise])

for clause in all_clause:
    current_y = torch.zeros(2)
    if clause in all_premises:
        current_y[0] = 1
    if clause in all_conclusions:
        current_y[1] = 1
    PCR_x.append(clause)
    PCR_y.append(current_y)
print(f"""
There are:
      - {len([y for y in PCR_y if y[1] == 1])} conclusions
      - {len([y for y in PCR_y if y[0] == 1])} premises
""")


There are:
      - 662 conclusions
      - 1857 premises

