# Parameters 

In [1]:
# Model
model_checkpoint = 'bert-base-uncased'
batch_size = 2
metric_name = "accuracy"
num_epoch = 10

# Fold
num_folds = 4

# Experiment
labels = ["none", "causal", "contrast", "equivalence", "identity", "temporal", "others"]
def index_of_label(val):
    global labels
    try:
        return labels.index(val)
    except ValueError:
        return len(labels_subset) - 1

# Import

In [2]:
import torch
import numpy as np
import random
import pandas as pd
from IPython.display import display, HTML

In [3]:
def import_fold(path, fold):
    train = pd.read_csv(f"{path}/train.{fold}.csv")
    test = pd.read_csv(f"{path}/test.{fold}.csv")
    train_origin = train["origin"].tolist()
    train_target = train["target"].tolist()
    train_labels = train["label"].tolist()
    test_origin = test["origin"].tolist()
    test_target = test["target"].tolist()
    test_labels = test["label"].tolist()
    return train_origin, train_target, train_labels, test_origin, test_target, test_labels

# Model 

## Metric

In [4]:
from sklearn.metrics import classification_report
import collections

#classification_threshold = 0.

def flatten(d, parent_key='', sep='__'):
    items = []
    for k, v in d.items():
        new_key = parent_key + sep + k if parent_key else k
        if isinstance(v, collections.MutableMapping):
            items.extend(flatten(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)

def compute_metrics(eval_pred):
    global labels
    predictions, true_labels = eval_pred
    # take most probable guess
    predictions = np.argmax(predictions, axis=-1)
    return flatten(classification_report(
        y_true=true_labels,
        y_pred=predictions,
        target_names=labels,
        zero_division=0,
        output_dict=True))

In [5]:
#TEST
#flatten(classification_report(
#    y_true=[0,1,2,3,4,5,6,7,8,9,10,11,12],
#    y_pred=[0,0,0,1,3,0,0,0,0,0,0,0,0],
#    target_names=labels,
#    zero_division=0,
#    output_dict=True))

## Model Settings

In [6]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

args = TrainingArguments(
    "semantic-test",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=num_epoch,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
)

## Tokenize

In [7]:
from transformers import BertTokenizerFast, DebertaTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained(model_checkpoint)
#tokenizer = DebertaTokenizerFast.from_pretrained(model_checkpoint)

## Print Examples

In [8]:
#train_encodings

In [9]:
def show_random_elements(origin_list, target_list, label_list, encodings, num_examples=10):
    global labels
    assert num_examples <= len(origin_list), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(origin_list)-1)
        while pick in picks:
            pick = random.randint(0, len(origin_list)-1)
        picks.append(pick)
    data = []
    for n in picks:
        data.append([n, origin_list[n], labels[label_list[n]], target_list[n], encodings.input_ids[n], encodings.token_type_ids[n], encodings.attention_mask[n]])
    df = pd.DataFrame(data, columns=['index', 'Origin', 'Label', 'Target', 'Input_ids', 'Token_type_ids', 'Attention_mask'])
    display(HTML(df.to_html()))

In [10]:
# show_random_elements(train_origin, train_target, train_labels, train_encodings)
# Output adjustet to Folds
#show_random_elements(k_fold_origin[0][0], k_fold_target[0][0], k_fold_labels[0][0], train_encodings[0])

## Create Dataset

In [11]:
class SemanticDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

# Model

## Train & Evaluate

In [12]:
result = []
num_labels = len(labels)
models = []

for i in range(num_folds):
    model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)
    # import Fold data
    train_origin, train_target, train_labels, test_origin, test_target, test_labels = import_fold("data/export-subset", i)
    # tokenize
    train_encodings = tokenizer(train_origin, train_target, truncation=True, padding=True, return_token_type_ids=True)
    test_encodings = tokenizer(test_origin, test_target, truncation=True, padding=True, return_token_type_ids=True)
    # dataset creation
    train_dataset = SemanticDataset(train_encodings, train_labels)
    test_dataset = SemanticDataset(test_encodings, test_labels)
    # create Trainer
    trainer = Trainer(
        model,
        args,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics
    )
    # train & evaluate
    trainer.train()
    ev = trainer.evaluate(test_dataset)
    acc = ev["eval_accuracy"]
    print(f"Accuracy: {acc}")
    result.append(ev)
    models.append(trainer)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

