# Parameters 

In [1]:
# Model
model_checkpoint = 'bert-base-cased'
batch_size = 2
metric_name = "accuracy"
num_epoch = 5

# Fold
num_folds = 4

# Experiment
labels = ["none", "attribution", "causal", "conditional", "contrast", "description", "equivalence", "fulfillment", "identity", "purpose", "summary", "temporal"]

# Import

In [2]:
import torch
import numpy as np
import random
import pandas as pd
from IPython.display import display, HTML

In [3]:
def import_fold(path, fold):
    train = pd.read_csv(f"{path}/train.{fold}.csv")
    test = pd.read_csv(f"{path}/test.{fold}.csv")
    train_origin = train["origin"].tolist()
    train_target = train["target"].tolist()
    train_labels = train["label"].tolist()
    test_origin = test["origin"].tolist()
    test_target = test["target"].tolist()
    test_labels = test["label"].tolist()
    return train_origin, train_target, train_labels, test_origin, test_target, test_labels

# Model 

## Metric

In [4]:
from sklearn.metrics import classification_report
import collections

#classification_threshold = 0.

def flatten(d, parent_key='', sep='__'):
    items = []
    for k, v in d.items():
        new_key = parent_key + sep + k if parent_key else k
        if isinstance(v, collections.MutableMapping):
            items.extend(flatten(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)

def compute_metrics(eval_pred):
    global labels
    predictions, true_labels = eval_pred
    # take most probable guess
    predictions = np.argmax(predictions, axis=-1)
    return flatten(classification_report(
        y_true=true_labels,
        y_pred=predictions,
        target_names=labels,
        zero_division=0,
        output_dict=True))

In [5]:
#TEST
#flatten(classification_report(
#    y_true=[0,1,2,3,4,5,6,7,8,9,10,11,12],
#    y_pred=[0,0,0,1,3,0,0,0,0,0,0,0,0],
#    target_names=labels,
#    zero_division=0,
#    output_dict=True))

## Model Settings

In [6]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

args = TrainingArguments(
    "semantic-test",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=num_epoch,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
)

## Tokenize

In [7]:
from transformers import BertTokenizerFast, DebertaTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained(model_checkpoint)
#tokenizer = DebertaTokenizerFast.from_pretrained(model_checkpoint)

## Print Examples

In [8]:
#train_encodings

In [9]:
def show_random_elements(origin_list, target_list, label_list, encodings, num_examples=10):
    global labels
    assert num_examples <= len(origin_list), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(origin_list)-1)
        while pick in picks:
            pick = random.randint(0, len(origin_list)-1)
        picks.append(pick)
    data = []
    for n in picks:
        data.append([n, origin_list[n], labels[label_list[n]], target_list[n], encodings.input_ids[n], encodings.token_type_ids[n], encodings.attention_mask[n]])
    df = pd.DataFrame(data, columns=['index', 'Origin', 'Label', 'Target', 'Input_ids', 'Token_type_ids', 'Attention_mask'])
    display(HTML(df.to_html()))

In [10]:
# show_random_elements(train_origin, train_target, train_labels, train_encodings)
# Output adjustet to Folds
#show_random_elements(k_fold_origin[0][0], k_fold_target[0][0], k_fold_labels[0][0], train_encodings[0])

## Create Dataset

In [11]:
class SemanticDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

# Model

## Train & Evaluate

In [12]:
result = []
num_labels = len(labels)
models = []

for i in range(num_folds):
    model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)
    # import Fold data
    train_origin, train_target, train_labels, test_origin, test_target, test_labels = import_fold("data/export-ohnetime", i)
    # tokenize
    train_encodings = tokenizer(train_origin, train_target, truncation=True, padding=True, return_token_type_ids=True)
    test_encodings = tokenizer(test_origin, test_target, truncation=True, padding=True, return_token_type_ids=True)
    # dataset creation
    train_dataset = SemanticDataset(train_encodings, train_labels)
    test_dataset = SemanticDataset(test_encodings, test_labels)
    # create Trainer
    trainer = Trainer(
        model,
        args,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics
    )
    # train & evaluate
    trainer.train()
    ev = trainer.evaluate(test_dataset)
    acc = ev["eval_accuracy"]
    print(f"Accuracy: {acc}")
    result.append(ev)
    models.append(trainer)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

