In [4]:
import numpy as np
import pandas as pd
import psutil as ps
import os
from keras.models import Sequential, load_model
from keras.layers import Dense, Activation
from keras.optimizers import Adam
from keras.wrappers.scikit_learn import KerasRegressor
from sklearn import linear_model
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, WhiteKernel
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.neighbors import KNeighborsRegressor
from sklearn.model_selection import train_test_split
from Utility import draw_bitmap
import time

myProcess = ps.Process(os.getpid())
nb_feature = 5
nb_mdt = 479
split_ratio = 0.3
df = pd.read_csv('./data/train_sm.csv').values

def load_data():
    train_df, test_df = train_test_split(df, test_size=split_ratio)
    X_train = train_df[:, 2:2+nb_feature]
    y_train = train_df[:, -2]
    X_test = test_df[:, 2:2+nb_feature]
    y_test = test_df[:, -2]
    #print X_train.shape, y_train.shape
    #print X_train[:3], y_train[:3]
    #print X_test.shape, y_test.shape
    #print X_test[:3], y_test[:3]
    return X_train, y_train, X_test, y_test

def load_pixel_data(nb_feature):
    df = pd.read_csv('./data/heatmap_pixel_1m.csv').values
    data = df[:, :nb_feature]     
    return data

# Linear regression
def build_lin():
    model = linear_model.LinearRegression()
    return model

# K-NN
def build_knn():
    model = KNeighborsRegressor(n_neighbors=3, weights='uniform')
    return model

# Random forest
def build_rf():
    model = RandomForestRegressor(n_estimators=100, max_depth=30, max_features='auto')
    return model

# Decision tree
def build_dt():
    model = DecisionTreeRegressor(max_depth=30)
    return model

# Gradient Boosting
def build_gb():    
    model = GradientBoostingRegressor(n_estimators=100, max_depth=3, learning_rate=0.1)
    return model

#Kriging
def build_kg():
    kernel = 1.0 * RBF(length_scale=1.0, length_scale_bounds=(1e-2, 1e3)) \
        + WhiteKernel(noise_level=1e-5, noise_level_bounds=(1e-10, 1e+1))
    model = GaussianProcessRegressor(kernel=kernel, alpha=0.0)
    return model

# Neural Network
def build_nn():
    def create_model():
        model = Sequential()    
        model.add(Dense(512, activation='relu', input_dim = nb_feature)) 
        model.add(Dense(512, activation='relu'))    
        model.add(Dense(512, activation='relu'))     
        model.add(Dense(512, activation='relu'))   
        model.add(Dense(1))     
        model.compile(loss='mean_squared_error', optimizer='adam')
        return model
    model = KerasRegressor(build_fn=create_model, epochs=1000, batch_size=32, shuffle=True, verbose=1)
    return model

def build_model(option):
    fmap = {'lin':build_lin, 'knn':build_knn, 'rf':build_rf, 'dt':build_dt, 'nn':build_nn, 'gb':build_gb, 'kg':build_kg}
    return fmap[option]()

In [10]:
model_option = 'nn'         # lin, knn, rf, dt, gb, nn, kg 
iteration = 100
memory = 0
rmse = 0
err_mean = 0
err_std = 0
err_cdf = [0] * int(round(nb_mdt*split_ratio))    
hmap_z = [0] * (105*27)
hmap_test = load_pixel_data(nb_feature)
train_time = begin_time = end_time = 0

for _ in range(iteration):
    X_train, y_train, X_test, y_test = load_data()
    begin_time = time.time()  
    model = build_model(model_option)    
    model.fit(X_train, y_train)
    end_time = time.time()
    y_pred = model.predict(X_test)
    hmap_pred = model.predict(hmap_test)
    
    testerr = sorted([abs(y_pred[i] - y_test[i]) for i in range(len(y_test))])
    err_mean += np.mean(testerr)       
    err_std  += np.std(testerr)          # first caculate std for each iteration, then avg all stds
    err_cdf  = np.add(err_cdf, testerr)
    rmse += (mean_squared_error(y_test, y_pred))**(0.5)
    train_time += (end_time - begin_time)
    memory += (myProcess.memory_info()[0]/2.**20)   # RSS in MB
    hmap_z = np.add(hmap_z, hmap_pred)
    
