In [1]:
# coding: utf-8
import os
import sys
sys.path.append(os.pardir)
import matplotlib.pyplot as plt
from dataset.mnist import load_mnist
from common.util import smooth_curve
from common.multi_layer_net import MultiLayerNet
from common.optimizer import *

(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True)

train_size = x_train.shape[0]
batch_size = 128
max_iterations = 2000

optimizers = {}
optimizers['SGD'] = SGD()
optimizers['Momentum'] = Momentum()
optimizers['AdaGrad'] = AdaGrad()
optimizers['Adam'] = Adam()
#optimizers['RMSprop'] = RMSprop()

networks = {}
train_loss = {}
for key in optimizers.keys():
    networks[key] = MultiLayerNet(
        input_size=784, hidden_size_list=[100, 100, 100, 100],
        output_size=10)
    train_loss[key] = []    


for i in range(max_iterations):
    batch_mask = np.random.choice(train_size, batch_size)
    x_batch = x_train[batch_mask]
    t_batch = t_train[batch_mask]
    
    for key in optimizers.keys():
        grads = networks[key].gradient(x_batch, t_batch)
        optimizers[key].update(networks[key].params, grads)
    
        loss = networks[key].loss(x_batch, t_batch)
        train_loss[key].append(loss)
    
    if i % 100 == 0:
        print( "===========" + "iteration:" + str(i) + "===========")
        for key in optimizers.keys():
            loss = networks[key].loss(x_batch, t_batch)
            print(key + ":" + str(loss))


markers = {"SGD": "o", "Momentum": "x", "AdaGrad": "s", "Adam": "D"}
x = np.arange(max_iterations)
for key in optimizers.keys():
    plt.plot(x, smooth_curve(train_loss[key]), marker=markers[key], markevery=100, label=key)
plt.xlabel("iterations")
plt.ylabel("loss")
plt.ylim(0, 1)
plt.legend()
plt.show()


SGD:2.4793477853087866
Momentum:2.378656957870902
AdaGrad:2.7306228981849765
Adam:2.195724823003241
SGD:1.6657326187916768
Momentum:0.3220152106327263
AdaGrad:0.1641503628335228
Adam:0.2437048340252706
SGD:0.8008886425154433
Momentum:0.20891521542718755
AdaGrad:0.07884540703857049
Adam:0.13171135574491963
SGD:0.5253857854297265
Momentum:0.17499406477607915
AdaGrad:0.08162995104459606
Adam:0.10413157056938158
SGD:0.42937224137471985
Momentum:0.21044823032471094
AdaGrad:0.1019022405859088
Adam:0.2077314003582647
SGD:0.4338221354520376
Momentum:0.17224632393513073
AdaGrad:0.07687377219323319
Adam:0.12511835409858327
SGD:0.42977725746077866
Momentum:0.1728351503393628
AdaGrad:0.07750002649274504
Adam:0.12548465833647404
SGD:0.3091942867161237
Momentum:0.14080953648120392
AdaGrad:0.05166662977425281
Adam:0.07594330312009537
SGD:0.21092047035801403
Momentum:0.08947693099639042
AdaGrad:0.03673286441568563
Adam:0.07266056941460382
SGD:0.399041288372961
Momentum:0.19234826847788006
AdaGrad:0.06

<Figure size 640x480 with 1 Axes>