In [1]:
from json import load
import numpy as np
from neuralNetwork import NeuralNetwork, In_between_epochs
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from random import shuffle, randint
from transformers import RobertaModel, RobertaTokenizer, AutoTokenizer, AutoModel
from helper import dict_lists_to_list_of_dicts, get_train_valdiation_test_split, Dataset
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score

In [2]:
f = open("ECHR_Corpus.json")
dataset = load(f)
f.close()

# Dataset
ogni caso nel dataset è organizzato così:
```json
{
    "text":"il testo completo della sentenza"
    "clauses":[
        {
            "_id": "id della clause",
            "start": "index di inizio della clause",
            "end": "index di fine della cluause"
        },
        {
            "...":"..."
        }
    ]
    "argument":[
        {
            "premises":[
                "id premise1", "id premise2", "..."
            ],
            "conclusion": "id conclusion"
        },
        {
            "...":"..."
        }
    ]
}
```
In questo passaggio riorganizzo i dati per avere comunque la stessa struttura ma trasformo le clauses in un dizionario che ha come id la chiave della clause e come valore il testo (senza quindi dover usare start e end per cercarlo nel testo). Non tutte le clause sono parte di una premessa o di una conclusione.

In [3]:
refactored_dataset = []
for datapoint in dataset:
    text = datapoint["text"]
    dict_clauses = {}
    for clause in datapoint["clauses"]:
        start = clause["start"]
        end = clause["end"]
        id = clause["_id"]
        dict_clauses[id] = text[start:end]
    refactored_dataset.append({
        "text": text,
        "arguments": datapoint["arguments"],
        "n_clauses": len(datapoint["clauses"]),
        "all_clauses": dict_clauses
    })

Secondo me ha poco senso salvare il dataset come dataframe dato che è praticamente solo testo. Comunque non dovrebbe essere difficilissimo tirarci fuori qualche statistica. Ho fatto degli esempi scemi qua:

In [4]:
n_arguments = []
n_premises = []
n_clauses = []
for case in refactored_dataset:
    n_arguments.append(len(case["arguments"]))
    n_clauses.append(case["n_clauses"])
    for argument in case["arguments"]:
        n_premises.append(len(argument["premises"]))
print(f"""
On average, a case has {np.mean(n_clauses):.2f} clauses with a median of {np.median(n_clauses):.0f} clauses per case.
On average, a case has {np.mean(n_arguments):.2f} arguments with a median of {np.median(n_arguments):.0f} arguments per case.
Each argument, on average, has: {np.mean(n_premises):.2f} premises with a median of {np.median(n_premises):.0f} premises per argument.
""")


On average, a case has 248.95 clauses with a median of 226 clauses per case.
On average, a case has 17.69 arguments with a median of 14 arguments per case.
Each argument, on average, has: 2.63 premises with a median of 2 premises per argument.



Qua puoi vedere come tirare fuori il testo di una conclusion o di una premise: ti basta usare la conclusion/premise come indice nel dizionario delle clause

In [5]:
premise = "\n\t".join([refactored_dataset[0]["all_clauses"][premise] for premise in refactored_dataset[0]["arguments"][0]["premises"]])
conclusion = refactored_dataset[0]["all_clauses"][refactored_dataset[0]["arguments"][0]["conclusion"]]
print(f"""
Here an example of an argument:
    - Premises:
        {premise}
    - Conclusion:
        {conclusion}
""")


Here an example of an argument:
    - Premises:
        The Commission notes that the applicant was detained after having been sentenced by the first instance court to 18 months' imprisonment.
	He was released after the Court of Appeal reviewed this sentence, reducing it to 15 months' imprisonment, convertible to a fine.
    - Conclusion:
        The Commission finds that the applicant was deprived of his liberty "after conviction by a competent court" within the meaning of Article 5 para. 1 (a) (Art. 5-1-a) of the Convention.



## Prepare for a task

### 6.1 Argument Clause Recognition

In [6]:
#9, 4, 2

In [7]:
train_cases, validation_cases, test_cases = get_train_valdiation_test_split(refactored_dataset, [9])

