In [1]:
import numpy as np
from src.NN import NN
import src.utils as utils

# Settings
csv_filename = "data/seeds_dataset.csv"
hidden_layers = [10,5] # number of nodes in hidden layers i.e. [layer1, layer2, ...]
eta = 0.1 # learning rate
n_epochs = 400 # number of training epochs
n_folds = 4 # number of folds for cross-validation
seed_crossval = 1 # seed for cross-validation
seed_weights = 1 # 

In [2]:

print("Reading '{}'...".format(csv_filename))
X, y, n_classes = utils.read_csv(csv_filename, target_name="y", normalize=True)

Reading 'data/seeds_dataset.csv'...


In [3]:
print(np.hsplit(X,7)[0])

[[ 0.14209777]
 [ 0.01118803]
 [-0.19206658]
 [-0.34709127]
 [ 0.44525718]
 [-0.16106164]
 [-0.05426685]
 [-0.25407645]
 [ 0.61406184]
 [ 0.54860697]
 [ 0.14209777]
 [-0.2816364 ]
 [-0.32986631]
 [-0.36776123]
 [-0.3815412 ]
 [-0.08871678]
 [-0.29541637]
 [ 0.29023248]
 [-0.05082185]
 [-0.73293052]
 [-0.23685149]
 [-0.25407645]
 [ 0.35568735]
 [-0.95341008]
 [ 0.05597294]
 [ 0.46248214]
 [-0.62958072]
 [-0.72604053]
 [-0.25407645]
 [-0.48144601]
 [-0.58135081]
 [ 0.22133261]
 [-0.26096644]
 [-0.31264134]
 [ 0.06975291]
 [ 0.43836719]
 [ 0.46592713]
 [ 0.76908654]
 [-0.01637192]
 [-0.19551157]
 [-0.45044107]
 [-0.46422104]
 [-0.58135081]
 [ 0.22477761]
 [ 0.09042287]
 [-0.36087125]
 [ 0.1765477 ]
 [ 0.04908295]
 [-0.01981691]
 [ 0.00429804]
 [-0.14383667]
 [ 0.32123742]
 [-0.12316671]
 [-0.1782866 ]
 [-0.11283173]
 [ 0.06286292]
 [-0.13350169]
 [ 0.024968  ]
 [ 0.18343769]
 [-0.9430751 ]
 [-1.18077964]
 [-1.24623451]
 [-0.85695027]
 [-0.56068085]
 [-0.71226056]
 [-0.67781062]
 [-0.17484

In [4]:
m,n=X.shape
n

7

In [19]:
def get_splits(X):
    m,n=X.shape
    splitWeights=list()
    splits = np.hsplit(X,n)
    # return splits

    for o,oX in enumerate(splits):
        N, d = oX.shape

        print("Neural network model:")
        print(" input_dim = {}".format(d))
        print(" hidden_layers = {}".format(hidden_layers))
        print(" output_dim = {}".format(n_classes))
        print(" eta = {}".format(eta))
        print(" n_epochs = {}".format(n_epochs))
        print(" n_folds = {}".format(n_folds))
        print(" seed_crossval = {}".format(seed_crossval))
        print(" seed_weights = {}\n".format(seed_weights))

        # Create cross-validation folds
        idx_all = np.arange(0, N)
        idx_folds = utils.crossval_folds(N, n_folds, seed=seed_crossval) # list of list of fold indices

        # Train/evaluate the model on each fold
        acc_train, acc_valid = list(), list()
        print("Cross-validating with {} folds...".format(len(idx_folds)))
        for i, idx_valid in enumerate(idx_folds):

            # Collect training and test data from folds
            idx_train = np.delete(idx_all, idx_valid)
            X_train, y_train = oX[idx_train], y[idx_train]
            X_valid, y_valid = oX[idx_valid], y[idx_valid]

            # Build neural network classifier model and train
            model = NN(input_dim=d, output_dim=n_classes,
                        hidden_layers=hidden_layers, seed=seed_weights)
            model.train(X_train, y_train, eta=eta, n_epochs=n_epochs)

            # Make predictions for training and test data
            ypred_train = model.predict(X_train)
            ypred_valid = model.predict(X_valid)

            # Compute training/test accuracy score from predicted values
            acc_train.append(100*np.sum(y_train==ypred_train)/len(y_train))
            acc_valid.append(100*np.sum(y_valid==ypred_valid)/len(y_valid))

            # Print cross-validation result
            print(" Fold {}/{}: acc_train = {:.2f}%, acc_valid = {:.2f}% (n_train = {}, n_valid {})".format(i+1,n_folds, acc_train[-1], acc_valid[-1], len(X_train), len(X_valid)))

        # Print results
        print(o," : ")
        print("  -> acc_train_avg = {:.2f}%, acc_valid_avg = {:.2f}%".format(sum(acc_train)/float(len(acc_train)), sum(acc_valid)/float(len(acc_valid))))

        splitWeights.append(model.get_weights())
    return splitWeights

In [20]:
get_splits(X)

Neural network model:
 input_dim = 1
 hidden_layers = [10, 5]
 output_dim = 3
 eta = 0.1
 n_epochs = 400
 n_folds = 4
 seed_crossval = 1
 seed_weights = 1

Cross-validating with 4 folds...
 Fold 1/4: acc_train = 86.71%, acc_valid = 84.62% (n_train = 158, n_valid 52)
 Fold 2/4: acc_train = 85.44%, acc_valid = 90.38% (n_train = 158, n_valid 52)
 Fold 3/4: acc_train = 85.44%, acc_valid = 84.62% (n_train = 158, n_valid 52)
 Fold 4/4: acc_train = 87.34%, acc_valid = 76.92% (n_train = 158, n_valid 52)
0  : 
  -> acc_train_avg = 86.23%, acc_valid_avg = 84.13%
Neural network model:
 input_dim = 1
 hidden_layers = [10, 5]
 output_dim = 3
 eta = 0.1
 n_epochs = 400
 n_folds = 4
 seed_crossval = 1
 seed_weights = 1

Cross-validating with 4 folds...
 Fold 1/4: acc_train = 84.81%, acc_valid = 84.62% (n_train = 158, n_valid 52)
 Fold 2/4: acc_train = 85.44%, acc_valid = 92.31% (n_train = 158, n_valid 52)
 Fold 3/4: acc_train = 88.61%, acc_valid = 86.54% (n_train = 158, n_valid 52)
 Fold 4/4: acc_tra

[[[[-1.66406618808289],
   [1.7401206554276416],
   [1.8034835632914996],
   [-0.19228893147114443],
   [0.300053500365052],
   [-1.4792857506676047],
   [1.6877894979496824],
   [1.713910044616189],
   [-1.2815484710376994],
   [-1.6131794076081907]],
  [[1.1007714443757592,
    0.1403067096035861,
    0.4610631947405939,
    -0.025970896436359865,
    0.3729601487480798,
    0.8588089818148654,
    -0.08028337645595728,
    0.6365290235998164,
    0.9906384599436214,
    0.26491614050573653],
   [-2.587396934508217,
    3.4506802466052147,
    3.8945512576333985,
    0.5464331315126455,
    0.9158606262240984,
    -1.1741011480670913,
    3.1556129029426674,
    3.3156957285125968,
    -0.6920280225633846,
    -1.9007071901595771],
   [-2.4455938162029818,
    0.9930932402424065,
    0.9898660757308997,
    -0.5103018078593474,
    -0.30294088023280974,
    -1.7047632804993238,
    1.7231734732922,
    1.45961227594573,
    -0.9731234208177453,
    -2.3257543431464165],
   [1.3274638

In [21]:
wt=[[[[-1.66406618808289],
   [1.7401206554276416],
   [1.8034835632914996],
   [-0.19228893147114443],
   [0.300053500365052],
   [-1.4792857506676047],
   [1.6877894979496824],
   [1.713910044616189],
   [-1.2815484710376994],
   [-1.6131794076081907]],
  [[1.1007714443757592,
    0.1403067096035861,
    0.4610631947405939,
    -0.025970896436359865,
    0.3729601487480798,
    0.8588089818148654,
    -0.08028337645595728,
    0.6365290235998164,
    0.9906384599436214,
    0.26491614050573653],
   [-2.587396934508217,
    3.4506802466052147,
    3.8945512576333985,
    0.5464331315126455,
    0.9158606262240984,
    -1.1741011480670913,
    3.1556129029426674,
    3.3156957285125968,
    -0.6920280225633846,
    -1.9007071901595771],
   [-2.4455938162029818,
    0.9930932402424065,
    0.9898660757308997,
    -0.5103018078593474,
    -0.30294088023280974,
    -1.7047632804993238,
    1.7231734732922,
    1.45961227594573,
    -0.9731234208177453,
    -2.3257543431464165],
   [1.3274638925721933,
    0.5294949301529805,
    -0.21567262736287485,
    0.33529335674970756,
    0.6486243399859926,
    0.9063743930534388,
    0.5960920631706478,
    0.0767177662285607,
    0.9981478210207537,
    0.9777964411228849],
   [0.7313565608549771,
    -0.34690428873357276,
    -0.05725418508055866,
    0.6226634287418991,
    0.10115916093635202,
    0.7778710646591547,
    -0.9157934015133674,
    -0.7182549634360712,
    0.943406227269809,
    0.8072162988577991]],
  [[-1.6559730830187964,
    5.887268017961344,
    -7.243324726783445,
    -1.1196680205225846,
    -0.4381079586442381],
   [-1.6161730629300024,
    2.5255449268290215,
    5.005089660018295,
    -2.199234761666272,
    -2.4294523740044554],
   [0.6886604594675587,
    -6.591998291103454,
    -3.1665254913996743,
    1.3736324818162018,
    1.4988381951265506]]],
 [[[-1.760957330886613],
   [2.095409754048746],
   [2.1674524501761194],
   [-0.46754004079758615],
   [0.05064136663351292],
   [-1.5775765308169327],
   [2.0275943042591287],
   [2.0714579606847385],
   [-1.3822020728422049],
   [-1.7165737362783624]],
  [[1.0718096260561958,
    0.2111389502263268,
    0.5331888444414365,
    0.0006617710123636117,
    0.41022096790906754,
    0.8411750849845167,
    -0.009094402293305647,
    0.7077282189928987,
    0.9837128003586778,
    0.2389626515561315],
   [-2.3753676159160957,
    3.859850443886569,
    4.311121179344818,
    0.5554266431812824,
    0.9405981677638754,
    -1.0204815145922252,
    3.5434544816730495,
    3.7138557477004195,
    -0.5845216332026061,
    -1.7015185486584903],
   [-3.217341734624179,
    0.9103977929816283,
    0.9210539644441336,
    -1.0940314227306467,
    -0.8213583482134296,
    -2.457479336366994,
    1.6320087734442636,
    1.3762872386689096,
    -1.6943309781266727,
    -3.091959143083834],
   [1.302806460276597,
    0.5768966450194037,
    -0.16809740834713435,
    0.36581110446488563,
    0.6915918364019714,
    0.8890974934385805,
    0.644225735458328,
    0.12424419347959816,
    0.9887170083130944,
    0.9553527883230541],
   [0.6282684156115664,
    -0.40074412125906217,
    -0.11411152684611132,
    0.6138390242033851,
    0.11444395363441016,
    0.6815819338575478,
    -0.9671689661043635,
    -0.7718660837019291,
    0.8530249925846741,
    0.7058777446165432]],
  [[-1.7621008149414583,
    5.210208280536859,
    -7.615667509232555,
    -1.186085325192577,
    -0.2755568933426562],
   [-1.5104424463833874,
    2.696374151947845,
    5.427173986114994,
    -2.108413666918169,
    -2.3501485086217957],
   [0.7600379315721768,
    -6.501839779868569,
    -3.0101275872720055,
    1.423694815759757,
    1.6147858473722108]]],
 [[[-1.3119510893104636],
   [1.668422911521331],
   [1.7612337846745683],
   [0.7596008901628563],
   [0.8282389909630394],
   [-0.581501858657278],
   [1.8102493891133222],
   [1.7744940350744645],
   [-0.18753621091810932],
   [-1.0014774273984761]],
  [[1.1444443502539179,
    -0.0221914951928602,
    0.3025639555098891,
    -0.2539514010118421,
    0.1491243657989406,
    0.8012380694823723,
    -0.22590821754327053,
    0.4863697104354413,
    0.8882678936318018,
    0.2809287828948376],
   [-1.8191598984264463,
    0.12499325549264252,
    0.5424637589559137,
    -0.7278563608408565,
    -0.7270793526310657,
    -0.6357652197332441,
    -0.3959800466383205,
    -0.1709394748888976,
    -0.7720469542890304,
    -1.1155371661539455],
   [-1.9104812745995077,
    1.4047835565525497,
    1.4320120300256338,
    0.43859381774214073,
    0.5377859072378582,
    -0.8401129751498112,
    2.0027827689756266,
    1.7641530479976057,
    -0.1486902853655892,
    -1.6117566720607197],
   [1.347505855505424,
    0.6447609945381041,
    -0.09788434288918517,
    0.2951541206050129,
    0.6436365784269436,
    0.8679982847304446,
    0.7213509497029768,
    0.2015712671072528,
    0.94869583133904,
    0.9746244374825398],
   [0.8980461827943879,
    0.16194590706506232,
    0.45477323444607537,
    0.7558470938008289,
    0.331490074199901,
    0.8269447104605501,
    -0.3827454679676895,
    -0.18982578112472145,
    0.9898077006577013,
    0.9292564665335149]],
  [[-0.5960870683022026,
    -1.8288172303484016,
    0.8688371630889319,
    -0.9095759730757182,
    -0.7236607802229361],
   [-1.2672041737338635,
    0.8446841657054889,
    3.973361887724883,
    -1.3636176014710588,
    -1.4281043661039838],
   [0.48263323449356826,
    -1.2127419835402622,
    -4.8320627913875045,
    1.3929363558027954,
    1.3592155595023412]]],
 [[[-3.1555310182018563],
   [3.3457136780333125],
   [3.8054003055748034],
   [-1.1067191825308247],
   [-1.0345170865232904],
   [-1.8286767274970261],
   [3.2644978982994624],
   [3.358882967858341],
   [-1.4466548922585296],
   [-2.4896548675351533]],
  [[1.041532423658349,
    -0.1090333111836504,
    0.19670997501605755,
    -0.11272280889858391,
    0.29270010028132215,
    0.778608465821822,
    -0.32976331102284673,
    0.38440313475632865,
    0.8930178439739422,
    0.19428535879400022],
   [-1.9837723697941243,
    1.9624745920458702,
    2.6885458859769797,
    0.717391316523552,
    0.7321447234134087,
    -0.143101125230583,
    1.6279081675376004,
    1.8499129005333605,
    0.3081338680322052,
    -1.031446071769588],
   [-4.102191041519183,
    0.7851665534327187,
    0.8380362678591857,
    -1.7063222852026387,
    -1.5058140787868624,
    -3.1820088739243872,
    1.5181902933090954,
    1.2691886982462102,
    -2.2875428756525307,
    -3.9212136266972193],
   [1.3022682125280267,
    0.5651571209953244,
    -0.18366451700628159,
    0.39927732342369815,
    0.7408863862754367,
    0.8922611324589292,
    0.6341151537328688,
    0.11239673936458783,
    0.990178075964497,
    0.9561180138002843],
   [0.5719870960275392,
    -0.4013934634217006,
    -0.13332836939061082,
    0.6276658187542217,
    0.19835655886301423,
    0.600716968301957,
    -0.9690197524625243,
    -0.7754001199581206,
    0.758492319640818,
    0.6303628174606803]],
  [[-1.385008521293307,
    3.653334247876212,
    -7.176057378880299,
    -0.7211744158061424,
    -0.8683219710861244],
   [-1.524778738671275,
    2.2027431304081255,
    5.344180468875113,
    -2.2102335739030283,
    -1.9431026677915166],
   [0.5882533327023871,
    -4.667507817370612,
    -2.816178754341709,
    1.1919223377968575,
    1.5731650913584008]]],
 [[[-2.033286306426715],
   [1.6375921504148567],
   [1.648763483084844],
   [-0.39063084942329934],
   [0.11732919271282037],
   [-1.8077377403051562],
   [1.7708982086899603],
   [1.7129779147598767],
   [-1.559540662274387],
   [-1.971750480134103]],
  [[1.1157927800389376,
    0.0462270107670841,
    0.36509528987416817,
    -0.07842390617564882,
    0.31644860046263185,
    0.8438977343522894,
    -0.1790431810424879,
    0.5396279341974715,
    0.9689296458478236,
    0.2692579291956562],
   [-2.240991311826739,
    3.3070813331918507,
    3.7407693250187397,
    0.5668885576382686,
    0.9491230629733836,
    -0.8795397823941886,
    2.96623105173844,
    3.1413421747375674,
    -0.5577104998359973,
    -1.549257710037491],
   [-3.2041943978750593,
    0.8934209937047668,
    0.8780748702878911,
    -0.89559636739607,
    -0.634607890483122,
    -2.382589112109414,
    1.6404277794617466,
    1.3620363343416197,
    -1.6140261776777065,
    -3.0461442206329203],
   [1.314705334433384,
    0.5272396763583299,
    -0.21771007569275871,
    0.32067765016658967,
    0.6425738190173625,
    0.8760006238289477,
    0.5891577199982165,
    0.0722871623973461,
    0.9688586198283062,
    0.9573138143371492],
   [0.73016610557961,
    -0.3959547604486958,
    -0.10563847659780928,
    0.5915929376913027,
    0.08734675722479218,
    0.7275297827594672,
    -0.9846753309645002,
    -0.7787041297753209,
    0.8878207092303694,
    0.7879518425279438]],
  [[-1.7751009984743618,
    5.062899309217602,
    -6.300858705739076,
    -1.3712569742880825,
    -0.7130240086728938],
   [-1.6091483719684052,
    2.56927590343735,
    4.24812324211991,
    -2.0919941291458914,
    -2.251613433518257],
   [0.9500851494256328,
    -6.098661793543682,
    -3.4271093100662116,
    1.6541199691689517,
    1.7833026019287042]]],
 [[[-0.9213987209140126],
   [1.5770137615818962],
   [1.934715628295204],
   [-0.1556859634991464],
   [0.9261557128541228],
   [1.3720388256338125],
   [-1.5247089317987315],
   [1.369473465548608],
   [-0.7738134822400246],
   [-1.927538283399191]],
  [[0.20315912497748148,
    0.6777274229823873,
    1.1159620501196044,
    -0.4605521958895023,
    0.2913500524302374,
    0.768978069392706,
    -0.5307836386383148,
    1.0314447440224084,
    0.3316930232476588,
    -1.1056618018908764],
   [0.10055188324782502,
    0.3956774574733501,
    0.747553583117622,
    0.4502042225953009,
    0.24666273648333642,
    0.3845444726738219,
    0.10495097291416658,
    0.15083145435022213,
    0.5407694855876719,
    0.7683561727334365],
   [0.37480967909551854,
    -1.6529202416326383,
    -1.7192247445111615,
    0.0744927536265941,
    -1.088388648211476,
    -1.6184038916004169,
    1.1440422703386204,
    -1.176635730948693,
    0.6384428488425629,
    0.9360407408903166],
   [1.0048744360112347,
    0.13712999035696904,
    -0.6674893473557687,
    0.17095002265040785,
    0.2084646619124073,
    0.0747322622191985,
    1.0775913158927206,
    -0.2319251867045642,
    0.7989905712913755,
    0.9783970286134107],
   [-1.2314543491332648,
    1.0232579304059297,
    1.516244016127292,
    -0.2548182232074643,
    0.17636130605677297,
    0.6272559048678618,
    -1.6709508665242472,
    0.3856092927824669,
    -0.6589237910589687,
    -2.02874982495607]],
  [[0.3571392004746639,
    -1.983139078689069,
    2.775298942966706,
    -0.6422759905880038,
    -0.8120660997709025],
   [-1.1715403016765908,
    0.037589599716815685,
    -1.7112992068473976,
    1.277431057649496,
    -1.75341070190685],
   [0.8685845512024103,
    0.9976103825985612,
    -3.4621532740753476,
    -1.2741963607202327,
    0.535346164019056]]],
 [[[-2.6729937646560944],
   [1.8344482491855414],
   [2.0055956988270887],
   [-0.4917993510478493],
   [-0.19307375050860298],
   [-1.7115771028653404],
   [2.1566086357176086],
   [2.0268831389269946],
   [-0.1743334152195069],
   [-2.3997188136166536]],
  [[0.5023446532457259,
    0.5965935684298702,
    0.8813350669279391,
    -0.3161517060154177,
    0.21700252849554902,
    0.19409468308282637,
    0.33507633033372974,
    1.063708703144458,
    0.6319857190070338,
    -0.3618182131543738],
   [-2.704386613541276,
    0.7746331738619904,
    1.2811138619544897,
    -0.7093962006795828,
    -0.7050117640273628,
    -1.3784579595558504,
    0.30558413310546767,
    0.5730826830666503,
    -0.7434428963017616,
    -2.146369785832613],
   [-3.6432163153602217,
    1.032341161650672,
    1.2171449287790714,
    -0.851038512131185,
    -0.7476117924859325,
    -2.3167921540414738,
    1.7328144444792324,
    1.5738778240012263,
    -0.8198893842696257,
    -3.5392641402305065],
   [1.335045529792765,
    0.7856307593357655,
    0.03957581895073402,
    0.4813129767448901,
    0.8299728939969323,
    0.9447500232445354,
    0.869229818598808,
    0.33712853347001714,
    0.9978909318716088,
    1.0154293582798721],
   [0.5975522412472681,
    0.43594395186088974,
    0.7106931301544064,
    0.913876791283105,
    0.5404314204089496,
    0.737845944294801,
    -0.12571022635229437,
    0.06789249285478301,
    0.8781394627289659,
    0.6968430585221433]],
  [[-1.397805561869284,
    -2.0714033311405817,
    -2.3573953308632527,
    0.4081931380662904,
    -0.233960775741613],
   [-0.7269092128033953,
    3.9419571150491874,
    5.981591479375553,
    -1.8845429704628647,
    -1.3602696657808884],
   [0.6953537686516581,
    -1.4194037857327066,
    -3.218695306157055,
    0.10568568795293108,
    0.22881878474540596]]]]

In [22]:
# len(model.get_weights()[0][])
len(wt)

7

In [None]:
def set_w():
    w=list()
    for layer in model.get_weights():
        lw=[]
        for node in layer:
            lw.append(node["weights"])
        w.append(lw)
    return w
r=[]
for l in set_w():
    r.append(np.array(l).T.tolist())

In [None]:
len(r[1])