In [11]:
## classic pydata stack
%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline 

from NN_LSTM import *

plt.style.use('ggplot')
plt.rcParams['figure.figsize'] = (15,7)

from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
import itertools as it
from datetime import datetime

## torch 
import torch.nn as nn

from torch.utils.data import random_split

## SEEDING

torch.manual_seed(1)


REBUILD_DATA = True

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
def test_model(model, data_loader):
    predictions = np.array([])
    labels = np.array([])

    with torch.no_grad():
        for X, y in iter(data_loader):
            probs = model(X)
            preds = torch.argmax(probs, dim=1, keepdim=False)
            predictions = np.concatenate((predictions,preds), axis=None)
            labels= np.concatenate((labels,y),axis=None)

    return(accuracy_score(labels,predictions), f1_score(labels,predictions))

    print(confusion_matrix(labels,predictions))


In [16]:


###
# Feature engineering param
dataset_params_list={
    "psd":[False],
    "auto_corr":[False],
    "num_blocks":[30,20],
}


####
# Optimizer and model params
model_params_list={

"nn_specs": #Make sure that there are as much linear layers dims than activation function
    [
    ([8,8],[nn.Tanh(),nn.Tanh()]), 
    ],


"hidden_dim":[20,10],
"lr":[0.001],
"rstm_layers":[5,3,1],
"num_epochs":[400],

}

dataset_param_combinations = list(it.product(*(dataset_params_list[param_name] for param_name in dataset_params_list.keys())))

model_param_combinations = list(it.product(*(model_params_list[param_name] for param_name in model_params_list.keys())))

print(len(model_param_combinations)*len(dataset_param_combinations))

24


In [17]:
def grid_search(dataset_params_list, model_params_list):

    dataset_param_combinations = list(it.product(*(dataset_params_list[param_name] for param_name in dataset_params_list.keys())))

    model_param_combinations = list(it.product(*(model_params_list[param_name] for param_name in model_params_list.keys())))

    best_accuracy=0

    with open("cv_logs_{}.txt".format(datetime.now()), "a") as f_logs:
        ds_kwargs={}
        for ds_params in dataset_param_combinations:
            for i,key in enumerate(dataset_params_list):
                ds_kwargs[key]=ds_params[i]

            dataset = PolymerDataset(data_paths=["../data/AA66266AA.npy","../data/AA662266AA.npy"],lstm=True, **ds_kwargs)
            train_size = int(0.8 * len(dataset))
            test_size = len(dataset) - train_size
            train_data, test_data = random_split(dataset, [train_size, test_size])
            data_loader = DataLoader(test_data, batch_size=64, shuffle=False)
            num_features = dataset.data[0].shape[1]
            
            for model_params in model_param_combinations:
                model_kwargs={}
                for i,key in enumerate(model_params_list):
                    model_kwargs[key]=model_params[i]

                print(ds_kwargs | model_kwargs)

                model = LSTM.train(dataset=train_data, num_features=num_features, batch_size=64, num_blocks=ds_kwargs["num_blocks"], **model_kwargs)
                test_accuracy, test_f1 = test_model(model, data_loader)

                f_logs.write("Accuracy = {} | F1 = {} with params : {} \n".format(test_accuracy,test_f1, ds_kwargs| model_kwargs))
                f_logs.flush()

                if test_accuracy>best_accuracy:
                    best_accuracy = test_accuracy
                    best_f1=test_f1
                    best_model=model
                    best_params=ds_kwargs | model_kwargs

                    with open('best_params.txt', 'w') as f:
                        f.write("Best accuracy ({}) and f1 ({}) were reached with params {} \n".format(best_accuracy,best_f1, best_params))  
                        #for param in best_model.parameters():
                        #    f.write(param.data) 
                        f.close()

        print("Best accuracy ({}) and f1 ({}) were reached with params {}".format(best_accuracy,best_f1, best_params))
        f_logs.close()
        return best_model, best_params