Epoch,Training Loss,Validation Loss,None Precision,None Recall,None F1-score,None Support,Causal Precision,Causal Recall,Causal F1-score,Causal Support,Contrast Precision,Contrast Recall,Contrast F1-score,Contrast Support,Equivalence Precision,Equivalence Recall,Equivalence F1-score,Equivalence Support,Identity Precision,Identity Recall,Identity F1-score,Identity Support,Temporal Precision,Temporal Recall,Temporal F1-score,Temporal Support,Others Precision,Others Recall,Others F1-score,Others Support,Accuracy,Macro avg Precision,Macro avg Recall,Macro avg F1-score,Macro avg Support,Weighted avg Precision,Weighted avg Recall,Weighted avg F1-score,Weighted avg Support
1,1.1453,1.207678,0.76121,0.920555,0.833333,793,0.200743,0.495413,0.285714,109,0.0,0.0,0.0,42,0.5,0.023529,0.044944,85,0.0,0.0,0.0,20,0.263158,0.034014,0.060241,147,0.0,0.0,0.0,55,0.632294,0.246444,0.210502,0.17489,1251,0.564912,0.632294,0.563271,1251
2,0.9241,1.215784,0.805869,0.900378,0.850506,793,0.329545,0.266055,0.294416,109,0.0,0.0,0.0,42,0.321212,0.623529,0.424,85,0.0,0.0,0.0,20,0.464286,0.353741,0.401544,147,0.0,0.0,0.0,55,0.677858,0.274416,0.306243,0.281495,1251,0.615929,0.677858,0.640775,1251
3,0.8057,1.203074,0.888418,0.79319,0.838108,793,0.401786,0.412844,0.40724,109,0.0,0.0,0.0,42,0.378571,0.623529,0.471111,85,0.5,0.05,0.090909,20,0.355805,0.646259,0.458937,147,0.238095,0.090909,0.131579,55,0.661871,0.394668,0.373819,0.342555,1251,0.684163,0.661871,0.65993,1251
4,0.6205,1.479166,0.882759,0.807062,0.843215,793,0.38961,0.550459,0.456274,109,0.285714,0.142857,0.190476,42,0.425532,0.705882,0.530973,85,0.583333,0.35,0.4375,20,0.435294,0.503401,0.466877,147,0.214286,0.109091,0.144578,55,0.681855,0.459504,0.452679,0.438556,1251,0.701923,0.681855,0.684947,1251
5,0.4301,1.622784,0.845777,0.871375,0.858385,793,0.496,0.568807,0.529915,109,0.434783,0.238095,0.307692,42,0.643836,0.552941,0.594937,85,0.5,0.55,0.52381,20,0.424051,0.455782,0.439344,147,0.333333,0.2,0.25,55,0.718625,0.525397,0.491,0.500583,1251,0.710169,0.718625,0.71204,1251
6,0.2889,1.675451,0.872914,0.857503,0.86514,793,0.570093,0.559633,0.564815,109,0.339286,0.452381,0.387755,42,0.765625,0.576471,0.657718,85,0.5,0.6,0.545455,20,0.427632,0.442177,0.434783,147,0.26087,0.327273,0.290323,55,0.722622,0.533774,0.545062,0.535141,1251,0.73613,0.722622,0.7279,1251
7,0.2331,1.874238,0.846626,0.870113,0.858209,793,0.635417,0.559633,0.595122,109,0.633333,0.452381,0.527778,42,0.777778,0.576471,0.662162,85,0.518519,0.7,0.595745,20,0.448276,0.442177,0.445205,147,0.266667,0.363636,0.307692,55,0.733813,0.589516,0.566344,0.570273,1251,0.738832,0.733813,0.733942,1251
8,0.1023,1.987502,0.873533,0.844893,0.858974,793,0.565891,0.669725,0.613445,109,0.659091,0.690476,0.674419,42,0.691358,0.658824,0.674699,85,0.684211,0.65,0.666667,20,0.422819,0.428571,0.425676,147,0.387097,0.436364,0.410256,55,0.741807,0.612,0.62555,0.617734,1251,0.749776,0.741807,0.745147,1251
9,0.0215,2.066607,0.865311,0.858764,0.862025,793,0.544776,0.669725,0.600823,109,0.517241,0.714286,0.6,42,0.753425,0.647059,0.696203,85,0.684211,0.65,0.666667,20,0.482759,0.380952,0.425856,147,0.3125,0.363636,0.336134,55,0.741807,0.594318,0.61206,0.598244,1251,0.745943,0.741807,0.741706,1251
10,0.0125,2.080004,0.874674,0.844893,0.859525,793,0.541985,0.651376,0.591667,109,0.617021,0.690476,0.651685,42,0.726027,0.623529,0.670886,85,0.736842,0.7,0.717949,20,0.456376,0.462585,0.459459,147,0.333333,0.4,0.363636,55,0.741007,0.612323,0.624694,0.616401,1251,0.75178,0.741007,0.745316,1251


  if isinstance(v, collections.MutableMapping):


