In [248]:
# Useful starting lines
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [267]:
# Import files to use in preprocessing and machine learning
from implementations import *
from proj1_helpers import *
from preprocess import *
from cross_validation import *
from helpers import *
from costs import *

## Load the training data into feature matrix, class labels, and event ids:

In [250]:
# Download train data and supply path here 
DATA_TRAIN_PATH = '../data/train.csv' 
y, tX, ids = load_csv_data(DATA_TRAIN_PATH)

In [251]:
# Check the array shape of y, tX, and ids
print(y.shape)
print(tX.shape)
print(ids.shape)

(250000,)
(250000, 30)
(250000,)


# Initial Data Analysis

In observing the original training data, we found out that there exists missing data all over tX. The missing data are represented as value -999. Considering that these columns are critical in model training, we cannot simply delete these rows with -999 values. Therefore, we need to process the original training set before model training.

Firstly, we check the columns of tX to obtain an overview of missing data:

In [None]:
# Check whether the missing values are associated with the classification result
for col in range(tX.shape[1]):
    tX_T = np.transpose(tX)
    
    null = (tX_T[col] == -999)
    null_s = np.logical_and(y >= 0, null)
    null_b = np.logical_and(y < 0, null)
    
    tX_null = tX[null]
    tX_null_s = tX[null_s]
    tX_null_b = tX[null_b]
    
    if (tX_null.shape[0] > 0):
        # Print the percentage of column 'col' having a -999 (missing) value
        print('Column', col, 'has {}% percentage of missing values'.format(tX_null.shape[0] * 100 / tX.shape[0]))

        # Print the conditional probability of P(y = 1|x having -999)
        print('P(y = 1|x having -999) = {:.3f}%'.format(tX_null_s.shape[0] * 100 / tX_null.shape[0]))
        
        # Print the conditional probability of P(y = -1|x having -999)
        print('P(y = -1|x having -999) = {:.3f}% \n'.format(tX_null_b.shape[0] * 100 / tX_null.shape[0]))

We can see that 11 columns contains at least one -999 (missing value). Now we check whether some of the missing values are dependent on the column named 'PRI_jet_num' (column No. 23), since 'PRI_jet_num' has a discrete value range {0, 1, 2, 3} and our observation on the beginning data rows showed a dependency of some missing values to the value of 'PRI_jet_num' column.

In [None]:
PRI_jet_range = [i for i in range(0, 4)]
PRI_jet_sum = []
PRI_jet_null = []

for value in PRI_jet_range:
    tX_PRI = tX[tX[:, 22] == value]
    
    # Append values of row numbers for different PRI_jet_num, finally sum up to see whether it equals to the length of tX
    PRI_jet_sum.append(len(tX_PRI))
    
    # Count the number of missing columns corresponding to different PRI_jet_num values
    PRI_jet_keys = []
    for i in range (len(tX_PRI)):
        tX_null_cols = np.count_nonzero(tX_PRI[i] == -999, axis = 0)
        PRI_jet_keys.append(tX_null_cols)
    
    PRI_jet_null.append(list(set(PRI_jet_keys)))

    
print("Sum of rows for different PRI_jet_num: {} \n".format(sum(PRI_jet_sum)))

for i in range(4):
    print("PRI_jet_num =", PRI_jet_range[i], "No. of columns having -999 (a missing value):{}".format(PRI_jet_null[i]))

The above analysis showed that one column with -999 (missing value) is independent of the column 'PRI_jet_num', we check the original training set and we can easily find out that the first tX column 'DER_mass_MMC' is independent of 'PRI_jet_num'. 

# Data Preprocessing

Based on the data analysis above, we conduct the following method to pre-process the training data. 

In [252]:
METHOD = 'median'
MODE = 'std'

In [253]:
TX, Y, r_ids = data_preprocess(tX, y, ids, replacing=METHOD, mode=MODE)

print(TX[0].shape)
print(TX[1].shape)
print(TX[2].shape)
print(TX[3].shape)

(99913, 18)
(77544, 22)
(50379, 29)
(22164, 29)


In [5]:
# (Not used)
# Split the database based on 'PRI_jet_num' (0, 1, 2&3)
tX_jet_0, tX_jet_1, tX_jet_2, y_jet_0, y_jet_1, y_jet_2, r_ids = split_reformat_data(tX, y, ids)

#print(tX_jet_0.shape)
#print(tX_jet_1.shape)
#print(tX_jet_2.shape)

In [None]:
# (Not used)
# Replacing the missing values of first feature column
# 'median', 'mean', or 'lr' (linear regression)
replace_missing_value(tX_jet_0, 0, 'lr')
replace_missing_value(tX_jet_1, 0, 'lr')
replace_missing_value(tX_jet_2, 0, 'lr')

In [None]:
# (Not used)
# (Use the above block or this block)
# Replacing the missing values by k-means clustering
k_means_replacing(tX_jet_0)
k_means_replacing(tX_jet_1)
k_means_replacing(tX_jet_23)

print((tX_jet_0[tX_jet_0[:, 0]== -999][:,0]).shape)
print((tX_jet_1[tX_jet_1[:, 0]== -999][:,0]).shape)
print((tX_jet_23[tX_jet_23[:, 0]== -999][:,0]).shape)

In [7]:
# (Not used)
# Standardize and/or normalize the splitted training data
tx_0 = std_norm_preprocess(tx_jet_0, 'std_norm')
tx_1 = std_norm_preprocess(tx_jet_1, 'std_norm')
tx_2 = std_norm_preprocess(tx_jet_2, 'std_norm')

TX = [tx_0, tx_1, tx_2]
Y = [y_jet_0, y_jet_1, y_jet_2]

# Model Training

## Least Squares

### Linear Regression with Gradient Descent

In [110]:
# Set H-parameters
GAMMA_LRGD = 0.0005
MAX_ITERS_LRGD = 50

In [115]:
i =3

In [116]:
INITIAL_W = np.zeros((TX[i].shape[1]+1, 1))
w_lrgd, loss_lrgd = least_squares_GD(Y[i], TX[i], INITIAL_W, MAX_ITERS_LRGD, GAMMA_LRGD)
    
print('Loss of tx_{},'.format(i), loss_lrgd, 'iterations: {}'.format(MAX_ITERS_LRGD))


KeyboardInterrupt: 

### Linear Regression with Stochastic Gradient Descent

In [105]:
# Set H-parameters
GAMMA_LRSGD = 0.0005
MAX_ITERS_LRSGD = 2000
BATCH_SIZE = 64

In [107]:
WEIGHT_LRSGD = []
LOSS_LRSGD = []

for i in range(len(TX)):
    print('Parameters for training:\n')
    print('Gamma = {}'.format(GAMMA_LRSGD))
    print('Iterations = {}'.format(MAX_ITERS_LRSGD))
    print('Batch size = {}'.format(BATCH_SIZE))

    # Fine-tuning the data for each splitted training set
    INITIAL_W = np.zeros((TX[i].shape[1]+1, 1))
    w_lrsgd, loss_lrsgd = least_squares_SGD(Y[i], TX[i], INITIAL_W, MAX_ITERS_LRSGD, 
                                          GAMMA_LRSGD, BATCH_SIZE, num_batch = 50)
    print(loss_lrsgd)
    LOSS_LRSGD.append(loss_lrsgd)
    WEIGHT_LRSGD.append(w_lrsgd)

################### 0
Parameters for training:

Gamma = 0.0005
Iterations = 2000
Batch size = 64
loss=0.4885398008220875, iteration=0
loss=0.47344628225752466, iteration=1
loss=0.4560795417810354, iteration=2
loss=0.44111168923248134, iteration=3
loss=0.42494361812940823, iteration=4
loss=0.4136109822833537, iteration=5
loss=0.40050370747505626, iteration=6
loss=0.41894431651809283, iteration=7
loss=0.387054838257324, iteration=8
loss=0.428696352657047, iteration=9
loss=0.39915228621792176, iteration=10
loss=0.42455661327489447, iteration=11
loss=0.3831815922598795, iteration=12
loss=0.3312663421839633, iteration=13
loss=0.4370303732802314, iteration=14
loss=0.34725275719743276, iteration=15
loss=0.36663315984985195, iteration=16
loss=0.3488022812954269, iteration=17
loss=0.33033791893821596, iteration=18
loss=0.3081761073717261, iteration=19
loss=0.34995923448186, iteration=20
loss=0.3316876959243519, iteration=21
loss=0.3498679621467521, iteration=22
loss=0.3291713961281561, iteratio