In [18]:
best_model, best_params= grid_search(dataset_params_list, model_params_list)

{'psd': False, 'auto_corr': False, 'num_blocks': 20, 'nn_specs': ([8, 8], [Tanh(), Tanh()]), 'hidden_dim': 10, 'lr': 0.001, 'rstm_layers': 5, 'num_epochs': 400}
epoch=0/399, loss=0.7309266328811646, accuracy=48.06749725341797
epoch=50/399, loss=0.1605028361082077, accuracy=90.58259582519531
epoch=100/399, loss=0.18715010583400726, accuracy=91.66252136230469
epoch=150/399, loss=0.1308804154396057, accuracy=92.37655639648438
epoch=200/399, loss=0.21947576105594635, accuracy=92.80994415283203
epoch=250/399, loss=0.2000802904367447, accuracy=93.03730010986328
epoch=300/399, loss=0.07575014978647232, accuracy=93.31438446044922
epoch=350/399, loss=0.25947287678718567, accuracy=93.53108215332031
epoch=399/399, loss=0.11465419828891754, accuracy=93.63410186767578
{'psd': False, 'auto_corr': False, 'num_blocks': 20, 'nn_specs': ([8, 8], [Tanh(), Tanh()]), 'hidden_dim': 10, 'lr': 0.001, 'rstm_layers': 3, 'num_epochs': 400}
epoch=0/399, loss=0.47779297828674316, accuracy=67.95026397705078
epoch=5



epoch=0/399, loss=0.5021557807922363, accuracy=61.21847152709961
epoch=50/399, loss=0.3501514494419098, accuracy=90.89165496826172
epoch=100/399, loss=0.4295003116130829, accuracy=91.63765716552734
epoch=150/399, loss=0.28521978855133057, accuracy=92.0994644165039
epoch=200/399, loss=0.22034060955047607, accuracy=92.2806396484375
epoch=250/399, loss=0.19520454108715057, accuracy=92.53996276855469
epoch=300/399, loss=0.2753126919269562, accuracy=92.8703384399414
epoch=350/399, loss=0.21637852489948273, accuracy=93.00532531738281
epoch=399/399, loss=0.16919641196727753, accuracy=92.8490219116211
{'psd': False, 'auto_corr': False, 'num_blocks': 20, 'nn_specs': ([8, 8], [Tanh(), Tanh()]), 'hidden_dim': 8, 'lr': 0.001, 'rstm_layers': 5, 'num_epochs': 400}
epoch=0/399, loss=0.7191191911697388, accuracy=49.28596878051758
epoch=50/399, loss=0.31978222727775574, accuracy=90.5186538696289
epoch=100/399, loss=0.25361740589141846, accuracy=91.80461883544922
epoch=150/399, loss=0.1754220575094223, 



epoch=0/399, loss=0.7120316028594971, accuracy=51.765541076660156
epoch=50/399, loss=0.22018998861312866, accuracy=90.38365936279297
epoch=100/399, loss=0.16834838688373566, accuracy=91.39964294433594
epoch=150/399, loss=0.15435785055160522, accuracy=91.6909408569336
epoch=200/399, loss=0.07727361470460892, accuracy=92.14209747314453
epoch=250/399, loss=0.17888562381267548, accuracy=92.10657501220703
epoch=300/399, loss=0.26252999901771545, accuracy=92.3516845703125
epoch=350/399, loss=0.22944270074367523, accuracy=92.45115661621094
epoch=399/399, loss=0.18949460983276367, accuracy=92.55062103271484
{'psd': False, 'auto_corr': False, 'num_blocks': 3, 'nn_specs': ([8, 8], [Tanh(), Tanh()]), 'hidden_dim': 10, 'lr': 0.001, 'rstm_layers': 5, 'num_epochs': 400}
epoch=0/399, loss=0.7138363122940063, accuracy=50.10301971435547
epoch=50/399, loss=0.2881923317909241, accuracy=88.22024536132812
epoch=100/399, loss=0.2161921113729477, accuracy=89.60923767089844
epoch=150/399, loss=0.2450992017984



