In [1]:
import pickle
import numpy as np
from sklearn.metrics import mean_squared_error, mean_absolute_error
from scipy.stats import pearsonr

import pandas as pd
import torch


In [2]:
def smape(y_true, y_pred):
    return np.mean(np.abs((y_true - y_pred) / (y_true + y_pred)))

In [106]:
PICKLE_FILE_PATH = "/data/avirinchipur/EMI/datadicts/rift_PCL11_minmax_ans_avg.pkl"
datadict = pickle.load(open(PICKLE_FILE_PATH, "rb"))
print (datadict.keys())

dict_keys(['train_data', 'val_data', 'test_data'])


In [107]:
train_data, val_data = datadict['train_data'], datadict['val_data']
print ("Num train examples: {}".format(len(train_data['labels'])))
print ("Num val examples: {}".format(len(val_data['labels'])))


Num train examples: 122
Num val examples: 19


In [108]:
train_labels = list(map(lambda x: x[0], train_data['labels']))
val_labels = list(map(lambda x: x[0], val_data['labels']))
print ("Train Label distribution")
print ("Min: {}, Max: {}, Avg: {}, Median: {}".format(min(train_labels), max(train_labels), sum(train_labels)/len(train_labels), sorted(train_labels)[len(train_labels)//2]))
print ("Val Label distribution")
print ("Min: {}, Max: {}, Avg: {}, Median: {}".format(min(val_labels), max(val_labels), sum(val_labels)/len(val_labels), sorted(val_labels)[len(val_labels)//2]))

Train Label distribution
Min: 1.224744871391589, Max: 1.8602191717778138, Avg: 1.3581212167652628, Median: 1.311099796200312
Val Label distribution
Min: 1.2377967793203888, Max: 1.72086022041801, Avg: 1.3385628714282165, Median: 1.303390248687618


In [109]:
avg_train_label = np.array([sum(train_labels)/len(train_labels)]*len(val_labels))
median_train_label = np.array([sorted(train_labels)[len(train_labels)//2]]*len(val_labels))
val_labels = np.array(val_labels)

In [110]:
print (avg_train_label[0], median_train_label[0])

1.3581212167652628 1.311099796200312


In [111]:
mse_baseline_avg = mean_squared_error(val_labels, avg_train_label)
mae_baseline_avg = mean_absolute_error(val_labels, avg_train_label)
smape_baseline_avg = smape(val_labels, avg_train_label)

mse_baseline_median = mean_squared_error(val_labels, median_train_label)
mae_baseline_avg = mean_absolute_error(val_labels, median_train_label)
smape_baseline_median = smape(val_labels, median_train_label)

In [112]:
print ("Baseline Avg") 
print ("MSE: {}, MAE: {}, SMAPE: {}".format(round(mse_baseline_avg, 4), round(mae_baseline_avg, 4), round(smape_baseline_avg, 4)))
print ("Baseline Median")
print ("MSE: {}, MAE: {}, SMAPE: {}".format(round(mse_baseline_median, 4), round(mae_baseline_avg, 4), round(smape_baseline_median, 4)))

Baseline Avg
MSE: 0.0126, MAE: 0.0719, SMAPE: 0.0308
Baseline Median
MSE: 0.0129, MAE: 0.0719, SMAPE: 0.0261


In [113]:
val_labels, avg_train_label, median_train_label = ((val_labels/2)**2 - (3/8))*55 + 11, ((avg_train_label/2)**2 - (3/8))*55 + 11, ((median_train_label/2)**2 - (3/8))*55 + 11

mse_baseline_avg = mean_squared_error(val_labels, avg_train_label)
mae_baseline_avg = mean_absolute_error(val_labels, avg_train_label)
smape_baseline_avg = smape(val_labels, avg_train_label)

mse_baseline_median = mean_squared_error(val_labels, median_train_label)
mae_baseline_avg = mean_absolute_error(val_labels, median_train_label)
smape_baseline_median = smape(val_labels, median_train_label)

In [114]:
print ("Un Anscombed")
print ("Baseline Avg") 
print ("MSE: {}, MAE: {}, SMAPE: {}".format(round(mse_baseline_avg, 4), round(mae_baseline_avg, 4), round(smape_baseline_avg, 4)))
print ("Baseline Median")
print ("MSE: {}, MAE: {}, SMAPE: {}".format(round(mse_baseline_median, 4), round(mae_baseline_avg, 4), round(smape_baseline_median, 4)))

Un Anscombed
Baseline Avg
MSE: 19.9262, MAE: 2.7393, SMAPE: 0.0987
Baseline Median
MSE: 20.9791, MAE: 2.7393, SMAPE: 0.0828


In [3]:
PREDS_FILE_PATH = "/data/avirinchipur/EMI/outputs/voicemails/PCL11_minmax_ans_avg_roba128/voicemails_PCL11_minmax_ans_avg/7563578782e54beab8dcc3df3bb212e9/preds.pkl"
preds_dict = pickle.load(open(PREDS_FILE_PATH, "rb"))
preds_dict.keys()

dict_keys(['train', 'val', 'test'])

In [4]:
len(preds_dict['train']['preds'])

848

In [26]:
preds, target = preds_dict['train']['preds'][-10], preds_dict['train']['target'][-10]
# mask = target != 0
# # Calculate Pearson correlation
# pearsonr(preds[mask], target[mask])

In [25]:
preds, target = preds_dict['val']['preds'][-10][0], preds_dict['val']['target'][-10][0]
mask = target != 0
# Calculate Pearson correlation
pearsonr(preds[mask], target[mask])

(-0.19404317750494207, 0.002588534606951888)

In [5]:
timestep_level_preds = {}
epoch=225
for batch_pred, batch_target in zip(preds_dict['val']['preds'][epoch], preds_dict['val']['target'][epoch]):
    batch_mask = ~(batch_target==0)
    for time_step in range(batch_pred.shape[1]):
        if time_step not in timestep_level_preds: timestep_level_preds[time_step] = {'preds': [], 'target': [], 'mask': []}
        timestep_level_preds[time_step]['preds'].extend(batch_pred[:, time_step])
        timestep_level_preds[time_step]['target'].extend(batch_target[:, time_step])
        timestep_level_preds[time_step]['mask'].extend(batch_mask[:, time_step])

In [6]:
timestep_level_metrics = {}
for i in timestep_level_preds:
    preds, target, mask = timestep_level_preds[i]['preds'], timestep_level_preds[i]['target'], timestep_level_preds[i]['mask']
    preds, target, mask = torch.tensor(preds), torch.tensor(target), torch.tensor(mask)
    preds, target = ((preds/2)**2 - (3/8))*55 + 11, ((target/2)**2 - (3/8))*55 + 11
    mse_timestep = torch.sum(torch.square(preds - target)*mask)/torch.sum(mask)
    mae_timestep = torch.sum(torch.abs(preds - target)*mask)/torch.sum(mask)
    smape_timestep = torch.sum(torch.abs(preds - target)/(torch.abs(preds) + torch.abs(target) + 1e-8)*mask)/torch.sum(mask)
    count_timestep = torch.sum(mask)
    timestep_level_metrics[i] = {'mse': mse_timestep.item(), 'mae': mae_timestep.item(), 'smape': smape_timestep.item(), 'count': count_timestep.item()}

In [7]:
pd.DataFrame(timestep_level_metrics).T

Unnamed: 0,mse,mae,smape,count
0,233.162643,12.513783,0.571252,19.0
1,280.968994,14.042586,0.527226,19.0
2,260.24826,12.646894,0.479471,19.0
3,302.833984,13.782494,0.519594,18.0
4,321.125671,14.271898,0.516644,18.0
5,324.427063,15.089183,0.600448,18.0
6,309.886566,13.114624,0.459036,18.0
7,191.289673,10.11242,0.371442,18.0
8,227.627075,10.557377,0.368045,17.0
9,234.442093,10.958411,0.406806,16.0
