# Text grade classification

# Trying different models and fragment lengths



### Environment setup 

In [None]:
#!g1.1
%pip install transformers
%pip install sentencepiece
%pip install pymorphy2

In [322]:
#!g1.1

DATASET_PATH = '../dataset'

### Loading file list, model, tokenizer

In [323]:
#!g1.1
# Load the files
import glob 

files = glob.glob(f"{DATASET_PATH}/**/*.txt", recursive=True)
print(files)

['./dataset/year_ben_4u.txt', './dataset/year_lut_1u.txt', './dataset/year_rag_1u.txt', './dataset/year_rog_1u.txt', './dataset/year_rud_1u.txt', './dataset/year_vah_1u.txt', './dataset/year_vah_1pu.txt', './dataset/year_kur_1u.txt', './dataset/year_gor_4u.txt', './dataset/year_vah_2pu.txt', './dataset/year_vah_2u.txt', './dataset/yead_rud_2u.txt', './dataset/year_rud_3u.txt', './dataset/year_vah_4u.txt', './dataset/year_uch_2u.txt', './dataset/year_vah_4pu.txt', './dataset/year_uch_2pu.txt', './dataset/year_unk_10.txt', './dataset/year_petrov_11.txt', './dataset/year_ponom_11.txt', './dataset/year_plenko_11.txt', './dataset/year_guryan_11.txt', './dataset/year_sobol_10.txt', './dataset/year_klimov_10.txt', './dataset/year_nik_07.txt', './dataset/year_nik_05.txt', './dataset/year_nik_06.txt', './dataset/year_nik_11.txt', './dataset/year_nik_09.txt', './dataset/year_nik_10.txt', './dataset/year_nik_08.txt', './dataset/year_bog_07.txt', './dataset/year_bog_06.txt', './dataset/year_bog_08

In [None]:
#!g1.1
from transformers import AutoTokenizer, AutoModel, BertForSequenceClassification

def get_model(name: str):        
  model = BertForSequenceClassification.from_pretrained(name, num_labels=1).to("cuda")
  return model

def get_tokenizer(name: str):
  return AutoTokenizer.from_pretrained(name)

In [None]:
#!g1.1
import torch
import random
from transformers.file_utils import is_tf_available, is_torch_available, is_torch_tpu_available


def set_seed(seed: int):
    """
    Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf`` (if
    installed).

    Args:
        seed (:obj:`int`): The seed to set.
    """
    random.seed(seed)
    np.random.seed(seed)
    if is_torch_available():
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        # ^^ safe to call this function even if cuda is not available
    if is_tf_available():
        import tensorflow as tf

        tf.random.set_seed(seed)


### Making dataset and training

In [None]:
#!g1.1
class SchoolTextDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels
        # assert len(labels) == len(encodings)

    def __getitem__(self, idx):
        item = {k: (v[idx]).clone().detach() for k, v in self.encodings.items()}
        item["labels"] = torch.tensor([self.labels[idx]]).float()
        return item

    def __len__(self):
        return len(self.labels)

In [None]:
#!g1.1
import os
from spacy.lang.ru import Russian

def get_label_from_path(path):
  filename = os.path.basename(path)
  grade = filename.split('.')[0].split('_')[-1]
  grade = grade.replace('u', '').replace('p', '.5')
  try:
    return float(grade)
  except:
    print(grade)
    raise

def get_next_fragment(dataset_files, fragment_length):
  spacy_tokenizer = Russian()
  for path in dataset_files:
    with open(path, 'r') as f:
      current_label = get_label_from_path(path)
      lines = f.readlines()
      text = ''.join(lines)
      tokens = spacy_tokenizer(text)
      ind = 0
      while ind < len(tokens):
        yield ''.join([token.text_with_ws for token in tokens[ind:min(ind+fragment_length-1, len(tokens))]]), current_label
        ind+=fragment_length

sliding_window_step = 20

def get_next_fragment_sliding_window(dataset_files, fragment_length):
  spacy_tokenizer = Russian()
  for path in dataset_files:
    with open(path, 'r') as f:
      current_label = get_label_from_path(path)
      lines = f.readlines()
      text = ''.join(lines)
      tokens = spacy_tokenizer(text)
      ind = 0
      while ind < len(tokens):
        yield ''.join([token.text_with_ws for token in tokens[ind:min(ind+fragment_length-1, len(tokens))]]), current_label
        ind+=sliding_window_step

In [None]:
#!g1.1
import pandas as pd
from sklearn.model_selection import train_test_split, KFold



# Putting all of the dataset preparation in one function

def make_dataset(dataset_files, fragment_length, tokenizer, next_fragment, max_instances=None, kfold=None):
  # Load from files and divide to fragments
  df = pd.DataFrame([x for x in next_fragment(dataset_files, fragment_length)], columns=['text', 'grade'])

  set_seed(1)

  x = df['text'].values
  y = df['grade'].values

  # Limit the number of instances
  if max_instances is not None:
    normalized_size = np.random.choice(range(len(x)), size=max_instances, replace=False)
    x, y = [x[i] for i in normalized_size], [y[i] for i in normalized_size]
    
  if kfold is not None:
    kf = KFold(n_splits=kfold, shuffle=True)
    for train_indices, test_indices in kf.split(x):
        x_train = np.take(x, train_indices)
        y_train = np.take(y, train_indices)
        x_test = np.take(x, test_indices)
        y_test = np.take(y, test_indices)
        
        # Tokenize
        x_train = tokenizer(list(x_train), truncation=True, padding=True, return_tensors="pt", max_length=512)
        x_test = tokenizer(list(x_test), truncation=True, padding=True, return_tensors="pt", max_length=512)

        train_dataset = SchoolTextDataset(x_train, list(y_train))
        test_dataset = SchoolTextDataset(x_test, list(y_test))

        print(f"Dataset of fragment length {fragment_length} ({len(train_dataset)} cross-validation instances)")

        yield train_dataset, test_dataset
  else:
    x_train,  x_test, y_train,  y_test = train_test_split(x, y, test_size=0.2)

    # Tokenize
    x_train = tokenizer(list(x_train), truncation=True, padding=True, return_tensors="pt", max_length=512)
    x_test = tokenizer(list(x_test), truncation=True, padding=True, return_tensors="pt", max_length=512)

    train_dataset = SchoolTextDataset(x_train, list(y_train))
    test_dataset = SchoolTextDataset(x_test, list(y_test))

    print(f"Dataset of fragment length {fragment_length} ({len(train_dataset)} instances)")

    return train_dataset, test_dataset

In [None]:
#!g1.1
from sklearn.metrics import mean_squared_error, accuracy_score, mean_absolute_error, confusion_matrix, precision_recall_fscore_support

def grade_class(values: list):
    result = []
    for i in values:
        if i < 3.5:
            result.append(1)
        elif i < 6.5:
            result.append(2)
        elif i < 9.5:
            result.append(3)
        else:
            result.append(4)
    return result

def grade_class_granular(values: list):
    result = []
    for i in values:
        x = max(0.5, min(11.49, i))
        result.append(int(np.round(x)))
    return result

def compute_metrics(pred):
    labels = [x[0] for x in pred.label_ids]
    preds = [x[0] for x in pred.predictions]
    # calculate accuracy using sklearn's function
    #   print(f"labels : {labels[:10]}")
    # print(f"pred : {preds[:10]}")
    mae = mean_absolute_error(labels, preds)
    correlation = np.corrcoef([labels, preds])[0, 1]

    pred_classes = grade_class(preds)
    label_classes = grade_class(labels)
    confusion_mat = confusion_matrix(y_true=label_classes, y_pred=pred_classes, labels=[1, 2, 3, 4])
    acc = accuracy_score(label_classes, pred_classes)
    precision, recall, f1, support = precision_recall_fscore_support(y_true=label_classes, y_pred=pred_classes, average='weighted')
    
    pred_classes_granular = grade_class_granular(preds)
    label_classes_granular = grade_class_granular(labels)
    
    confusion_mat_granular = confusion_matrix(y_true=label_classes_granular, y_pred=pred_classes_granular, labels=range(1, 12))
    acc_granular = accuracy_score(label_classes_granular, pred_classes_granular)
    precision_granular, recall_granular, f1_granular, support_granular = precision_recall_fscore_support(
        y_true=label_classes_granular, y_pred=pred_classes_granular, average='weighted')
    return {
        'mean_absolute_error': mae,
        'correlation_coefficient': correlation,
        'confusion_matrix': confusion_mat.tolist(),
        'accuracy': acc,
        'precision': precision,
        'recall': recall,
        'f1': f1,
#         'support': support,
        'confusion_matrix_granular': confusion_mat_granular.tolist(),
        'accuracy_granular': acc_granular,
        'precision_granular': precision_granular,
        'recall_granular': recall_granular,
        'f1_granular': f1_granular,
#         'support_granular': support_granular,
        
    }

In [None]:
#!g1.1
import numpy as np

min_length = 50  # 50
max_length = 512  # 500%



step = 20
lengths = np.arange(0,(max_length-min_length)//step + 1)*step+min_length

print(lengths)

In [None]:
#!g1.1
model_names = [
    # "DeepPavlov/rubert-base-cased",
    # "DeepPavlov/rubert-base-cased-sentence",
    "DeepPavlov/xlm-roberta-large-en-ru"
]

In [None]:
#!g1.1
def dict_sum(original, new_dict):
    for k,v in new_dict.items():
        if original.get(k) is None:
            original[k] = v
        else:
            if type(v) == list:
                original[k] = (np.array(v) + np.array(original[k])).tolist()
            else:
                original[k] += v
    return original

def dict_divide(x, n):
    for k,v in x.items():
        if type(v) == list:
            x[k] = (np.array(v) / n).tolist()
        else:
            x[k] = v/n
    return x

In [None]:
#!g1.1
from transformers import Trainer, TrainingArguments
from datetime import datetime
import time
import json
import gc

training_args = TrainingArguments(
    output_dir='../results',          # output directory
    num_train_epochs=3,              # total number of training epochs
    per_device_train_batch_size=5,  # batch size per device during training
    per_device_eval_batch_size=5,   # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='../logs',            # directory for storing logs
    load_best_model_at_end=True,     # load the best model when finished training (default metric is loss)
    # but you can specify `metric_for_best_model` argument to change to accuracy or other metric
    logging_steps=1500,              # log & save weights each logging_steps
    save_steps=12000,
    save_total_limit=1,
    evaluation_strategy="steps",     # evaluate each `logging_steps`
)

cross_validation_folds = 5

model_results = {}

for model_name in model_names:
  max_instances = None

  evaluation_results = {}

  for fragment_length in reversed(lengths):
    
    print(f"Length of fragment : {fragment_length} (max instances {max_instances})")
    average_result = {}
    gc.collect()
    # 1. Make the dataset for each length
    tokenizer = get_tokenizer(model_name)
    for train_dataset, test_dataset in make_dataset(files, fragment_length, tokenizer, get_next_fragment, max_instances=max_instances, kfold=cross_validation_folds):
        # We ignore this, because we don't want to sample the dataset

        # if max_instances is None:
        #   max_instances = len(train_dataset) + len(test_dataset)
        # 2. Load into trainer, train, evaluate then append to results
        gc.collect()
        model = get_model(model_name)
        
        trainer = Trainer(
          model=model,                         # the instantiated Transformers model to be trained
          args=training_args,                  # training arguments, defined above
          train_dataset=train_dataset,         # training dataset
          eval_dataset=test_dataset,          # evaluation dataset
          compute_metrics=compute_metrics,     # the callback that computes metrics of interest
        )
        start_time = time.time()
        trainer.train()
        training_time = time.time() - start_time
        result = trainer.evaluate()
        result['training_time_s'] = training_time
        print(result)
        average_result = dict_sum(average_result, result)
    average_result = dict_divide(average_result, cross_validation_folds)
#     average_result['confusion_matrix'] = average_result['confusion_matrix'].tolist()
#     average_result['confusion_matrix_granular'] = average_result['confusion_matrix_granular'].tolist()

        
    evaluation_results[fragment_length] = average_result
    
    
    timestamp = datetime.now().strftime("%m_%d_%H_%M")
    with open(f'../custom_results/model_fragment_results_{timestamp}_{fragment_length}.json', 'a') as f:
        json.dump(average_result, f)

  model_results[model_name] = evaluation_results

  timestamp = datetime.now().strftime("%m_%d_%H_%M")
  with open(f'../custom_results/model_results_{model_name}_{timestamp}.json', 'a') as f:
    json.dump(model_results, f)


In [None]:
#!g1.1
from datetime import datetime
print(model_results)

import json

timestamp = datetime.now().strftime("%m_%d_%H_%M")
with open(f'../custom_result/model_results_{timestamp}.json', 'a') as f:
    json.dump(model_results, f)

In [None]:
#!g1.1
# model_results = {'DeepPavlov/rubert-base-cased': {510: {'eval_loss': 0.6438136696815491, 'eval_mean_absolute_error': 0.5549294352531433, 'eval_correlation_coefficient': 0.9584111857817864, 'eval_runtime': 11.9883, 'eval_samples_per_second': 54.303, 'eval_steps_per_second': 3.42, 'epoch': 3.0}, 490: {'eval_loss': 0.4725295901298523, 'eval_mean_absolute_error': 0.4326665997505188, 'eval_correlation_coefficient': 0.9682124476596938, 'eval_runtime': 12.499, 'eval_samples_per_second': 54.244, 'eval_steps_per_second': 3.44, 'epoch': 3.0}, 470: {'eval_loss': 0.4729669988155365, 'eval_mean_absolute_error': 0.4750634431838989, 'eval_correlation_coefficient': 0.9723216405236511, 'eval_runtime': 13.0773, 'eval_samples_per_second': 54.14, 'eval_steps_per_second': 3.441, 'epoch': 3.0}, 450: {'eval_loss': 0.5354031920433044, 'eval_mean_absolute_error': 0.43037524819374084, 'eval_correlation_coefficient': 0.9654460956485931, 'eval_runtime': 13.6221, 'eval_samples_per_second': 54.177, 'eval_steps_per_second': 3.45, 'epoch': 3.0}, 430: {'eval_loss': 0.5336100459098816, 'eval_mean_absolute_error': 0.4922550320625305, 'eval_correlation_coefficient': 0.9668407050639991, 'eval_runtime': 14.2423, 'eval_samples_per_second': 54.205, 'eval_steps_per_second': 3.44, 'epoch': 3.0}, 410: {'eval_loss': 0.4583664834499359, 'eval_mean_absolute_error': 0.46573591232299805, 'eval_correlation_coefficient': 0.9720261979387992, 'eval_runtime': 14.9826, 'eval_samples_per_second': 54.063, 'eval_steps_per_second': 3.404, 'epoch': 3.0}, 390: {'eval_loss': 0.41994309425354004, 'eval_mean_absolute_error': 0.42090439796447754, 'eval_correlation_coefficient': 0.9753245884255551, 'eval_runtime': 15.7174, 'eval_samples_per_second': 54.144, 'eval_steps_per_second': 3.436, 'epoch': 3.0}, 370: {'eval_loss': 0.5394895672798157, 'eval_mean_absolute_error': 0.49361708760261536, 'eval_correlation_coefficient': 0.9689686111352713, 'eval_runtime': 16.5899, 'eval_samples_per_second': 54.069, 'eval_steps_per_second': 3.436, 'epoch': 3.0}, 350: {'eval_loss': 0.7105809450149536, 'eval_mean_absolute_error': 0.5651165246963501, 'eval_correlation_coefficient': 0.9525019604506539, 'eval_runtime': 17.5087, 'eval_samples_per_second': 54.144, 'eval_steps_per_second': 3.427, 'epoch': 3.0}, 330: {'eval_loss': 0.6363464593887329, 'eval_mean_absolute_error': 0.48796868324279785, 'eval_correlation_coefficient': 0.9603907352117782, 'eval_runtime': 17.0807, 'eval_samples_per_second': 58.838, 'eval_steps_per_second': 3.688, 'epoch': 3.0}, 310: {'eval_loss': 0.5006794929504395, 'eval_mean_absolute_error': 0.46303293108940125, 'eval_correlation_coefficient': 0.9677901061544582, 'eval_runtime': 16.2168, 'eval_samples_per_second': 65.981, 'eval_steps_per_second': 4.132, 'epoch': 3.0}, 290: {'eval_loss': 0.5548178553581238, 'eval_mean_absolute_error': 0.46474817395210266, 'eval_correlation_coefficient': 0.9697196908475955, 'eval_runtime': 16.3021, 'eval_samples_per_second': 70.114, 'eval_steps_per_second': 4.417, 'epoch': 3.0}, 270: {'eval_loss': 0.4739500880241394, 'eval_mean_absolute_error': 0.41211020946502686, 'eval_correlation_coefficient': 0.9698304652276473, 'eval_runtime': 17.3051, 'eval_samples_per_second': 70.904, 'eval_steps_per_second': 4.45, 'epoch': 3.0}, 250: {'eval_loss': 0.5357069969177246, 'eval_mean_absolute_error': 0.46953898668289185, 'eval_correlation_coefficient': 0.9666831805201189, 'eval_runtime': 16.823, 'eval_samples_per_second': 78.761, 'eval_steps_per_second': 4.934, 'epoch': 3.0}, 230: {'eval_loss': 0.631683349609375, 'eval_mean_absolute_error': 0.5498731136322021, 'eval_correlation_coefficient': 0.962849781422552, 'eval_runtime': 17.0232, 'eval_samples_per_second': 84.649, 'eval_steps_per_second': 5.346, 'epoch': 3.0}, 210: {'eval_loss': 0.5303255915641785, 'eval_mean_absolute_error': 0.4654589891433716, 'eval_correlation_coefficient': 0.9647626961119282, 'eval_runtime': 18.3498, 'eval_samples_per_second': 85.941, 'eval_steps_per_second': 5.395, 'epoch': 3.0}, 190: {'eval_loss': 0.5579525828361511, 'eval_mean_absolute_error': 0.44538575410842896, 'eval_correlation_coefficient': 0.9652222224965191, 'eval_runtime': 16.1201, 'eval_samples_per_second': 108.064, 'eval_steps_per_second': 6.762, 'epoch': 3.0}, 170: {'eval_loss': 0.5542275905609131, 'eval_mean_absolute_error': 0.4738572835922241, 'eval_correlation_coefficient': 0.9662785867842771, 'eval_runtime': 16.7985, 'eval_samples_per_second': 115.903, 'eval_steps_per_second': 7.263, 'epoch': 3.0}, 150: {'eval_loss': 0.6081576347351074, 'eval_mean_absolute_error': 0.49149081110954285, 'eval_correlation_coefficient': 0.9585183890915988, 'eval_runtime': 16.4075, 'eval_samples_per_second': 134.511, 'eval_steps_per_second': 8.411, 'epoch': 3.0}, 130: {'eval_loss': 0.7267868518829346, 'eval_mean_absolute_error': 0.547415554523468, 'eval_correlation_coefficient': 0.9554530574848882, 'eval_runtime': 18.307, 'eval_samples_per_second': 139.018, 'eval_steps_per_second': 8.74, 'epoch': 3.0}, 110: {'eval_loss': 0.7015063166618347, 'eval_mean_absolute_error': 0.4837210476398468, 'eval_correlation_coefficient': 0.9549796710608941, 'eval_runtime': 19.4681, 'eval_samples_per_second': 154.458, 'eval_steps_per_second': 9.657, 'epoch': 3.0}, 90: {'eval_loss': 0.6470832228660583, 'eval_mean_absolute_error': 0.4643601179122925, 'eval_correlation_coefficient': 0.9576133100589036, 'eval_runtime': 22.071, 'eval_samples_per_second': 166.463, 'eval_steps_per_second': 10.421, 'epoch': 3.0}, 70: {'eval_loss': 0.8725117444992065, 'eval_mean_absolute_error': 0.5471084713935852, 'eval_correlation_coefficient': 0.9421364653921181, 'eval_runtime': 19.0916, 'eval_samples_per_second': 247.387, 'eval_steps_per_second': 15.504, 'epoch': 3.0}, 50: {'eval_loss': 2.163191318511963, 'eval_mean_absolute_error': 1.0305832624435425, 'eval_correlation_coefficient': 0.8575062229558695, 'eval_runtime': 23.4923, 'eval_samples_per_second': 281.369, 'eval_steps_per_second': 17.623, 'epoch': 3.0}}}
# model_results = {'DeepPavlov/rubert-base-cased': {510: {'eval_loss': 0.33591960966587064, 'eval_mean_absolute_error': 0.37690727412700653, 'eval_correlation_coefficient': 0.9839640468492881, 'eval_runtime': 6.703, 'eval_samples_per_second': 97.12100000000001, 'eval_steps_per_second': 6.1166, 'epoch': 3.0}, 490: {'eval_loss': 0.34794422090053556, 'eval_mean_absolute_error': 0.38018497824668884, 'eval_correlation_coefficient': 0.9800381531213669, 'eval_runtime': 6.9529, 'eval_samples_per_second': 97.51320000000001, 'eval_steps_per_second': 6.1846, 'epoch': 3.0}, 470: {'eval_loss': 0.25607356876134874, 'eval_mean_absolute_error': 0.32758817374706267, 'eval_correlation_coefficient': 0.9852291944226597, 'eval_runtime': 7.2535, 'eval_samples_per_second': 97.526, 'eval_steps_per_second': 6.204000000000001, 'epoch': 3.0}, 450: {'eval_loss': 0.5038900047540664, 'eval_mean_absolute_error': 0.4382790386676788, 'eval_correlation_coefficient': 0.9747393305473271, 'eval_runtime': 7.568420000000001, 'eval_samples_per_second': 97.4846, 'eval_steps_per_second': 6.2097999999999995, 'epoch': 3.0}, 430: {'eval_loss': 0.38141776621341705, 'eval_mean_absolute_error': 0.39857401251792907, 'eval_correlation_coefficient': 0.9767342528888356, 'eval_runtime': 7.90808, 'eval_samples_per_second': 97.6226, 'eval_steps_per_second': 6.1964, 'epoch': 3.0}, 410: {'eval_loss': 0.3831910252571106, 'eval_mean_absolute_error': 0.40722976326942445, 'eval_correlation_coefficient': 0.9768684853448801, 'eval_runtime': 8.30764, 'eval_samples_per_second': 97.4766, 'eval_steps_per_second': 6.138999999999999, 'epoch': 3.0}, 390: {'eval_loss': 0.5406262338161468, 'eval_mean_absolute_error': 0.4885970771312714, 'eval_correlation_coefficient': 0.9696705493213669, 'eval_runtime': 8.710759999999999, 'eval_samples_per_second': 97.69539999999999, 'eval_steps_per_second': 6.1994, 'epoch': 3.0}, 370: {'eval_loss': 0.5919877976179123, 'eval_mean_absolute_error': 0.5212241470813751, 'eval_correlation_coefficient': 0.9693880432696773, 'eval_runtime': 9.012560000000002, 'eval_samples_per_second': 99.4788, 'eval_steps_per_second': 6.2378, 'epoch': 3.0}, 350: {'eval_loss': 0.42964840233325957, 'eval_mean_absolute_error': 0.4400419622659683, 'eval_correlation_coefficient': 0.9751094358213068, 'eval_runtime': 9.29702, 'eval_samples_per_second': 102.07520000000001, 'eval_steps_per_second': 6.4605999999999995, 'epoch': 3.0}, 330: {'eval_loss': 0.1709460996091366, 'eval_mean_absolute_error': 0.23774559199810028, 'eval_correlation_coefficient': 0.9900742103040162, 'eval_runtime': 9.283219999999998, 'eval_samples_per_second': 108.293, 'eval_steps_per_second': 6.791200000000001, 'epoch': 3.0}, 310: {'eval_loss': 0.18852979317307472, 'eval_mean_absolute_error': 0.25065877139568327, 'eval_correlation_coefficient': 0.9889187076692714, 'eval_runtime': 9.37306, 'eval_samples_per_second': 114.6328, 'eval_steps_per_second': 7.1836, 'epoch': 3.0}, 290: {'eval_loss': 0.1920248217880726, 'eval_mean_absolute_error': 0.2625650465488434, 'eval_correlation_coefficient': 0.9883266187309181, 'eval_runtime': 9.06436, 'eval_samples_per_second': 126.56599999999999, 'eval_steps_per_second': 7.974200000000001, 'epoch': 3.0}, 270: {'eval_loss': 0.26097937524318693, 'eval_mean_absolute_error': 0.3040627181529999, 'eval_correlation_coefficient': 0.9851892801629459, 'eval_runtime': 9.09624, 'eval_samples_per_second': 135.073, 'eval_steps_per_second': 8.4766, 'epoch': 3.0}, 250: {'eval_loss': 0.23734421879053116, 'eval_mean_absolute_error': 0.255857814848423, 'eval_correlation_coefficient': 0.9846106498572513, 'eval_runtime': 9.11712, 'eval_samples_per_second': 145.6234, 'eval_steps_per_second': 9.122200000000001, 'epoch': 3.0}, 230: {'eval_loss': 0.3071304991841316, 'eval_mean_absolute_error': 0.3202324643731117, 'eval_correlation_coefficient': 0.98227223697466, 'eval_runtime': 8.69802, 'eval_samples_per_second': 165.8236, 'eval_steps_per_second': 10.384599999999999, 'epoch': 3.0}, 210: {'eval_loss': 0.30425441041588785, 'eval_mean_absolute_error': 0.31183269917964934, 'eval_correlation_coefficient': 0.980916511649651, 'eval_runtime': 8.8199, 'eval_samples_per_second': 179.09859999999998, 'eval_steps_per_second': 11.249, 'epoch': 3.0}, 190: {'eval_loss': 0.2309357151389122, 'eval_mean_absolute_error': 0.27047823667526244, 'eval_correlation_coefficient': 0.9853722456149419, 'eval_runtime': 8.873000000000001, 'eval_samples_per_second': 196.69400000000002, 'eval_steps_per_second': 12.3102, 'epoch': 3.0}, 170: {'eval_loss': 0.24664568901062012, 'eval_mean_absolute_error': 0.2628838747739792, 'eval_correlation_coefficient': 0.9848558579991604, 'eval_runtime': 9.14264, 'eval_samples_per_second': 213.31980000000004, 'eval_steps_per_second': 13.366800000000001, 'epoch': 3.0}, 150: {'eval_loss': 0.23657941967248916, 'eval_mean_absolute_error': 0.2675287455320358, 'eval_correlation_coefficient': 0.9859411214448496, 'eval_runtime': 9.212639999999999, 'eval_samples_per_second': 241.44400000000002, 'eval_steps_per_second': 15.1012, 'epoch': 3.0}, 130: {'eval_loss': 0.2135237969458103, 'eval_mean_absolute_error': 0.22362841963768004, 'eval_correlation_coefficient': 0.9871307367771525, 'eval_runtime': 9.77538, 'eval_samples_per_second': 261.13620000000003, 'eval_steps_per_second': 16.3986, 'epoch': 3.0}, 110: {'eval_loss': 0.2859374389052391, 'eval_mean_absolute_error': 0.2696235179901123, 'eval_correlation_coefficient': 0.9822560924432681, 'eval_runtime': 9.84658, 'eval_samples_per_second': 305.4956, 'eval_steps_per_second': 19.101, 'epoch': 3.0}, 90: {'eval_loss': 0.4090846538543701, 'eval_mean_absolute_error': 0.3687426745891571, 'eval_correlation_coefficient': 0.9736618349565476, 'eval_runtime': 9.69474, 'eval_samples_per_second': 379.004, 'eval_steps_per_second': 23.729, 'epoch': 3.0}, 70: {'eval_loss': 0.38811469078063965, 'eval_mean_absolute_error': 0.321956941485405, 'eval_correlation_coefficient': 0.9752192893392853, 'eval_runtime': 11.29176, 'eval_samples_per_second': 418.8161999999999, 'eval_steps_per_second': 26.2524, 'epoch': 3.0}, 50: {'eval_loss': 0.4108700931072235, 'eval_mean_absolute_error': 0.30514280796051024, 'eval_correlation_coefficient': 0.973251585879195, 'eval_runtime': 11.7428, 'eval_samples_per_second': 563.0508, 'eval_steps_per_second': 35.2664, 'epoch': 3.0}}}

In [None]:
#!g1.1

for model_name, evaluation_results in model_results.items():
    print(f"Evaluation results for {model_name}")
    for k, v in evaluation_results.items():
        print(f"Fragment length {k}: MAE = {v['eval_mean_absolute_error']}, Correlation={v['eval_correlation_coefficient']}")
    
    for k, v in evaluation_results.items():
        print(f"{k}, {v['eval_mean_absolute_error']}")
        
    print("\n\n\n")
    for k, v in evaluation_results.items():
        print(f"{k}, {v['eval_correlation_coefficient']}")

In [None]:
#!g1.1
import matplotlib.pyplot as plt

for model_name, evaluation_results in model_results.items():
  plt.plot(list(evaluation_results.keys()), [v['eval_mean_absolute_error'] for _,v in evaluation_results.items()])
  plt.title(f"MAE: {model_name}")
  plt.xlabel("Fragment size (words)")
  plt.ylabel("MAE")
  plt.show()
  plt.plot(list(evaluation_results.keys()), [v['eval_correlation_coefficient'] for _,v in evaluation_results.items()])
  plt.title(f"Correlation between prediction and labels: {model_name}")
  plt.xlabel("Fragment size (words) ")
  plt.ylabel("Correlation coefficient")
  plt.show()

In [None]:
#!g1.1