Accuracy: 0.7418065547561951


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

Epoch,Training Loss,Validation Loss,None Precision,None Recall,None F1-score,None Support,Causal Precision,Causal Recall,Causal F1-score,Causal Support,Contrast Precision,Contrast Recall,Contrast F1-score,Contrast Support,Equivalence Precision,Equivalence Recall,Equivalence F1-score,Equivalence Support,Identity Precision,Identity Recall,Identity F1-score,Identity Support,Temporal Precision,Temporal Recall,Temporal F1-score,Temporal Support,Others Precision,Others Recall,Others F1-score,Others Support,Accuracy,Macro avg Precision,Macro avg Recall,Macro avg F1-score,Macro avg Support,Weighted avg Precision,Weighted avg Recall,Weighted avg F1-score,Weighted avg Support
1,1.1941,1.175283,0.77131,0.935687,0.845584,793,0.0,0.0,0.0,109,0.0,0.0,0.0,41,0.242215,0.823529,0.374332,85,0.0,0.0,0.0,20,0.0,0.0,0.0,147,0.0,0.0,0.0,56,0.649081,0.144789,0.251317,0.174274,1251,0.505385,0.649081,0.561444,1251
2,0.923,0.933314,0.869509,0.848676,0.858966,793,0.419355,0.238532,0.304094,109,0.0,0.0,0.0,41,0.268398,0.729412,0.392405,85,0.0,0.0,0.0,20,0.48913,0.612245,0.543807,147,0.0,0.0,0.0,56,0.680256,0.292342,0.346981,0.299896,1251,0.663426,0.680256,0.661551,1251
3,0.788,1.057299,0.897606,0.851198,0.873786,793,0.406015,0.495413,0.446281,109,0.0,0.0,0.0,41,0.393939,0.611765,0.479263,85,0.0,0.0,0.0,20,0.439252,0.639456,0.520776,147,0.25,0.089286,0.131579,56,0.703437,0.340973,0.383874,0.350241,1251,0.693935,0.703437,0.69242,1251
4,0.6824,1.381366,0.885604,0.868852,0.877148,793,0.423611,0.559633,0.482213,109,0.0,0.0,0.0,41,0.519608,0.623529,0.566845,85,1.0,0.1,0.181818,20,0.39899,0.537415,0.457971,147,0.259259,0.125,0.168675,56,0.71223,0.498153,0.402061,0.390667,1251,0.708069,0.71223,0.70082,1251
5,0.4563,1.483635,0.867485,0.891551,0.879353,793,0.452229,0.651376,0.533835,109,0.5,0.02439,0.046512,41,0.590476,0.729412,0.652632,85,0.428571,0.15,0.222222,20,0.503937,0.435374,0.467153,147,0.210526,0.142857,0.170213,56,0.732214,0.507604,0.432137,0.42456,1251,0.721294,0.732214,0.715862,1251
6,0.2369,1.656163,0.869193,0.896595,0.882682,793,0.514493,0.651376,0.574899,109,0.545455,0.146341,0.230769,41,0.662338,0.6,0.62963,85,0.7,0.35,0.466667,20,0.513158,0.530612,0.521739,147,0.333333,0.267857,0.29703,56,0.7506,0.591138,0.491826,0.514774,1251,0.745094,0.7506,0.742025,1251
7,0.1114,1.775691,0.88191,0.885246,0.883575,793,0.440678,0.715596,0.545455,109,0.482759,0.341463,0.4,41,0.679012,0.647059,0.662651,85,0.846154,0.55,0.666667,20,0.515625,0.44898,0.48,147,0.481481,0.232143,0.313253,56,0.7506,0.618231,0.545784,0.564514,1251,0.75506,0.7506,0.746835,1251
8,0.081,1.874965,0.893561,0.857503,0.875161,793,0.437186,0.798165,0.564935,109,0.435897,0.414634,0.425,41,0.776119,0.611765,0.684211,85,0.866667,0.65,0.742857,20,0.521739,0.489796,0.505263,147,0.46875,0.267857,0.340909,56,0.748201,0.62856,0.584246,0.591191,1251,0.76768,0.748201,0.750907,1251
9,0.0218,1.861705,0.860409,0.901639,0.880542,793,0.603448,0.642202,0.622222,109,0.52381,0.536585,0.53012,41,0.786667,0.694118,0.7375,85,0.8,0.6,0.685714,20,0.515873,0.442177,0.47619,147,0.456522,0.375,0.411765,56,0.770584,0.649533,0.598817,0.620579,1251,0.762447,0.770584,0.765218,1251
10,0.0185,1.91651,0.870807,0.883985,0.877347,793,0.561538,0.669725,0.610879,109,0.462963,0.609756,0.526316,41,0.794521,0.682353,0.734177,85,0.866667,0.65,0.742857,20,0.485507,0.455782,0.470175,147,0.555556,0.357143,0.434783,56,0.764988,0.656794,0.615535,0.628076,1251,0.765858,0.764988,0.763091,1251


