In [None]:
%pip install scikit-learn
%pip install datasets
%pip install transformers
%pip install transformers[torch]
%pip install evaluate
%pip install pandas==2.0.3
%pip install torch
%pip install joblib
%pip install tqdm
%pip install progressbar
%pip install seaborn

In [None]:
# this cell is required for running in google collab VM
import os
if os.getenv("COLAB_RELEASE_TAG"):
    print("Running in Colab")
    import sys
    from google.colab import drive
    drive.mount('/content/drive/')
    sys.path.append('/content/drive/')
    %cd /content/drive/MyDrive/Faks/research_uiktp
else:
   print("NOT in Colab")

In [3]:
import joblib
import torch
import sklearn
import evaluate
import numpy as np
import pandas as pd
import seaborn as sns
from train_model import softmax, validate, predict_durations_for_tokenized_tensor_inputs as run_prediction
from get_task_durations import plot_durations_histogram
from data_utils import rename_columns, get_global_constants, balance_dataframe
from datasets import DatasetDict, Dataset
from make_dataset import split_dataset
from transformers import BertTokenizer, BertModel, AutoModelForSequenceClassification, TrainingArguments, Trainer

In [None]:
GLOBAL_CONSTANTS = get_global_constants()
print(GLOBAL_CONSTANTS)

In [None]:
# this cell is required only when the processed dataset is not saved on the path provided in GLOBAL_CONSTANTS
from make_dataset import get_jira_tasks
generated_dataframe = get_jira_tasks()
print(generated_dataframe)

In [None]:
dataframe = pd.read_csv(GLOBAL_CONSTANTS.CSV_DATASET_PATH)
dataframe = rename_columns(dataframe)

print(dataframe)
plot_durations_histogram(dataframe, column_name='label')

dataframe = balance_dataframe(dataframe)
dataframe = dataframe.sample(frac=1, random_state=GLOBAL_CONSTANTS.RANDOM_STATE).reset_index(drop=True)

print(dataframe)
plot_durations_histogram(dataframe, column_name='label')

In [None]:
train_set, test_set, validation_set = split_dataset(dataframe, train_set_length=.8, test_set_length=.1, validation_set_length=.1, axis=0)
print(train_set)
print(test_set)
print(validation_set)

In [None]:
dataset = DatasetDict(
    {
        "train":Dataset.from_dict(train_set.to_dict('list')),
        "test":Dataset.from_dict(test_set.to_dict('list')),
        "validation":Dataset.from_dict(validation_set.to_dict('list'))
    }
)
print(dataset)

In [None]:
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

def tokenize_function(jira_tasks, column_name="text"):
    ret = tokenizer(jira_tasks[column_name], padding="max_length", truncation=True)
    return ret

tokenized_datasets = dataset.map(tokenize_function, batched=True)
train_set = tokenized_datasets["train"]
test_set = tokenized_datasets["test"]
validation_set = tokenized_datasets["validation"]
print(train_set)
print(test_set)
print(validation_set)

In [None]:
metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

In [None]:
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=11)

model.to(GLOBAL_CONSTANTS.DEVICE)

training_args = TrainingArguments(
    output_dir="training_logs",
    evaluation_strategy="epoch",
    num_train_epochs=4,
    learning_rate=2.5e-5,
    save_steps=3250,
    save_total_limit=1
    )
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_set,
    eval_dataset=test_set,
    compute_metrics=compute_metrics
)
trainer.train()
print("Training finished")

In [None]:
torch.save(model.state_dict(), GLOBAL_CONSTANTS.MODEL_PATH)
print("Model serialized")

In [None]:
loaded_model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=11)
loaded_model.load_state_dict(torch.load(GLOBAL_CONSTANTS.MODEL_PATH, map_location=GLOBAL_CONSTANTS.DEVICE_STRING))
print("Model loaded")

In [None]:
loaded_model.eval()
VALIDATION_SAMPLE_SIZE = 16
input_text = validation_set['text']#[:VALIDATION_SAMPLE_SIZE]
input_ids = validation_set['input_ids']#[:VALIDATION_SAMPLE_SIZE]
input_masks = validation_set['attention_mask']#[:VALIDATION_SAMPLE_SIZE]
true_durations = validation_set['label']#[:VALIDATION_SAMPLE_SIZE]
input_ids = torch.tensor(input_ids)
input_masks = torch.tensor(input_masks)
print("Input sample (of type {}): {}".format(type(input_text), input_text))
print("Input ids (of type {}): {}".format(type(input_ids), input_ids))
print("Input masks (of type {}): {}".format(type(input_masks), input_masks))
predicted_durations = run_prediction(loaded_model, input_ids, input_masks)
print("Predicted duration: {}".format(predicted_durations))
print("Expected durations: {}".format(true_durations))