Epoch,Training Loss,Validation Loss,None Precision,None Recall,None F1-score,None Support,Attribution Precision,Attribution Recall,Attribution F1-score,Attribution Support,Causal Precision,Causal Recall,Causal F1-score,Causal Support,Conditional Precision,Conditional Recall,Conditional F1-score,Conditional Support,Contrast Precision,Contrast Recall,Contrast F1-score,Contrast Support,Description Precision,Description Recall,Description F1-score,Description Support,Equivalence Precision,Equivalence Recall,Equivalence F1-score,Equivalence Support,Fulfillment Precision,Fulfillment Recall,Fulfillment F1-score,Fulfillment Support,Identity Precision,Identity Recall,Identity F1-score,Identity Support,Purpose Precision,Purpose Recall,Purpose F1-score,Purpose Support,Summary Precision,Summary Recall,Summary F1-score,Summary Support,Temporal Precision,Temporal Recall,Temporal F1-score,Temporal Support,Accuracy,Macro avg Precision,Macro avg Recall,Macro avg F1-score,Macro avg Support,Weighted avg Precision,Weighted avg Recall,Weighted avg F1-score,Weighted avg Support
1,1.2565,1.223622,0.722593,0.955864,0.823018,793,0.0,0.0,0.0,6,0.11215,0.110092,0.111111,109,0.0,0.0,0.0,14,0.0,0.0,0.0,42,0.0,0.0,0.0,13,0.181818,0.023529,0.041667,85,0.0,0.0,0.0,11,0.0,0.0,0.0,20,0.0,0.0,0.0,5,0.0,0.0,0.0,6,0.25,0.142857,0.181818,147,0.633893,0.105547,0.102695,0.096468,1251,0.509548,0.633893,0.555582,1251
2,0.9951,1.271319,0.84129,0.822194,0.831633,793,0.0,0.0,0.0,6,0.290323,0.412844,0.340909,109,0.0,0.0,0.0,14,0.0,0.0,0.0,42,0.0,0.0,0.0,13,0.27,0.635294,0.378947,85,0.0,0.0,0.0,11,0.0,0.0,0.0,20,0.0,0.0,0.0,5,0.0,0.0,0.0,6,0.396694,0.326531,0.358209,147,0.638689,0.149859,0.183072,0.159142,1251,0.623543,0.638689,0.624709,1251
3,0.6784,1.434299,0.822542,0.865069,0.84327,793,0.0,0.0,0.0,6,0.335714,0.431193,0.37751,109,0.0,0.0,0.0,14,0.0,0.0,0.0,42,0.0,0.0,0.0,13,0.339623,0.635294,0.442623,85,0.0,0.0,0.0,11,0.0,0.0,0.0,20,0.0,0.0,0.0,5,0.0,0.0,0.0,6,0.381356,0.306122,0.339623,147,0.665068,0.156603,0.186473,0.166919,1251,0.618542,0.665068,0.637417,1251
4,0.5555,1.478074,0.839853,0.86633,0.852886,793,0.0,0.0,0.0,6,0.402985,0.495413,0.444444,109,0.0,0.0,0.0,14,0.333333,0.02381,0.044444,42,0.0,0.0,0.0,13,0.416107,0.729412,0.529915,85,0.0,0.0,0.0,11,0.434783,0.5,0.465116,20,0.0,0.0,0.0,5,0.0,0.0,0.0,6,0.41129,0.346939,0.376384,147,0.691447,0.236529,0.246825,0.226099,1251,0.662233,0.691447,0.668524,1251
5,0.3642,1.597566,0.841388,0.856242,0.84875,793,0.0,0.0,0.0,6,0.377358,0.550459,0.447761,109,0.25,0.071429,0.111111,14,0.571429,0.095238,0.163265,42,0.0,0.0,0.0,13,0.435185,0.552941,0.487047,85,0.0,0.0,0.0,11,0.434783,0.5,0.465116,20,0.0,0.0,0.0,5,0.0,0.0,0.0,6,0.391608,0.380952,0.386207,147,0.685052,0.275146,0.250605,0.242438,1251,0.670748,0.685052,0.669665,1251


  if isinstance(v, collections.MutableMapping):


Accuracy: 0.6914468425259792


Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