loss=0.3126706516882824, iteration=208
loss=0.25808167765835344, iteration=209
loss=0.33577953564556884, iteration=210
loss=0.30492543966619085, iteration=211
loss=0.2689744089692922, iteration=212
loss=0.3070377695616374, iteration=213
loss=0.3080504368154054, iteration=214
loss=0.21874303774410964, iteration=215
loss=0.28409816046617387, iteration=216
loss=0.3390708016369932, iteration=217
loss=0.28143543813373023, iteration=218
loss=0.2752377974198499, iteration=219
loss=0.39501337660211744, iteration=220
loss=0.2812178732521071, iteration=221
loss=0.2788875511573512, iteration=222
loss=0.33444262816218573, iteration=223
loss=0.30460905851373143, iteration=224
loss=0.2493317925341864, iteration=225
loss=0.23089519563687522, iteration=226
loss=0.33023521865088334, iteration=227
loss=0.3109079912324956, iteration=228
loss=0.3291293940174402, iteration=229
loss=0.2434025865323864, iteration=230
loss=0.2692719096183516, iteration=231
loss=0.3148447836165055, iteration=232
loss=0.3591772

loss=0.2463146645094279, iteration=418
loss=0.24285052124269976, iteration=419
loss=0.2982151093717014, iteration=420
loss=0.25232026824084686, iteration=421
loss=0.28434239568582376, iteration=422
loss=0.3495899175611604, iteration=423
loss=0.3533074184072042, iteration=424
loss=0.21971091917460825, iteration=425
loss=0.25596068210702116, iteration=426
loss=0.28935531066699305, iteration=427
loss=0.2891166494514741, iteration=428
loss=0.3111018077647531, iteration=429
loss=0.24949417006705624, iteration=430
loss=0.2852787988541974, iteration=431
loss=0.3753596549404875, iteration=432
loss=0.30828209029739695, iteration=433
loss=0.2886432511098851, iteration=434
loss=0.3141503294715892, iteration=435
loss=0.347775768424149, iteration=436
loss=0.2713005986402753, iteration=437
loss=0.24683112338776472, iteration=438
loss=0.28122736508584034, iteration=439
loss=0.2375319921462996, iteration=440
loss=0.2854406038666891, iteration=441
loss=0.23973605508040224, iteration=442
loss=0.24248675

loss=0.19979801368711142, iteration=628
loss=0.2988037623041737, iteration=629
loss=0.2604035988847206, iteration=630
loss=0.2270580903075503, iteration=631
loss=0.2778323701779431, iteration=632
loss=0.2809584081860919, iteration=633
loss=0.28889567100380265, iteration=634
loss=0.3078163972263699, iteration=635
loss=0.32736321203470026, iteration=636
loss=0.24302758395046448, iteration=637
loss=0.23140297554482825, iteration=638
loss=0.2339740109966798, iteration=639
loss=0.251206696675684, iteration=640
loss=0.2591011821481547, iteration=641
loss=0.2836401844673937, iteration=642
loss=0.30834174119591695, iteration=643
loss=0.2735803240058784, iteration=644
loss=0.2962169673951376, iteration=645
loss=0.313459467302762, iteration=646
loss=0.319995197769289, iteration=647
loss=0.2648286785840478, iteration=648
loss=0.30778374689006843, iteration=649
loss=0.29460998212310385, iteration=650
loss=0.24077106085575645, iteration=651
loss=0.3672516641260736, iteration=652
loss=0.279205609579

loss=0.3280607753513025, iteration=841
loss=0.2992679327786951, iteration=842
loss=0.25997824762191585, iteration=843
loss=0.2716530249276888, iteration=844
loss=0.24152769739252566, iteration=845
loss=0.3472712372760266, iteration=846
loss=0.3061463299386613, iteration=847
loss=0.2552789884904685, iteration=848
loss=0.23539376714425528, iteration=849
loss=0.2578266330584987, iteration=850
loss=0.21255331215472356, iteration=851
loss=0.273388187535285, iteration=852
loss=0.28803693254951546, iteration=853
loss=0.3460635122659248, iteration=854
loss=0.2854423936600833, iteration=855
loss=0.2893395076407299, iteration=856
loss=0.3355322283243216, iteration=857
loss=0.23891739311967616, iteration=858
loss=0.23463310475368765, iteration=859
loss=0.2707543902648123, iteration=860
loss=0.22444823265366906, iteration=861
loss=0.3460499593115166, iteration=862
loss=0.33081033904981094, iteration=863
loss=0.22426220763215227, iteration=864
loss=0.2612849571931457, iteration=865
loss=0.364792842

loss=0.23214395284300518, iteration=1052
loss=0.24290939900126235, iteration=1053
loss=0.22521871181985345, iteration=1054
loss=0.27190996485111746, iteration=1055
loss=0.1986327154713307, iteration=1056
loss=0.21806387838065064, iteration=1057
loss=0.19830937928895742, iteration=1058
loss=0.2662166011955108, iteration=1059
loss=0.2854826485819775, iteration=1060
loss=0.22903205577982963, iteration=1061
loss=0.22959801700531335, iteration=1062
loss=0.2799824771088615, iteration=1063
loss=0.3211235467047793, iteration=1064
loss=0.22198471057124886, iteration=1065
loss=0.21834040589664255, iteration=1066
loss=0.3002569885290214, iteration=1067
loss=0.2732902223375303, iteration=1068
loss=0.2472640966166778, iteration=1069
loss=0.2810382788488569, iteration=1070
loss=0.20715930408398858, iteration=1071
loss=0.26781749190783244, iteration=1072
loss=0.3390388817608594, iteration=1073
loss=0.20081522693530374, iteration=1074
loss=0.25638335431845904, iteration=1075
loss=0.313278333058152, it

loss=0.2653448962516881, iteration=1259
loss=0.22328055481817874, iteration=1260
loss=0.26025052024834266, iteration=1261
loss=0.28227950035723764, iteration=1262
loss=0.25916590911884696, iteration=1263
loss=0.3208395923516205, iteration=1264
loss=0.2530911612384591, iteration=1265
loss=0.3030344422867566, iteration=1266
loss=0.2948141703364653, iteration=1267
loss=0.33239167197096886, iteration=1268
loss=0.24145834582853182, iteration=1269
loss=0.33686384015633924, iteration=1270
loss=0.2944331692590967, iteration=1271
loss=0.1802890878992033, iteration=1272
loss=0.2559079645780564, iteration=1273
loss=0.32527039551448467, iteration=1274
loss=0.22354997297742585, iteration=1275
loss=0.2120288400526273, iteration=1276
loss=0.28808686053017163, iteration=1277
loss=0.2678234924199473, iteration=1278
loss=0.22913991981810136, iteration=1279
loss=0.21971434037918655, iteration=1280
loss=0.3428680128183481, iteration=1281
loss=0.2174029225657078, iteration=1282
loss=0.2693410234179894, ite