In [None]:
#predicted_durations = [7, 10, 10, 2, 9, 1, 10, 8, 2, 10, 1, 3, 1, 4, 8, 1, 8, 8, 2, 8, 10, 9, 7, 4, 8, 10, 5, 10, 1, 8, 8, 8, 8, 10, 2, 2, 8, 10, 7, 4, 1, 7, 5, 4, 8, 10, 1, 1, 2, 8, 8, 10, 9, 10, 1, 8, 10, 8, 2, 7, 5, 1, 2, 8, 8, 8, 10, 8, 7, 7, 8, 7, 9, 1, 7, 1, 6, 8, 10, 8, 10, 8, 7, 9, 7, 7, 8, 1, 5, 10, 8, 8, 7, 4, 1, 7, 7, 2, 5, 8, 8, 2, 10, 1, 1, 10, 1, 4, 2, 7, 10, 1, 9, 4, 8, 1, 1, 4, 8, 1, 1, 8, 10, 1, 8, 10, 9, 1, 8, 8, 1, 8, 7, 10, 10, 1, 4, 8, 1, 7, 9, 8, 10, 7, 5, 5, 1, 1, 9, 10, 7, 10, 8, 10, 10, 10, 7, 7, 8, 7, 9, 8, 7, 8, 10, 8, 5, 5, 1, 7, 7, 4, 4, 5, 8, 8, 7, 1, 8, 7, 3, 1, 2, 10, 8, 9, 10, 8, 10, 10, 10, 8, 10, 8, 8, 8, 9, 2, 10, 5, 1, 2, 10, 9, 7, 1, 8, 2, 7, 8, 8, 6, 1, 3, 5, 10, 7, 8, 1, 8, 7, 8, 4, 4, 1, 10, 8, 10, 8, 7, 3, 8, 7, 8, 8, 10, 1, 1, 8, 7, 8, 1, 10, 1, 8, 8, 1, 3, 1, 9, 7, 1, 8, 5, 10, 1, 10, 2, 10, 7, 3, 8, 7, 8, 1, 10, 3, 9, 10, 8, 7, 1, 1, 1, 1, 7, 1, 8, 3, 1, 7, 8, 9, 10, 2, 10, 10, 5, 7, 8, 4, 4, 7, 5, 9, 2, 5, 8, 2, 1, 1, 7, 8, 10, 2, 10, 1, 1, 7, 8, 10, 1, 8, 5, 10, 5, 7, 8, 7, 10, 5, 4, 7, 7, 8, 1, 1, 1, 8, 1, 8, 10, 10, 1, 1, 5, 9, 4, 8, 1, 10, 9, 2, 7, 9, 8, 3, 8, 10, 9, 8, 7, 8, 4, 7, 8, 8, 8, 1, 9, 9, 8, 7, 7, 1, 4, 5, 8, 1, 8, 7, 1, 2, 8, 1, 1, 1, 10, 7, 8, 10, 1, 1, 8, 1, 10, 8, 7, 8, 8, 1, 2, 4, 8, 8, 7, 1, 8, 1, 7, 4, 6, 10, 8, 8, 4, 8, 4, 4, 1, 2, 1, 2, 8, 7, 9, 2, 9, 9, 10, 8, 8, 1, 8, 8, 6, 1, 1, 8, 7, 2, 10, 5, 8, 5, 8, 7, 10, 8, 2, 9, 1, 5, 1, 1, 8, 10, 1, 2, 8, 8, 1, 7, 6, 8, 8, 8, 10, 10, 4, 5, 7, 8, 7, 10, 8, 7, 2, 10, 2, 8, 2, 8, 10, 10, 10, 1, 10, 8, 1, 8, 8, 10, 8, 4, 1, 7, 8, 8, 8, 1, 8, 4, 1, 10, 8, 9, 10, 8, 8, 1, 1, 10, 10, 8, 7, 9, 10, 4, 10, 10, 10, 8, 1, 7, 1, 8, 8, 1, 10, 2, 8, 10, 1, 8, 4, 8, 8, 3, 1, 8, 10, 5, 8, 4, 7, 7, 2, 8, 8, 3, 8, 8, 5, 7, 2, 10, 4, 7, 1, 1, 9, 4, 7, 1, 8, 10, 7, 1, 8, 3, 9, 7, 9, 10, 7, 3, 8, 9, 2, 10, 1, 10, 10, 10, 4, 7, 8, 1, 1, 7, 4, 1, 4, 7, 10, 7, 6, 7, 1, 1, 9, 1, 10, 10, 9, 10, 8, 2, 5, 10, 2, 8, 8, 1, 8, 1, 8, 7, 2, 10, 7, 1, 1, 10, 8, 8, 10, 10, 1, 7, 10, 1, 2, 7, 8, 8, 10, 10, 10, 1, 10, 1, 6, 1, 8, 7, 10, 1, 7, 4, 3, 10, 2, 2, 8, 7, 7, 10, 1, 7, 1, 4, 8, 8, 7, 10, 2, 10, 4, 1, 7, 8, 1, 5, 8, 7, 10, 1, 1, 4, 8, 1, 5, 4, 9, 8, 1, 8, 7, 1, 10, 1, 1, 8, 8, 9, 8, 9, 4, 10, 1, 7, 8, 8, 8, 8, 7, 10, 1, 8, 7, 1, 3, 8, 10, 9, 10, 8, 1, 8, 8, 4, 7, 8, 8, 10, 7, 2, 9, 10, 10, 5, 10, 7, 4, 7, 1, 5, 1, 9, 8, 7, 8, 4, 8, 7, 10, 8, 8, 2, 10, 8, 1, 2, 10, 5, 4, 7, 8, 10, 8, 5, 10, 6, 5, 1, 2, 2, 1, 7, 7, 7, 4, 1, 8, 1, 9, 2, 1, 6, 3, 8, 1, 10, 10, 4, 10, 3, 8, 10, 2, 10, 8, 7, 10, 10, 10, 4, 7, 10, 8, 2, 7, 7, 3, 1, 8, 8, 8, 7, 10, 9, 1, 7, 8, 7, 7, 3, 10, 8, 5, 8, 2, 1, 10, 6, 1, 10, 7, 8, 7, 7, 10, 8, 8, 9, 7, 1, 10, 1, 9, 8, 9, 8, 10, 2, 8, 10, 10, 9, 10, 8, 8, 4, 8, 7, 8, 1, 5, 3, 1, 10, 1, 1, 1, 3, 8, 7, 7, 7, 7, 10, 1, 8, 8, 1, 5, 10, 2, 7, 1, 7, 4, 7, 8, 1, 9, 7, 7, 10, 1, 1, 8, 3, 1, 10, 1, 1, 1, 7, 1, 4, 5, 9, 7, 2, 10, 8, 1, 8, 7, 10, 1, 9, 7, 1, 1, 4, 10, 2, 10, 10, 8, 9, 4, 8, 1, 9, 4, 1, 1, 4, 4, 1, 10, 2, 8, 8, 8, 10, 1, 8, 1, 5, 10, 2, 1, 8, 8, 7, 8, 8, 7, 8, 10, 1, 2, 7, 10, 8, 7, 8, 7, 8, 10, 4, 2, 8, 10, 5, 1, 1, 10, 8, 8, 1, 10, 9, 2, 10, 10, 2, 1, 8, 5, 10, 10, 5, 8, 6, 8, 10, 5, 1, 4, 8, 1, 1, 7, 7, 1, 8, 1, 2, 9, 1, 5, 7, 8, 1, 2, 1, 5, 4, 7, 1, 8, 10, 7, 2, 2, 2, 2, 5, 3, 10, 7, 7, 2, 8, 8, 8, 5, 10, 1, 6, 7, 1, 10, 10, 8, 10, 8, 8, 7, 1, 8, 10, 7, 8, 7, 1, 8, 9, 7, 1, 10, 5, 1, 7, 8, 2, 8, 10, 8, 1, 5, 4, 1, 1, 8, 10, 7, 9, 9, 9, 1, 8, 9, 8, 10, 9, 10, 8, 10, 4, 8, 10, 1, 8, 1, 1, 1, 10, 7, 8, 5, 10, 3, 4, 10, 2, 5, 8, 8, 1, 10, 10, 10, 5, 1, 8, 7, 1, 8, 8, 3, 10, 1, 10, 8, 7, 10, 7, 1, 7, 1, 5, 8, 10, 8, 1, 10, 8, 9, 2, 7, 7, 7, 1, 7, 3, 10, 7, 9, 8, 10, 2, 1, 7, 8, 1, 6, 1, 8, 5, 10, 1, 10, 5, 6, 1, 10, 10, 2, 4, 1, 8, 7, 9, 8, 3, 1, 9, 8, 1, 3, 7, 8, 1, 5, 9, 2, 9, 2, 7, 4, 8, 10, 10, 8, 1, 8, 8, 8, 8, 1, 8, 10, 1, 10, 10, 7, 1, 7, 10, 1, 10, 7, 7, 4, 9, 1, 7, 9, 6, 10, 9, 9, 10, 1, 8, 7, 2, 5, 7, 10, 3, 9, 1, 10, 7, 10, 7, 5, 10, 2, 8, 4, 9, 8, 10, 1, 10, 8, 1, 1, 8, 8, 1, 1, 1, 2, 10, 7, 10, 8, 1, 8, 7, 2, 7, 8, 8, 10, 10, 10, 10, 2, 7, 8, 1, 1, 9, 10, 10, 9, 8, 10, 2, 8, 4, 1, 8, 2, 8, 10, 1, 7, 8, 9, 2, 8, 3, 3, 7, 10, 10, 1, 8, 5, 8, 10, 1, 4, 4, 7, 8, 1, 1, 8, 4, 8, 5, 8, 5, 8, 8, 2, 1, 7, 9, 1, 8, 7, 8, 1, 1, 8, 1, 8, 3, 10, 10, 4, 3, 4, 8, 7, 7, 9, 8, 1, 1, 4, 1, 8, 10, 8, 10, 10, 8, 7, 8, 3, 9, 10, 8, 7, 5, 1, 10, 10, 10, 7, 10, 5, 8, 9, 7, 3, 1, 1, 9, 7, 7, 4, 7, 4, 1, 3, 1, 10, 2, 10, 8, 5, 4, 10, 10, 8, 5, 8, 10, 1, 9, 7, 1, 10, 3, 8, 2, 1, 9, 7, 2, 2, 1, 9, 8, 10, 7, 8, 4, 10, 4, 8, 10, 6, 1, 1, 1, 2, 10, 10, 10, 10, 9, 10, 1, 7, 1, 10, 1, 7, 1, 1, 8, 7, 8, 7, 4, 2, 1, 7, 9, 5, 6, 7, 4, 8, 8, 1, 1, 7, 8, 7, 10, 7, 7, 8, 8, 7, 1, 1, 1, 10, 8, 1, 7, 1, 1, 1, 7, 3, 4, 7, 10, 10, 10, 10, 4, 1, 10, 8, 4, 5, 1, 10, 1, 1, 5, 8, 10, 7, 7, 1, 1, 8, 5, 1, 8, 1, 8, 7, 7, 8, 8, 8, 10, 10, 8, 6, 4, 1, 7, 8, 1, 9, 3, 4, 8, 1, 1, 8, 7, 8, 1, 7, 7, 1, 8, 7, 8, 10, 1, 1, 4, 8, 2, 1, 8, 1, 3, 7, 10, 10, 7, 1, 1, 10, 6, 7, 7, 7, 1, 8, 10, 1, 8, 3, 1, 1, 10, 7, 5, 1, 8, 10, 1, 1, 10, 7, 1, 10, 1, 10, 4, 7, 3, 2, 2, 10, 7, 7, 9, 5, 5, 8, 2, 8, 10, 8, 8, 3, 8, 3, 8, 8, 3, 7, 2, 8, 8, 7, 10, 10, 10, 8, 1, 3, 8, 1, 7, 8, 4, 5, 8, 7, 7, 1, 7, 10, 8, 1, 10, 1, 7, 2, 9, 10, 10, 8, 7, 8, 4, 1, 3, 7, 7, 1, 8, 4, 8, 1, 7, 4, 9, 7, 8, 8, 6, 10, 8, 10, 1, 8, 8, 1, 7, 7, 9, 10, 1, 10, 4, 8, 7, 1, 8, 7, 8, 8, 1, 7, 3, 2, 2, 9, 1, 1, 8, 10, 1, 6, 9, 9, 2, 10, 8, 8, 10, 1, 8, 4, 10, 10, 8, 5, 1, 2, 2, 8, 5, 7, 9, 10, 10, 10, 7, 10, 1, 7, 7, 1, 7, 4, 8, 10, 1, 8, 4, 10, 9, 1, 1, 8, 7, 10, 4, 1, 1, 8, 8, 10, 8, 5, 7, 1, 8, 10, 1, 3, 8, 7, 2, 4, 2, 7, 8, 7, 2, 1, 9, 10, 1, 1, 8, 8, 4, 1, 7, 5, 8, 8, 9, 10, 10, 3, 1, 1, 2, 10, 8, 1, 2, 8, 8, 1, 1, 1, 4, 7, 1, 8, 8, 1, 7, 10, 2, 8, 5, 1, 10, 10, 1, 3, 7, 8, 7, 10, 7, 8, 7, 1, 9, 8, 7, 8, 8, 9, 10, 5, 7, 7, 3, 7, 10, 1, 1, 1, 10, 2, 8, 8, 4, 8, 8, 9, 9, 8, 1, 8, 1, 8, 2, 3, 10, 10, 8, 3, 10, 7, 7, 4, 10, 2, 7, 7, 8, 7, 8, 1, 1, 5, 9, 10, 1, 10, 10, 7, 8, 4, 7, 9, 10, 3, 10, 7, 1, 7, 8, 7, 7, 1, 7, 7, 1, 10, 7, 7, 7, 9, 10, 4, 1, 8, 1, 8, 2, 1, 1, 4, 10, 4, 4, 1, 9, 8, 4, 9, 10, 2, 1, 7, 8, 2, 1, 1, 3, 8, 8, 8, 7, 7, 8, 8, 2, 2, 7, 8, 10, 7, 5, 7, 5, 5, 10, 4, 1, 10, 10, 1, 10, 10, 1, 7, 8, 1, 8, 1, 7, 8, 1, 1, 1, 8]
#true_durations = [4, 10, 9, 5, 7, 2, 7, 8, 10, 3, 6, 1, 5, 6, 3, 7, 7, 4, 3, 9, 5, 6, 9, 7, 6, 4, 4, 9, 3, 10, 9, 8, 5, 9, 5, 9, 8, 10, 8, 4, 4, 4, 8, 10, 9, 10, 4, 1, 2, 6, 6, 10, 9, 7, 8, 5, 7, 2, 10, 9, 5, 3, 3, 9, 8, 4, 10, 9, 8, 3, 4, 9, 7, 7, 8, 5, 5, 8, 10, 4, 6, 10, 5, 8, 7, 4, 1, 1, 2, 8, 3, 3, 10, 6, 6, 1, 3, 6, 4, 8, 9, 8, 6, 8, 8, 2, 4, 8, 5, 5, 4, 4, 3, 4, 9, 1, 10, 6, 9, 6, 1, 2, 9, 4, 3, 4, 4, 9, 7, 7, 9, 9, 8, 7, 1, 3, 6, 6, 9, 7, 6, 3, 10, 5, 2, 8, 3, 4, 2, 7, 1, 2, 10, 6, 6, 4, 7, 2, 5, 5, 2, 8, 3, 2, 2, 6, 1, 9, 2, 7, 4, 4, 6, 6, 9, 8, 3, 9, 8, 3, 7, 2, 5, 10, 10, 7, 9, 5, 7, 4, 2, 3, 2, 6, 5, 4, 9, 9, 3, 4, 5, 4, 9, 10, 10, 8, 1, 9, 5, 5, 7, 5, 9, 3, 5, 2, 9, 1, 7, 2, 7, 6, 9, 6, 7, 2, 7, 10, 6, 7, 3, 4, 2, 2, 8, 10, 2, 8, 5, 2, 7, 5, 5, 2, 8, 8, 2, 8, 3, 4, 5, 4, 8, 7, 8, 6, 10, 4, 10, 1, 2, 4, 6, 8, 5, 10, 6, 7, 6, 2, 10, 8, 1, 8, 4, 5, 2, 1, 7, 6, 7, 1, 9, 7, 3, 10, 10, 8, 7, 5, 6, 9, 1, 9, 9, 7, 8, 9, 10, 2, 10, 7, 5, 1, 1, 4, 10, 2, 10, 1, 8, 7, 4, 4, 1, 9, 5, 5, 5, 10, 10, 9, 4, 9, 10, 3, 9, 2, 9, 6, 8, 6, 3, 8, 10, 9, 7, 4, 1, 9, 9, 8, 4, 1, 10, 1, 1, 7, 1, 8, 9, 3, 4, 7, 7, 5, 7, 7, 7, 10, 8, 6, 5, 3, 7, 9, 3, 4, 6, 5, 1, 2, 2, 1, 6, 3, 8, 8, 5, 3, 3, 10, 9, 6, 8, 4, 7, 5, 6, 9, 8, 7, 3, 8, 10, 10, 8, 6, 5, 3, 1, 10, 3, 7, 2, 9, 6, 2, 6, 1, 6, 7, 1, 1, 7, 3, 7, 5, 9, 6, 9, 7, 7, 1, 4, 8, 6, 9, 1, 3, 4, 3, 9, 7, 5, 7, 6, 8, 9, 6, 4, 3, 1, 6, 1, 7, 5, 2, 10, 5, 7, 1, 1, 1, 10, 6, 2, 8, 2, 9, 7, 5, 1, 10, 2, 4, 1, 6, 10, 6, 4, 4, 9, 1, 6, 1, 10, 10, 8, 9, 7, 9, 9, 1, 4, 3, 5, 3, 9, 2, 5, 2, 8, 9, 2, 9, 4, 4, 7, 3, 2, 7, 10, 1, 1, 10, 9, 8, 6, 3, 7, 4, 9, 10, 8, 1, 7, 1, 3, 4, 2, 6, 2, 6, 4, 4, 8, 4, 4, 5, 5, 4, 4, 2, 8, 2, 5, 5, 3, 7, 8, 6, 2, 3, 2, 6, 10, 4, 4, 4, 3, 9, 1, 8, 3, 8, 3, 4, 3, 7, 8, 7, 4, 9, 4, 1, 4, 8, 9, 3, 9, 7, 10, 3, 2, 4, 4, 4, 2, 5, 4, 3, 3, 10, 2, 10, 7, 4, 2, 6, 3, 2, 7, 1, 1, 9, 6, 5, 1, 6, 2, 3, 5, 5, 6, 1, 8, 10, 1, 8, 1, 1, 7, 5, 6, 10, 8, 9, 9, 10, 6, 4, 5, 2, 5, 10, 4, 3, 1, 6, 6, 2, 8, 10, 4, 8, 10, 2, 9, 7, 3, 2, 10, 1, 7, 3, 10, 10, 6, 8, 4, 4, 5, 8, 9, 5, 1, 6, 3, 9, 10, 6, 7, 1, 10, 5, 8, 4, 4, 9, 2, 3, 10, 10, 4, 9, 7, 8, 3, 2, 2, 10, 8, 5, 8, 6, 9, 2, 10, 4, 4, 9, 5, 6, 5, 7, 8, 5, 9, 5, 7, 4, 5, 6, 9, 3, 7, 9, 9, 3, 1, 5, 1, 8, 4, 2, 4, 9, 9, 3, 3, 2, 1, 6, 6, 5, 5, 8, 3, 1, 3, 1, 7, 7, 8, 2, 3, 10, 5, 3, 1, 10, 3, 10, 1, 8, 5, 1, 9, 6, 10, 8, 8, 2, 4, 6, 1, 10, 9, 3, 9, 4, 2, 4, 1, 7, 2, 9, 1, 7, 6, 1, 6, 2, 10, 10, 2, 9, 6, 5, 10, 4, 10, 8, 10, 1, 10, 10, 9, 2, 10, 3, 5, 5, 5, 2, 5, 2, 2, 1, 10, 5, 8, 9, 7, 8, 4, 7, 9, 4, 8, 6, 10, 7, 1, 5, 3, 10, 6, 1, 8, 1, 5, 10, 8, 8, 9, 9, 10, 6, 7, 1, 8, 7, 1, 5, 6, 9, 1, 8, 9, 9, 9, 10, 4, 6, 6, 8, 6, 10, 6, 6, 7, 7, 3, 7, 4, 10, 6, 7, 7, 7, 3, 1, 5, 8, 5, 6, 6, 7, 5, 7, 8, 3, 4, 7, 9, 3, 8, 2, 2, 6, 9, 10, 3, 5, 5, 10, 9, 3, 7, 4, 3, 2, 1, 3, 10, 8, 5, 5, 8, 8, 9, 6, 9, 8, 10, 10, 4, 10, 2, 5, 4, 1, 7, 10, 3, 8, 10, 2, 1, 8, 6, 9, 7, 4, 2, 8, 9, 7, 10, 6, 7, 5, 10, 3, 2, 9, 5, 3, 1, 1, 2, 6, 5, 3, 8, 10, 4, 1, 8, 3, 4, 7, 5, 1, 7, 9, 9, 8, 1, 8, 7, 7, 6, 3, 6, 10, 9, 4, 7, 10, 10, 9, 2, 7, 4, 4, 5, 10, 7, 6, 3, 2, 7, 3, 6, 6, 5, 4, 7, 1, 8, 4, 8, 7, 5, 1, 6, 7, 8, 6, 5, 10, 8, 3, 3, 10, 1, 9, 2, 6, 8, 4, 5, 8, 10, 10, 2, 10, 2, 2, 10, 5, 5, 2, 1, 2, 2, 7, 7, 8, 9, 6, 4, 1, 8, 10, 4, 3, 8, 1, 8, 4, 3, 4, 3, 10, 6, 2, 8, 4, 4, 6, 8, 6, 9, 7, 2, 7, 6, 2, 6, 2, 5, 9, 9, 7, 9, 9, 5, 7, 9, 10, 4, 10, 3, 1, 2, 4, 10, 7, 6, 10, 10, 9, 5, 10, 6, 5, 3, 1, 4, 9, 5, 2, 9, 1, 4, 7, 7, 6, 8, 5, 1, 6, 2, 2, 9, 3, 10, 2, 2, 8, 8, 7, 1, 8, 1, 10, 2, 1, 8, 1, 6, 2, 4, 7, 10, 5, 10, 1, 1, 9, 5, 4, 10, 9, 10, 7, 7, 5, 7, 3, 7, 8, 9, 3, 2, 3, 7, 6, 10, 10, 1, 9, 8, 8, 8, 3, 6, 5, 1, 9, 2, 1, 5, 1, 10, 3, 9, 9, 6, 4, 3, 1, 1, 9, 2, 10, 5, 3, 6, 5, 8, 9, 1, 2, 2, 3, 5, 4, 4, 5, 10, 10, 8, 6, 9, 6, 9, 9, 9, 2, 8, 1, 6, 1, 2, 8, 9, 3, 7, 7, 3, 4, 8, 9, 8, 4, 1, 9, 3, 5, 6, 5, 10, 6, 6, 4, 7, 5, 10, 8, 10, 1, 7, 5, 6, 10, 3, 9, 1, 5, 5, 4, 4, 6, 10, 10, 9, 8, 1, 3, 8, 10, 3, 3, 10, 1, 6, 5, 3, 5, 2, 8, 4, 4, 8, 3, 9, 10, 9, 1, 10, 6, 2, 7, 4, 3, 3, 6, 1, 8, 10, 3, 4, 6, 4, 3, 4, 1, 8, 6, 2, 4, 10, 5, 1, 10, 3, 5, 9, 8, 5, 1, 8, 6, 3, 4, 4, 8, 10, 2, 5, 9, 2, 1, 7, 4, 6, 2, 3, 5, 7, 4, 4, 4, 8, 1, 5, 4, 6, 7, 1, 5, 7, 8, 5, 6, 9, 10, 5, 4, 3, 1, 7, 8, 8, 3, 10, 2, 10, 10, 5, 5, 3, 7, 7, 7, 1, 10, 6, 3, 1, 1, 5, 2, 2, 8, 7, 7, 1, 8, 1, 4, 7, 6, 1, 8, 1, 5, 7, 8, 6, 10, 2, 1, 1, 10, 5, 4, 1, 6, 1, 2, 3, 10, 6, 9, 1, 10, 2, 4, 9, 6, 6, 8, 3, 4, 2, 8, 6, 1, 6, 9, 8, 9, 5, 4, 5, 1, 4, 10, 6, 7, 2, 1, 5, 9, 6, 5, 3, 8, 4, 6, 7, 9, 8, 7, 8, 7, 2, 10, 1, 3, 9, 10, 6, 3, 8, 5, 2, 6, 9, 2, 4, 9, 9, 1, 7, 5, 5, 2, 1, 7, 2, 2, 4, 7, 7, 5, 4, 8, 7, 2, 6, 8, 10, 10, 2, 5, 4, 7, 7, 7, 4, 6, 6, 1, 4, 5, 10, 8, 3, 9, 6, 1, 5, 7, 3, 10, 2, 5, 3, 1, 3, 3, 7, 9, 6, 2, 3, 3, 8, 1, 7, 8, 8, 1, 4, 1, 5, 8, 3, 4, 3, 10, 5, 5, 7, 3, 8, 5, 6, 6, 5, 1, 3, 6, 10, 4, 6, 6, 7, 8, 10, 8, 1, 3, 1, 3, 4, 2, 2, 9, 10, 9, 2, 10, 2, 6, 2, 4, 4, 3, 6, 6, 6, 6, 6, 2, 5, 6, 10, 5, 7, 8, 5, 3, 3, 1, 1, 7, 7, 5, 6, 1, 5, 8, 6, 2, 7, 1, 2, 5, 2, 7, 1, 10, 4, 3, 9, 1, 1, 1, 3, 5, 4, 1, 7, 3, 4, 1, 10, 5, 4, 8, 4, 9, 5, 4, 8, 9, 4, 5, 6, 3, 3, 6, 2, 5, 3, 9, 4, 3, 7, 1, 3, 3, 3, 3, 9, 2, 10, 10, 2, 1, 8, 3, 6, 1, 8, 10, 10, 3, 4, 2, 7, 5, 2, 7, 8, 8, 8, 1, 3, 7, 2, 9, 3, 2, 3, 10, 7, 5, 1, 9, 3, 1, 1, 8, 10, 3, 3, 2, 9, 1, 3, 6, 2, 9, 2, 9, 2, 5, 7, 10, 1, 5, 9, 3, 4, 7, 3, 6, 9, 9, 2, 10, 10, 10, 3, 2, 10, 3, 3, 10, 7, 10, 4, 9, 5, 8, 8, 1, 8, 5, 6, 8, 3, 2, 3, 7, 2, 1, 4, 6, 5, 7, 9, 7, 2, 6, 7, 8, 9, 8, 2, 8, 5, 4, 8, 9, 4, 4, 6, 7, 8, 2, 5, 3, 3, 4, 8, 9, 4, 4, 1, 1, 5, 7, 6, 6, 2, 3, 1, 3, 9, 5, 1, 5, 10, 10, 10, 9, 1, 6, 6, 9, 1, 6, 5, 8, 5, 1, 8, 2, 9, 4, 6, 9, 1, 2, 7, 7, 5, 3, 9, 2, 1, 5, 6, 2, 3, 3, 6, 3, 5, 6, 8, 5, 3, 4, 7, 7, 7, 4, 7, 4, 7, 9, 5, 9, 8, 4, 4, 2, 6, 6, 10, 10, 3, 6, 3, 1, 2, 3, 6, 10, 4, 3, 9, 10, 9, 3, 9, 3, 3, 8, 6, 2, 6, 3, 1, 10, 5, 8, 6, 9, 4, 9, 10, 9, 5, 7, 3, 8, 9, 8, 3, 2, 2, 2, 8, 4, 8, 4, 1, 4, 4, 5, 2, 5, 5, 3, 3, 8, 6, 10, 4, 6, 6, 8, 8, 9, 3, 3, 1, 8, 4, 1, 9, 7, 3, 9, 5, 4, 5, 2, 1, 6, 8, 8, 5, 8, 6, 1, 10, 10, 3, 1, 10, 8, 6, 1, 3, 7]

distribution = [predicted_durations[i] - true_durations[i] for i in range(len(predicted_durations))]

sns.histplot(distribution, bins=7, kde=True)

In [None]:
metrics = validate(true_durations, predicted_durations)
print(metrics)

In [None]:
[print("{}: {}".format(distribution[i], input_text[i])) for i in range(len(distribution)) if distribution[i] > metrics['mean_absolute_error'] or distribution[i] < -metrics['mean_absolute_error']]