In [11]:
from sklearn.svm import SVR
from trainer.embed import load_embedding_data
import numpy as np
from sklearn.metrics import mean_absolute_error, mean_squared_error
from sklearn.metrics import r2_score
from PlanRegr.metrics import relative_error_mean, relative_error
import matplotlib.pyplot as plt

In [2]:
sample_name = "wikidata_0_1_10_v2_path_hybrid"
config = {
    "reg_train_path" : f"/data/{sample_name}/reg_train_sampled.pickle",
"reg_val_path" : f"/data/{sample_name}/reg_val_sampled.pickle",
"reg_test_path" : f"/data/{sample_name}/reg_test_sampled.pickle",
}

In [3]:
def prepare(path):
    ids, embeds, lats = load_embedding_data(path)
    X = np.vstack(embeds)
    y = np.vstack(lats)
    y = y.reshape((-1,))
    return X, y, ids
train_X, train_y, train_ids = prepare(config['reg_train_path'])
val_X, val_y, val_ids = prepare(config['reg_val_path'])
test_X, test_y, test_ids = prepare(config['reg_test_path'])

In [8]:
def print_message(preds, gt):
    r = relative_error_mean(preds,gt)
    print(f"Relative error is {r}, mse: {mean_squared_error(preds, gt)} mae: {mean_absolute_error(preds, gt)}, r2: {r2_score(preds, gt)}")
    pred2 = preds.copy()
    pred2[pred2<0] = 0
    r = relative_error_mean(pred2,gt)
    print(f"Non-negative: Relative error is {r}, mse: {mean_squared_error(pred2, gt)} mae: {mean_absolute_error(pred2, gt)}, r2: {r2_score(preds, gt)}")

In [6]:
svr = SVR(kernel='rbf').fit(train_X,train_y)
train_pred = svr.predict(train_X)
print_message(train_pred, train_y)

Relative error is 25.37749617473213, mse: 14089.387748337567 mae: 13.887717293018664
Non-negative: Relative error is 21.02073199196876, mse: 14089.385449972046 mae: 13.872708990592375


In [9]:
print_message(train_pred, train_y)


Relative error is 25.37749617473213, mse: 14089.387748337567 mae: 13.887717293018664, r2: -181257.6377322743
Non-negative: Relative error is 21.02073199196876, mse: 14089.385449972046 mae: 13.872708990592375, r2: -181257.6377322743


In [14]:
%%capture test_print
list(train_pred)

In [16]:
test_print.show()

[0.002913043927947667,
 -0.07671498753111194,
 -0.05435711315738301,
 0.002913043927947667,
 0.002913043927947667,
 -0.07348821644179493,
 -0.08284007230710166,
 0.002913043927947667,
 0.026981339919663894,
 0.36654704967462926,
 0.33517138954765235,
 0.021315756099080208,
 0.002913043927947667,
 0.03873245949226911,
 -0.0135404033737343,
 0.17729731900136136,
 -0.07999910840595459,
 0.1134080144759011,
 0.38948325823101215,
 -0.06462989084449644,
 -0.05631822638720596,
 -0.013499085400678013,
 0.026981339919663894,
 -0.018683272781275573,
 0.002913043927947667,
 -0.05911219078702845,
 0.030252261773844324,
 0.39891182358658783,
 -0.06462989084449644,
 0.014647703663287492,
 0.24140151140487698,
 0.37876619583662785,
 0.001944116833392684,
 -0.007928657565959174,
 0.04115153011605388,
 -0.07736160608326847,
 -0.02883300562786184,
 -0.07671498753111194,
 0.026981339919663894,
 0.02696320381321926,
 -0.07781747731327471,
 0.0738796724818731,
 0.06958471576106873,
 0.16496999929274114,
 -