loss=0.3084207859781524, iteration=1464
loss=0.30653908521318296, iteration=1465
loss=0.2293807719953053, iteration=1466
loss=0.252614483178802, iteration=1467
loss=0.3403918143970637, iteration=1468
loss=0.20218111832819596, iteration=1469
loss=0.3148496317505213, iteration=1470
loss=0.2606303435627182, iteration=1471
loss=0.184702812841947, iteration=1472
loss=0.26793328231514235, iteration=1473
loss=0.27145824248711303, iteration=1474
loss=0.26076482822125935, iteration=1475
loss=0.291667829819198, iteration=1476
loss=0.1980936477576417, iteration=1477
loss=0.22326240419087562, iteration=1478
loss=0.28765811584157847, iteration=1479
loss=0.2759615849925052, iteration=1480
loss=0.21145624367302024, iteration=1481
loss=0.24351217981638557, iteration=1482
loss=0.3062633537744025, iteration=1483
loss=0.21001858688310668, iteration=1484
loss=0.2680562037265844, iteration=1485
loss=0.28601486370737195, iteration=1486
loss=0.29751106148180917, iteration=1487
loss=0.31498910896872234, itera

loss=0.2304603104786354, iteration=1668
loss=0.28341980272384804, iteration=1669
loss=0.32127023617545347, iteration=1670
loss=0.30932977974970377, iteration=1671
loss=0.17775827597736474, iteration=1672
loss=0.2722515675546402, iteration=1673
loss=0.2298276620279867, iteration=1674
loss=0.2713185480714505, iteration=1675
loss=0.25790508991776406, iteration=1676
loss=0.2884379001383737, iteration=1677
loss=0.2680081010221014, iteration=1678
loss=0.2391602558576471, iteration=1679
loss=0.34516641350807464, iteration=1680
loss=0.27817156708904756, iteration=1681
loss=0.2381135046261088, iteration=1682
loss=0.44221581397897186, iteration=1683
loss=0.281297396991428, iteration=1684
loss=0.21380580238165903, iteration=1685
loss=0.3120533800269954, iteration=1686
loss=0.3322468151574089, iteration=1687
loss=0.2615795857002516, iteration=1688
loss=0.24456661385274542, iteration=1689
loss=0.20408144012538082, iteration=1690
loss=0.32245218380787377, iteration=1691
loss=0.27589760879475306, ite

loss=0.2953457359138968, iteration=1873
loss=0.21145829038470204, iteration=1874
loss=0.1789487585821035, iteration=1875
loss=0.27143625195016896, iteration=1876
loss=0.2561902656370185, iteration=1877
loss=0.2793956405119853, iteration=1878
loss=0.30825049597281706, iteration=1879
loss=0.2384380454884842, iteration=1880
loss=0.32749532049448893, iteration=1881
loss=0.32369089061061757, iteration=1882
loss=0.32373997493134105, iteration=1883
loss=0.29083454668289777, iteration=1884
loss=0.3248551172774493, iteration=1885
loss=0.26938818336021075, iteration=1886
loss=0.29588382298127514, iteration=1887
loss=0.21233967075756568, iteration=1888
loss=0.20832898277695336, iteration=1889
loss=0.28645834168761597, iteration=1890
loss=0.3340029274415578, iteration=1891
loss=0.3222375383611972, iteration=1892
loss=0.28206935325750493, iteration=1893
loss=0.2623000248594426, iteration=1894
loss=0.3997585499620151, iteration=1895
loss=0.31916297862164367, iteration=1896
loss=0.28664651703389876, 

loss=0.353918489229038, iteration=81
loss=0.3921065940785637, iteration=82
loss=0.42676527081572085, iteration=83
loss=0.3282229208630498, iteration=84
loss=0.37116412790892395, iteration=85
loss=0.2852227483895673, iteration=86
loss=0.4218409422159504, iteration=87
loss=0.36040549501173036, iteration=88
loss=0.2969452435793678, iteration=89
loss=0.33826007546257947, iteration=90
loss=0.42608859313817415, iteration=91
loss=0.38800625218873325, iteration=92
loss=0.3749676466635078, iteration=93
loss=0.3373909212484051, iteration=94
loss=0.45956508921445116, iteration=95
loss=0.3992828503304862, iteration=96
loss=0.39253993766771167, iteration=97
loss=0.42852103403823594, iteration=98
loss=0.3280881201183864, iteration=99
loss=0.39753868657139935, iteration=100
loss=0.44707070299915097, iteration=101
loss=0.42954666845818246, iteration=102
loss=0.31141792904342325, iteration=103
loss=0.38460830943164076, iteration=104
loss=0.39559020402214945, iteration=105
loss=0.35414983561291763, iter

loss=0.350316230455744, iteration=290
loss=0.34736340893317513, iteration=291
loss=0.3176148231587037, iteration=292
loss=0.33846952115062723, iteration=293
loss=0.3267457181833675, iteration=294
loss=0.43071885925665365, iteration=295
loss=0.2888012932943678, iteration=296
loss=0.38907295645919876, iteration=297
loss=0.3661500643618383, iteration=298
loss=0.3864090201231554, iteration=299
loss=0.3382313394117357, iteration=300
loss=0.3564786316797638, iteration=301
loss=0.32701345177249086, iteration=302
loss=0.42652618473386306, iteration=303
loss=0.3648754098493169, iteration=304
loss=0.37054938467780785, iteration=305
loss=0.3548142501570042, iteration=306
loss=0.41387811593459123, iteration=307
loss=0.385026918987203, iteration=308
loss=0.3936263579089493, iteration=309
loss=0.3396948643621327, iteration=310
loss=0.35721100567929787, iteration=311
loss=0.39468731344692876, iteration=312
loss=0.30515486514012913, iteration=313
loss=0.37390749169530857, iteration=314
loss=0.35059734

loss=0.4482364527466053, iteration=498
loss=0.4077247452826256, iteration=499
loss=0.3999916517562402, iteration=500
loss=0.366088390384836, iteration=501
loss=0.3726697142332711, iteration=502
loss=0.3722120939113331, iteration=503
loss=0.29909043901079535, iteration=504
loss=0.3295778650048461, iteration=505
loss=0.3556729566342905, iteration=506
loss=0.4055189315242145, iteration=507
loss=0.3463150612795787, iteration=508
loss=0.38115888329953657, iteration=509
loss=0.30954899728102, iteration=510
loss=0.33215397869637114, iteration=511
loss=0.3894328443126281, iteration=512
loss=0.3128044296946748, iteration=513
loss=0.3264041859465914, iteration=514
loss=0.33450934853325115, iteration=515
loss=0.3523132030870966, iteration=516
loss=0.3411841933963887, iteration=517
loss=0.3967392504731315, iteration=518
loss=0.47745496531012815, iteration=519
loss=0.4028272089809334, iteration=520
loss=0.4793535283869511, iteration=521
loss=0.43496286073392953, iteration=522
loss=0.314995701846002

loss=0.3429822234624591, iteration=709
loss=0.36569992906074555, iteration=710
loss=0.3366254126126963, iteration=711
loss=0.3901667635714393, iteration=712
loss=0.4388095621335414, iteration=713
loss=0.3530823937653771, iteration=714
loss=0.40739171041189104, iteration=715
loss=0.35659560100974513, iteration=716
loss=0.37726910081852916, iteration=717
loss=0.36590552433393636, iteration=718
loss=0.4369247174379782, iteration=719
loss=0.34585030943836004, iteration=720
loss=0.2892811255135326, iteration=721
loss=0.36670806254269117, iteration=722
loss=0.37424043068702817, iteration=723
loss=0.3263109911518179, iteration=724
loss=0.32493615908352974, iteration=725
loss=0.37802093126178193, iteration=726
loss=0.4151416351172981, iteration=727
loss=0.3354485947137747, iteration=728
loss=0.3411192325086977, iteration=729
loss=0.4411090743258588, iteration=730
loss=0.35837088360448843, iteration=731
loss=0.40376651001636477, iteration=732
loss=0.34709694961861814, iteration=733
loss=0.33484