Accuracy: 0.7705835331734612


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

Epoch,Training Loss,Validation Loss,None Precision,None Recall,None F1-score,None Support,Causal Precision,Causal Recall,Causal F1-score,Causal Support,Contrast Precision,Contrast Recall,Contrast F1-score,Contrast Support,Equivalence Precision,Equivalence Recall,Equivalence F1-score,Equivalence Support,Identity Precision,Identity Recall,Identity F1-score,Identity Support,Temporal Precision,Temporal Recall,Temporal F1-score,Temporal Support,Others Precision,Others Recall,Others F1-score,Others Support,Accuracy,Macro avg Precision,Macro avg Recall,Macro avg F1-score,Macro avg Support,Weighted avg Precision,Weighted avg Recall,Weighted avg F1-score,Weighted avg Support
1,1.1473,1.092206,0.765756,0.919294,0.83553,793,0.241667,0.268519,0.254386,108,0.0,0.0,0.0,41,0.338983,0.470588,0.394089,85,0.0,0.0,0.0,20,0.3,0.121622,0.173077,148,0.0,0.0,0.0,55,0.6528,0.235201,0.254289,0.236726,1250,0.565247,0.6528,0.59933,1250
2,0.9586,1.343463,0.788889,0.895334,0.838748,793,0.224719,0.185185,0.203046,108,0.0,0.0,0.0,41,0.317365,0.623529,0.420635,85,0.0,0.0,0.0,20,0.340426,0.216216,0.264463,148,0.0,0.0,0.0,55,0.652,0.238771,0.274324,0.246699,1250,0.581774,0.652,0.60956,1250
3,0.7513,1.22751,0.837349,0.876419,0.856439,793,0.321918,0.435185,0.370079,108,0.0,0.0,0.0,41,0.388889,0.576471,0.464455,85,0.333333,0.1,0.153846,20,0.433628,0.331081,0.375479,148,0.275862,0.145455,0.190476,55,0.68,0.37014,0.352087,0.344396,1250,0.654285,0.68,0.662182,1250
4,0.6072,1.511092,0.858667,0.812106,0.834738,793,0.385714,0.5,0.435484,108,0.0,0.0,0.0,41,0.62963,0.6,0.614458,85,1.0,0.1,0.181818,20,0.368421,0.52027,0.431373,148,0.264706,0.327273,0.292683,55,0.6768,0.50102,0.408521,0.39865,1250,0.692147,0.6768,0.675828,1250
5,0.4741,1.856478,0.856404,0.834805,0.845466,793,0.357488,0.685185,0.469841,108,1.0,0.02439,0.047619,41,0.686567,0.541176,0.605263,85,0.7,0.35,0.466667,20,0.376712,0.371622,0.37415,148,0.26087,0.218182,0.237624,55,0.6856,0.605434,0.432194,0.435233,1250,0.720957,0.6856,0.681899,1250
6,0.2399,2.021365,0.84204,0.85372,0.84784,793,0.448276,0.601852,0.513834,108,0.6,0.219512,0.321429,41,0.753425,0.647059,0.696203,85,0.818182,0.45,0.580645,20,0.4,0.391892,0.395904,148,0.280702,0.290909,0.285714,55,0.7112,0.591803,0.493563,0.520224,1250,0.716636,0.7112,0.708886,1250
7,0.1317,2.159104,0.837209,0.862547,0.849689,793,0.507937,0.592593,0.547009,108,0.535714,0.365854,0.434783,41,0.761194,0.6,0.671053,85,0.785714,0.55,0.647059,20,0.360544,0.358108,0.359322,148,0.27451,0.254545,0.264151,55,0.7136,0.580403,0.51195,0.539009,1250,0.711682,0.7136,0.710716,1250
8,0.0438,2.19103,0.850126,0.851198,0.850662,793,0.551181,0.648148,0.595745,108,0.489796,0.585366,0.533333,41,0.771429,0.635294,0.696774,85,0.833333,0.5,0.625,20,0.386207,0.378378,0.382253,148,0.283019,0.272727,0.277778,55,0.7232,0.595013,0.553016,0.565935,1250,0.726977,0.7232,0.723487,1250
9,0.0414,2.370014,0.853846,0.839849,0.84679,793,0.543307,0.638889,0.587234,108,0.489362,0.560976,0.522727,41,0.782609,0.635294,0.701299,85,0.875,0.7,0.777778,20,0.361111,0.351351,0.356164,148,0.283582,0.345455,0.311475,55,0.7176,0.598402,0.581688,0.58621,1250,0.727123,0.7176,0.721093,1250
10,0.007,2.400868,0.84557,0.842371,0.843967,793,0.514925,0.638889,0.570248,108,0.5,0.634146,0.55914,41,0.809524,0.6,0.689189,85,0.866667,0.65,0.742857,20,0.368056,0.358108,0.363014,148,0.307692,0.290909,0.299065,55,0.7168,0.601776,0.573489,0.581069,1250,0.723349,0.7168,0.717912,1250


