* Neural Network by Chainer

In [2]:
import time
import numpy as np
import pandas as pd
import chainer
import chainer.functions as F
import chainer.links as L
from chainer.training import extensions
from sklearn import datasets

In [3]:
# 教師データ

N = 100
in_size = 4
out_size = 3
iris = datasets.load_iris()
data = pd.DataFrame(data= np.c_[iris["data"], iris["target"]], columns= iris["feature_names"] + ["target"])
data = np.array(data.values)
dataset = []
for d in data:
    x = d[0:4]
    y = d[4]
    dataset.append((np.array(x, dtype="float32"), np.array(y, dtype="int32")))
N = len(dataset)

In [4]:
class NN(chainer.Chain):
    
    def __init__(self, in_size, hidden_size, out_size):
        
        super(NN, self).__init__(
            xh = L.Linear(in_size, hidden_size),
            hh = L.Linear(hidden_size, hidden_size),
            hy = L.Linear(hidden_size, out_size)
        )
 
    def __call__(self, x):
        
        h = F.relu(self.xh(x))
        h = F.relu(self.hh(h))
        y = self.hy(h)
        
        return y

In [6]:
HIDDEN_SIZE = 20

model = L.Classifier(NN(in_size=in_size, hidden_size=HIDDEN_SIZE, out_size=out_size))
optimizer = chainer.optimizers.Adam()
optimizer.setup(model)

<chainer.optimizers.adam.Adam at 0x7f900ea54828>

In [8]:
# train

EPOCH_NUM = 100
BATCH_SIZE = 20

train, test = chainer.datasets.split_dataset_random(dataset, N-50) # 100件を学習用、50件をテスト用
train_iter = chainer.iterators.SerialIterator(train, BATCH_SIZE)
test_iter = chainer.iterators.SerialIterator(test, BATCH_SIZE, repeat=False, shuffle=False)
updater = chainer.training.StandardUpdater(train_iter, optimizer, device=-1)
trainer = chainer.training.Trainer(updater, (EPOCH_NUM, "epoch"), out="result")
trainer.extend(extensions.Evaluator(test_iter, model, device=-1))
trainer.extend(extensions.LogReport(trigger=(10, "epoch"))) # 10エポックごとにログ出力
trainer.extend(extensions.PrintReport( ["epoch", "main/loss", "validation/main/loss", "main/accuracy", "validation/main/accuracy", "elapsed_time"])) # エポック、学習損失、テスト損失、学習正解率、テスト正解率、経過時間
#trainer.extend(extensions.ProgressBar()) # プログレスバー出力
trainer.run()

epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy  elapsed_time
[J10          3.46867     2.65507               0.304          0.446667                  0.159232      
[J20          1.13221     1.0186                0.517          0.466667                  0.318512      
[J30          0.672347    0.719645              0.71           0.586667                  0.477058      
[J40          0.534068    0.594945              0.721          0.606667                  0.636286      
[J50          0.464032    0.526126              0.764          0.661667                  0.79505       
[J60          0.422643    0.48023               0.829          0.736667                  0.954213      
[J70          0.385826    0.439647              0.87           0.805                     1.1127        
[J80          0.344917    0.392585              0.909          0.918333                  1.272         
[J90          0.302845    0.345263              0.933      

In [9]:
# predict

print("x\ty\tpredict")
idx = np.random.choice(N, 10)
for i in idx:
    x = dataset[i][0]
    y_ = np.argmax(model.predictor(x=x.reshape(1,len(x))).data)
    y = dataset[i][1]
    print(x, "\t", y, "\t", y_)

x	y	predict
[5.5 2.3 4.  1.3] 	 1 	 1
[6.9 3.1 4.9 1.5] 	 1 	 1
[4.6 3.4 1.4 0.3] 	 0 	 0
[4.9 3.1 1.5 0.1] 	 0 	 0
[6.7 3.1 4.7 1.5] 	 1 	 1
[6.  3.4 4.5 1.6] 	 1 	 1
[6.2 2.8 4.8 1.8] 	 2 	 2
[5.2 3.5 1.5 0.2] 	 0 	 0
[6.2 3.4 5.4 2.3] 	 2 	 2
[6.7 2.5 5.8 1.8] 	 2 	 2


In [10]:
!python --version

Python 3.6.3


In [11]:
!pip freeze

absl-py==0.2.0
astor==0.6.2
backcall==0.1.0
bleach==1.5.0
boto==2.49.0
boto3==1.7.73
botocore==1.10.73
bz2file==0.98
certifi==2018.4.16
chainer==4.0.0
chainercv==0.10.0
chardet==3.0.4
cntk-gpu==2.5.1
cupy==4.0.0
cycler==0.10.0
Cython==0.28.2
decorator==4.3.0
dm-sonnet==1.20
docutils==0.14
edward==1.3.5
entrypoints==0.2.3
fastrlock==0.3
filelock==3.0.4
future==0.16.0
gast==0.2.0
gensim==3.5.0
graphviz==0.8.3
grpcio==1.11.0
h5py==2.7.1
html5lib==0.9999999
idna==2.6
ipykernel==4.8.2
ipython==6.3.1
ipython-genutils==0.2.0
ipywidgets==7.2.1
jedi==0.12.0
Jinja2==2.10
jmespath==0.9.3
jsonschema==2.6.0
jupyter==1.0.0
jupyter-client==5.2.3
jupyter-console==5.2.0
jupyter-core==4.4.0
Keras==2.1.6
kiwisolver==1.0.1
Lasagne==0.2.dev1
leveldb==0.194
Mako==1.0.7
Markdown==2.6.11
MarkupSafe==1.0
matplotlib==2.2.2
mecab-python3==0.7
mistune==0.8.3
mxnet-cu90==1.1.0.post0
nbconvert==5.3.1
nbformat==4.4.0
networkx==2.1
nose==1.3.7
notebook==5.4.1