np.set_printoptions(suppress=True)    #suppress scientific notation
[err_mean, err_std, rmse, train_time, memory] = np.divide([err_mean, err_std, rmse, train_time, memory], iteration)
err_cdf = np.divide(err_cdf, iteration)
hmap_z  = np.divide(hmap_z, iteration)

print '------' + model_option + '------'
print ('mean:%f \nstd:%f \nrmse:%f \ntime:%f \nmemory:%f') % (err_mean, err_std, rmse, train_time, memory)
print ('total memory:%f') % (myProcess.memory_info()[0]/2.**20)
print list(err_cdf)

------gb------
mean:1.006660 
std:1.241169 
rmse:1.600007 
time:0.033264 
memory:244.019531
total memory:244.019531
[0.0079229204767899825, 0.014536662153290223, 0.02117700054736403, 0.02900987095971928, 0.035422433914060039, 0.042324581894559545, 0.049771748717911068, 0.057030717386956836, 0.065064845657799991, 0.07197289273330966, 0.07835355062644539, 0.085138064969516453, 0.090685558229239974, 0.097208646619388897, 0.1039177677980426, 0.11157626925836013, 0.11763191938366845, 0.124509720511642, 0.1308996477417618, 0.13815432999484656, 0.14509086354214049, 0.15188620830141361, 0.15773524779520856, 0.16371957546576127, 0.17006699231554406, 0.17582674970405038, 0.18200723471296384, 0.18916700985209559, 0.19743540496245146, 0.20515692091290971, 0.21260328801499143, 0.21916659305875527, 0.2268212815818319, 0.23487478862392108, 0.24231318514267217, 0.24897078837585979, 0.25497973398359347, 0.26172128458968758, 0.2683902675714569, 0.27606276572691951, 0.28407050750194218, 0.294340972322960

In [5]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import spline
import random as rd

objects = ['Random\n forest', 'Neural\n network', 'Linear\n regression', 'Decision\n tree', 'Gradient\n boosting', 'K-NN']
means  = [4.191894, 4.368361, 10.152151, 5.151650, 4.947673, 4.354931]  # rf, nn, lin, dt, gb, knn
stds   = [4.084162, 4.010076, 8.428471, 5.276139, 4.478879, 4.254475]
rmse   = [5.858216, 5.941823, 13.198297, 7.385763, 6.678244, 6.094503]
time   = [0.280681, 208.940886, 0.000600, 0.001238, 0.029086, 0.000590]
memory = [223.407500, 1680.747617, 214.925781, 215.144531, 214.632812, 215.431953]
cdf_rf  = [0.030394668539006488, 0.064223997263993762, 0.099004665254000909, 0.13470266620500723, 0.17087466255399661, 0.21299933272199867, 0.25129866727399502, 0.28166599945099252, 0.31578466604199451, 0.35204667130200562, 0.38712332637099661, 0.42527466693399363, 0.45903399946699008, 0.49631866663599938, 0.53537599921700429, 0.57013866376399558, 0.61126132930600263, 0.64574332873800477, 0.68118266726299448, 0.71657466466300268, 0.75331533323800004, 0.79105199919800373, 0.82880933612400187, 0.86721466700000105, 0.90472866603600122, 0.94660066683399979, 0.98350466496400757, 1.0235819984699959, 1.0575293304279993, 1.0935633315300033, 1.1402753373439949, 1.1779953317090066, 1.2137660010149989, 1.2541779982920096, 1.2915233307210019, 1.3274799933329819, 1.3738826671500066, 1.4167179986200029, 1.4527640034359997, 1.4935399986000084, 1.5362953341120125, 1.5767106652539959, 1.613603336580004, 1.6534853336799986, 1.6995726678570173, 1.7373833361580047, 1.7830806651199971, 1.8279306686780057, 1.86786400174101, 1.9091933343130123, 1.9566353328650086, 2.0006586700120006, 2.0456073349010064, 2.0937060015749869, 2.1341546687600021, 2.1718726676440023, 2.2207793359049988, 2.2708726691130035, 2.312041335727014, 2.3648273321749946, 2.4104813307410069, 2.4664846697190148, 2.5224766666879992, 2.5732473344429985, 2.6313953338480069, 2.6693240003030008, 2.7214593341789977, 2.7666706697879988, 2.8232100015880035, 2.878996671965997, 2.9470700012310034, 3.0030813325590056, 3.0703360021889989, 3.1272079981640051, 3.1856966701620069, 3.234466002125997, 3.2919000003019971, 3.3516813338189957, 3.4060580017709947, 3.4692979965520068, 3.5300020004630004, 3.5853193332030098, 3.6627706666060091, 3.7373806659969864, 3.7970013343469997, 3.8740233321029951, 3.9352906661969809, 3.9933566654559933, 4.053313329983002, 4.1364093313980019, 4.2095726677700132, 4.28878600231499, 4.3590213304350014, 4.4324233355089993, 4.5040326632640033, 4.596373332420999, 4.6841426651959992, 4.778245997804996, 4.8492879996460099, 4.9289819980720049, 5.0076213345529963, 5.0924339979329956, 5.1854706657019998, 5.2848760005589908, 5.357465997197, 5.447188661179994, 5.532090664059, 5.6343826654649991, 5.7280733318899877, 5.8470193311810057, 5.9541186653330058, 6.054118666589992, 6.1739899982990005, 6.3025206681170012, 6.4343186670020005, 6.5628993313419963, 6.7063453308320051, 6.8489719963619917, 7.0020686620950139, 7.1597953324549932, 7.3357746667709964, 7.5137993310259938, 7.7101060002139992, 7.8864293314809917, 8.0583059993309991, 8.2230726674760106, 8.4426099968609929, 8.6436286680880023, 8.914990002243, 9.2124106703649957, 9.4764066656360022, 9.8028639993900004, 10.142022666217004, 10.504326007641009, 10.873413337610005, 11.282387339999021, 11.828559337214999, 12.465029341244003, 13.284773335373004, 14.368998671504015, 15.637563999326003, 17.431640009419013, 20.041804006819014, 24.140507341608021]
cdf_nn  = [0.036539720629102275, 0.08003805621875032, 0.12387777649843798, 0.16356288922382717, 0.20332944070761713, 0.24575970373925743, 0.2841710829935542, 0.32512407911972646, 0.36330652300644628, 0.40473119446718797, 0.44841117678769576, 0.48980858919707049, 0.54105364004707057, 0.59100128383945416, 0.64243337207167983, 0.68049060864004007, 0.72607050415957053, 0.76792588088271441, 0.8165836640416011, 0.86908615982890536, 0.90646393467675845, 0.94807262703359385, 0.99299900292460963, 1.033817041516796, 1.0802286513282224, 1.1302882428771486, 1.174751238413672, 1.2118139388296878, 1.2544706343345708, 1.303314328827637, 1.338524390297851, 1.3764882828005864, 1.4243109107859369, 1.4710766076763679, 1.5108779048228516, 1.5506285274308595, 1.5945303776478519, 1.6368362789957041, 1.6717199452373053, 1.7136670649791013, 1.7522716531205083, 1.7922116888335926, 1.8329892452066394, 1.8712070683710942, 1.9194940202064459, 1.960601702179688, 2.0056764799519531, 2.0435672728986325, 2.0889220357902345, 2.1327163010384771, 2.17956952379707, 2.2308715974394531, 2.2715464604621105, 2.3221287409539073, 2.3746759979753906, 2.4242765545458975, 2.4756584346091786, 2.5241752765087897, 2.5814574851630856, 2.6351853117378896, 2.6923954245138675, 2.7364488577414061, 2.7956255940174812, 2.8547313361009752, 2.89957862428711, 2.9555565686851559, 3.007969532349998, 3.0624752190265614, 3.1267247997863272, 3.1670599570339859, 3.2236251854572253, 3.2907594622503895, 3.3426224915562495, 3.4055437324400373, 3.4554280187367197, 3.520410027276172, 3.5897346872322267, 3.6518068129246082, 3.7245448400775403, 3.7874719288328098, 3.8583036204859389, 3.9243900972964831, 3.9873317665470696, 4.0612497249912094, 4.1242036312917971, 4.1877317043601598, 4.2549068568492201, 4.3269064181972654, 4.3812383384310536, 4.4529360268527318, 4.5146228025849586, 4.5873627779593766, 4.6568816056062481, 4.7390890390871112, 4.8087542264351564, 4.9002316545837905, 4.9638603719248033, 5.0400909134169902, 5.1210163525689474, 5.201841208937501, 5.2866425381015638, 5.3798026687869154, 5.4626522865160174, 5.5402221448558597, 5.614383620615822, 5.7184178345479495, 5.8251604107964843, 5.9038617331570347, 6.008153287541993, 6.0928553934054674, 6.2029071154992188, 6.2972449552455041, 6.4245546571330081, 6.5501666280283199, 6.6639591586757829, 6.7747877508947285, 6.8846992726804714, 7.0395004387037101, 7.1705794920380841, 7.3130472970437506, 7.4655867785816374, 7.6405948023564427, 7.7973661093306577, 7.9525213491447229, 8.1207858990765622, 8.3324320472087887, 8.5429557959374964, 8.7982841928945277, 9.0391061937177692, 9.3150462481378842, 9.5777483287378917, 9.8283300893814456, 10.147772088221872, 10.450451306752345, 10.830854314034958, 11.257000890036812, 11.730325631260939, 12.342048790950983, 13.018453225078618, 13.999347563793753, 15.333997368934172, 16.905939066301169, 19.754916110024809, 23.777640603482411]
cdf_lin = [0.11573903750635452, 0.24307314420224016, 0.3688469834183577, 0.47477275840821664, 0.59058717094690294, 0.69644527752940388, 0.82255627539386467, 0.93732425690684773, 1.0653734281920213, 1.186165543942681, 1.2969355503120412, 1.4032198920040058, 1.5194102588916976, 1.6329114032356478, 1.7506707595871824, 1.8514952574215864, 1.9394167343593995, 2.056090757586158, 2.1360345261273896, 2.2223029450296234, 2.3010552065384005, 2.4008688844644621, 2.4950739248849887, 2.5981556304873767, 2.6902199473837105, 2.7747446534757345, 2.8679657708812876, 2.942279771107418, 3.0255991872801253, 3.1244094772721009, 3.2030977369867184, 3.2786199085301062, 3.3603700867162036, 3.4474951430118699, 3.5360483576510808, 3.6287001727102268, 3.7051411097611022, 3.7949563390631367, 3.8920012783096065, 3.9888588075643145, 4.0929987248832989, 4.2003569533097771, 4.2964001705626993, 4.3805286044964866, 4.4941811083143159, 4.6276086852425253, 4.7362965642258397, 4.850168033679032, 4.9508496574107896, 5.0638087641028013, 5.1762221260271586, 5.2842030587683135, 5.4045905315342067, 5.5100304496509436, 5.6055263885368571, 5.7276800263467429, 5.8268381582604993, 5.9304541262613553, 6.0486188556964331, 6.1630614951532303, 6.2775145942267532, 6.4004080771679863, 6.5191414686263691, 6.6564656309129244, 6.7754123106511095, 6.8862146860223961, 7.0356195473910184, 7.1482092660075889, 7.2799660570377727, 7.387191289265445, 7.533975339804595, 7.6649529140249841, 7.7768717386507369, 7.9083933104548985, 8.0605456048730488, 8.2083253191073613, 8.3581047192790585, 8.5140658937978078, 8.6343963960983725, 8.7981117308881345, 8.9289234350463627, 9.0782996349147886, 9.2136585817802388, 9.3543701013959133, 9.4983266702983524, 9.6583918512835645, 9.8184973071503983, 10.000245971955797, 10.140654285357828, 10.27694062134838, 10.434758846279344, 10.606801853865399, 10.77332693994342, 10.960855822516343, 11.139365463739152, 11.289769864831182, 11.464549725634152, 11.628721341816485, 11.795750830353949, 11.987073634158405, 12.209927866214329, 12.394565278440497, 12.623297422241583, 12.921088351117941, 13.174470926006268, 13.392835164215404, 13.758473674700836, 14.053879088565274, 14.390266273041197, 14.764926053975401, 15.073277542472919, 15.450822664485537, 15.8075454411495, 16.169142935730857, 16.539328325502559, 16.979826488799628, 17.351698393938999, 17.736762957098264, 18.187331795499151, 18.549159532678338, 18.981487558451292, 19.413886624205354, 19.833884859232363, 20.242765345863869, 20.60663158106027, 20.993530196870253, 21.391764472064509, 21.871571774672812, 22.370997370035102, 22.818277329333085, 23.430221969023975, 24.076272130182847, 24.714861948435342, 25.434256898596189, 26.327127102319636, 26.920161237387614, 27.720756378932979, 28.408521695321269, 29.064021273102249, 29.820636771399027, 31.327069342235273, 32.976118735978929, 35.628968996024611, 38.398648633214698]
cdf_dt  = [0.048133326499998789, 0.11186665719999553, 0.17266666269999745, 0.21573333410000062, 0.27366666269999923, 0.3160666574999999, 0.36746666239999937, 0.41826665839999966, 0.46099999320000024, 0.50779999730000036, 0.55626666150000004, 0.59986666360000063, 0.64606666400000079, 0.69333332989999941, 0.74113332609999916, 0.78113333440000121, 0.8355333339000004, 0.8824666624999995, 0.9244666648999994, 0.96439999970000168, 1.0067333309999993, 1.0551333282999982, 1.0993333304000006, 1.1403333316000002, 1.1902666641999993, 1.239266668700002, 1.2840666621999992, 1.3205999971000011, 1.3695333279999999, 1.4096666721000028, 1.4577333369000025, 1.4959333348000023, 1.5384666663000004, 1.5709333356000017, 1.6147999989000004, 1.6508000047000035, 1.6907333334000014, 1.7314666699000008, 1.7651333407000032, 1.8089999953999989, 1.8461333372000006, 1.8959999995000019, 1.9426000055000003, 1.9855333349000017, 2.0334666668000017, 2.0800000009000015, 2.1303333370000015, 2.1700666684000027, 2.2163333356000026, 2.2628000068000023, 2.3060000018000024, 2.3586000029000007, 2.4156666703000034, 2.4786000051000001, 2.5330000046999999, 2.5830666668000002, 2.6494666703000003, 2.7056000005000005, 2.7483333380000015, 2.7990666703000007, 2.8504666699000007, 2.9043333402000018, 2.9549333389000001, 3.0056000033999997, 3.0662000006000003, 3.1265999991999989, 3.185533334900001, 3.2489333334999992, 3.3059333313000012, 3.3611333314999978, 3.4113999930999985, 3.4641333342000014, 3.5407333279999973, 3.6065999983000006, 3.6762000007000015, 3.7419999956999992, 3.8005333295999986, 3.8739333330999979, 3.9553333260999968, 4.0293999937999985, 4.109599993999999, 4.2021999964999992, 4.2673333273999985, 4.3480666669000003, 4.4570666638000001, 4.548933331799998, 4.6505999952999977, 4.7346666670999973, 4.836066663099996, 4.9339333322999996, 5.0051999958999955, 5.1087999963999993, 5.2109333346999973, 5.3077999990000002, 5.421533325299996, 5.4991333322999933, 5.5867333299999959, 5.6928666695000008, 5.7804666622999967, 5.8860000044999969, 6.0044666639999953, 6.1019333350999965, 6.2286666651999978, 6.3569333376999975, 6.4757333359999993, 6.6092000016999926, 6.7150666667999985, 6.8214666650999991, 6.9433333315999937, 7.0898666627999969, 7.2152666618999977, 7.3321999970999991, 7.4542666639000004, 7.5592666667999993, 7.6847999965999962, 7.8221999993999987, 7.9672666630000002, 8.1644000023000007, 8.3467999970999998, 8.520933335099997, 8.6883333287000042, 8.8778000009999989, 9.1150666696000027, 9.3329333334000015, 9.6255333344999983, 9.903200002200002, 10.204466666399998, 10.532133332000003, 10.862266667000002, 11.282133331599999, 11.7318666675, 12.169066670699998, 12.6550666711, 13.145933338599997, 13.619733330200006, 14.342200004100004, 15.121266669300006, 16.35386666929999, 17.478400003200004, 18.910466671999995, 20.465666666799994, 23.114533339899996, 26.3592000031, 31.736666676299997]
cdf_gb  = [0.046194996724912014, 0.10062134870372802, 0.14987639790219817, 0.1975327063936706, 0.23832212043136636, 0.28144640447846297, 0.32600158032682486, 0.37473430055454687, 0.42184333510978889, 0.46825138694756491, 0.51391506222121375, 0.55229428250354473, 0.60763602132172867, 0.66875004683451866, 0.7098862993018612, 0.75594085271893063, 0.80326299108420929, 0.84915029913155959, 0.89400189371489813, 0.93935942313646148, 0.9846412316728923, 1.0324208614226691, 1.0827927594663691, 1.1256691455864525, 1.1727042886380823, 1.2158670685213608, 1.2650633578055162, 1.3111541927627939, 1.3511326987852368, 1.3996267791658723, 1.4487472635623937, 1.4984324762149357, 1.5373265640358893, 1.5868337135278332, 1.6330923852991135, 1.686406205755709, 1.7340976227902212, 1.7858556786314141, 1.834232875563357, 1.8913612575096175, 1.9374796462849557, 1.9950052803942671, 2.0395677799842127, 2.0926991904046379, 2.1390971474338749, 2.1879691073101246, 2.2370121412706125, 2.3044987466500171, 2.3627841199349593, 2.4205163269149348, 2.4764838192202165, 2.5334334161484118, 2.599237780799537, 2.6600180996196423, 2.7106089206198534, 2.7579312170230166, 2.8113169035414303, 2.8687324575626865, 2.920276180808512, 2.9805127745796756, 3.0461011459867775, 3.1020402423558888, 3.1634692201968573, 3.2325933506366038, 3.2946603687091112, 3.3612004655693921, 3.4215375033435032, 3.4882154489963177, 3.5571750206726178, 3.6237569454060257, 3.6926273551665658, 3.7605787550480545, 3.8286467783380611, 3.8937149789355954, 3.9682434102439199, 4.0259143340911718, 4.0840911405024904, 4.147372234649648, 4.2055088869772241, 4.2801792963200613, 4.3434018797434302, 4.4167335263137284, 4.4773612663338938, 4.5385470123391638, 4.6172240564544804, 4.6988656789220205, 4.7821267679099986, 4.8549744835888804, 4.9292971824062901, 5.0085083750876098, 5.0803253425596751, 5.1606820426895217, 5.2416049945308192, 5.3324005517800739, 5.42155577092639, 5.4979181453044408, 5.5808954342726125, 5.6680059481963712, 5.7547558338339506, 5.8464918846329734, 5.9393598228534419, 6.0258538780281672, 6.1300652370665514, 6.2400763200531992, 6.3359258324795436, 6.4534026727777025, 6.5583995498388097, 6.660235072931064, 6.7544494506444721, 6.8739114561979351, 6.9896749185102003, 7.1014961292959837, 7.2325904856018601, 7.3808347118422128, 7.5069001134648792, 7.661134679484408, 7.8335250373333363, 8.0014423032656374, 8.1630929868589206, 8.3187616891098557, 8.4974372258288557, 8.7264314893546029, 8.9368764338031319, 9.1558149923470165, 9.3756083121029654, 9.6470837878521571, 9.9035511058155183, 10.171246313894848, 10.473341164849888, 10.811698066457488, 11.186713113558286, 11.512755101366409, 11.90949220367736, 12.382336410840072, 12.810088900360329, 13.242102468277423, 13.804180595751776, 14.355573840325071, 15.107828102604628, 16.116751491054245, 17.211120117575277, 19.110768103660394, 21.223976505576044, 24.690004110521045]
cdf_knn = [0.047955557466668922, 0.090866667000001233, 0.13244444523333571, 0.16851111269999805, 0.20673333776666922, 0.24106666539999794, 0.26837777866666729, 0.30255555626666691, 0.33811110816666601, 0.37339999780000055, 0.40808888516666625, 0.44375555386666721, 0.47351111263333351, 0.50844444366666597, 0.55162222223333413, 0.59184444703333372, 0.63251111753333433, 0.67600000423333473, 0.72364444576666576, 0.76880000039999841, 0.80973333583333429, 0.8487999981333314, 0.89497778013333407, 0.93751111376666363, 0.98066666776666511, 1.0183555564666646, 1.0585111052666627, 1.0990444413999965, 1.1355111012333279, 1.1719999999666653, 1.2139333300999984, 1.2583111089999981, 1.3030666657999956, 1.3393111109666638, 1.382511107799997, 1.4279111108999973, 1.4644666659999996, 1.5015111087666642, 1.5448444419333316, 1.5902222206333319, 1.6333555570999989, 1.6785333318333318, 1.7229333305333312, 1.7629111132333346, 1.8090222197999992, 1.8541111119000016, 1.9029777756333317, 1.9440888866333323, 1.9885111114666671, 2.0403777766333309, 2.085066666933336, 2.1325111064666671, 2.1745555554000013, 2.2194666663999993, 2.2711999962666654, 2.3132888848333351, 2.3555777756000023, 2.4079777770999997, 2.4615111079666674, 2.5128444444333327, 2.572111111066667, 2.6228444437333343, 2.6778000023333353, 2.7206888879000006, 2.7684000000999989, 2.8125111133333349, 2.8704444487333318, 2.9256666656333334, 2.987199999933333, 3.0343555557000017, 3.0885555573333341, 3.1393777780333325, 3.2002222242333356, 3.2668000020666672, 3.3294888874999979, 3.3983333340999993, 3.4639999998999982, 3.530999997033331, 3.5888444406666644, 3.6492666640999971, 3.7084222162999971, 3.7600666678000003, 3.8183777765666651, 3.8773333271999979, 3.9286888912666642, 3.9896666693333329, 4.0518444382666647, 4.1083999973333309, 4.1776888842999975, 4.2340000000333342, 4.3032666658333332, 4.3701333318333306, 4.4457777745333349, 4.5263111126666651, 4.5976666671666662, 4.669844443999998, 4.7657333323333324, 4.8374222218333323, 4.9260888872666664, 5.0134444472333319, 5.1181555532999967, 5.2026444469999982, 5.3119777771000001, 5.4031333310999994, 5.4954000003333316, 5.5990666642333338, 5.6956000007333296, 5.7984222187333305, 5.8963555543000021, 5.9861111101999969, 6.1032222189666676, 6.2101999995333346, 6.3454888861666641, 6.4653999966666662, 6.6140666655000011, 6.7567111097999986, 6.8741777770999999, 7.0164222225333353, 7.2075555522666654, 7.3578666663666672, 7.5350222226333328, 7.6826666664666661, 7.8449333358000057, 8.0639111115000013, 8.2438000016000075, 8.4913777796333303, 8.7236888895333351, 8.9646222250333309, 9.2654222216333366, 9.5378666713333384, 9.7932444445666658, 10.075533334533334, 10.409088889866661, 10.825666668233334, 11.267066669600004, 11.77066667003333, 12.410311115799994, 13.033311116566667, 13.701777783699999, 14.734666674666675, 16.071866670633334, 18.291200006966665, 21.3426666758, 25.613422239066679]

# CDF
plt.figure(rd.randint(0, 100000))
N = len(cdf_rf) 
yvals = np.array(range(N))/ float(N)
plt.plot(cdf_rf,  yvals, linestyle='--', label='Random forest')
plt.plot(cdf_nn,  yvals, linestyle=':' , label='Neural network')
plt.plot(cdf_lin, yvals, linestyle='-.', label='Linear regression')
plt.plot(cdf_dt,  yvals, linestyle='-' , label='Decision tree')
plt.plot(cdf_gb,  yvals, linestyle='--', label='Gradient boosting')
plt.plot(cdf_knn, yvals, linestyle=':' , label='K-NN')
plt.legend(loc='lower right')
plt.xlabel('Prediction Error (dBm)')
plt.ylabel('CDF')
plt.savefig('cdf', dpi=200)
plt.show()

# Error Bar
plt.figure(rd.randint(0, 100000))
idx = range(len(objects))
fig, ax = plt.subplots()
ax.bar(idx, means, yerr=stds, alpha=0.7, color=['r', 'b', 'g', 'k', 'm', 'c'], error_kw=dict(ecolor='gray', lw=1.5, capsize=5, capthick=2))
ax.set_ylabel('Prediction Error (dBm)')
ax.set_xticks(idx)
xtickNames = ax.set_xticklabels(objects)
plt.setp(xtickNames, rotation=45, fontsize=10)
plt.savefig('errorbar', dpi=200, bbox_inches="tight")
plt.show()

# RMSE
plt.figure(rd.randint(0, 100000))
plt.bar(idx, rmse, alpha=0.7, color=['r', 'b', 'g', 'k', 'm', 'c'])
plt.xticks(idx, objects)
for i, j in zip(idx, rmse):
    plt.text(i-0.45, j+0.15, str(j))
plt.ylabel('RMSE (dBm)')
plt.savefig('rmse', dpi=200)
plt.show()

In [3]:
import matplotlib.pyplot as plt
import random as rd

plt.figure(rd.randint(0, 100000))
x_resolution, y_resolution = 104, 26
xaxis = np.linspace(0., x_resolution, x_resolution+1)   
yaxis = np.linspace(0., y_resolution, y_resolution+1) 
x, y = np.meshgrid(xaxis, yaxis)
z = np.reshape(hmap_z, (y_resolution+1, x_resolution+1)) 
plt.contourf(x, y, z, 500, cmap='jet')                             
plt.colorbar() 
plt.savefig('heapmap', dpi=200)
plt.show()

In [3]:
draw_bitmap(X_train)