In [8]:
print(f"""there are {len(train_cases)} cases in the training set, 
{len(validation_cases)} cases in the validation set and
{len(test_cases)} cases in the test set""")

there are 35 cases in the training set, 
3 cases in the validation set and
4 cases in the test set


In [9]:
tokenizer = AutoTokenizer.from_pretrained("nlpaueb/legal-bert-base-uncased")

dato che non tutte le clause sono parte di una premise/conclusion, le dividiamo per cercare di tirare fuori un classificatore di clauses

In [10]:
def get_acr_dataloader(cases, tokenizer, batch_size=16, shuffle=False, verbose=True):
    ACR_x = []
    ACR_y = []
    for case in cases:
        n_clauses = case["n_clauses"]
        clauses = case["all_clauses"]
        args_set = set()
        splitter = "AS TO THE LAW" if "AS TO THE LAW" in case["text"] else "THE LAW"
        law_section = case["text"].split(splitter)[1]
        for argument in case["arguments"]:
            ACR_x.append(clauses[argument["conclusion"]])
            ACR_y.append(torch.tensor([1,0]))
            args_set.add(argument["conclusion"])
            for premise in argument["premises"]:
                ACR_x.append(clauses[premise])
                ACR_y.append(torch.tensor([1.,0]))
                args_set.add(premise)
        for clause_id in clauses.keys():
            if not clause_id in args_set and clauses[clause_id] in law_section:
                ACR_x.append(clauses[clause_id])
                ACR_y.append(torch.tensor([0,1.]))
    print(f"""
There are:
      - {len(ACR_x)} clause
      - {len([a for a in ACR_y if a[0] == 1])} true clause
      - {len([a for a in ACR_y if a[0] == 0])} fake clause
""")
    ACR_x_tokenized = dict_lists_to_list_of_dicts(tokenizer(ACR_x, padding=True, truncation=True, return_tensors='pt'))
    dataset = Dataset(ACR_x_tokenized, ACR_y)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
print("train set:")
train_dataloader = get_acr_dataloader(train_cases, tokenizer, shuffle=True, batch_size=8)
print("validation set:")
validation_dataloder = get_acr_dataloader(validation_cases, tokenizer, batch_size=8)
print("test set:")
test_dataloader = get_acr_dataloader(test_cases, tokenizer, batch_size=8)

train set:

There are:
      - 3351 clause
      - 1953 true clause
      - 1398 fake clause

validation set:

There are:
      - 541 clause
      - 284 true clause
      - 257 fake clause

test set:

There are:
      - 798 clause
      - 457 true clause
      - 341 fake clause



# Pre-train

In [11]:
# def to_text_file(dataset):
#     text = ""
#     for case in dataset:
#         text += "\n\n" + case["text"]

#     f = open("pre_train_text.txt", "w")
#     f.write(text)

# to_text_file(refactored_dataset[:30])

In [12]:
# from transformers import BertTokenizer, BertForMaskedLM
# from transformers import LineByLineTextDataset, DataCollatorForLanguageModeling
# from transformers import Trainer, TrainingArguments
# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# model = BertForMaskedLM.from_pretrained('bert-base-uncased')

# train_data_file = "pre_train_text.txt"

# train_dataset = LineByLineTextDataset(
#     tokenizer=tokenizer,
#     file_path=train_data_file,
#     block_size=128,
# )

# data_collator = DataCollatorForLanguageModeling(
#     tokenizer=tokenizer, mlm=True, mlm_probability=0.15
# )

# training_args = TrainingArguments(
#     overwrite_output_dir=True,
#     num_train_epochs=8,  
#     per_device_train_batch_size=16,
#     save_steps=10_000,  
#     save_total_limit=2, 
#     prediction_loss_only=True,
#     report_to="none"
# )

# trainer = Trainer(
#     model=model,
#     args=training_args,
#     data_collator=data_collator,
#     train_dataset=train_dataset,
# )

# # Start pre-training
# trainer.train()
# trainer.save_model("./bert_pretrained/")

# Train

Da qui in avanti ti dovrebbe essere tutto abbastanza familiare dato che è praticamente lo stesso codice dell'assignment 2. Qua ho fatto solo 5 epoche e con un learning rate a caso ma è giusto per far vedere come funziona. Da ora si può iniziare a giocare...