epoch=0/399, loss=0.7133114337921143, accuracy=51.31793975830078
epoch=50/399, loss=0.16322939097881317, accuracy=88.8134994506836
epoch=100/399, loss=0.3037637174129486, accuracy=90.071044921875
epoch=150/399, loss=0.33170533180236816, accuracy=90.62877655029297
epoch=200/399, loss=0.19436174631118774, accuracy=90.72468566894531
epoch=250/399, loss=0.2501072287559509, accuracy=90.95204162597656
epoch=300/399, loss=0.1490132063627243, accuracy=91.28596496582031
epoch=350/399, loss=0.2938686013221741, accuracy=91.31793975830078
epoch=399/399, loss=0.3057388961315155, accuracy=91.45293426513672
{'psd': False, 'auto_corr': False, 'num_blocks': 3, 'nn_specs': ([8, 8], [Tanh(), Tanh()]), 'hidden_dim': 8, 'lr': 0.001, 'rstm_layers': 5, 'num_epochs': 400}
epoch=0/399, loss=0.71634840965271, accuracy=49.9680290222168
epoch=50/399, loss=0.39988940954208374, accuracy=86.98401641845703
epoch=100/399, loss=0.2942923903465271, accuracy=88.72468566894531
epoch=150/399, loss=0.17623396217823029, accu



epoch=0/399, loss=0.7094389796257019, accuracy=49.62699890136719
epoch=50/399, loss=0.4291994571685791, accuracy=88.31261444091797
epoch=100/399, loss=0.36213135719299316, accuracy=89.52753448486328
epoch=150/399, loss=0.1671294867992401, accuracy=90.06039428710938
epoch=200/399, loss=0.16253109276294708, accuracy=90.50088500976562
epoch=250/399, loss=0.1490405648946762, accuracy=90.62877655029297
epoch=300/399, loss=0.18575610220432281, accuracy=90.71403503417969
epoch=350/399, loss=0.29380911588668823, accuracy=90.92362213134766
epoch=399/399, loss=0.2813403308391571, accuracy=91.20781707763672
{'psd': False, 'auto_corr': False, 'num_blocks': 5, 'nn_specs': ([8, 8], [Tanh(), Tanh()]), 'hidden_dim': 10, 'lr': 0.001, 'rstm_layers': 5, 'num_epochs': 400}
epoch=0/399, loss=0.7173842191696167, accuracy=50.131439208984375
epoch=50/399, loss=0.2459491342306137, accuracy=89.26820373535156
epoch=100/399, loss=0.2649412453174591, accuracy=90.11722564697266
epoch=150/399, loss=0.203349500894546



epoch=0/399, loss=0.7254657745361328, accuracy=51.53108215332031
epoch=50/399, loss=0.32597827911376953, accuracy=89.33570098876953
epoch=100/399, loss=0.18401479721069336, accuracy=90.25577545166016
epoch=150/399, loss=0.19406235218048096, accuracy=91.01953887939453
epoch=200/399, loss=0.2731916904449463, accuracy=91.60568237304688
epoch=250/399, loss=0.2931877672672272, accuracy=91.6483154296875
epoch=300/399, loss=0.27602842450141907, accuracy=91.76909637451172
epoch=350/399, loss=0.19052278995513916, accuracy=91.89698028564453
epoch=399/399, loss=0.22847823798656464, accuracy=91.97868347167969
{'psd': False, 'auto_corr': False, 'num_blocks': 5, 'nn_specs': ([8, 8], [Tanh(), Tanh()]), 'hidden_dim': 8, 'lr': 0.001, 'rstm_layers': 5, 'num_epochs': 400}
epoch=0/399, loss=0.7119053602218628, accuracy=48.0355224609375
epoch=50/399, loss=0.20070745050907135, accuracy=88.50444030761719
epoch=100/399, loss=0.3287748694419861, accuracy=89.69449615478516
epoch=150/399, loss=0.2226546853780746