Epoch,Training Loss,Validation Loss,None Precision,None Recall,None F1-score,None Support,Attribution Precision,Attribution Recall,Attribution F1-score,Attribution Support,Causal Precision,Causal Recall,Causal F1-score,Causal Support,Conditional Precision,Conditional Recall,Conditional F1-score,Conditional Support,Contrast Precision,Contrast Recall,Contrast F1-score,Contrast Support,Description Precision,Description Recall,Description F1-score,Description Support,Equivalence Precision,Equivalence Recall,Equivalence F1-score,Equivalence Support,Fulfillment Precision,Fulfillment Recall,Fulfillment F1-score,Fulfillment Support,Identity Precision,Identity Recall,Identity F1-score,Identity Support,Purpose Precision,Purpose Recall,Purpose F1-score,Purpose Support,Summary Precision,Summary Recall,Summary F1-score,Summary Support,Temporal Precision,Temporal Recall,Temporal F1-score,Temporal Support,Accuracy,Macro avg Precision,Macro avg Recall,Macro avg F1-score,Macro avg Support,Weighted avg Precision,Weighted avg Recall,Weighted avg F1-score,Weighted avg Support
1,1.2129,1.084501,0.721956,0.949559,0.820261,793,0.0,0.0,0.0,6,0.195122,0.293578,0.234432,109,0.0,0.0,0.0,14,0.0,0.0,0.0,41,0.0,0.0,0.0,13,0.0,0.0,0.0,85,0.0,0.0,0.0,11,0.0,0.0,0.0,20,0.0,0.0,0.0,6,0.0,0.0,0.0,6,0.363636,0.108844,0.167539,147,0.640288,0.106726,0.112665,0.101853,1251,0.517373,0.640288,0.560071,1251
2,1.0431,1.064208,0.798165,0.87768,0.836036,793,0.0,0.0,0.0,6,0.324074,0.321101,0.322581,109,0.0,0.0,0.0,14,0.0,0.0,0.0,41,0.0,0.0,0.0,13,0.30597,0.482353,0.374429,85,0.0,0.0,0.0,11,0.0,0.0,0.0,20,0.0,0.0,0.0,6,0.0,0.0,0.0,6,0.445255,0.414966,0.429577,147,0.665867,0.156122,0.174675,0.163552,1251,0.607297,0.665867,0.633983,1251
3,0.8876,1.290893,0.871728,0.839849,0.855491,793,0.0,0.0,0.0,6,0.375887,0.486239,0.424,109,0.0,0.0,0.0,14,0.0,0.0,0.0,41,0.0,0.0,0.0,13,0.36478,0.682353,0.47541,85,0.0,0.0,0.0,11,0.0,0.0,0.0,20,0.0,0.0,0.0,6,0.0,0.0,0.0,6,0.42246,0.537415,0.473054,147,0.684253,0.169571,0.212155,0.185663,1251,0.65976,0.684253,0.667122,1251
4,0.6098,1.543676,0.836735,0.878941,0.857319,793,0.0,0.0,0.0,6,0.367021,0.633028,0.464646,109,0.0,0.0,0.0,14,0.0,0.0,0.0,41,0.0,0.0,0.0,13,0.40404,0.470588,0.434783,85,0.0,0.0,0.0,11,1.0,0.1,0.181818,20,0.0,0.0,0.0,6,0.0,0.0,0.0,6,0.387597,0.340136,0.362319,147,0.685851,0.249616,0.201891,0.19174,1251,0.651364,0.685851,0.658956,1251
5,0.5061,1.696877,0.843521,0.870113,0.856611,793,0.0,0.0,0.0,6,0.342857,0.550459,0.422535,109,0.0,0.0,0.0,14,0.0,0.0,0.0,41,0.0,0.0,0.0,13,0.493506,0.447059,0.469136,85,0.0,0.0,0.0,11,0.5,0.55,0.52381,20,0.0,0.0,0.0,6,0.0,0.0,0.0,6,0.375796,0.401361,0.388158,147,0.685851,0.212973,0.234916,0.221687,1251,0.650259,0.685851,0.665676,1251


Accuracy: 0.6858513189448441


Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