In [13]:
class Model(NeuralNetwork):
    def __init__(self, out_features:int, dropout:float = .3) -> None:
        super().__init__()
        self.encoder = AutoModel.from_pretrained("nlpaueb/legal-bert-base-uncased")
        self.output_layer = nn.Linear(self.encoder.config.hidden_size, out_features)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, input):
        encoded_input, e2 = self.encoder(**input, return_dict = False)
        encoded_input = self.dropout(encoded_input)
        encoded_input = torch.nn.functional.avg_pool1d(
            encoded_input.permute(0, 2, 1), 
            kernel_size=encoded_input.size(1)
        ).squeeze(2)
        return self.output_layer(encoded_input)

    def freeze_bert(self):
        for param in self.encoder.parameters():
                param.requires_grad = False
    def unfreeze_bert(self):
        for param in self.encoder.parameters():
                param.requires_grad = True
    
model = Model(2)

In [14]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("operating on device:", device)

operating on device: cuda:0


In [15]:
class EarlyStopping(In_between_epochs):
    def __init__(self, delta, patience):
        self.delta = delta
        self.patience = patience
        self.current_patience = 0
        self.best_valid_loss = 100000000000000000000
        self.best_model = None
        self.epochs = 0
    
    def __call__(self, model:torch.nn.Module, loaders:dict[str,torch.utils.data.DataLoader], device:'torch.device|str', output_extraction_function, losses:dict[str, float]) -> bool:
        self.epochs += 1
        if losses["validation"] < self.best_valid_loss:
            print("00000000000000000000000000000000000000000000")
            print(f"new best loss found: {losses['validation']}")
            print("00000000000000000000000000000000000000000000")
            self.best_valid_loss = losses["validation"]
            self.best_model = model
            self.current_patience = 0
        else:
            self.current_patience += 1
            print("00000000000000000000000000000000000000000000")
            print(f"old loss was better... patience: {self.current_patience}/{self.patience}")
            print("00000000000000000000000000000000000000000000")
            if self.current_patience >= self.patience:
                return True
        return False
    def reset(self):
        self.current_patience = 0
        self.epochs = 0

In [16]:
def train(model, train, validation, min_lr, start_lr, early_stopping, frac):
    tot_train_data, tot_val_data = [], []
    lr = start_lr
    while lr > min_lr:
        train_data, validation_data = model.train_network(train, 
                        validation, 
                        torch.optim.SGD, 
                        loss_function=nn.CrossEntropyLoss(),
                        device=device, 
                        batch_size=32,
                        verbose=True, 
                        output_extraction_function= lambda x: torch.max(x, -1)[1].view(-1).cpu(), 
                        metrics={
                        "accuracy": accuracy_score, 
                        "f1_score": lambda y_true, y_pred: f1_score(y_true, y_pred, average="macro")},
                        in_between_epochs = {"early_stopping": early_stopping},
                        learning_rate=lr,
                        epochs=30)
        train_data["epochs"] = early_stopping.epochs
        train_data["lr"] = lr
        validation_data["epochs"] = early_stopping.epochs
        validation_data["lr"] = lr
        tot_train_data.append(train_data)
        tot_val_data.append(validation_data)
        model = early_stopping.best_model
        lr = lr * frac
        early_stopping.reset()

In [19]:
_, (x, y) = next(enumerate(train_dataloader))
new_y = torch.max(y, -1)[1].view(-1).cpu()
for i in range(len(y)):
    print(f"orginal y: {y[i]}, trainsformed y: {new_y[i]}")

orginal y: tensor([1., 0.]), trainsformed y: 0
orginal y: tensor([1., 0.]), trainsformed y: 0
orginal y: tensor([1., 0.]), trainsformed y: 0
orginal y: tensor([1., 0.]), trainsformed y: 0
orginal y: tensor([0., 1.]), trainsformed y: 1
orginal y: tensor([0., 1.]), trainsformed y: 1
orginal y: tensor([0., 1.]), trainsformed y: 1
orginal y: tensor([1., 0.]), trainsformed y: 0


In [20]:
pred_y = model(x)