Accuracy: 0.7232


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

Epoch,Training Loss,Validation Loss,None Precision,None Recall,None F1-score,None Support,Causal Precision,Causal Recall,Causal F1-score,Causal Support,Contrast Precision,Contrast Recall,Contrast F1-score,Contrast Support,Equivalence Precision,Equivalence Recall,Equivalence F1-score,Equivalence Support,Identity Precision,Identity Recall,Identity F1-score,Identity Support,Temporal Precision,Temporal Recall,Temporal F1-score,Temporal Support,Others Precision,Others Recall,Others F1-score,Others Support,Accuracy,Macro avg Precision,Macro avg Recall,Macro avg F1-score,Macro avg Support,Weighted avg Precision,Weighted avg Recall,Weighted avg F1-score,Weighted avg Support
1,1.1282,1.089031,0.770314,0.896595,0.828671,793,0.191126,0.518519,0.279302,108,0.0,0.0,0.0,41,0.0,0.0,0.0,85,0.0,0.0,0.0,20,0.5,0.114865,0.186813,148,0.0,0.0,0.0,55,0.6272,0.208777,0.218568,0.184969,1250,0.564401,0.6272,0.571959,1250
2,1.0029,1.085252,0.884097,0.827238,0.854723,793,0.34127,0.398148,0.367521,108,0.0,0.0,0.0,41,0.254054,0.552941,0.348148,85,0.0,0.0,0.0,20,0.436548,0.581081,0.498551,148,0.0,0.0,0.0,55,0.6656,0.27371,0.337058,0.295563,1250,0.65932,0.6656,0.656693,1250
3,0.7049,1.341398,0.809795,0.896595,0.850987,793,0.351351,0.240741,0.285714,108,0.0,0.0,0.0,41,0.358491,0.447059,0.397906,85,0.0,0.0,0.0,20,0.370166,0.452703,0.407295,148,0.2,0.036364,0.061538,55,0.6752,0.298543,0.296209,0.286206,1250,0.621096,0.6752,0.642541,1250
4,0.6216,1.542172,0.833537,0.858764,0.845963,793,0.416107,0.574074,0.48249,108,0.0,0.0,0.0,41,0.388889,0.494118,0.435233,85,0.0,0.0,0.0,20,0.41958,0.405405,0.412371,148,0.433333,0.236364,0.305882,55,0.6864,0.355921,0.366961,0.354563,1250,0.659937,0.6864,0.670245,1250
5,0.377,1.911284,0.825208,0.875158,0.849449,793,0.403361,0.444444,0.422907,108,0.0,0.0,0.0,41,0.758621,0.258824,0.385965,85,0.571429,0.4,0.470588,20,0.343434,0.459459,0.393064,148,0.234043,0.2,0.215686,55,0.6808,0.448014,0.376841,0.391094,1250,0.670052,0.6808,0.665234,1250
6,0.2216,1.86011,0.83774,0.878941,0.857846,793,0.478261,0.611111,0.536585,108,0.428571,0.292683,0.347826,41,0.771429,0.317647,0.45,85,0.705882,0.6,0.648649,20,0.421053,0.432432,0.426667,148,0.375,0.327273,0.349515,55,0.7168,0.573991,0.494298,0.516727,1250,0.716945,0.7168,0.708862,1250
7,0.0892,2.003365,0.853933,0.862547,0.858218,793,0.440252,0.648148,0.524345,108,0.444444,0.292683,0.352941,41,0.672414,0.458824,0.545455,85,0.611111,0.55,0.578947,20,0.473684,0.425676,0.448399,148,0.333333,0.327273,0.330275,55,0.7176,0.547024,0.509307,0.519797,1250,0.720603,0.7176,0.71531,1250
8,0.0546,2.063924,0.853015,0.856242,0.854626,793,0.503759,0.62037,0.556017,108,0.466667,0.341463,0.394366,41,0.722222,0.458824,0.561151,85,0.695652,0.8,0.744186,20,0.434783,0.472973,0.453074,148,0.377358,0.363636,0.37037,55,0.724,0.579065,0.559073,0.56197,1250,0.728308,0.724,0.723155,1250
9,0.0471,2.157789,0.833726,0.891551,0.86167,793,0.514286,0.666667,0.580645,108,0.472222,0.414634,0.441558,41,0.813953,0.411765,0.546875,85,0.666667,0.7,0.682927,20,0.483051,0.385135,0.428571,148,0.386364,0.309091,0.343434,55,0.7352,0.595753,0.539835,0.555097,1250,0.729048,0.7352,0.725262,1250
10,0.0237,2.116953,0.852357,0.86633,0.859287,793,0.514706,0.648148,0.57377,108,0.511628,0.536585,0.52381,41,0.78,0.458824,0.577778,85,0.681818,0.75,0.714286,20,0.457746,0.439189,0.448276,148,0.372549,0.345455,0.358491,55,0.7336,0.595829,0.57779,0.579385,1250,0.736526,0.7336,0.731453,1250