Epoch,Training Loss,Validation Loss,None Precision,None Recall,None F1-score,None Support,Attribution Precision,Attribution Recall,Attribution F1-score,Attribution Support,Causal Precision,Causal Recall,Causal F1-score,Causal Support,Conditional Precision,Conditional Recall,Conditional F1-score,Conditional Support,Contrast Precision,Contrast Recall,Contrast F1-score,Contrast Support,Description Precision,Description Recall,Description F1-score,Description Support,Equivalence Precision,Equivalence Recall,Equivalence F1-score,Equivalence Support,Fulfillment Precision,Fulfillment Recall,Fulfillment F1-score,Fulfillment Support,Identity Precision,Identity Recall,Identity F1-score,Identity Support,Purpose Precision,Purpose Recall,Purpose F1-score,Purpose Support,Summary Precision,Summary Recall,Summary F1-score,Summary Support,Temporal Precision,Temporal Recall,Temporal F1-score,Temporal Support,Accuracy,Macro avg Precision,Macro avg Recall,Macro avg F1-score,Macro avg Support,Weighted avg Precision,Weighted avg Recall,Weighted avg F1-score,Weighted avg Support
1,1.2983,1.241064,0.669795,0.989912,0.798982,793,0.0,0.0,0.0,5,0.0,0.0,0.0,108,0.0,0.0,0.0,14,0.0,0.0,0.0,41,0.0,0.0,0.0,13,0.294872,0.270588,0.282209,85,0.0,0.0,0.0,11,0.0,0.0,0.0,20,0.0,0.0,0.0,6,0.0,0.0,0.0,6,0.0,0.0,0.0,148,0.6464,0.080389,0.105042,0.090099,1250,0.444969,0.6464,0.526064,1250
2,1.0758,1.309541,0.808962,0.865069,0.836076,793,0.0,0.0,0.0,5,0.291139,0.212963,0.245989,108,0.0,0.0,0.0,14,0.0,0.0,0.0,41,0.0,0.0,0.0,13,0.276316,0.741176,0.402556,85,0.0,0.0,0.0,11,0.0,0.0,0.0,20,0.0,0.0,0.0,6,0.0,0.0,0.0,6,0.315789,0.202703,0.246914,148,0.6416,0.141017,0.168493,0.144295,1250,0.594539,0.6416,0.608268,1250
3,0.8579,1.400006,0.836735,0.827238,0.831959,793,0.0,0.0,0.0,5,0.29661,0.324074,0.309735,108,0.0,0.0,0.0,14,0.0,0.0,0.0,41,0.0,0.0,0.0,13,0.397959,0.458824,0.42623,85,0.0,0.0,0.0,11,0.0,0.0,0.0,20,0.0,0.0,0.0,6,0.0,0.0,0.0,6,0.308,0.52027,0.386935,148,0.6456,0.153275,0.177534,0.162905,1250,0.61998,0.6456,0.629353,1250
4,0.6048,1.723444,0.832099,0.849937,0.840923,793,0.0,0.0,0.0,5,0.352941,0.444444,0.393443,108,0.0,0.0,0.0,14,0.2,0.02439,0.043478,41,0.0,0.0,0.0,13,0.404959,0.576471,0.475728,85,0.0,0.0,0.0,11,0.3125,0.25,0.277778,20,0.0,0.0,0.0,6,0.0,0.0,0.0,6,0.345679,0.378378,0.36129,148,0.6664,0.204015,0.210302,0.199387,1250,0.638403,0.6664,0.648472,1250
5,0.4243,1.870731,0.852523,0.831021,0.841635,793,0.0,0.0,0.0,5,0.327778,0.546296,0.409722,108,0.0,0.0,0.0,14,0.625,0.121951,0.204082,41,0.0,0.0,0.0,13,0.473118,0.517647,0.494382,85,0.0,0.0,0.0,11,0.310345,0.45,0.367347,20,0.0,0.0,0.0,6,0.0,0.0,0.0,6,0.347305,0.391892,0.368254,148,0.6672,0.244672,0.238234,0.223785,1250,0.667919,0.6672,0.659124,1250


Accuracy: 0.6672


Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

