# Performance comparaison

| Fingerprint | R2 | RMSE |  
|:-:|:-:|:-:|  
| ECFP| 0.765 | 0.9808 |
|Can2Can|0.7176|1.073|
|Enum2Enum|0.725|1.059|
|Transformer|0.862|0.750|
| NFP| 0.8845 | 0.6868 |

In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestRegressor
from sklearn.neural_network import MLPRegressor
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import mean_squared_error, r2_score

import chainer
from chainer import serializers
from seq2seq import Seq2seq, load_vocabulary
import sys
sys.path.append('./transformer')
from net import Transformer
import preprocess

UNK, EOS = 0, 1

In [2]:
df_train = pd.read_csv('../data/sol_train.csv')
df_test = pd.read_csv('../data/sol_test.csv')
df_train.head()

Unnamed: 0,SMILES,unknown,solubility,processed_smiles,spaced
0,[nH0]1c(SC)c2c([nH0]cc[nH0]2)[nH0]c1,6966-78-5,-2.36,[ n H 0 ] 1 c ( S C ) c 2 c ( [ n H 0 ] c c [ ...,[ n H 0 ] 1 c ( S C ) c 2 c ( [ n H 0 ] c c [ ...
1,CCC(C)Cl,78-86-4,-1.96,C C C ( C ) Cl,C C C ( C ) C l
2,O=C(NC(=O)c1ccccc1)c1ccccc1,614-28-8,-2.27,O = C ( N C ( = O ) c 1 c c c c c 1 ) c 1 c c ...,O = C ( N C ( = O ) c 1 c c c c c 1 ) c 1 c c ...
3,CC(C(C)(C)C)O,464-07-3,-0.62,C C ( C ( C ) ( C ) C ) O,C C ( C ( C ) ( C ) C ) O
4,[O-][N+](c1c(O)cccc1)=O,88-75-5,-1.74,[ O- ] [ N+ ] ( c 1 c ( O ) c c c c 1 ) = O,[ O - ] [ N + ] ( c 1 c ( O ) c c c c 1 ) = O


In [4]:
x_train = df_train['processed_smiles']
y_train = df_train['solubility']
x_test = df_test['processed_smiles']
y_test = df_test['solubility']

In [20]:
def load_data(vocabulary, lst):
    data = []
    for l in lst:
        words = l.strip().split()
        array = np.array([vocabulary.get(w, UNK) for w in words], np.int32)
        data.append(array)
    return data

In [30]:
source_ids = load_vocabulary('../data/Enum2Enum/vocab2.txt')
xnum_train = load_data(source_ids, x_train)
xnum_test = load_data(source_ids, x_test)

# Encode to fingerprint

In [10]:
model = Seq2seq(1, 46, 46, 256)
chainer.serializers.load_npz('../result/can2can_iter_132000.npz', model)

In [31]:
model = Seq2seq(2, 43, 43, 256)
chainer.serializers.load_npz('../result/Enum2Enum/model_epoch_3.npz', model)

In [33]:
X_train = model.encode(xnum_train)[0].data
X_test = model.encode(xnum_test)[0].data

In [3]:
x_train = df_train['spaced']
y_train = df_train['solubility']
x_test = df_test['spaced']
y_test = df_test['solubility']

In [4]:
en_path = os.path.join('../data/Enum2Enum', 'sval.txt')
source_vocab = ['<eos>', '<unk>', '<bos>'] + \
    preprocess.count_words(en_path, 50)
source_ids = {word: index for index, word in enumerate(source_vocab)}
source_words = {i: w for w, i in source_ids.items()}

def encode(x):
    words = preprocess.split_sentence(x)
    x = model.xp.array([source_ids.get(w, 1) for w in words], 'i')
    h = model.encode([x])
    return np.mean(h.data[0], axis=1)

100% (500000 of 500000) |################| Elapsed Time: 0:00:39 Time:  0:00:39


In [5]:
# Transformer
model = Transformer(2, 38, 38,
        256,
        h=4,
        dropout=0.1,
        max_length=500,
        use_label_smoothing=False,
        embed_position=False)
chainer.serializers.load_npz('../result/Transformer/model_iter_706000.npz', model)

In [21]:
X_train = [encode(x_train[i]) for i in range(len(x_train))]
X_test = [encode(x_test[i]) for i in range(len(x_test))]

[[ 2 10  8 13  6 11  6  3  4 15  3  5  3  6  3  4 10  8 13  6 11  3  3 10
   8 13  6 11  6  5 10  8 13  6 11  3  6  0]]
(1, 256, 38)
[[ 2  3  3  3  4  3  5  3 18  0]]
(1, 256, 10)
[[2 7 9 3 4 8 3 4 9 7 5 3 6 3 3 3 3 3 6 5 3 6 3 3 3 3 3 6 0]]
(1, 256, 29)
[[2 3 3 4 3 4 3 5 4 3 5 3 5 7 0]]
(1, 256, 15)
[[ 2 10  7 14 11 10  8 19 11  4  3  6  3  4  7  5  3  3  3  3  6  5  9  7
   0]]
(1, 256, 25)
[[ 2  3  6  3  4  7  5  3  3 10  8 13  6 11  3  6  0]]
(1, 256, 17)
[[2 3 3 6 3 4 3 5 3 4 3 5 3 3 4 3 5 3 6 3 0]]
(1, 256, 21)
[[2 3 6 3 4 3 4 9 7 5 7 3 3 3 5 3 3 4 3 4 3 6 7 5 7 5 7 0]]
(1, 256, 28)
[[ 2 10  8 13  6 11  6  3  4  3  5  3  3  4 10  8 13  6 11  3  6  3  5  8
  10 15 19  6 11  4 10  7 14 11  5  4  3  6  3  3  3  4  3  3  6  5  8  5
  10  7 14 11  0]]
(1, 256, 53)
[[ 2  3 18  3  6  3  4  3 18  5  3  3  4  3  4  3  6  5  7  5  3 18  0]]
(1, 256, 23)
[[2 3 3 3 7 3 3 3 0]]
(1, 256, 9)
[[ 2  7  9  3  4  8  5  3  3 18  0]]
(1, 256, 11)
[[ 2  8 22  3  3  3  3 18  0]]
(1, 256, 9)
[[2 3 3 7 3

(1, 256, 36)
[[2 7 9 3 4 3 3 4 3 5 8 5 7 0]]
(1, 256, 14)
[[ 2  3  6  3  4  7  5  3  6  3  3  3 10  8 13  6 11  3  6  3  3  6  0]]
(1, 256, 23)
[[2 3 6 3 4 7 5 3 3 3 4 3 6 5 3 9 7 0]]
(1, 256, 18)
[[2 3 3 3 3 3 4 3 5 4 3 5 7 0]]
(1, 256, 14)
[[2 3 6 3 4 3 3 3 3 6 5 3 6 3 3 3 3 3 6 0]]
(1, 256, 20)
[[2 3 3 6 3 4 3 5 3 3 4 3 3 6 5 7 0]]
(1, 256, 17)
[[ 2  3 18  3  6  3  4  3 22  8  5  3  4  3  3  3  6  5  3 18  0]]
(1, 256, 21)
[[2 3 3 4 3 4 8 4 3 3 5 3 3 5 9 7 5 7 3 6 3 3 3 3 6 3 3 3 3 3 6 6 0]]
(1, 256, 33)
[[ 2 10  8 13  6 11  6  3  4  8  3  4  3  5  3  5 10  8 13  6 11  3  4  8
   3  4  3  5  3  5 10  8 13  6 11  3  6 15  3  0]]
(1, 256, 40)
[[ 2  3 18  3  3  3  3 18  0]]
(1, 256, 9)
[[ 2  3  6  3  4  3 18  5  3  3  6  3  4  3  3  3  3  6  5  3  6  0]]
(1, 256, 22)
[[ 2 10  7 14 11 10 15 19  6 11  4 10  7 14 11  5  4  8  3  6 10  8 13  6
  11  3  4  3  5  3  4  3  5  7  6  5  3  6  3  3  3  4  3  3  6  5  8  0]]
(1, 256, 48)
[[ 2  8 22  3  3  6  3  4  3 22  8  5  3  3  3  3  6  0]]
(

(1, 256, 19)
[[2 3 3 6 3 4 7 5 3 3 3 3 6 0]]
(1, 256, 14)
[[ 2  7  3  6  3  4  3  4 10  8 13  6 11  6  3  3  3  6  3  4  8  5 10  8
  13  6 11  3 10  8 13  6 11  3  6  6  5  7  3  6  3  7  5  7  0]]
(1, 256, 45)
[[ 2 10  7 14 11 10 15 19  6 11  4 10  7 14 11  5  4  8  4  3  6  3  4  3
   5  3  4  3  5 10  8 13  6 11  7  6  5  3  4  3  5  9  7  5  3  6  3  3
   3  4  3  3  6  5  8  0]]
(1, 256, 56)
[[ 2  7  9  3  4  3  5  8  3  6  3  3  3  4  3  3  6  5 16  0]]
(1, 256, 20)
[[2 3 6 3 4 3 3 5 3 3 6 3 4 3 3 6 3 4 3 6 5 3 3 3 3 6 5 3 6 0]]
(1, 256, 30)
[[ 2  3  6  3  4  8  3  4  9  7  5  3  3  5  3  3  4  3  4  3  6  5  3 18
   5  3 18  0]]
(1, 256, 28)
[[2 3 6 3 6 3 3 4 7 5 3 3 3 6 3 3 3 6 0]]
(1, 256, 19)
[[ 2 10  7 14 11 10 15 19  6 11  4 10  7 14 11  5  4  8  3  6  3  3  4  7
  10  8 13  6 11  6  5  3  5  3  6  3  3  3  4  3  3  6  5  8  0]]
(1, 256, 45)
[[ 2  3  6  3  6  3  4  3 10  8 13  6 11  3  6  3  6  3  3  3  3  6  5  3
   3  3  6  0]]
(1, 256, 28)
[[2 3 3 4 3 3 4 9 7 5 3 3 4 3 

(1, 256, 13)
[[ 2  3 18  3  6  3  4  3 18  5  3  4  3  6  3  3  4  3  3  3  6  3 18  5
   3 18  5  3  4  3  4  3  6  5  3 18  5  3 18  0]]
(1, 256, 40)
[[ 2  3  3  6  3  4  3  3  3  3  5  3  4  7  5 10  8 13  6 11  3  4 10  8
  13  6 11  6  5  8  4  3  5  3  0]]
(1, 256, 35)
[[ 2  7  3  6  3  4  3 18  5  3  3  4  3  3  6  5  3 18  0]]
(1, 256, 19)
[[2 7 9 3 4 8 3 5 8 0]]
(1, 256, 10)
[[2 7 9 3 6 8 3 4 9 7 5 8 3 6 0]]
(1, 256, 15)
[[2 3 3 6 3 4 7 5 3 3 3 4 3 5 3 6 0]]
(1, 256, 17)
[[2 3 3 7 3 4 9 7 5 3 9 3 0]]
(1, 256, 13)
[[ 2  3  6  3  4  3  5  3  3  4  3  5 10  8 13  6 11  3  6  0]]
(1, 256, 20)
[[ 2 15  6  3  4  8  6  3  4  9  7  5  8  3  3  6  5 10  8 13  6 11  3  3
   6 10  8 19 11  4 10  7 14 11  5  9  7  0]]
(1, 256, 38)
[[2 7 9 3 4 3 5 8 3 6 3 3 3 4 3 3 6 5 7 3 0]]
(1, 256, 21)
[[2 3 6 3 9 3 3 9 3 3 9 3 6 0]]
(1, 256, 14)
[[ 2  3  6  3  4  3 18  5  3  4  7  3  6  3  3  3  4  3  3  6  5 10  8 19
  11  4 10  7 14 11  5  9  7  5  3  3  3  6  3 18  0]]
(1, 256, 41)
[[ 2  3  3  3 10

(1, 256, 44)
[[ 2  7  9  3  6  3  9  3  3  6  4  3  5  3  6  3  3  3  6  4  3  5  3  4
   3  3  3  6  4  3  4  3  5  9  7  5  7  3  4  3  5  9  7  5  3  6  3  9
   3  4  3  6  9  3  6  5  3 18  0]]
(1, 256, 59)
[[2 3 3 7 3 4 9 7 5 3 3 3 4 9 7 5 7 3 3 0]]
(1, 256, 20)
[[ 2  7  9  3  6  3  9  3  4  8  3  4  9 15  5  8  6  5  3  0]]
(1, 256, 20)
[[2 3 9 3 4 3 5 3 4 7 5 9 7 0]]
(1, 256, 14)
[[2 3 3 3 4 3 5 3 3 0]]
(1, 256, 10)
[[2 3 6 3 4 3 3 5 3 3 3 3 6 0]]
(1, 256, 14)
[[ 2 10  7 14 11 10 15 19  6 11  4 10  7 14 11  5  4  8  3  4  8  5  9  8
   5  3  6  3  3  3  4  3  3  6  5  8  0]]
(1, 256, 37)
[[ 2 10  8 13  6 11  6  3  4  8  3  4  3  5  3  5 10  8 13  6 11  3  4  8
   3  3  5 10  8 13  6 11  3  6 15  3  0]]
(1, 256, 37)
[[2 3 3 4 3 3 4 3 5 4 3 5 7 5 3 0]]
(1, 256, 16)
[[2 7 3 4 3 7 5 3 4 3 7 5 7 0]]
(1, 256, 14)
[[2 7 3 6 3 4 8 5 3 3 3 3 6 0]]
(1, 256, 14)
[[ 2 10  8 13  6 11  6  3  4  8  3  4  3  5  4  3  5  3  5 10  8 13  6 11
   3  4 10  8 13  6 11  3  6  8  3  3  5 15  3  0]]
(1,

(1, 256, 35)
[[2 3 6 3 4 3 9 7 5 3 3 6 3 4 7 3 7 6 5 3 6 0]]
(1, 256, 22)
[[ 2  3 18  3  6  3  4  3 18  5  3  3  3  4  3  6  3  3  3  4  3  4  3 18
   5  3  6  3 18  5  3 18  5  3  6  3 18  0]]
(1, 256, 38)
[[ 2  3 18  3  6  3  4  3  6  3  4  3 18  5  3  3  3  3  6  5  3  3  3  3
   6  0]]
(1, 256, 26)
[[ 2  3  6  3  4  3  6  3  4  3 18  5  3  4  3 18  5  3  3  3  6  5  3  4
   3 18  5  3  4  3  3  6  5  3 18  0]]
(1, 256, 36)
[[2 7 9 3 3 6 3 3 3 7 6 0]]
(1, 256, 12)
[[ 2 10  7 14 11 10 15 19  6 11  4 10  7 14 11  5  4  8  3  4  8  3  3  3
   3  5  9  7  5  3  6  3  3  3  4  3  3  6  5  8  0]]
(1, 256, 41)
[[ 2  3  6  3  4 15  3 15 10 25 19 11  4 10 15 14 11  5  4  7  3  3  5  7
   3  3  5  3  3  3  4  3  6  5  3 18  0]]
(1, 256, 37)
[[ 2  3 18  3  6  3  4  3 18  5  3  3  4  3  3  6  5  8  0]]
(1, 256, 19)
[[2 3 6 3 4 7 5 3 3 3 3 6 3 0]]
(1, 256, 14)
[[ 2  3  6  3  4  7  3  5 10  8 13  6 11  3  6  3  4  3 10  8 13  6 11  3
  10  8 13  6 11  6  5 10  8 13  6 11  6  0]]
(1, 256, 38)
[[ 2

(1, 256, 16)
[[ 2  3  6  9  3  6  3  3  3  6  3  4  3  3  3  6  4  3  5  3  6  3  3  3
   6  4  3 22  3  5  7  5  3  6  3  3  3  6  9  7  0]]
(1, 256, 41)
[[2 3 3 4 7 3 3 5 3 0]]
(1, 256, 10)
[[ 2  7  9  3  6  3  9  3  3  6  4  3  5  3  4  3  4  3  5  3  3  6  3  6
   3  3  3  4  3  6  4  3  5  3  3  4  3  6  6 16  5  7  5  4  3  4  3  5
   9  7  5  7  5  9  3  6  0]]
(1, 256, 57)
[[2 3 6 3 4 3 3 3 3 6 5 3 6 3 3 3 4 3 3 6 5 3 6 3 3 3 3 3 6 0]]
(1, 256, 30)
[[ 2  3  3  3  3 18  0]]
(1, 256, 7)
[[ 2  3  6  3  4  3  6  3  4  3 18  5  3  3  3  3  6  5  3  3  3  4  3  6
   3 18  5  3 18  0]]
(1, 256, 30)
[[ 2  3  6  3  4  7  5  3  3  3  4  3  6  5 10  8 19 11  4 10  7 14 11  5
   9  7  0]]
(1, 256, 27)
[[2 3 6 3 6 3 3 3 4 3 3 6 3 3 3 6 7 3 5 3 4 3 5 3 4 7 5 9 7 0]]
(1, 256, 30)
[[2 7 9 3 4 3 3 4 9 7 5 7 3 3 5 7 3 3 0]]
(1, 256, 19)
[[ 2  3  6  3  4  7 10 25 19 11  4 10 15 14 11  5  4  7  3  3  5  7  3  3
   5  3  3  3  4  3  6  5 10 15 19 11  4  3  5 10  7 14 11  0]]
(1, 256, 44)
[[ 2  3 18

(1, 256, 32)
[[ 2  3 18  3  6  3  3  4  3  3  3  6  7  3  4  3  5  3  4  7  5  9  7  5
   3 18  0]]
(1, 256, 27)
[[2 3 6 3 4 3 5 3 6 3 3 3 3 3 6 3 4 3 6 5 3 0]]
(1, 256, 22)
[[2 3 8 6 3 6 3 3 4 3 3 6 3 6 3 6 7 6 5 7 3 4 3 4 3 6 3 3 3 3 3 6 5 3 7 5
  9 7 0]]
(1, 256, 39)
[[2 7 9 3 4 7 5 3 3 3 0]]
(1, 256, 11)
[[ 2 10  8 13  6 11  6  3  6  3  3  3  3  3  6 10  8 13 11 10  8 13  6 11
   6  0]]
(1, 256, 26)
[[ 2  3  6  3  4 15  3  5 10  8 13  6 11  3  6  3  4  3 10  8 13  6 11  3
  10  8 13  6 11  6  5 10  8 13  6 11  6  0]]
(1, 256, 38)
[[2 3 6 3 4 8 5 3 3 6 3 4 3 3 6 3 4 3 6 5 3 3 3 3 6 5 3 6 0]]
(1, 256, 29)
[[ 2 23 24  3  4 23 24  5  4 23 24  5 23 24  0]]
(1, 256, 15)
[[ 2  7  9  3  4  3  4  3  5  7  3  6  3  3  3  3  3  6  3 18  5  7  0]]
(1, 256, 23)
[[2 3 6 3 4 7 5 3 3 3 4 3 6 5 3 4 7 5 9 7 0]]
(1, 256, 21)
[[2 3 3 3 4 3 5 3 4 3 5 4 3 5 7 0]]
(1, 256, 16)
[[ 2  3  6  3  6 10  8 13 11  3  3  4  3  5  3  6  3  3  3  6  0]]
(1, 256, 21)
[[2 3 6 3 4 3 5 3 6 3 3 3 3 3 6 3 3 6 0]]
(1, 256

(1, 256, 8)
[[ 2 10  7 14 11 10 15 19  6 11  6  4 10  7 14 11  5  8  3  4  3  6  3  6
   3  3  3  3  6  5  9  7  0]]
(1, 256, 33)
[[2 3 6 3 4 3 3 9 3 5 3 3 3 4 3 6 5 7 3 0]]
(1, 256, 20)
[[2 3 6 9 3 3 9 3 6 3 6 3 3 4 3 8 3 6 5 3 8 6 3 6 9 7 0]]
(1, 256, 27)
[[2 3 6 3 6 3 3 3 6 3 4 3 3 3 6 4 3 5 3 6 3 3 3 6 9 7 5 3 6 4 3 5 3 3 3 6
  7 0]]
(1, 256, 38)
[[2 3 6 3 3 6 3 4 3 3 4 3 4 3 5 9 3 5 7 6 5 3 6 3 6 3 4 3 6 3 6 3 3 4 3 4
  3 3 6 7 3 3 6 7 6 5 7 3 5 7 3 5 9 7 0]]
(1, 256, 55)
[[2 7 3 6 3 4 3 7 3 4 3 6 7 5 7 5 7 0]]
(1, 256, 18)
[[ 2 10  8 13 11  6  3  3  3  3  6  0]]
(1, 256, 12)
[[ 2  3  6  3  6  3  3  4  7  5  3  3  3  6 10  8 13  6 11  3  3  6  0]]
(1, 256, 23)
[[2 3 3 3 3 4 3 5 3 6 4 3 3 5 3 4 8 3 4 8 3 6 9 7 5 9 7 5 9 7 0]]
(1, 256, 31)
[[ 2  3  6  3  4 15  5  3  3  3  3  6  0]]
(1, 256, 13)
[[2 3 3 4 9 7 5 7 3 3 3 3 3 0]]
(1, 256, 14)
[[2 3 6 3 4 7 3 3 5 3 3 3 4 3 6 5 8 3 4 3 5 9 7 0]]
(1, 256, 24)
[[ 2  3 18  3  3  9  3  0]]
(1, 256, 8)
[[2 7 3 6 3 9 3 3 6 3 6 6 3 3 8 4 3 5 3 6

(1, 256, 25)
[[ 2  3  6  3  4  3  5  3  3  3  4  3  6  5  3 18  0]]
(1, 256, 17)
[[ 2  7  9  3  4  3  5  8  3  6  3  3  3  4  3  3  6  5 26  0]]
(1, 256, 20)
[[ 2  3  3  3  3  3  3  3  3 22  3  0]]
(1, 256, 12)
[[2 8 3 6 3 4 8 5 3 3 3 3 6 0]]
(1, 256, 14)
[[2 3 3 8 4 3 3 7 3 4 3 6 3 3 3 4 3 3 6 5 7 3 3 3 3 5 9 7 5 3 3 0]]
(1, 256, 32)
[[ 2  7  9  3  4  8  5  3  6  3 10  8 13  6 11  3  3 10  8 13  6 11  6  0]]
(1, 256, 24)
[[2 7 9 3 4 8 5 7 3 3 3 3 0]]
(1, 256, 13)
[[ 2  3  6  9  3  6  3  3  3  6  3  6  3  3  4  3  5  3  4  3  6  4  3  5
   3  3  4  3  6  4  3  6  4  3  5  3  9  3  3  6  9  7  5 16  5  7  5  4
   3  4  3  7  5  9  7  5  7  3  4  3  5  9  7  0]]
(1, 256, 64)
[[ 2  3 18  3  6  3  4  3 18  5  3  4  3  4  3  4  3  6  3 18  5  3 18  5
   3 18  5  3 18  0]]
(1, 256, 30)
[[2 3 6 3 6 3 3 3 4 3 5 4 3 6 4 3 5 3 5 3 6 7 0]]
(1, 256, 23)
[[2 7 9 3 4 3 3 3 3 3 3 3 5 7 0]]
(1, 256, 15)
[[ 2  3  6  6  3  3  3  3  3  6 10  8 13 11  3  3  6  3  3  4  3  3  5  8
   0]]
(1, 256, 25)
[[ 2 

(1, 256, 18)
[[2 7 9 3 6 3 6 3 4 3 4 9 7 5 3 6 3 3 3 3 3 6 6 5 3 3 3 3 6 0]]
(1, 256, 30)
[[2 3 3 3 3 8 4 3 3 3 3 5 3 3 3 3 0]]
(1, 256, 17)
[[2 7 9 3 4 8 5 3 6 3 3 3 3 3 6 0]]
(1, 256, 16)
[[ 2 10  8 13  6 11  6  3 10  8 13  6 11  3 10  8 13  6 11  6  3  3  4  3
   6  3  3  3  4  3  3  6  5 16  5  4  3  6  3  3  3  3  3  6 16  5  7  0]]
(1, 256, 48)
[[2 7 9 3 6 3 6 3 4 3 4 9 7 5 3 6 3 6 3 3 3 3 6 5 3 4 3 4 3 3 6 5 7 5 7 0]]
(1, 256, 36)
[[2 3 3 4 3 5 4 3 5 3 6 3 3 3 3 3 6 0]]
(1, 256, 18)
[[2 7 9 3 3 9 3 4 3 3 3 9 3 4 3 5 3 5 3 0]]
(1, 256, 20)
[[ 2  3 18  3  6  3  4  3  6  3  4  3 18  5  3  4  3 18  5  3  4  3  4  3
   6  3 18  5  3 18  5  3 18  5  3  4  3 18  5  3  4  3  3  6  3 18  5  3
  18  0]]
(1, 256, 50)
[[2 7 9 3 4 3 5 7 0]]
(1, 256, 9)
[[2 3 3 3 3 3 3 3 4 3 5 7 0]]
(1, 256, 13)
[[2 7 9 3 6 3 4 3 3 3 3 5 3 4 8 4 3 6 3 3 3 4 3 3 6 5 7 5 8 6 3 6 3 3 3 3
  3 6 5 9 7 0]]
(1, 256, 42)
[[2 7 3 6 3 3 3 6 7 3 6 3 4 3 4 9 7 5 3 6 3 6 5 3 4 3 3 4 3 6 5 7 3 5 7 0]]
(1, 256, 36)
[[2 3 6 

(1, 256, 41)
[[ 2  3  6  3  4  3  6  3  4  3 18  5  3  4  3 18  5  3  4  3  4  3  6  3
  18  5  3 18  5  3 18  5  3  3  3  3  6  0]]
(1, 256, 38)
[[ 2  7  9  3  4  3  6  3  3  3  4  3 18  5  3  3  6  5 10  8 13  6 11  6
   3  4  3  5  3  4  3  3  4  7  5  9  7  5  3  6  3  3  4  3  3  3  6  6
   5  7  3  0]]
(1, 256, 52)
[[2 7 9 3 6 3 4 3 4 9 7 5 8 4 3 6 3 3 3 3 3 6 5 8 6 3 6 3 3 3 3 3 6 5 3 3
  3 3 0]]
(1, 256, 39)
[[ 2 10  7 14 11 10  8 19 11  4  9  7  5  3  6  3  3  3  3  6  3  6  3  3
   3  3  6  0]]
(1, 256, 28)
[[2 3 6 3 6 3 3 4 3 6 3 4 3 3 4 3 6 4 3 5 3 4 3 3 3 6 6 5 3 4 3 5 3 3 3 4
  7 5 9 7 5 7 5 3 6 4 3 5 3 3 3 6 7 5 7 0]]
(1, 256, 56)
[[ 2 23 24  3  3  3 23 24  0]]
(1, 256, 9)
[[2 3 6 3 4 8 3 6 3 3 3 3 3 6 5 3 3 3 3 6 0]]
(1, 256, 21)
[[ 2 15  6  3  4  3  3  5  3  3  3  6  0]]
(1, 256, 13)
[[2 7 9 3 4 3 3 3 3 3 5 7 3 0]]
(1, 256, 14)
[[2 3 3 4 7 3 3 3 5 3 0]]
(1, 256, 11)
[[2 7 9 3 4 3 5 8 3 6 3 3 3 4 3 3 6 5 7 3 4 3 5 9 7 0]]
(1, 256, 26)
[[ 2 23 24  3  3  4  3  5 23 24  0]

(1, 256, 29)
[[ 2 26  3  3  3  3  3  3  3  0]]
(1, 256, 10)
[[2 3 3 3 3 3 3 3 6 3 4 7 5 3 3 4 3 3 6 5 7 0]]
(1, 256, 22)
[[2 3 3 3 3 3 3 3 3 3 3 3 4 9 7 5 7 0]]
(1, 256, 18)
[[ 2  3  6  3  4  3 18  5  3  3  4  3  3  6  5  3  6  3  3  4  3  3  3  6
   3 18  5  3 18  0]]
(1, 256, 30)
[[2 7 6 3 4 3 5 3 4 3 3 6 5 3 4 8 3 6 3 3 3 3 3 6 5 9 7 0]]
(1, 256, 28)
[[ 2  3  6  3  4 23 24  5  3  3  3  3  6  0]]
(1, 256, 14)
[[2 7 9 3 4 3 5 8 3 6 3 3 3 3 3 6 0]]
(1, 256, 17)
[[ 2  7  3  6  3  4  8  3  6  3  3  3  4  3  3  6  3  4  3  6  3  3  3  3
   3  6  3 18  5  9  8  6  5  3 18  5  9  7  0]]
(1, 256, 39)
[[ 2  3  3  6  3  4  3 18  5  3  4  3  5  3  3  4  3  6  5  7  0]]
(1, 256, 21)
[[ 2  3  8  4  3  5  3  3  8  4  3  3  6  3  3  3  3  3  6  5  3  6  3  3
   3  3 10  8 13  6 11  6  0]]
(1, 256, 33)
[[ 2  3 18  3  6  3  4  3 18  5  3  3  4  3  4 10  8 13  6 11  6  5  7 10
  25 19 11  4  7  3  5  4  7  3  5 10 15 14 11  5  3 18  0]]
(1, 256, 43)
[[2 3 6 3 4 3 5 3 3 6 3 4 3 3 6 3 4 3 6 5 3 3 3 3 6 

# Prediction

In [22]:
MLP = MLPRegressor((1000, 1000, 1000))
MLP.fit(X_train, y_train)

y_train_pred = MLP.predict(X_train)
print("Train R2: {:.4f}".format(r2_score(y_train, y_train_pred)))
print("Train MSE: {:.4f}".format(mean_squared_error(y_train, y_train_pred)))

y_pred = MLP.predict(X_test)
print("Test R2: {:.4f}".format(r2_score(y_test, y_pred)))
print("Test MSE: {:.4f}".format(mean_squared_error(y_test, y_pred)))

Train R2: 0.9988
Train MSE: 0.0051
Test R2: 0.8619
Test MSE: 0.5636


In [34]:
MLP = MLPRegressor((1000, 1000, 1000))
MLP.fit(X_train, y_train)

y_train_pred = MLP.predict(X_train)
print("Train R2: {:.4f}".format(r2_score(y_train, y_train_pred)))
print("Train MSE: {:.4f}".format(mean_squared_error(y_train, y_train_pred)))

y_pred = MLP.predict(X_test)
print("Test R2: {:.4f}".format(r2_score(y_test, y_pred)))
print("Test MSE: {:.4f}".format(mean_squared_error(y_test, y_pred)))

Train R2: 0.9781
Train MSE: 0.0907
Test R2: 0.6986
Test MSE: 1.2306


In [29]:
0.5636**0.5

0.7507329751649384

In [24]:
param_grid = { "hidden_layer_sizes": [(100, 100), (500, 500), (1000, 1000), (100,100,100), (500,500,500), (1000,1000,1000), (100,100,100,100), (500, 500,500,500), (1000,1000,1000,1000)]}
MLP_grid = GridSearchCV(estimator=MLPRegressor(), param_grid=param_grid, cv=5, verbose=1, n_jobs=8)
MLP_grid.fit(X_train, y_train)

Fitting 5 folds for each of 9 candidates, totalling 45 fits


[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  45 out of  45 | elapsed:  3.9min finished


GridSearchCV(cv=5, error_score='raise-deprecating',
       estimator=MLPRegressor(activation='relu', alpha=0.0001, batch_size='auto', beta_1=0.9,
       beta_2=0.999, early_stopping=False, epsilon=1e-08,
       hidden_layer_sizes=(100,), learning_rate='constant',
       learning_rate_init=0.001, max_iter=200, momentum=0.9,
       n_iter_no_change=10, nesterovs_momentum=True, power_t=0.5,
       random_state=None, shuffle=True, solver='adam', tol=0.0001,
       validation_fraction=0.1, verbose=False, warm_start=False),
       fit_params=None, iid='warn', n_jobs=8,
       param_grid={'hidden_layer_sizes': [(100, 100), (500, 500), (1000, 1000), (100, 100, 100), (500, 500, 500), (1000, 1000, 1000), (100, 100, 100, 100), (500, 500, 500, 500), (1000, 1000, 1000, 1000)]},
       pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',
       scoring=None, verbose=1)

In [25]:
MLP_grid.best_params_

{'hidden_layer_sizes': (1000, 1000, 1000, 1000)}

In [26]:
MLP_grid.best_estimator_.fit(X_train, y_train)
y_train_pred = MLP_grid.best_estimator_.predict(X_train)
y_test_pred = MLP_grid.best_estimator_.predict(X_test)

print('MSE train : %.3f, test : %.3f' % (mean_squared_error(y_train, y_train_pred), mean_squared_error(y_test, y_test_pred)) )
print('R2 train : %.3f, test : %.3f' % (r2_score(y_train, y_train_pred), r2_score(y_test, y_test_pred)) )

MSE train : 0.003, test : 0.570
R2 train : 0.999, test : 0.860


In [38]:
1.122**0.5

1.059245014149229

In [27]:
RF = RandomForestRegressor()
RF.fit(X_train, y_train)



RandomForestRegressor(bootstrap=True, criterion='mse', max_depth=None,
           max_features='auto', max_leaf_nodes=None,
           min_impurity_decrease=0.0, min_impurity_split=None,
           min_samples_leaf=1, min_samples_split=2,
           min_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=None,
           oob_score=False, random_state=None, verbose=0, warm_start=False)

In [28]:
y_train_pred = RF.predict(X_train)
y_test_pred = RF.predict(X_test)

print('MSE train : %.3f, test : %.3f' % (mean_squared_error(y_train, y_train_pred), mean_squared_error(y_test, y_test_pred)) )
print('R2 train : %.3f, test : %.3f' % (r2_score(y_train, y_train_pred), r2_score(y_test, y_test_pred)) )

MSE train : 0.181, test : 0.938
R2 train : 0.956, test : 0.770


In [14]:
param_grid = { "max_depth": [2,5,10, None],
                "n_estimators": [10,50,100,300],
                "max_features": [1, 3, 10],
                "min_samples_split": [2, 3, 10],
                "min_samples_leaf": [1, 3, 10]
}

In [15]:
RF_grid = GridSearchCV(estimator=RandomForestRegressor(), param_grid=param_grid, cv=5, n_jobs=8, verbose=1)
RF_grid.fit(X_train, y_train)

Fitting 5 folds for each of 432 candidates, totalling 2160 fits


[Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  34 tasks      | elapsed:    2.4s
[Parallel(n_jobs=8)]: Done 184 tasks      | elapsed:    6.9s
[Parallel(n_jobs=8)]: Done 434 tasks      | elapsed:   14.6s
[Parallel(n_jobs=8)]: Done 784 tasks      | elapsed:   26.4s
[Parallel(n_jobs=8)]: Done 1234 tasks      | elapsed:   45.4s
[Parallel(n_jobs=8)]: Done 1784 tasks      | elapsed:  1.2min
[Parallel(n_jobs=8)]: Done 2160 out of 2160 | elapsed:  1.5min finished


GridSearchCV(cv=5, error_score='raise-deprecating',
       estimator=RandomForestRegressor(bootstrap=True, criterion='mse', max_depth=None,
           max_features='auto', max_leaf_nodes=None,
           min_impurity_decrease=0.0, min_impurity_split=None,
           min_samples_leaf=1, min_samples_split=2,
           min_weight_fraction_leaf=0.0, n_estimators='warn', n_jobs=None,
           oob_score=False, random_state=None, verbose=0, warm_start=False),
       fit_params=None, iid='warn', n_jobs=8,
       param_grid={'max_depth': [2, 5, 10, None], 'n_estimators': [10, 50, 100, 300], 'max_features': [1, 3, 10], 'min_samples_split': [2, 3, 10], 'min_samples_leaf': [1, 3, 10]},
       pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',
       scoring=None, verbose=1)

In [16]:
RF_grid.best_params_

{'max_depth': None,
 'max_features': 10,
 'min_samples_leaf': 1,
 'min_samples_split': 2,
 'n_estimators': 300}

In [17]:
RF_grid.best_estimator_.fit(X_train, y_train)
y_train_pred = RF_grid.best_estimator_.predict(X_train)
y_test_pred = RF_grid.best_estimator_.predict(X_test)

print('MSE train : %.3f, test : %.3f' % (mean_squared_error(y_train, y_train_pred), mean_squared_error(y_test, y_test_pred)) )
print('R2 train : %.3f, test : %.3f' % (r2_score(y_train, y_train_pred), r2_score(y_test, y_test_pred)) )

MSE train : 0.204, test : 1.339
R2 train : 0.951, test : 0.672