In [21]:
pred_y = model(x)
new_y = torch.max(pred_y, -1)[1].view(-1).cpu()
for i in range(len(y)):
    print(f"orginal y: {y[i]}, trainsformed y: {new_y[i]}")

orginal y: tensor([1., 0.]), trainsformed y: 0
orginal y: tensor([1., 0.]), trainsformed y: 0
orginal y: tensor([1., 0.]), trainsformed y: 1
orginal y: tensor([1., 0.]), trainsformed y: 1
orginal y: tensor([0., 1.]), trainsformed y: 1
orginal y: tensor([0., 1.]), trainsformed y: 1
orginal y: tensor([0., 1.]), trainsformed y: 0
orginal y: tensor([1., 0.]), trainsformed y: 1


In [17]:
def greedy_train(lrs:list[float], train, validation):
    trainings = []
    models = []
    for lr in lrs:
        model = Model(2)
        train_data, validation_data = model.train_network(train, 
                        validation, 
                        torch.optim.SGD, 
                        loss_function=nn.CrossEntropyLoss(),
                        device=device, 
                        batch_size=32,
                        verbose=True, 
                        output_extraction_function= lambda x: torch.max(x, -1)[1].view(-1).cpu(), 
                        metrics={
                        "accuracy": accuracy_score, 
                        "f1_score": lambda y_true, y_pred: f1_score(y_true, y_pred, average="macro")},
                        learning_rate=lr,
                        epochs=15)
        for key in train_data:
                train_data[key] = [float(v) for v in train_data[key]]
                validation_data[key] = [float(v) for v in validation_data[key]]
        trainings.append({"train": train_data, "validation":validation_data, "lr":lr})
        models.append(model)
    return trainings, models

In [18]:
es = EarlyStopping(.1, 3)
train(model, train_dataloader, validation_dataloder, 1e-6, 5e-4, es, .6)

EPOCH 1 training loss:      0.601 - validation loss:      0.751                                                                                                                                                                                                                                                      
EPOCH 1 training accuracy:      0.693 - validation accuracy:      0.561
EPOCH 1 training f1_score:      0.580 - validation f1_score:      0.534
----------------------------------------------------------------------------------------------------

00000000000000000000000000000000000000000000
new best loss found: 0.7507253289222717
00000000000000000000000000000000000000000000
EPOCH 2 training loss:      0.510 - validation loss:      0.725                                                                                                                                                                                                                                                      
EPOC

KeyboardInterrupt: 

In [18]:
lrs = [1e-4, 5e-5, 2e-5, 1e-5, 5e-6, 2e-6, 1e-6, 5e-7, 2e-7, 1e-7]
output, models = greedy_train(lrs, train_dataloader, validation_dataloder)
import json 
f = open("trainings.json", "w")
json.dump(output, f)
f.close()
for i in range(len(models)):
    torch.save(models[i], f"model_lr_{lrs[i]}")

EPOCH 1 training loss:      0.675 - validation loss:      0.707                                                                                                                                                                                                                                                      
EPOCH 1 training accuracy:      0.600 - validation accuracy:      0.539
EPOCH 1 training f1_score:      0.414 - validation f1_score:      0.522
----------------------------------------------------------------------------------------------------

EPOCH 2 training loss:      0.660 - validation loss:      0.715                                                                                                                                                                                                                                                      
EPOCH 2 training accuracy:      0.613 - validation accuracy:      0.537
EPOCH 2 training f1_score:      0.431 - validation f1_score:  

In [18]:
lrs = [1e-2, 5e-2, 2e-2, 2e-3, 5e-3, 2e-3, 2e-4, 5e-4, 2e-4]
output, models = greedy_train(lrs, train_dataloader, validation_dataloder)
import json 
f = open("trainings.json", "w")
json.dump(output, f)
f.close()
for i in range(len(models)):
    torch.save(models[i], f"model_lr_{lrs[i]}")

EPOCH 1 training loss:      0.682 - validation loss:      0.717                                                                                                                                                                                                                                                      
EPOCH 1 training accuracy:      0.583 - validation accuracy:      0.522
EPOCH 1 training f1_score:      0.365 - validation f1_score:      0.507
----------------------------------------------------------------------------------------------------