loss=0.3033387516414095, iteration=922
loss=0.4161986237102759, iteration=923
loss=0.3823728190604403, iteration=924
loss=0.4181349858188182, iteration=925
loss=0.4266512862615356, iteration=926
loss=0.3470445870387726, iteration=927
loss=0.36852616780137554, iteration=928
loss=0.33760416884680167, iteration=929
loss=0.38427952625779693, iteration=930
loss=0.29664196112538993, iteration=931
loss=0.3916413154181685, iteration=932
loss=0.31872877681850165, iteration=933
loss=0.37125640041251284, iteration=934
loss=0.38317015556015, iteration=935
loss=0.4018261495313106, iteration=936
loss=0.36547933613205935, iteration=937
loss=0.3438683310317565, iteration=938
loss=0.3940512941873915, iteration=939
loss=0.3715161742210634, iteration=940
loss=0.336911754641389, iteration=941
loss=0.40780938491189095, iteration=942
loss=0.3309849966466496, iteration=943
loss=0.37262671055014807, iteration=944
loss=0.3291694283167407, iteration=945
loss=0.32418833476499737, iteration=946
loss=0.37345425432

loss=0.31685033983376293, iteration=1130
loss=0.36788532106839955, iteration=1131
loss=0.37116103425977476, iteration=1132
loss=0.31206795099920714, iteration=1133
loss=0.28360435687872454, iteration=1134
loss=0.48858750111368243, iteration=1135
loss=0.36919127755916736, iteration=1136
loss=0.35927629535883904, iteration=1137
loss=0.3627808635263049, iteration=1138
loss=0.38350796669482934, iteration=1139
loss=0.3415418843098892, iteration=1140
loss=0.33969804325245223, iteration=1141
loss=0.34318733177004235, iteration=1142
loss=0.3649452813844399, iteration=1143
loss=0.3941200917841606, iteration=1144
loss=0.4104422511955943, iteration=1145
loss=0.34515676861997924, iteration=1146
loss=0.3198642084356922, iteration=1147
loss=0.41568127275526656, iteration=1148
loss=0.3944819282850253, iteration=1149
loss=0.38976590458495763, iteration=1150
loss=0.33516625300358094, iteration=1151
loss=0.3029979576463526, iteration=1152
loss=0.3900646537920059, iteration=1153
loss=0.43051997856042185,

loss=0.41103942152487943, iteration=1541
loss=0.35650925364078234, iteration=1542
loss=0.3792919907421056, iteration=1543
loss=0.3269705742515423, iteration=1544
loss=0.3090319352291184, iteration=1545
loss=0.42152003662879267, iteration=1546
loss=0.31443199884033857, iteration=1547
loss=0.3887916830884563, iteration=1548
loss=0.3482821832384582, iteration=1549
loss=0.40748410563407034, iteration=1550
loss=0.3347256981008634, iteration=1551
loss=0.4273260065309906, iteration=1552
loss=0.4404411732724215, iteration=1553
loss=0.4696395071803192, iteration=1554
loss=0.3615394946096218, iteration=1555
loss=0.3876579568814299, iteration=1556
loss=0.37146389285911174, iteration=1557
loss=0.3509980765350567, iteration=1558
loss=0.31227687161405476, iteration=1559
loss=0.37997260965061735, iteration=1560
loss=0.3893369400050765, iteration=1561
loss=0.410841605042755, iteration=1562
loss=0.36140581415133477, iteration=1563
loss=0.372406224380201, iteration=1564
loss=0.3666673158610895, iteratio

loss=0.4144620622654335, iteration=1748
loss=0.4256016808451591, iteration=1749
loss=0.40404945607342957, iteration=1750
loss=0.39373690795086336, iteration=1751
loss=0.3618648659016886, iteration=1752
loss=0.3222286159166286, iteration=1753
loss=0.34838548527934166, iteration=1754
loss=0.3130930756370561, iteration=1755
loss=0.4106080400177058, iteration=1756
loss=0.37695862761118354, iteration=1757
loss=0.3804213809563125, iteration=1758
loss=0.37781293836769414, iteration=1759
loss=0.3921260949983022, iteration=1760
loss=0.3895431038500799, iteration=1761
loss=0.5194221764497123, iteration=1762
loss=0.3721756664675919, iteration=1763
loss=0.34101943031906046, iteration=1764
loss=0.34519786821009285, iteration=1765
loss=0.3068591741630393, iteration=1766
loss=0.3524310227618243, iteration=1767
loss=0.32752818959807284, iteration=1768
loss=0.3472545129072825, iteration=1769
loss=0.35141775187322943, iteration=1770
loss=0.4008914703314371, iteration=1771
loss=0.4349291715790332, iterat

loss=0.4275161815950119, iteration=1956
loss=0.40454369187971345, iteration=1957
loss=0.37589560917307124, iteration=1958
loss=0.3972469901669086, iteration=1959
loss=0.40471505557206433, iteration=1960
loss=0.3579824864655555, iteration=1961
loss=0.3609741083940574, iteration=1962
loss=0.4009293625179779, iteration=1963
loss=0.43187980569875595, iteration=1964
loss=0.3263082933527536, iteration=1965
loss=0.37399357687535195, iteration=1966
loss=0.3800801665862199, iteration=1967
loss=0.4202431322994284, iteration=1968
loss=0.3988558742589393, iteration=1969
loss=0.389676567087183, iteration=1970
loss=0.3971415856851489, iteration=1971
loss=0.36616183165305205, iteration=1972
loss=0.32929685133053327, iteration=1973
loss=0.43690837524859466, iteration=1974
loss=0.3898520469247105, iteration=1975
loss=0.43671266213974264, iteration=1976
loss=0.36530733778905666, iteration=1977
loss=0.4143690226861482, iteration=1978
loss=0.4745142293556809, iteration=1979
loss=0.32658892256287536, itera

loss=0.3661129035187223, iteration=164
loss=0.3900605234401008, iteration=165
loss=0.4064935261164377, iteration=166
loss=0.3793991736354838, iteration=167
loss=0.3420901122822017, iteration=168
loss=0.35152803803571475, iteration=169
loss=0.28371497092530684, iteration=170
loss=0.3072320217780352, iteration=171
loss=0.4463483887512769, iteration=172
loss=0.3950539035352565, iteration=173
loss=0.3392053930806923, iteration=174
loss=0.36653508651063493, iteration=175
loss=0.43052033549549557, iteration=176
loss=0.35385433105243974, iteration=177
loss=0.3752345446600171, iteration=178
loss=0.3411410117250683, iteration=179
loss=0.3097767170498253, iteration=180
loss=0.39012908838458704, iteration=181
loss=0.39301276674099644, iteration=182
loss=0.33048924519006606, iteration=183
loss=0.3287366685228301, iteration=184
loss=0.3427995885687295, iteration=185
loss=0.4072916900635346, iteration=186
loss=0.39791081314111576, iteration=187
loss=0.4189920538477034, iteration=188
loss=0.395730838

loss=0.39698966370279376, iteration=380
loss=0.4116004332935963, iteration=381
loss=0.3536293976295791, iteration=382
loss=0.3662113327041488, iteration=383
loss=0.2960665299896091, iteration=384
loss=0.38344659582100016, iteration=385
loss=0.3315625190400829, iteration=386
loss=0.31709859526155815, iteration=387
loss=0.3868387763021349, iteration=388
loss=0.33804083450722144, iteration=389
loss=0.34124030209339984, iteration=390
loss=0.3508055808306367, iteration=391
loss=0.37868451233879147, iteration=392
loss=0.3371877164948611, iteration=393
loss=0.3907579108631557, iteration=394
loss=0.37266496746352196, iteration=395
loss=0.34459614672218464, iteration=396
loss=0.3759963455569962, iteration=397
loss=0.3628299338634706, iteration=398
loss=0.36586046891364876, iteration=399
loss=0.3873602263111183, iteration=400
loss=0.3801981256546407, iteration=401
loss=0.39440941125247475, iteration=402
loss=0.313934766753492, iteration=403
loss=0.2897948809383555, iteration=404
loss=0.352305430