Epoch,Training Loss,Validation Loss,None Precision,None Recall,None F1-score,None Support,Attribution Precision,Attribution Recall,Attribution F1-score,Attribution Support,Causal Precision,Causal Recall,Causal F1-score,Causal Support,Conditional Precision,Conditional Recall,Conditional F1-score,Conditional Support,Contrast Precision,Contrast Recall,Contrast F1-score,Contrast Support,Description Precision,Description Recall,Description F1-score,Description Support,Equivalence Precision,Equivalence Recall,Equivalence F1-score,Equivalence Support,Fulfillment Precision,Fulfillment Recall,Fulfillment F1-score,Fulfillment Support,Identity Precision,Identity Recall,Identity F1-score,Identity Support,Purpose Precision,Purpose Recall,Purpose F1-score,Purpose Support,Summary Precision,Summary Recall,Summary F1-score,Summary Support,Temporal Precision,Temporal Recall,Temporal F1-score,Temporal Support,Accuracy,Macro avg Precision,Macro avg Recall,Macro avg F1-score,Macro avg Support,Weighted avg Precision,Weighted avg Recall,Weighted avg F1-score,Weighted avg Support
1,1.3381,1.201117,0.739979,0.907945,0.815402,793,0.0,0.0,0.0,6,0.083333,0.009259,0.016667,108,0.0,0.0,0.0,14,0.0,0.0,0.0,41,0.0,0.0,0.0,14,0.202381,0.4,0.268775,85,0.0,0.0,0.0,10,0.0,0.0,0.0,20,0.0,0.0,0.0,5,0.0,0.0,0.0,6,0.381443,0.25,0.302041,148,0.6336,0.117261,0.1306,0.116907,1250,0.535568,0.6336,0.572769,1250
2,1.1451,1.275135,0.827543,0.84111,0.834271,793,0.0,0.0,0.0,6,0.277778,0.324074,0.299145,108,0.0,0.0,0.0,14,0.0,0.0,0.0,41,0.0,0.0,0.0,14,0.243478,0.329412,0.28,85,0.0,0.0,0.0,10,0.0,0.0,0.0,20,0.0,0.0,0.0,5,0.0,0.0,0.0,6,0.384236,0.527027,0.444444,148,0.6464,0.14442,0.168469,0.154822,1250,0.611044,0.6464,0.62677,1250
3,0.9185,1.364766,0.841902,0.825977,0.833864,793,0.0,0.0,0.0,6,0.335484,0.481481,0.395437,108,0.0,0.0,0.0,14,0.0,0.0,0.0,41,0.0,0.0,0.0,14,0.301205,0.588235,0.398406,85,0.0,0.0,0.0,10,0.0,0.0,0.0,20,0.0,0.0,0.0,5,0.0,0.0,0.0,6,0.403974,0.412162,0.408027,148,0.6544,0.15688,0.192321,0.169645,1250,0.631401,0.6544,0.638571,1250
4,0.7565,1.570095,0.835411,0.844893,0.840125,793,0.0,0.0,0.0,6,0.306878,0.537037,0.390572,108,0.0,0.0,0.0,14,0.0,0.0,0.0,41,0.0,0.0,0.0,14,0.355372,0.505882,0.417476,85,0.0,0.0,0.0,10,0.0,0.0,0.0,20,0.0,0.0,0.0,5,0.0,0.0,0.0,6,0.405797,0.378378,0.391608,148,0.6616,0.158622,0.188849,0.169982,1250,0.628711,0.6616,0.641476,1250
5,0.5618,1.74383,0.832099,0.849937,0.840923,793,0.0,0.0,0.0,6,0.319767,0.509259,0.392857,108,0.0,0.0,0.0,14,0.0,0.0,0.0,41,0.0,0.0,0.0,14,0.363636,0.517647,0.427184,85,0.0,0.0,0.0,10,1.0,0.1,0.181818,20,0.0,0.0,0.0,5,0.0,0.0,0.0,6,0.37931,0.371622,0.375427,148,0.664,0.241234,0.195705,0.184851,1250,0.641149,0.664,0.643833,1250


Accuracy: 0.664


## Interpret evaluation

### Helper functions

In [13]:
def mean(data):
    """Return the sample arithmetic mean of data."""
    n = len(data)
    if n < 1:
        raise ValueError('mean requires at least one data point')
    return sum(data)/n # in Python 2 use sum(data)/float(n)

def _ss(data):
    """Return sum of square deviations of sequence data."""
    c = mean(data)
    ss = sum((x-c)**2 for x in data)
    return ss

def stddev(data, ddof=0):
    """Calculates the population standard deviation
    by default; specify ddof=1 to compute the sample
    standard deviation."""
    n = len(data)
    if n < 2:
        raise ValueError('variance requires at least two data points')
    ss = _ss(data)
    pvar = ss/(n-ddof)
    return pvar**0.5

### Prepare Data

In [14]:
result