EPOCH 2 training loss:      0.689 - validation loss:      0.702                                                                                                                                                                                                                                                      
EPOCH 2 training accuracy:      0.583 - validation accuracy:      0.522
EPOCH 2 training f1_score:      0.366 - validation f1_score:  

In [15]:
train_data, validation_data =   model.train_network(train_dataloader, 
                    test_dataloader, 
                    torch.optim.Adam, 
                    loss_function=nn.CrossEntropyLoss(),
                    device=device, 
                    batch_size=32,
                    verbose=True, 
                    output_extraction_function= lambda x: torch.max(x, -1)[1].view(-1).cpu(), 
                    metrics={
                     "accuracy": accuracy_score, 
                     "f1_score": lambda y_true, y_pred: f1_score(y_true, y_pred, average="macro")},
                    learning_rate=1e-5,
                    epochs=5)

EPOCH 1 training loss:      0.685 - validation loss:      0.697                                                                                                                                                                                                                                                      
EPOCH 1 training accuracy:      0.583 - validation accuracy:      0.571
EPOCH 1 training f1_score:      0.366 - validation f1_score:      0.563
----------------------------------------------------------------------------------------------------

EPOCH 2 training loss:      0.681 - validation loss:      0.691                                                                                                                                                                                                                                                      
EPOCH 2 training accuracy:      0.583 - validation accuracy:      0.571
EPOCH 2 training f1_score:      0.369 - validation f1_score:  

## 6.2 Argument Relation Mining

In [8]:
def add_two_premises(idx_1, argument):
    idx_2 = randint(0, len(argument["premises"]) - 1)
    while idx_1 == idx_2:
       idx_2 = randint(0, len(argument["premises"]) - 1)
    premise_1 = argument["premises"][idx_1]
    premise_2 = argument["premises"][idx_2]
    return {
        "e1": clauses[premise_1],
        "e2": clauses[premise_2]
    }, (premise_1, premise_2)

def add_premise_conclusion(idx_1, argument):
    premise = argument["premises"][idx_1]
    conclusion = argument["conclusion"]
    if randint(1, 10) > 5: # with 50% chance we make the premise the first element
        return {
            "e1": clauses[premise],
            "e2": clauses[conclusion]
        }, (premise, conclusion)
    else: # with 50% chance we make the conclusion the first element
        return {
            "e2": clauses[conclusion],
            "e1": clauses[premise]
        }, (premise, conclusion)

In [9]:
# ARM_x = []
# ARM_y = []
# FACTOR = .6
# n_related_two_premises = 0
# n_related_premise_conclusion = 0
# n_unrelated_two_premises = 0
# n_unrelated_two_conclusions = 0
# n_unrelated_premise_conclusion = 0
# for case in refactored_dataset: # for each case
#     n_clauses = case["n_clauses"]
#     clauses = case["all_clauses"]

#     args_set = set() # get the number of clauses in arguments
#     for argument in case["arguments"]: 
#         args_set.add(argument["conclusion"])
#         for premise in argument["premises"]:
#             args_set.add(premise)

#     n_related = int(len(args_set) * FACTOR) # get the number of elements we want from a given case 
#     n_unrelated = int(len(args_set) * FACTOR)
#     already_taken_couples = set()
#     related = []
#     unrelated = []
#     while len(related) < n_related: # generate untill we need more related elements 
#         while True:
#             idx_argument = randint(0, len(case["arguments"]) - 1) # take a random argument
#             argument = case["arguments"][idx_argument]
#             idx_1 = randint(0, len(argument["premises"]) - 1) #take a random premise (we know we need at least 1 premise since there is only 1 conclusion)
#             if len(argument["premises"]) > 1: # some arguments have only 1 premise
#                 two_premises = randint(0,10)
#                 if two_premises >= 4: # with 60% chance we get 2 random premises
#                     element, (id_1, id_2) = add_two_premises(idx_1, argument)
#                     if not (id_1 + id_2 in already_taken_couples or id_2 + id_1 in already_taken_couples): #we add them only if we do not already have selected the combination
#                         related.append(element)
#                         n_related_two_premises += 1
#                         already_taken_couples.add(id_1 + id_2)
#                         break
#                 else: #with 40% chance we get a premise and a conclusion
#                     element, (id_1, id_2) = add_premise_conclusion(idx_1, argument)
#                     if not (id_1 + id_2 in already_taken_couples or id_2 + id_1 in already_taken_couples): #only if the couple has not been already selected
#                         related.append(element)
#                         n_related_premise_conclusion += 1
#                         already_taken_couples.add(id_1 + id_2)
#                         break
#             else: # if we have only one premise we just get the premise and conclusion
#                 element, (id_1, id_2) = add_premise_conclusion(idx_1, argument)
#                 if not (id_1 + id_2 in already_taken_couples or id_2 + id_1 in already_taken_couples):
#                     related.append(element)
#                     n_related_premise_conclusion += 1
#                     already_taken_couples.add(id_1 + id_2)
#                     break