loss=0.33419333916245925, iteration=590
loss=0.3950519767977616, iteration=591
loss=0.3652017487229263, iteration=592
loss=0.2670044578883547, iteration=593
loss=0.3442881701112807, iteration=594
loss=0.31283792763790963, iteration=595
loss=0.43518294751928177, iteration=596
loss=0.33611743713938136, iteration=597
loss=0.34680901175539636, iteration=598
loss=0.3322233334622898, iteration=599
loss=0.33829364527656364, iteration=600
loss=0.34618513496925873, iteration=601
loss=0.31447677515196104, iteration=602
loss=0.3466623108441954, iteration=603
loss=0.31962745870429765, iteration=604
loss=0.3920942704245103, iteration=605
loss=0.3354951088805977, iteration=606
loss=0.4120856857894467, iteration=607
loss=0.3388005228814935, iteration=608
loss=0.3445032931770942, iteration=609
loss=0.42592967638558255, iteration=610
loss=0.38432202150150346, iteration=611
loss=0.3013877878116822, iteration=612
loss=0.3757965195588516, iteration=613
loss=0.3043970262357523, iteration=614
loss=0.4133943

loss=0.45056212487725045, iteration=804
loss=0.3239870468256961, iteration=805
loss=0.42921862450364145, iteration=806
loss=0.278250756040912, iteration=807
loss=0.32085839866055466, iteration=808
loss=0.331912356939, iteration=809
loss=0.39523548428328414, iteration=810
loss=0.2794807381434258, iteration=811
loss=0.27267591586818873, iteration=812
loss=0.40474482359758585, iteration=813
loss=0.35543386291935025, iteration=814
loss=0.33160053119138255, iteration=815
loss=0.3473569535264793, iteration=816
loss=0.3637183474674307, iteration=817
loss=0.3073091348421222, iteration=818
loss=0.38933798072188336, iteration=819
loss=0.3781275567598613, iteration=820
loss=0.33085240281967854, iteration=821
loss=0.37381468695839404, iteration=822
loss=0.4022972358625291, iteration=823
loss=0.30890336787538575, iteration=824
loss=0.4012713275567393, iteration=825
loss=0.3308763388356768, iteration=826
loss=0.31377957657305355, iteration=827
loss=0.2954793150404257, iteration=828
loss=0.3772127686

loss=0.3428609114492627, iteration=1016
loss=0.34163286314645464, iteration=1017
loss=0.35737644668316193, iteration=1018
loss=0.3426502024692115, iteration=1019
loss=0.395477485835659, iteration=1020
loss=0.3577264342169193, iteration=1021
loss=0.3633390631232931, iteration=1022
loss=0.38619968839480134, iteration=1023
loss=0.352414765559308, iteration=1024
loss=0.39572069740183446, iteration=1025
loss=0.32716326530610756, iteration=1026
loss=0.3972101615308824, iteration=1027
loss=0.3361265528666375, iteration=1028
loss=0.29942538453504236, iteration=1029
loss=0.3844778357693043, iteration=1030
loss=0.28682046755638063, iteration=1031
loss=0.3525638458782582, iteration=1032
loss=0.38496929714070816, iteration=1033
loss=0.43265551668428603, iteration=1034
loss=0.42038428792431637, iteration=1035
loss=0.3790290578136035, iteration=1036
loss=0.4157210672609146, iteration=1037
loss=0.3745474707813311, iteration=1038
loss=0.34228897565943983, iteration=1039
loss=0.32758914496875824, itera

loss=0.3843499275245002, iteration=1222
loss=0.3959626867718806, iteration=1223
loss=0.3802817747804661, iteration=1224
loss=0.385331526984948, iteration=1225
loss=0.4243566933215026, iteration=1226
loss=0.3211429752196696, iteration=1227
loss=0.3003472413877109, iteration=1228
loss=0.28110480279693795, iteration=1229
loss=0.2720180656018926, iteration=1230
loss=0.4046993033145394, iteration=1231
loss=0.30546607499783396, iteration=1232
loss=0.4225395874716985, iteration=1233
loss=0.3757501551444612, iteration=1234
loss=0.3724635691010724, iteration=1235
loss=0.34705839757499823, iteration=1236
loss=0.3766377172970097, iteration=1237
loss=0.410921132855534, iteration=1238
loss=0.3369113706727483, iteration=1239
loss=0.3559585964260348, iteration=1240
loss=0.45524664115090663, iteration=1241
loss=0.3885577392404527, iteration=1242
loss=0.3490821715624549, iteration=1243
loss=0.2944712951936529, iteration=1244
loss=0.3683293092937756, iteration=1245
loss=0.3805355097591949, iteration=124

loss=0.36701995633396534, iteration=1428
loss=0.38582476702111657, iteration=1429
loss=0.4090644425133874, iteration=1430
loss=0.40094638892497414, iteration=1431
loss=0.25835316087302773, iteration=1432
loss=0.32708508724934704, iteration=1433
loss=0.3513042788192812, iteration=1434
loss=0.3366051479534836, iteration=1435
loss=0.30347009782123635, iteration=1436
loss=0.35214885288151376, iteration=1437
loss=0.32924386520162974, iteration=1438
loss=0.34726384686597644, iteration=1439
loss=0.33519381171767787, iteration=1440
loss=0.41166307443499955, iteration=1441
loss=0.37578457816635324, iteration=1442
loss=0.2723731957559287, iteration=1443
loss=0.28165267587909226, iteration=1444
loss=0.3961836141951745, iteration=1445
loss=0.32242504696550045, iteration=1446
loss=0.3665700291615489, iteration=1447
loss=0.2997498316626301, iteration=1448
loss=0.37920038504255715, iteration=1449
loss=0.3525847518713333, iteration=1450
loss=0.3705737075579343, iteration=1451
loss=0.34200713785624504,

loss=0.34983507333345376, iteration=1636
loss=0.37810441633679304, iteration=1637
loss=0.34435392743215854, iteration=1638
loss=0.3228579491377574, iteration=1639
loss=0.33379554136438616, iteration=1640
loss=0.36403341561985647, iteration=1641
loss=0.3549312241460774, iteration=1642
loss=0.2710163287605428, iteration=1643
loss=0.34301633619127786, iteration=1644
loss=0.3171061346426074, iteration=1645
loss=0.31628651741317376, iteration=1646
loss=0.3587459152040713, iteration=1647
loss=0.39404758800836187, iteration=1648
loss=0.34296725589680666, iteration=1649
loss=0.3335603381290969, iteration=1650
loss=0.38164137559815936, iteration=1651
loss=0.31405280144514813, iteration=1652
loss=0.32721166062098284, iteration=1653
loss=0.3649158033327793, iteration=1654
loss=0.27806146292136014, iteration=1655
loss=0.31640826420828966, iteration=1656
loss=0.39125411002042865, iteration=1657
loss=0.3862379381791031, iteration=1658
loss=0.3401869988292251, iteration=1659
loss=0.3547269622704853, 

loss=0.3199189288757767, iteration=1840
loss=0.3237706501216211, iteration=1841
loss=0.31948691630638604, iteration=1842
loss=0.33579354009357354, iteration=1843
loss=0.33149835183695275, iteration=1844
loss=0.372035690106019, iteration=1845
loss=0.3563136773360559, iteration=1846
loss=0.34627162775445136, iteration=1847
loss=0.5315907692800929, iteration=1848
loss=0.351877497295111, iteration=1849
loss=0.3482136264500401, iteration=1850
loss=0.3167145096553776, iteration=1851
loss=0.3602530508844739, iteration=1852
loss=0.3573158494452885, iteration=1853
loss=0.33686213020195244, iteration=1854
loss=0.33391736976792763, iteration=1855
loss=0.41251893140225593, iteration=1856
loss=0.3175656070351649, iteration=1857
loss=0.3236009079700518, iteration=1858
loss=0.344593051470439, iteration=1859
loss=0.3048477420723532, iteration=1860
loss=0.3352761845823111, iteration=1861
loss=0.28204538600491624, iteration=1862
loss=0.2775867189367107, iteration=1863
loss=0.3433436846007831, iteration=