[{'eval_loss': 1.478074312210083,
  'eval_none__precision': 0.8398533007334963,
  'eval_none__recall': 0.8663303909205549,
  'eval_none__f1-score': 0.8528864059590316,
  'eval_none__support': 793,
  'eval_attribution__precision': 0.0,
  'eval_attribution__recall': 0.0,
  'eval_attribution__f1-score': 0.0,
  'eval_attribution__support': 6,
  'eval_causal__precision': 0.40298507462686567,
  'eval_causal__recall': 0.4954128440366973,
  'eval_causal__f1-score': 0.4444444444444445,
  'eval_causal__support': 109,
  'eval_conditional__precision': 0.0,
  'eval_conditional__recall': 0.0,
  'eval_conditional__f1-score': 0.0,
  'eval_conditional__support': 14,
  'eval_contrast__precision': 0.3333333333333333,
  'eval_contrast__recall': 0.023809523809523808,
  'eval_contrast__f1-score': 0.044444444444444446,
  'eval_contrast__support': 42,
  'eval_description__precision': 0.0,
  'eval_description__recall': 0.0,
  'eval_description__f1-score': 0.0,
  'eval_description__support': 13,
  'eval_equival

In [15]:

def transform_to_regular_dict(result):
    output_dict = {}
    count = 0
    for eval_item in result:
        for key in eval_item:
            if count == 0:
              output_dict[key] = [float(eval_item[key])]
            else:
              output_dict[key].append(eval_item[key]) 
        count += 1
    return output_dict
            
eval_dict = transform_to_regular_dict(result)
eval_df = pd.DataFrame(eval_dict)

def add_mean_std_row(df):
    row_mean = []
    row_std = []
    for column in df:
        row_mean.append(mean(df[column]))
        row_std.append(stddev(df[column], ddof=1))
    df = df.append(pd.DataFrame([row_mean], columns=df.columns), ignore_index=True)
    df = df.append(pd.DataFrame([row_std], columns=df.columns), ignore_index=True)
    # add better readable Index
    df["fold"] = ["1", "2", "3", "4", "avg", "std"]
    df = df.set_index("fold")
    return df

eval_df = add_mean_std_row(eval_df)
display(HTML(eval_df.to_html()))

Unnamed: 0_level_0,eval_loss,eval_none__precision,eval_none__recall,eval_none__f1-score,eval_none__support,eval_attribution__precision,eval_attribution__recall,eval_attribution__f1-score,eval_attribution__support,eval_causal__precision,eval_causal__recall,eval_causal__f1-score,eval_causal__support,eval_conditional__precision,eval_conditional__recall,eval_conditional__f1-score,eval_conditional__support,eval_contrast__precision,eval_contrast__recall,eval_contrast__f1-score,eval_contrast__support,eval_description__precision,eval_description__recall,eval_description__f1-score,eval_description__support,eval_equivalence__precision,eval_equivalence__recall,eval_equivalence__f1-score,eval_equivalence__support,eval_fulfillment__precision,eval_fulfillment__recall,eval_fulfillment__f1-score,eval_fulfillment__support,eval_identity__precision,eval_identity__recall,eval_identity__f1-score,eval_identity__support,eval_purpose__precision,eval_purpose__recall,eval_purpose__f1-score,eval_purpose__support,eval_summary__precision,eval_summary__recall,eval_summary__f1-score,eval_summary__support,eval_temporal__precision,eval_temporal__recall,eval_temporal__f1-score,eval_temporal__support,eval_accuracy,eval_macro avg__precision,eval_macro avg__recall,eval_macro avg__f1-score,eval_macro avg__support,eval_weighted avg__precision,eval_weighted avg__recall,eval_weighted avg__f1-score,eval_weighted avg__support,eval_runtime,eval_samples_per_second,epoch
fold,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1,Unnamed: 51_level_1,Unnamed: 52_level_1,Unnamed: 53_level_1,Unnamed: 54_level_1,Unnamed: 55_level_1,Unnamed: 56_level_1,Unnamed: 57_level_1,Unnamed: 58_level_1,Unnamed: 59_level_1,Unnamed: 60_level_1,Unnamed: 61_level_1
1,1.478074,0.839853,0.86633,0.852886,793.0,0.0,0.0,0.0,6.0,0.402985,0.495413,0.444444,109.0,0.0,0.0,0.0,14.0,0.333333,0.02381,0.044444,42.0,0.0,0.0,0.0,13.0,0.416107,0.729412,0.529915,85.0,0.0,0.0,0.0,11.0,0.434783,0.5,0.465116,20.0,0.0,0.0,0.0,5.0,0.0,0.0,0.0,6.0,0.41129,0.346939,0.376384,147.0,0.691447,0.236529,0.246825,0.226099,1251.0,0.662233,0.691447,0.668524,1251.0,9.7601,128.175,5.0
2,1.543676,0.836735,0.878941,0.857319,793.0,0.0,0.0,0.0,6.0,0.367021,0.633028,0.464646,109.0,0.0,0.0,0.0,14.0,0.0,0.0,0.0,41.0,0.0,0.0,0.0,13.0,0.40404,0.470588,0.434783,85.0,0.0,0.0,0.0,11.0,1.0,0.1,0.181818,20.0,0.0,0.0,0.0,6.0,0.0,0.0,0.0,6.0,0.387597,0.340136,0.362319,147.0,0.685851,0.249616,0.201891,0.19174,1251.0,0.651364,0.685851,0.658956,1251.0,9.601,130.299,5.0
3,1.870731,0.852523,0.831021,0.841635,793.0,0.0,0.0,0.0,5.0,0.327778,0.546296,0.409722,108.0,0.0,0.0,0.0,14.0,0.625,0.121951,0.204082,41.0,0.0,0.0,0.0,13.0,0.473118,0.517647,0.494382,85.0,0.0,0.0,0.0,11.0,0.310345,0.45,0.367347,20.0,0.0,0.0,0.0,6.0,0.0,0.0,0.0,6.0,0.347305,0.391892,0.368254,148.0,0.6672,0.244672,0.238234,0.223785,1250.0,0.667919,0.6672,0.659124,1250.0,9.7479,128.232,5.0
4,1.74383,0.832099,0.849937,0.840923,793.0,0.0,0.0,0.0,6.0,0.319767,0.509259,0.392857,108.0,0.0,0.0,0.0,14.0,0.0,0.0,0.0,41.0,0.0,0.0,0.0,14.0,0.363636,0.517647,0.427184,85.0,0.0,0.0,0.0,10.0,1.0,0.1,0.181818,20.0,0.0,0.0,0.0,5.0,0.0,0.0,0.0,6.0,0.37931,0.371622,0.375427,148.0,0.664,0.241234,0.195705,0.184851,1250.0,0.641149,0.664,0.643833,1250.0,12.4092,100.731,5.0
avg,1.659078,0.840302,0.856557,0.848191,793.0,0.0,0.0,0.0,5.75,0.354388,0.545999,0.427918,108.5,0.0,0.0,0.0,14.0,0.239583,0.03644,0.062132,41.25,0.0,0.0,0.0,13.25,0.414226,0.558824,0.471566,85.0,0.0,0.0,0.0,10.75,0.686282,0.2875,0.299025,20.0,0.0,0.0,0.0,5.5,0.0,0.0,0.0,6.0,0.381376,0.362647,0.370596,147.5,0.677125,0.243013,0.220664,0.206619,1250.5,0.655666,0.677125,0.657609,1250.5,10.37955,121.85925,5.0
std,0.180794,0.008748,0.020756,0.008189,0.0,0.0,0.0,0.0,0.5,0.038419,0.061868,0.03257,0.57735,0.0,0.0,0.0,0.0,0.301184,0.058102,0.096925,0.5,0.0,0.0,0.0,0.5,0.045221,0.115869,0.049152,0.0,0.0,0.0,0.0,0.5,0.365795,0.217466,0.141102,0.0,0.0,0.0,0.0,0.57735,0.0,0.0,0.0,0.0,0.026449,0.02373,0.006604,0.57735,0.013565,0.005524,0.025616,0.021365,0.57735,0.011867,0.013565,0.010215,0.57735,1.35503,14.120115,0.0


# Save

## Model

In [16]:
count = 0
for model in models:
    model.save_model(f"/data/experiments/raring/semantic_storytelling/data/model-ohnealles/{model_checkpoint.replace(r'/', '-')}/epoch_{num_epoch}/fold_{count}")
    count += 1

## Metrics

In [17]:
eval_df.to_csv(f"data/eval-ohnealles/{model_checkpoint.replace(r'/', '-')}_epoch_{num_epoch}.csv")