Accuracy: 0.7352


## Interpret evaluation

### Helper functions

In [13]:
def mean(data):
    """Return the sample arithmetic mean of data."""
    n = len(data)
    if n < 1:
        raise ValueError('mean requires at least one data point')
    return sum(data)/n # in Python 2 use sum(data)/float(n)

def _ss(data):
    """Return sum of square deviations of sequence data."""
    c = mean(data)
    ss = sum((x-c)**2 for x in data)
    return ss

def stddev(data, ddof=0):
    """Calculates the population standard deviation
    by default; specify ddof=1 to compute the sample
    standard deviation."""
    n = len(data)
    if n < 2:
        raise ValueError('variance requires at least two data points')
    ss = _ss(data)
    pvar = ss/(n-ddof)
    return pvar**0.5

### Prepare Data

In [14]:
result

[{'eval_loss': 1.9875022172927856,
  'eval_none__precision': 0.8735332464146024,
  'eval_none__recall': 0.8448928121059268,
  'eval_none__f1-score': 0.8589743589743589,
  'eval_none__support': 793,
  'eval_causal__precision': 0.5658914728682171,
  'eval_causal__recall': 0.6697247706422018,
  'eval_causal__f1-score': 0.6134453781512605,
  'eval_causal__support': 109,
  'eval_contrast__precision': 0.6590909090909091,
  'eval_contrast__recall': 0.6904761904761905,
  'eval_contrast__f1-score': 0.6744186046511628,
  'eval_contrast__support': 42,
  'eval_equivalence__precision': 0.691358024691358,
  'eval_equivalence__recall': 0.6588235294117647,
  'eval_equivalence__f1-score': 0.674698795180723,
  'eval_equivalence__support': 85,
  'eval_identity__precision': 0.6842105263157895,
  'eval_identity__recall': 0.65,
  'eval_identity__f1-score': 0.6666666666666667,
  'eval_identity__support': 20,
  'eval_temporal__precision': 0.4228187919463087,
  'eval_temporal__recall': 0.42857142857142855,
  '

In [15]:

def transform_to_regular_dict(result):
    output_dict = {}
    count = 0
    for eval_item in result:
        for key in eval_item:
            if count == 0:
              output_dict[key] = [float(eval_item[key])]
            else:
              output_dict[key].append(eval_item[key]) 
        count += 1
    return output_dict
            
eval_dict = transform_to_regular_dict(result)
eval_df = pd.DataFrame(eval_dict)

def add_mean_std_row(df):
    row_mean = []
    row_std = []
    for column in df:
        row_mean.append(mean(df[column]))
        row_std.append(stddev(df[column], ddof=1))
    df = df.append(pd.DataFrame([row_mean], columns=df.columns), ignore_index=True)
    df = df.append(pd.DataFrame([row_std], columns=df.columns), ignore_index=True)
    # add better readable Index
    df["fold"] = ["1", "2", "3", "4", "avg", "std"]
    df = df.set_index("fold")
    return df

eval_df = add_mean_std_row(eval_df)
display(HTML(eval_df.to_html()))

Unnamed: 0_level_0,eval_loss,eval_none__precision,eval_none__recall,eval_none__f1-score,eval_none__support,eval_causal__precision,eval_causal__recall,eval_causal__f1-score,eval_causal__support,eval_contrast__precision,eval_contrast__recall,eval_contrast__f1-score,eval_contrast__support,eval_equivalence__precision,eval_equivalence__recall,eval_equivalence__f1-score,eval_equivalence__support,eval_identity__precision,eval_identity__recall,eval_identity__f1-score,eval_identity__support,eval_temporal__precision,eval_temporal__recall,eval_temporal__f1-score,eval_temporal__support,eval_others__precision,eval_others__recall,eval_others__f1-score,eval_others__support,eval_accuracy,eval_macro avg__precision,eval_macro avg__recall,eval_macro avg__f1-score,eval_macro avg__support,eval_weighted avg__precision,eval_weighted avg__recall,eval_weighted avg__f1-score,eval_weighted avg__support,eval_runtime,eval_samples_per_second,epoch
fold,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1
1,1.987502,0.873533,0.844893,0.858974,793.0,0.565891,0.669725,0.613445,109.0,0.659091,0.690476,0.674419,42.0,0.691358,0.658824,0.674699,85.0,0.684211,0.65,0.666667,20.0,0.422819,0.428571,0.425676,147.0,0.387097,0.436364,0.410256,55.0,0.741807,0.612,0.62555,0.617734,1251.0,0.749776,0.741807,0.745147,1251.0,9.7904,127.779,10.0
2,1.861705,0.860409,0.901639,0.880542,793.0,0.603448,0.642202,0.622222,109.0,0.52381,0.536585,0.53012,41.0,0.786667,0.694118,0.7375,85.0,0.8,0.6,0.685714,20.0,0.515873,0.442177,0.47619,147.0,0.456522,0.375,0.411765,56.0,0.770584,0.649533,0.598817,0.620579,1251.0,0.762447,0.770584,0.765218,1251.0,9.9274,126.014,10.0
3,2.19103,0.850126,0.851198,0.850662,793.0,0.551181,0.648148,0.595745,108.0,0.489796,0.585366,0.533333,41.0,0.771429,0.635294,0.696774,85.0,0.833333,0.5,0.625,20.0,0.386207,0.378378,0.382253,148.0,0.283019,0.272727,0.277778,55.0,0.7232,0.595013,0.553016,0.565935,1250.0,0.726977,0.7232,0.723487,1250.0,9.5813,130.463,10.0
4,2.157789,0.833726,0.891551,0.86167,793.0,0.514286,0.666667,0.580645,108.0,0.472222,0.414634,0.441558,41.0,0.813953,0.411765,0.546875,85.0,0.666667,0.7,0.682927,20.0,0.483051,0.385135,0.428571,148.0,0.386364,0.309091,0.343434,55.0,0.7352,0.595753,0.539835,0.555097,1250.0,0.729048,0.7352,0.725262,1250.0,12.2108,102.368,10.0
avg,2.049506,0.854449,0.87232,0.862962,793.0,0.558702,0.656685,0.603014,108.5,0.53623,0.556765,0.544858,41.25,0.765852,0.6,0.663962,85.0,0.746053,0.6125,0.665077,20.0,0.451987,0.408565,0.428173,147.5,0.37825,0.348295,0.360808,55.25,0.742698,0.613075,0.579305,0.589836,1250.5,0.742062,0.742698,0.739779,1250.5,10.377475,121.656,10.0
std,0.153696,0.016811,0.028448,0.012622,0.0,0.036893,0.013568,0.018538,0.57735,0.08466,0.114463,0.096275,0.5,0.052686,0.127797,0.082278,0.0,0.082976,0.085391,0.028007,0.0,0.058379,0.031571,0.038387,0.57735,0.071506,0.07238,0.063869,0.5,0.020123,0.025538,0.039868,0.034164,0.57735,0.017049,0.020123,0.019597,0.57735,1.230474,12.988116,0.0


# Save

## Model

In [16]:
count = 0
for model in models:
    model.save_model(f"/data/experiments/raring/semantic_storytelling/data/model-subset/{model_checkpoint.replace(r'/', '-')}/epoch_{num_epoch}/fold_{count}")
    count += 1

## Metrics

In [17]:
eval_df.to_csv(f"data/eval-subset/{model_checkpoint.replace(r'/', '-')}_epoch_{num_epoch}.csv")