loss=0.42277037408482554, iteration=47
loss=0.42095466298823864, iteration=48
loss=0.38837135386088795, iteration=49
loss=0.34795977061667527, iteration=50
loss=0.3687628812552951, iteration=51
loss=0.4203552753533256, iteration=52
loss=0.3521108759527501, iteration=53
loss=0.3747604986244943, iteration=54
loss=0.3347751160240754, iteration=55
loss=0.32295329829488273, iteration=56
loss=0.34826270108483515, iteration=57
loss=0.32740766465376636, iteration=58
loss=0.3979011139989065, iteration=59
loss=0.37923156992996576, iteration=60
loss=0.3782353472607368, iteration=61
loss=0.3131601159758728, iteration=62
loss=0.39123994067947815, iteration=63
loss=0.4161062756669297, iteration=64
loss=0.3898828386320453, iteration=65
loss=0.3639183842966264, iteration=66
loss=0.37354975682463376, iteration=67
loss=0.3246337390537761, iteration=68
loss=0.3764216540827461, iteration=69
loss=0.3287137988205816, iteration=70
loss=0.3686104918679437, iteration=71
loss=0.3413140795674522, iteration=72
lo

loss=0.34958193490807243, iteration=268
loss=0.2883559097355609, iteration=269
loss=0.3454779970419898, iteration=270
loss=0.29940504966728076, iteration=271
loss=0.361663784174984, iteration=272
loss=0.3717488474419842, iteration=273
loss=0.39483055149070667, iteration=274
loss=0.40023557549320793, iteration=275
loss=0.3567297609681855, iteration=276
loss=0.3622880103696797, iteration=277
loss=0.3661456886180767, iteration=278
loss=0.34641952568816387, iteration=279
loss=0.32384788074240917, iteration=280
loss=0.3978427301722033, iteration=281
loss=0.4054951652486122, iteration=282
loss=0.39337880046525275, iteration=283
loss=0.2998510863340239, iteration=284
loss=0.3683790565200661, iteration=285
loss=0.3430937109010004, iteration=286
loss=0.3335711002249354, iteration=287
loss=0.3297214108494237, iteration=288
loss=0.33913328098678236, iteration=289
loss=0.2896266461180549, iteration=290
loss=0.38409705862668886, iteration=291
loss=0.4114509166957968, iteration=292
loss=0.3672859237

loss=0.36476420714047025, iteration=496
loss=0.31732950770183654, iteration=497
loss=0.29882076021585335, iteration=498
loss=0.3514600416244643, iteration=499
loss=0.3417715112597337, iteration=500
loss=0.356353750511835, iteration=501
loss=0.33647671739900076, iteration=502
loss=0.3760672385742495, iteration=503
loss=0.323479479298678, iteration=504
loss=0.36439015049380785, iteration=505
loss=0.3707848896500289, iteration=506
loss=0.3128388266359602, iteration=507
loss=0.3478047139641196, iteration=508
loss=0.4012567574855094, iteration=509
loss=0.3123381302006482, iteration=510
loss=0.4768559103170409, iteration=511
loss=0.41156064291037875, iteration=512
loss=0.4642004881812821, iteration=513
loss=0.4072986852603783, iteration=514
loss=0.4660351961676262, iteration=515
loss=0.3506861709034066, iteration=516
loss=0.32874968953277794, iteration=517
loss=0.44174557313693774, iteration=518
loss=0.34884008873428035, iteration=519
loss=0.3920276067699156, iteration=520
loss=0.35858040563

loss=0.3442893263588028, iteration=724
loss=0.4031939606508854, iteration=725
loss=0.40449814659577527, iteration=726
loss=0.35396847565010175, iteration=727
loss=0.3276272591072279, iteration=728
loss=0.37832508431030115, iteration=729
loss=0.31927668447961205, iteration=730
loss=0.363455148692747, iteration=731
loss=0.30861052927276844, iteration=732
loss=0.2971400438290469, iteration=733
loss=0.363781094351063, iteration=734
loss=0.3224390968990031, iteration=735
loss=0.3403314838315792, iteration=736
loss=0.35012555137253565, iteration=737
loss=0.3177291714377602, iteration=738
loss=0.33170759931585064, iteration=739
loss=0.2682946499351255, iteration=740
loss=0.42127940740575254, iteration=741
loss=0.39217366582373214, iteration=742
loss=0.39554880997439124, iteration=743
loss=0.38011490716895235, iteration=744
loss=0.44090277365737074, iteration=745
loss=0.38407564207845807, iteration=746
loss=0.42159819210382987, iteration=747
loss=0.3553568198222732, iteration=748
loss=0.440328

loss=0.3518552170604072, iteration=944
loss=0.2722524875294327, iteration=945
loss=0.40252995343176906, iteration=946
loss=0.3495414118706919, iteration=947
loss=0.4428187163268008, iteration=948
loss=0.4400655809176034, iteration=949
loss=0.3971793628470849, iteration=950
loss=0.3167168887214782, iteration=951
loss=0.3529519447901301, iteration=952
loss=0.43619760848618094, iteration=953
loss=0.35381120841209834, iteration=954
loss=0.3449168349846187, iteration=955
loss=0.3299624390117265, iteration=956
loss=0.34379652825473483, iteration=957
loss=0.42844381757318056, iteration=958
loss=0.3974288728531911, iteration=959
loss=0.3862572157267307, iteration=960
loss=0.3930079347460179, iteration=961
loss=0.35675821909745364, iteration=962
loss=0.3965541749817384, iteration=963
loss=0.38102262118856667, iteration=964
loss=0.3526664690179425, iteration=965
loss=0.3576981406684533, iteration=966
loss=0.36848993851012335, iteration=967
loss=0.35013428603714336, iteration=968
loss=0.360844388

loss=0.3487667821325464, iteration=1161
loss=0.33921828111591995, iteration=1162
loss=0.35871250906118857, iteration=1163
loss=0.3712593931614474, iteration=1164
loss=0.38317322759896866, iteration=1165
loss=0.4044155083121872, iteration=1166
loss=0.3369006551747401, iteration=1167
loss=0.35444953714201966, iteration=1168
loss=0.36501368118654776, iteration=1169
loss=0.32934727029120614, iteration=1170
loss=0.41123396904000065, iteration=1171
loss=0.3220112200425128, iteration=1172
loss=0.4002717743471813, iteration=1173
loss=0.3338211389114354, iteration=1174
loss=0.3089752875757782, iteration=1175
loss=0.34330765582295664, iteration=1176
loss=0.3640839556570078, iteration=1177
loss=0.39430894135481137, iteration=1178
loss=0.399030502730236, iteration=1179
loss=0.37002029324162855, iteration=1180
loss=0.32684393161466574, iteration=1181
loss=0.36437073809960435, iteration=1182
loss=0.300333131698472, iteration=1183
loss=0.3449568883966193, iteration=1184
loss=0.3542367860548659, itera

loss=0.357390984383663, iteration=1381
loss=0.3223990874902557, iteration=1382
loss=0.34800924519483534, iteration=1383
loss=0.3159891320118092, iteration=1384
loss=0.4093591401751032, iteration=1385
loss=0.39145251467369324, iteration=1386
loss=0.34202783227253913, iteration=1387
loss=0.38925331285588627, iteration=1388
loss=0.3700931589797375, iteration=1389
loss=0.3213275925961333, iteration=1390
loss=0.3444866950182204, iteration=1391
loss=0.3540158407992203, iteration=1392
loss=0.3649365318918412, iteration=1393
loss=0.3757806477922698, iteration=1394
loss=0.4000074767735795, iteration=1395
loss=0.37361921297853184, iteration=1396
loss=0.3608034886502479, iteration=1397
loss=0.2516880888518476, iteration=1398
loss=0.42404588594647447, iteration=1399
loss=0.36889465382206466, iteration=1400
loss=0.3716461422720232, iteration=1401
loss=0.34577677843707333, iteration=1402
loss=0.3826424440340535, iteration=1403
loss=0.4401604813330204, iteration=1404
loss=0.34872267783853095, iterati