#     while len(unrelated) < n_unrelated: # generate untill we need more unrelated elements 
#         while True:
#             idx_argument_1 = randint(0, len(case["arguments"]) - 1) # select 2 random arguments 
#             idx_argument_2 = randint(0, len(case["arguments"]) - 1)
#             while idx_argument_1 == idx_argument_2: # they must be different
#                 idx_argument_2 = randint(0, len(case["arguments"]) - 1)
#             argument_1 = case["arguments"][idx_argument_1]
#             argument_2 = case["arguments"][idx_argument_2]
#             type_relation = randint(0,9)
#             if type_relation < 3: # with 30% chance we select two premises
#                 id_1 = randint(0, len(argument_1["premises"]) - 1)
#                 id_2 = randint(0, len(argument_2["premises"]) - 1)
#                 id_1 = argument_1["premises"][id_1]
#                 id_2 = argument_2["premises"][id_2]
#                 if not (id_1 + id_2 in already_taken_couples or id_2 + id_1 in already_taken_couples): # only if the couple is not already present in the dataset
#                     unrelated.append({
#                         "e2": clauses[id_1],
#                         "e1": clauses[id_2]
#                     })
#                     n_unrelated_two_premises += 1
#                     already_taken_couples.add(id_1 + id_2)
#                     break
#             elif type_relation < 6: # with 30% chance we select a premise and a conclusion
#                 prem = randint(0, 1) # we select the premise argument at random
#                 args = [argument_1, argument_2]
#                 id_1 = randint(0, len(args[prem]["premises"]) - 1)
#                 id_1 = args[prem]["premises"][id_1]
#                 id_2 = args[abs(prem - 1)]["conclusion"] # the other argument gives the conclusion
#                 if not (id_1 + id_2 in already_taken_couples or id_2 + id_1 in already_taken_couples): # only if the couple is not already present in the dataset
#                     n_unrelated_premise_conclusion += 1
#                     already_taken_couples.add(id_1 + id_2)
#                     if randint(0, 5) > 5: # with 50% chance we put the premise first 
#                         unrelated.append({
#                             "e2": clauses[id_1],
#                             "e1": clauses[id_2]
#                         })
#                     else: # with 50% chance we put the conclusion first 
#                         unrelated.append({
#                             "e2": clauses[id_2],
#                             "e1": clauses[id_1]
#                         })
#                     break
#             else: # with 30% chance we select 2 conclusions
#                 id_1 = argument_1["conclusion"]
#                 id_2 = argument_2["conclusion"]
#                 if not (id_1 + id_2 in already_taken_couples or id_2 + id_1 in already_taken_couples): # only if the couple is not already present in the dataset
#                     unrelated.append({
#                         "e2": clauses[id_1],
#                         "e1": clauses[id_2]
#                     })
#                     n_unrelated_two_conclusions += 1
#                     already_taken_couples.add(id_1 + id_2)
#                     break

#     for related_elem in related: # add all the new elements to the dataset
#         ARM_x.append(related_elem)
#         ARM_y.append(torch.tensor([1., 0]))
#     for unrelated_elem in unrelated:
#         ARM_x.append(unrelated_elem)
#         ARM_y.append(torch.tensor([0, 1.]))
        