epoch=0/399, loss=0.7110241055488586, accuracy=47.928951263427734
epoch=50/399, loss=0.20455148816108704, accuracy=89.41740417480469
epoch=100/399, loss=0.2015937715768814, accuracy=89.89698028564453
epoch=150/399, loss=0.463584303855896, accuracy=90.54706573486328
epoch=200/399, loss=0.1740243285894394, accuracy=90.88809967041016
epoch=250/399, loss=0.23816753923892975, accuracy=91.4813461303711
epoch=300/399, loss=0.2880527079105377, accuracy=91.55595397949219
epoch=350/399, loss=0.22248941659927368, accuracy=91.82238006591797
epoch=399/399, loss=0.41025522351264954, accuracy=92.00355529785156
{'psd': False, 'auto_corr': False, 'num_blocks': 10, 'nn_specs': ([8, 8], [Tanh(), Tanh()]), 'hidden_dim': 10, 'lr': 0.001, 'rstm_layers': 5, 'num_epochs': 400}
epoch=0/399, loss=0.7187310457229614, accuracy=47.772647857666016
epoch=50/399, loss=0.13811884820461273, accuracy=90.7708740234375
epoch=100/399, loss=0.4052661955356598, accuracy=91.83658599853516
epoch=150/399, loss=0.183510988950729



epoch=0/399, loss=0.7230815291404724, accuracy=47.92539978027344
epoch=50/399, loss=0.23358191549777985, accuracy=90.58614349365234
epoch=100/399, loss=0.10513024777173996, accuracy=91.38898468017578
epoch=150/399, loss=0.26864707469940186, accuracy=91.97868347167969
epoch=200/399, loss=0.1348302960395813, accuracy=92.2877426147461
epoch=250/399, loss=0.08496806025505066, accuracy=92.5328598022461
epoch=300/399, loss=0.23176686465740204, accuracy=92.7779769897461
epoch=350/399, loss=0.13741666078567505, accuracy=92.94493865966797
epoch=399/399, loss=0.2447303831577301, accuracy=93.02664184570312
{'psd': False, 'auto_corr': False, 'num_blocks': 10, 'nn_specs': ([8, 8], [Tanh(), Tanh()]), 'hidden_dim': 8, 'lr': 0.001, 'rstm_layers': 5, 'num_epochs': 400}
epoch=0/399, loss=0.7163046598434448, accuracy=49.7158088684082
epoch=50/399, loss=0.4321558177471161, accuracy=90.24866485595703
epoch=100/399, loss=0.18410737812519073, accuracy=91.00177764892578
epoch=150/399, loss=0.2627924680709839,



epoch=0/399, loss=0.7120599150657654, accuracy=48.68916702270508
epoch=50/399, loss=0.2652615010738373, accuracy=90.29129791259766
epoch=100/399, loss=0.1555413156747818, accuracy=91.24689483642578
epoch=150/399, loss=0.13259641826152802, accuracy=91.77264404296875
epoch=200/399, loss=0.30524271726608276, accuracy=92.2095947265625
epoch=250/399, loss=0.09019938856363297, accuracy=92.39431762695312
epoch=300/399, loss=0.16046041250228882, accuracy=92.38365936279297
epoch=350/399, loss=0.16849100589752197, accuracy=92.55772399902344
epoch=399/399, loss=0.180532768368721, accuracy=92.59325408935547
Best accuracy (0.927536231884058) and f1 (0.9280474040632054) were reached with params {'psd': False, 'auto_corr': False, 'num_blocks': 20, 'nn_specs': ([8, 8], [Tanh(), Tanh()]), 'hidden_dim': 10, 'lr': 0.001, 'rstm_layers': 1, 'num_epochs': 400}