loss=0.3151914960071358, iteration=1596
loss=0.38628510867574417, iteration=1597
loss=0.37167889088017525, iteration=1598
loss=0.3534101562232292, iteration=1599
loss=0.3277048929739678, iteration=1600
loss=0.359362358692495, iteration=1601
loss=0.35600986032280474, iteration=1602
loss=0.43287842378813124, iteration=1603
loss=0.4126152197642691, iteration=1604
loss=0.3645341783756694, iteration=1605
loss=0.36672710259161284, iteration=1606
loss=0.34558797321741325, iteration=1607
loss=0.34376807620546784, iteration=1608
loss=0.38811328858726657, iteration=1609
loss=0.38419204304309057, iteration=1610
loss=0.3958901232068152, iteration=1611
loss=0.3331247513311999, iteration=1612
loss=0.33348892475934866, iteration=1613
loss=0.3238611570467374, iteration=1614
loss=0.3224512774001749, iteration=1615
loss=0.43145750197528643, iteration=1616
loss=0.25766815349484384, iteration=1617
loss=0.33235833867379, iteration=1618
loss=0.3098861765704714, iteration=1619
loss=0.40967524065555355, itera

loss=0.30332453400785464, iteration=1816
loss=0.40743571243899795, iteration=1817
loss=0.37256662830222265, iteration=1818
loss=0.40602651221831987, iteration=1819
loss=0.3310300770651992, iteration=1820
loss=0.43446448947915217, iteration=1821
loss=0.397515438978177, iteration=1822
loss=0.3462950012111027, iteration=1823
loss=0.3889602298738687, iteration=1824
loss=0.37661257575838064, iteration=1825
loss=0.29433893481646595, iteration=1826
loss=0.38861282813946296, iteration=1827
loss=0.37267938392254363, iteration=1828
loss=0.29102117118202003, iteration=1829
loss=0.3757361947766559, iteration=1830
loss=0.3755732231027601, iteration=1831
loss=0.4589504098512714, iteration=1832
loss=0.3830536247157377, iteration=1833
loss=0.29937401353234094, iteration=1834
loss=0.38271354243382616, iteration=1835
loss=0.404153698030979, iteration=1836
loss=0.3537673131758383, iteration=1837
loss=0.3926359077129039, iteration=1838
loss=0.4524077833075396, iteration=1839
loss=0.350682816976377, iterat

In [108]:
y_pred = data_pred(TX, WEIGHT_LRSGD, r_ids)
metric_pred(y_pred, y)

0.758996

In [109]:
DATA_TRAIN_PATH = '../data/test.csv' 
test_y, test_tX, test_ids = load_csv_data(DATA_TRAIN_PATH)
test_tX, _, test_ids = data_preprocess(test_tX, test_y, test_ids, replacing=METHOD, mode='std')
test_pred = data_pred(test_tX, WEIGHT_LRSGD, test_ids)
OUTPUT_PATH = 'data/pred.csv'
create_csv_submission(test_ids, test_pred, OUTPUT_PATH)

### Least Squares with Normal Equation

In [122]:
# Set H-parameters
DEGREE = np.arange(2, 13)

In [136]:
# Fine-tuning each splitted training dataset
best_w = []
for i in range(len(TX)):
    W_LSNE_0 = []
    LOSS_LSNE_0 = []

    for idx, deg in enumerate(DEGREE):
        poly_tx = build_poly(TX[i], deg)
        w_lsne, loss_lsne = least_squares(Y[i], poly_tx)
        print('Degree={a}, Loss={b}'.format(a=deg-1, b=loss_lsne))
    
        W_LSNE_0.append(w_lsne)
        LOSS_LSNE_0.append(loss_lsne)

    # Obtain the optimal training weight
    w_lsne_opt = W_LSNE_0[np.argmin(LOSS_LSNE_0)]
    best_w.append(w_lsne_opt)
    print('\nOptimal Degree={a}, Loss={b}'.format(a=(np.min(DEGREE)+np.argmin(LOSS_LSNE_0)-1),
                                            b=LOSS_LSNE_0[np.argmin(LOSS_LSNE_0)]))

Degree=1, Loss=0.27213922342829283
Degree=2, Loss=0.25265036997564666
Degree=3, Loss=0.24463894415330834
Degree=4, Loss=0.24396587232083913
Degree=5, Loss=0.23822021202417515
Degree=6, Loss=0.23667387951094224
Degree=7, Loss=0.23472959958838002
Degree=8, Loss=0.23229137519463505
Degree=9, Loss=0.2311660825042048
Degree=10, Loss=0.2607976241655854
Degree=11, Loss=0.25354181918922797

Optimal Degree=9, Loss=0.2311660825042048
Degree=1, Loss=0.3726320875988558
Degree=2, Loss=0.33882921891271717
Degree=3, Loss=0.32928296043943894
Degree=4, Loss=0.321911824186863
Degree=5, Loss=0.31750676118344795
Degree=6, Loss=0.3126155386557532
Degree=7, Loss=0.305722279436587
Degree=8, Loss=0.2980324722525874
Degree=9, Loss=0.29368657005129056
Degree=10, Loss=0.2919209960317178
Degree=11, Loss=1.3522049211782425

Optimal Degree=10, Loss=0.2919209960317178
Degree=1, Loss=0.3553982628334921
Degree=2, Loss=0.3224218772271908
Degree=3, Loss=0.3036213114893265
Degree=4, Loss=0.2950969290322816
Degree=5, Loss

In [142]:
pred_tr_lsne = data_pred(TX, best_w, r_ids, poly=True)
metric_pred(pred_tr_lsne, y)

4
18
163
22
221
29
320
29
320


0.828864

In [145]:
DATA_TRAIN_PATH = '../data/test.csv' 
test_y, test_tX, test_ids = load_csv_data(DATA_TRAIN_PATH)
test_tX, _, test_ids = data_preprocess(test_tX, test_y, test_ids, replacing=METHOD, mode='std')
test_pred = data_pred(test_tX, best_w, test_ids, poly =True)
OUTPUT_PATH = 'data/pred.csv'
create_csv_submission(test_ids, test_pred, OUTPUT_PATH)

18
163
22
221
29
320
29
320


## Ridge Regression

In [254]:
# Set H-parameters
K_FOLD = 10
DEGREE = np.arange(1, 7)
SEED = 5
LAMBDA = np.logspace(-6, -3, 30)
K_CLUSTER = np.arange(1,11)

In [263]:
# Find an optimal set of polynomial expansion degree and learning rate (lambda)
# Then calculate the corresponding weight and loss (rmse)
# (based on the H-parameters set above, for each splitted training set)
DEGREE_RIDGE = []
LAMBDA_RIDGE = []
K_CLUSTER_RIDGE = []

if METHOD == 'k_means':
    for i in range(len(TX)):
        best_deg, best_lambda, best_k = find_optimal_KMC(Y[i], TX[i], DEGREE, 
                                                         K_FOLD, LAMBDA, K_CLUSTER, SEED)
        print('The best degree of tx_{}:'.format(i), best_deg, 'with lambda:', best_lambda)
        
        DEGREE_RIDGE.append(best_deg)
        LAMBDA_RIDGE.append(best_lambda)
        K_CLUSTER_RIDGE.append(best_k)
             
else:
    for i in range(len(TX)):
        best_deg, best_lambda = find_optimal(Y[i], TX[i], DEGREE, K_FOLD, LAMBDA, SEED)
        print('The best degree of tx_{}:'.format(i), best_deg, 'with lambda:', best_lambda)
        
        DEGREE_RIDGE.append(best_deg)
        LAMBDA_RIDGE.append(best_lambda)

The best degree of tx_0: 2 with lambda: 0.001
The best degree of tx_1: 6 with lambda: 0.001
The best degree of tx_2: 6 with lambda: 1e-06
The best degree of tx_3: 5 with lambda: 0.001