# assert len([y for y in ARM_y if y[0] == 0]) == (n_unrelated_two_premises + n_unrelated_premise_conclusion + n_unrelated_two_conclusions)
# assert len([y for y in ARM_y if y[0] == 1.]) == (n_related_premise_conclusion + n_related_two_premises) 
# print(f"""
# There are:
#       - {len(ARM_x)} couples
#       - {n_related_two_premises + n_related_premise_conclusion} related clauses:
#         - {n_related_two_premises} are couple of clauses both premises on the same argument
#         - {n_related_premise_conclusion} are couple of clauses premise and conclusion on the same argument
#       - {n_unrelated_two_premises + n_unrelated_premise_conclusion + n_unrelated_two_conclusions} unrelated clauses:
#         - {n_unrelated_two_premises} are couple with two premises from different arguments
#         - {n_unrelated_premise_conclusion} are couple with a premise and a conclusion from two different arguments
#         - {n_unrelated_two_conclusions} are couple with two conclusions from different arguments
# """)


There are:
      - 3194 couples
      - 1597 related clauses:
        - 777 are couple of clauses both premises on the same argument
        - 820 are couple of clauses premise and conclusion on the same argument
      - 1597 unrelated clauses:
        - 485 are couple with two premises from different arguments
        - 488 are couple with a premise and a conclusion from two different arguments
        - 624 are couple with two conclusions from different arguments



In [7]:
args_set = set()    #Prepare ids of clauses that are parts of arguments
for case in refactored_dataset: # for each case
    for argument in case["arguments"]: 
        args_set.add(argument["conclusion"])
        for premise in argument["premises"]:
            args_set.add(premise)

ARM_x = []
ARM_y = []
for caseidx in range(len(dataset)):
    sorted_clauses = sorted(dataset[caseidx]['clauses'], key = lambda x: x['start'])
    for i in range(len(sorted_clauses)-1):
        if sorted_clauses[i]['_id'] in args_set:
            for el in sorted_clauses[i+1:i+6]:
                if el['_id'] in args_set:
                    ARM_x.append({'e1': refactored_dataset[caseidx]['all_clauses'][sorted_clauses[i]['_id']], 'e2': refactored_dataset[caseidx]['all_clauses'][el['_id']]})
                    y = torch.tensor([0, 1.])
                    for arg in refactored_dataset[caseidx]['arguments']:
                        if sorted_clauses[i]['_id'] in arg['premises'] or sorted_clauses[i]['_id'] == arg['conclusion']:
                            if el['_id'] in arg['premises'] or sorted_clauses[i]['_id'] == arg['conclusion']:
                                y = torch.tensor([1., 0])
                    ARM_y.append(y)

print(len(ARM_x))
print(len([ARM_y[i] for i in range(len(ARM_y)) if ARM_y[i][0]]), len([ARM_y[i] for i in range(len(ARM_y)) if not ARM_y[i][0]]))

10430
5161 5269


### Train

Model1 prende le due componenti e le analizza indipendentemente per poi dare un output

Molel2 concatena componente 1 e componente 2 per analizzarle assieme

In [73]:
class Model1(NeuralNetwork):
    def __init__(self, out_features:int, dropout:float = .3) -> None:
        super().__init__()
        self.encoder = RobertaModel.from_pretrained("FacebookAI/roberta-base")
        self.output_layer = nn.Linear(self.encoder.config.hidden_size * 2, out_features)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, input):
        _, encoded_input_1 = self.encoder(**input["e1"], return_dict = False)
        _, encoded_input_2 = self.encoder(**input["e2"], return_dict = False)
        encoded_input = self.dropout(torch.cat((encoded_input_1, encoded_input_2), dim=1))
    
        return self.output_layer(encoded_input)
    
class Model2(NeuralNetwork):
    def __init__(self, out_features:int, dropout:float = .3) -> None:
        super().__init__()
        self.encoder = RobertaModel.from_pretrained("FacebookAI/roberta-base")
        self.output_layer = nn.Linear(self.encoder.config.hidden_size, out_features)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, input):
        input = {key: torch.cat((input["e1"][key], input["e2"][key]), dim=1) for key in input["e1"].keys()}
        _, encoded_input = self.encoder(**input, return_dict = False)
        encoded_input = self.dropout(encoded_input)
    
        return self.output_layer(encoded_input)

