# Load libraries

In [1]:
# load pandas and numpy
import pandas as pd
import numpy as np
import os
import sys
sys.path.insert(1, '..')
os.chdir('..')

# Read and create table for no-covariates models

In [6]:
# read in the txt files with the results
models = ['nhits', 'tft', 'linreg', 'xgboost', 'transformer']
datasets = ['weinstock', 'dubosson', 'colas', 'iglu', 'hall']
results = {d:{} for d in [d + '_ID_RMSE' for d in datasets] + 
           [d + '_ID_MAE' for d in datasets] + 
           [d + '_OOD_RMSE' for d in datasets] + 
           [d + '_OOD_MAE' for d in datasets]}
for model in models:
    for dataset in datasets:
        # read txt file
        # find line starting with the 'Key: median RS ID (MSE, MAE) stats'
        # in that line read in values after "'mean': [array([" and before "'])], 'std'"
        # save them in a dictionary
        with open(f'output/{model}_{dataset}.txt', 'r') as f:
            for line in f:
                if line[:34] == 'Key: median RS ID (MSE, MAE) stats':
                    # split line
                    line = line.split("'mean': [array([")[1]
                    line = line.split("])], 'std'")[0]
                    # read in 2 numbers 
                    line = line.split(', ')
                    results[dataset + '_ID_RMSE'][model] = np.sqrt(float(line[0]))
                    results[dataset + '_ID_MAE'][model] = float(line[1])
                if line[:35] == 'Key: median RS OOD (MSE, MAE) stats':
                    # split line
                    line = line.split("'mean': [array([")[1]
                    line = line.split("])], 'std'")[0]
                    # read in 2 numbers 
                    line = line.split(', ')
                    results[dataset + '_OOD_RMSE'][model] = np.sqrt(float(line[0]))
                    results[dataset + '_OOD_MAE'][model] = float(line[1])
                if line[:29] == 'RS ID (MSE, MAE) errors stats':
                    # split line
                    line = line.split("'median': array([[")[1]
                    line = line.split("]]), 'min'")[0]
                    # read in 2 numbers 
                    line = line.split(', ')
                    results[dataset + '_ID_RMSE'][model] = np.sqrt(float(line[0]))
                    results[dataset + '_ID_MAE'][model] = float(line[1])
                if line[:30] == 'RS OOD (MSE, MAE) errors stats':
                    # split line
                    line = line.split("'median': array([[")[1]
                    line = line.split("]]), 'min'")[0]
                    # read in 2 numbers 
                    line = line.split(', ')
                    results[dataset + '_OOD_RMSE'][model] = np.sqrt(float(line[0]))
                    results[dataset + '_OOD_MAE'][model] = float(line[1])
results = pd.DataFrame(results)

In [7]:
# select all columns with Hall data
results_hall = results[[c for c in results.columns if 'hall' in c]]
results_hall

Unnamed: 0,hall_ID_RMSE,hall_ID_MAE,hall_OOD_RMSE,hall_OOD_MAE
nhits,6.617095,5.794658,8.327687,7.162722
tft,10.698462,8.69772,11.623917,9.502617
linreg,6.222508,5.236052,7.319772,6.202425
xgboost,6.027002,5.178988,7.658983,6.694784
transformer,6.695598,5.670795,8.222221,6.790585


# Read and create table for with-covariates models

In [2]:
# read in the txt files with the results
models = ['nhits_covariates', 'tft_covariates', 'linreg_covariates', 'xgboost_covariates', 'transformer_covariates']
datasets = ['weinstock', 'dubosson', 'colas', 'iglu', 'hall']
results = {d:{} for d in [d + '_ID_RMSE' for d in datasets] + 
           [d + '_ID_MAE' for d in datasets] + 
           [d + '_OOD_RMSE' for d in datasets] + 
           [d + '_OOD_MAE' for d in datasets]}
for model in models:
    for dataset in datasets:
        # check if files exist
        if not os.path.exists(f'output/{model}_{dataset}.txt'):
            continue
        # read txt file
        # find line and the values, then save them in a dictionary
        with open(f'output/{model}_{dataset}.txt', 'r') as f:
            for line in f:
                if line[:34] == 'Key: median RS ID (MSE, MAE) stats':
                    # split line
                    line = line.split("'mean': [array([")[1]
                    line = line.split("])], 'std'")[0]
                    # read in 2 numbers 
                    line = line.split(', ')
                    results[dataset + '_ID_RMSE'][model] = np.sqrt(float(line[0]))
                    results[dataset + '_ID_MAE'][model] = float(line[1])
                if line[:35] == 'Key: median RS OOD (MSE, MAE) stats':
                    # split line
                    line = line.split("'mean': [array([")[1]
                    line = line.split("])], 'std'")[0]
                    # read in 2 numbers 
                    line = line.split(', ')
                    results[dataset + '_OOD_RMSE'][model] = np.sqrt(float(line[0]))
                    results[dataset + '_OOD_MAE'][model] = float(line[1])
                if line[:29] == 'RS ID (MSE, MAE) errors stats':
                    # split line
                    line = line.split("'median': array([[")[1]
                    line = line.split("]]), 'min'")[0]
                    # read in 2 numbers 
                    line = line.split(', ')
                    results[dataset + '_ID_RMSE'][model] = np.sqrt(float(line[0]))
                    results[dataset + '_ID_MAE'][model] = float(line[1])
                if line[:30] == 'RS OOD (MSE, MAE) errors stats':
                    # split line
                    line = line.split("'median': array([[")[1]
                    line = line.split("]]), 'min'")[0]
                    # read in 2 numbers 
                    line = line.split(', ')
                    results[dataset + '_OOD_RMSE'][model] = np.sqrt(float(line[0]))
                    results[dataset + '_OOD_MAE'][model] = float(line[1])
results = pd.DataFrame(results)

In [8]:
# select all columns with Hall data
results_weinstock = results[[c for c in results.columns if 'weinstock' in c]]
results_weinstock

Unnamed: 0,weinstock_ID_RMSE,weinstock_ID_MAE,weinstock_OOD_RMSE,weinstock_OOD_MAE
nhits,14.982595,12.686479,15.338324,13.015024
tft,27.453884,20.490143,27.409369,20.433597
linreg,15.475566,13.032758,15.370515,12.938196
xgboost,15.748382,13.423946,15.744156,13.442318
transformer,14.382459,12.151109,14.257016,11.975372


In [24]:
float(' 106 ')

106.0