In [274]:
if METHOD == 'k_means':
    TX, Y, r_ids = data_preprocess(tX, y, ids, k_list = K_CLUSTER_RIDGE, replacing=METHOD, mode=MODE)
WEIGHT = generate_weights(TX, Y, DEGREE_RIDGE, LAMBDA_RIDGE, K_CLUSTER_RIDGE, r_ids)
y_pred = data_pred(TX, WEIGHT, r_ids, True)
#print(type(y_pred),type(y))
metric_pred(y_pred, y)

-1.0


0.803408

In [271]:
DATA_TRAIN_PATH = '../data/test.csv' 
test_y, test_tX, test_ids = load_csv_data(DATA_TRAIN_PATH)
test_tX, _, test_ids = data_preprocess(test_tX, test_y, test_ids, replacing=METHOD, mode='std')
test_pred = data_pred(test_tX, WEIGHT, test_ids, True)
OUTPUT_PATH = 'data/pred.csv'
create_csv_submission(test_ids, test_pred, OUTPUT_PATH)

## Logistic Regression

### Logistic Regression (SGD)

In [199]:
from implementations import *
from proj1_helpers import *
from preprocess import *
from cross_validation import *
from helpers import *
from costs import *

In [179]:
# Set H-parameters
MAX_ITERS_LOGIC = 500
THRESHOLD_LOGIC = 1e-8
GAMMA_LOGIC = 0.005
BATCH_SIZE = 128

In [180]:
# Reshape label y to 2D-array
Y_LOGIC = []

for i in range(len(Y)):
    y_binary = np.asarray(Y[i]).reshape((-1,1))
    y_binary = np.asarray([int(0) if i == -1 else i for i in y_binary],dtype = int)
    Y_LOGIC.append(y_binary) 

In [200]:
WEIGHT_LOGIC = []
LOSS_LOGIC = []
for i in range(len(Y)):
    INITIAL_W = np.zeros((TX[i].shape[1]+1, 1))
    w_rlogic, loss_rlogic = logistic_regression(np.asarray(Y_LOGIC[i]).reshape((-1,1)), TX[i], INITIAL_W, MAX_ITERS_LOGIC,
                                                GAMMA_LOGIC, THRESHOLD_LOGIC, BATCH_SIZE)
    WEIGHT_LOGIC.append(w_rlogic)
    LOSS_LOGIC.append(loss_rlogic)

Current iteration=0, loss=88.722839111673
Current iteration=100, loss=53.655012723378306
Current iteration=200, loss=55.4918995919335
Current iteration=300, loss=58.33544673461169
Current iteration=400, loss=46.3122419828849
Current iteration=0, loss=88.722839111673
Current iteration=100, loss=69.62082062489979
Current iteration=200, loss=61.62998441464148
Current iteration=300, loss=70.28441520061615
Current iteration=400, loss=77.63005738428228
Current iteration=0, loss=88.72283911167298
Current iteration=100, loss=63.07054827604129
Current iteration=200, loss=70.63243652185477
Current iteration=300, loss=69.34078852536939
Current iteration=400, loss=66.49767907174837
Current iteration=0, loss=88.72283911167298
Current iteration=100, loss=59.742004131570894
Current iteration=200, loss=71.28726595604851
Current iteration=300, loss=65.36134739322296
Current iteration=400, loss=61.503698937266


In [201]:
y_pred = data_pred(TX, WEIGHT_LOGIC, r_ids)
#print(type(y_pred),type(y))
metric_pred(y_pred, y)

0.75836

In [202]:
DATA_TRAIN_PATH = '../data/test.csv' 
test_y, test_tX, test_ids = load_csv_data(DATA_TRAIN_PATH)
test_tX, _, test_ids = data_preprocess(test_tX, test_y, test_ids, replacing=METHOD, mode='std')
test_pred = data_pred(test_tX, WEIGHT_LOGIC, test_ids)
OUTPUT_PATH = 'data/pred.csv'
create_csv_submission(test_ids, test_pred, OUTPUT_PATH)

### Regularized Logistic Regression (SGD)

In [244]:
from implementations import *
from proj1_helpers import *
from preprocess import *
from cross_validation import *
from helpers import *
from costs import *

In [241]:
# Set H-parameters
MAX_ITERS_RLOGIC = 1000
THRESHOLD_RLOGIC = 1e-8
GAMMA_RLOGIC = 0.01
BATCH_SIZE = 128
LAMBDA = 0.001

In [242]:
# Reshape label y to 2D-array
Y_REG_LOGIC = []

for i in range(len(Y)):
    for i in range(len(Y)):
        y_binary = np.asarray(Y[i]).reshape((-1,1))
        y_binary = np.asarray([int(0) if i == -1 else i for i in y_binary],dtype = int)
        Y_REG_LOGIC.append(y_binary)

In [245]:
# Fine-tuning each splitted dataset
WEIGHT_RLOGIC = []
LOSS_RLOGIC = []
for i in range(len(Y)):
    INITIAL_W = np.zeros((TX[i].shape[1]+1, 1))     #y, tx, lambda_, initial_w, max_iters, gamma, threshold
    w_rlogic, loss_rlogic = reg_logistic_regression(np.asarray(Y_REG_LOGIC[i]).reshape((-1,1)), TX[i], LAMBDA, INITIAL_W, MAX_ITERS_RLOGIC,
                                                GAMMA_RLOGIC, THRESHOLD_RLOGIC, BATCH_SIZE)
    
    WEIGHT_RLOGIC.append(w_rlogic)
    LOSS_RLOGIC.append(loss_rlogic)

Current iteration=0, loss=88.72283911167298
Current iteration=100, loss=47.32069271471683
Current iteration=200, loss=49.71443139802818
Current iteration=300, loss=54.00090160179376
Current iteration=400, loss=56.369835547058095
Current iteration=500, loss=50.911846434284634
Current iteration=600, loss=49.412694370437706
Current iteration=700, loss=34.26368859356684
Current iteration=800, loss=39.75637813708755
Current iteration=900, loss=46.59108173235637
Current iteration=0, loss=88.72283911167298
Current iteration=100, loss=80.96947981954878
Current iteration=200, loss=69.73678744396683
Current iteration=300, loss=63.87935681529479
Current iteration=400, loss=73.48936243255092
Current iteration=500, loss=69.40203975264241
Current iteration=600, loss=73.51532079859786
Current iteration=700, loss=75.5314775148134
Current iteration=800, loss=79.35469123105283
Current iteration=900, loss=72.27375362235739
Current iteration=0, loss=88.72283911167298
Current iteration=100, loss=69.1037822

In [246]:
y_pred = data_pred(TX, WEIGHT_RLOGIC, r_ids)
#print(type(y_pred),type(y))
metric_pred(y_pred, y)

0.75174

In [247]:
DATA_TRAIN_PATH = '../data/test.csv' 
test_y, test_tX, test_ids = load_csv_data(DATA_TRAIN_PATH)
test_tX, _, test_ids = data_preprocess(test_tX, test_y, test_ids, replacing=METHOD, mode='std')
test_pred = data_pred(test_tX, WEIGHT_RLOGIC, test_ids)
OUTPUT_PATH = 'data/pred.csv'
create_csv_submission(test_ids, test_pred, OUTPUT_PATH)

## Generate predictions and save ouput in csv format for submission:

In [3]:
DATA_TEST_PATH = 'data/test.csv' # TODO: download train data and supply path here 
Y_test, tX_test, ids_test = load_csv_data(DATA_TEST_PATH)

In [10]:
TX_test, _, r_ids_test = data_preprocess(tX_test, Y_test, ids_test, replacing='lr',
                                     mode='std_norm')

print(TX_test[0].shape)
print(TX_test[1].shape)
print(TX_test[2].shape)

(227458, 18)
(175338, 22)
(165442, 29)


In [31]:
# W_tr = []
OUTPUT_PATH = 'data/pred.csv' # TODO: fill in desired name of output file for submission

y_pred = data_pred(TX, W_tr)
create_csv_submission(ids_test, y_pred, OUTPUT_PATH)