model1 = Model1(2)
model2 = Model2(2)

Some weights of RobertaModel were not initialized from the model checkpoint at FacebookAI/roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at FacebookAI/roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [71]:
ARM_x_tokenized = dict_lists_to_list_of_dicts({
    key : dict_lists_to_list_of_dicts(tokenizer([v[key] for v in ARM_x], padding=True, truncation=True, return_tensors='pt'))
for key in ARM_x[0].keys()})
train_dataloader, validation_dataloader, test_dataloader = get_dataloader(ARM_x_tokenized, ARM_y, 8, [9])

In [74]:
train_data, validation_data =   model1.train_network(train_dataloader, 
                    validation_dataloader, 
                    torch.optim.SGD, 
                    loss_function=nn.CrossEntropyLoss(),
                    device=device, 
                    batch_size=32,
                    verbose=True, 
                    output_extraction_function= lambda x: torch.max(x, -1)[1].view(-1).cpu(), 
                    metrics={
                     "accuracy": accuracy_score, 
                     "f1_score": lambda y_true, y_pred: f1_score(y_true, y_pred, average="macro")},
                    learning_rate=1e-5,
                    epochs=5)

EPOCH 1 training loss:      0.696 - validation loss:      0.711                                                                                                                                                                                                                                                      
EPOCH 1 training accuracy:      0.500 - validation accuracy:      0.527
EPOCH 1 training f1_score:      0.328 - validation f1_score:      0.513
----------------------------------------------------------------------------------------------------

EPOCH 2 training loss:      0.696 - validation loss:      0.712                                                                                                                                                                                                                                                      
EPOCH 2 training accuracy:      0.501 - validation accuracy:      0.527
EPOCH 2 training f1_score:      0.324 - validation f1_score:  

In [75]:
train_data, validation_data =   model2.train_network(train_dataloader, 
                    validation_dataloader, 
                    torch.optim.SGD, 
                    loss_function=nn.CrossEntropyLoss(),
                    device=device, 
                    batch_size=32,
                    verbose=True, 
                    output_extraction_function= lambda x: torch.max(x, -1)[1].view(-1).cpu(), 
                    metrics={
                     "accuracy": accuracy_score, 
                     "f1_score": lambda y_true, y_pred: f1_score(y_true, y_pred, average="macro")},
                    learning_rate=1e-5,
                    epochs=5)

EPOCH 1 training loss:      0.699 - validation loss:      0.721                                                                                                                                                                                                                                                      
EPOCH 1 training accuracy:      0.500 - validation accuracy:      0.473
EPOCH 1 training f1_score:      0.323 - validation f1_score:      0.445
----------------------------------------------------------------------------------------------------

EPOCH 2 training loss:      0.697 - validation loss:      0.717                                                                                                                                                                                                                                                      
EPOCH 2 training accuracy:      0.500 - validation accuracy:      0.473
EPOCH 2 training f1_score:      0.324 - validation f1_score:  

## 6.3 Premise/Conclusion Recognition

In [12]:
PCR_x = []
PCR_y = []
all_premises = set()
all_conclusions = set()
all_clause = set()
for case in refactored_dataset:
    n_clauses = case["n_clauses"]
    clauses = case["all_clauses"]
    for argument in case["arguments"]:
        all_conclusions.add(clauses[argument["conclusion"]])
        all_clause.add(clauses[argument["conclusion"]])
        for premise in argument["premises"]:
            all_premises.add(clauses[premise])
            all_clause.add(clauses[premise])

for clause in all_clause:
    current_y = torch.zeros(2)
    if clause in all_premises:
        current_y[0] = 1
    if clause in all_conclusions:
        current_y[1] = 1
    PCR_x.append(clause)
    PCR_y.append(current_y)
print(f"""
There are:
      - {len([y for y in PCR_y if y[1] == 1])} conclusions
      - {len([y for y in PCR_y if y[0] == 1])} premises
""")


There are:
      - 662 conclusions
      - 1857 premises

