In [3]:
!git clone https://github.com/Matimath/CapsNet-Tensorflow

Cloning into 'CapsNet-Tensorflow'...
remote: Enumerating objects: 7, done.[K
remote: Counting objects:  14% (1/7)[Kremote: Counting objects:  28% (2/7)[Kremote: Counting objects:  42% (3/7)[Kremote: Counting objects:  57% (4/7)[Kremote: Counting objects:  71% (5/7)[Kremote: Counting objects:  85% (6/7)[Kremote: Counting objects: 100% (7/7)[Kremote: Counting objects: 100% (7/7), done.[K
remote: Compressing objects: 100% (6/6), done.[K
remote: Total 570 (delta 1), reused 5 (delta 1), pack-reused 563[K
Receiving objects: 100% (570/570), 1.60 MiB | 3.61 MiB/s, done.
Resolving deltas: 100% (281/281), done.


In [0]:
!rm -r data logdir results

In [0]:
mv CapsNet-Tensorflow capsnet

In [0]:
import sys
sys.path.insert(1, "./capsnet")

In [6]:
!python capsnet/download_data.py --dataset fashion-mnist --save_to data/fashion-mnist

>> Downloading train-images-idx3-ubyte.gz 100.0%
Successfully Downloaded train-images-idx3-ubyte.gz
Extracting  train-images-idx3-ubyte.gz
Successfully extracted

>> Downloading train-labels-idx1-ubyte.gz 111.0%
Successfully Downloaded train-labels-idx1-ubyte.gz
Extracting  train-labels-idx1-ubyte.gz
Successfully extracted

>> Downloading t10k-images-idx3-ubyte.gz 100.0%
Successfully Downloaded t10k-images-idx3-ubyte.gz
Extracting  t10k-images-idx3-ubyte.gz
Successfully extracted

>> Downloading t10k-labels-idx1-ubyte.gz 159.1%
Successfully Downloaded t10k-labels-idx1-ubyte.gz
Extracting  t10k-labels-idx1-ubyte.gz
Successfully extracted



In [7]:
import os
import sys
import numpy as np
import tensorflow as tf
from tqdm import tqdm

from config import cfg
from utils import load_data
from capsNet import CapsNet

In [0]:
def save_to():
    if not os.path.exists(cfg.results):
        os.mkdir(cfg.results)
    if cfg.is_training:
        loss = cfg.results + '/loss.csv'
        train_acc = cfg.results + '/train_acc.csv'
        val_acc = cfg.results + '/val_acc.csv'

        if os.path.exists(val_acc):
            os.remove(val_acc)
        if os.path.exists(loss):
            os.remove(loss)
        if os.path.exists(train_acc):
            os.remove(train_acc)

        fd_train_acc = open(train_acc, 'w')
        fd_train_acc.write('step,train_acc\n')
        fd_loss = open(loss, 'w')
        fd_loss.write('step,loss\n')
        fd_val_acc = open(val_acc, 'w')
        fd_val_acc.write('step,val_acc\n')
        return(fd_train_acc, fd_loss, fd_val_acc)
    else:
        test_acc = cfg.results + '/test_acc.csv'
        if os.path.exists(test_acc):
            os.remove(test_acc)
        fd_test_acc = open(test_acc, 'w')
        fd_test_acc.write('test_acc\n')
        return(fd_test_acc)


def train(model, supervisor, num_label):
    trX, trY, num_tr_batch, valX, valY, num_val_batch = load_data(cfg.dataset, cfg.batch_size, is_training=True)
    #Y = valY[:num_val_batch * cfg.batch_size].reshape((-1, 1))
    #import pdb; pdb.set_trace()
    fd_train_acc, fd_loss, fd_val_acc = save_to()
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with supervisor.managed_session(config=config) as sess:
        print("\nNote: all of results will be saved to directory: " + cfg.results)
        for epoch in range(cfg.epoch):
            print("Training for epoch %d/%d:" % (epoch, cfg.epoch))
            if supervisor.should_stop():
                print('supervisor stoped!')
                break
            for step in tqdm(range(num_tr_batch), total=num_tr_batch, ncols=70, leave=False, unit='b'):
                start = step * cfg.batch_size
                end = start + cfg.batch_size
                global_step = epoch * num_tr_batch + step

                if global_step % cfg.train_sum_freq == 0:
                    _, loss, train_acc, summary_str = sess.run([model.train_op, model.total_loss, model.accuracy, model.train_summary])
                    assert not np.isnan(loss), 'Something wrong! loss is nan...'
                    supervisor.summary_writer.add_summary(summary_str, global_step)

                    fd_loss.write(str(global_step) + ',' + str(loss) + "\n")
                    fd_loss.flush()
                    fd_train_acc.write(str(global_step) + ',' + str(train_acc / cfg.batch_size) + "\n")
                    fd_train_acc.flush()
                else:
                    sess.run(model.train_op)

            if (epoch + 1) % cfg.save_freq == 0:
                supervisor.saver.save(sess, cfg.logdir + '/model_epoch_%04d_step_%02d' % (epoch, global_step))

        fd_val_acc.close()
        fd_train_acc.close()
        fd_loss.close()


def evaluation(model, supervisor, num_label, rots):
    teX, teY, num_te_batch = load_data(cfg.dataset, cfg.batch_size, is_training=False)
    teX = np.rot90(teX ,rots, (1,2))
    fd_test_acc = save_to()
    with supervisor.managed_session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        supervisor.saver.restore(sess, tf.train.latest_checkpoint(cfg.logdir))
        tf.logging.info('Model restored!')

        test_acc = 0
        for i in tqdm(range(num_te_batch), total=num_te_batch, ncols=70, leave=False, unit='b'):
            start = i * cfg.batch_size
            end = start + cfg.batch_size
            acc = sess.run(model.accuracy, {model.X: teX[start:end], model.labels: teY[start:end]})
            test_acc += acc
        test_acc = test_acc / (cfg.batch_size * num_te_batch)
        print("Test accuracy for images rotated {} degrees is: {}".format(rots * 90, str(test_acc)))



In [0]:
def main(_):
    tf.logging.info(' Loading Graph...')
    num_label = 10
    model = CapsNet()
    tf.logging.info(' Graph loaded')

    sv = tf.train.Supervisor(graph=model.graph, logdir=cfg.logdir, save_model_secs=0)

    #tf.logging.info(' Start training...')
    #train(model, sv, num_label)
    #tf.logging.info('Training done')
    evaluation(model, sv, num_label,0)
    evaluation(model, sv, num_label,1)
    evaluation(model, sv, num_label,2)
    evaluation(model, sv, num_label,3)

tf.app.run()

INFO:tensorflow: Loading Graph...


I0204 20:37:52.349667 139797565028224 <ipython-input-10-996d58c29946>:2]  Loading Graph...


INFO:tensorflow:Seting up the main structure


I0204 20:37:53.548284 139797565028224 capsNet.py:54] Seting up the main structure


INFO:tensorflow: Graph loaded


I0204 20:37:53.550242 139797565028224 <ipython-input-10-996d58c29946>:5]  Graph loaded


INFO:tensorflow:Restoring parameters from logdir/model_epoch_0049_step_23399


I0204 20:37:54.242643 139797565028224 saver.py:1284] Restoring parameters from logdir/model_epoch_0049_step_23399


Instructions for updating:
Use standard file utilities to get mtimes.


W0204 20:37:54.440993 139797565028224 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/training/saver.py:1069: get_checkpoint_mtimes (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file utilities to get mtimes.


INFO:tensorflow:Running local_init_op.


I0204 20:37:54.451571 139797565028224 session_manager.py:500] Running local_init_op.


INFO:tensorflow:Done running local_init_op.


I0204 20:37:54.474522 139797565028224 session_manager.py:502] Done running local_init_op.


INFO:tensorflow:Starting standard services.


I0204 20:38:29.256532 139797565028224 supervisor.py:737] Starting standard services.


INFO:tensorflow:Starting queue runners.


I0204 20:38:29.760130 139797565028224 supervisor.py:743] Starting queue runners.


INFO:tensorflow:Restoring parameters from logdir/model_epoch_0049_step_23399


I0204 20:38:29.895231 139797565028224 saver.py:1284] Restoring parameters from logdir/model_epoch_0049_step_23399


INFO:tensorflow:Recording summary at step 23400.


I0204 20:38:32.440379 139793434392320 supervisor.py:1050] Recording summary at step 23400.


INFO:tensorflow:Model restored!


I0204 20:38:33.982596 139797565028224 <ipython-input-8-b8f5bda789e8>:77] Model restored!


Test accuracy for images rotated 0 degrees is: 0.90625
INFO:tensorflow:Restoring parameters from logdir/model_epoch_0049_step_23399


I0204 20:38:37.694609 139797565028224 saver.py:1284] Restoring parameters from logdir/model_epoch_0049_step_23399


INFO:tensorflow:Running local_init_op.


I0204 20:38:37.886617 139797565028224 session_manager.py:500] Running local_init_op.


INFO:tensorflow:Done running local_init_op.


I0204 20:38:37.915925 139797565028224 session_manager.py:502] Done running local_init_op.


INFO:tensorflow:Starting standard services.


I0204 20:39:12.880405 139797565028224 supervisor.py:737] Starting standard services.


INFO:tensorflow:Starting queue runners.


I0204 20:39:13.412388 139797565028224 supervisor.py:743] Starting queue runners.


INFO:tensorflow:Restoring parameters from logdir/model_epoch_0049_step_23399


I0204 20:39:13.449257 139797565028224 saver.py:1284] Restoring parameters from logdir/model_epoch_0049_step_23399


INFO:tensorflow:Model restored!


I0204 20:39:15.171910 139797565028224 <ipython-input-8-b8f5bda789e8>:77] Model restored!


Test accuracy for images rotated 90 degrees is: 0.0751201923076923
INFO:tensorflow:Recording summary at step 23400.


I0204 20:39:19.764406 139793425999616 supervisor.py:1050] Recording summary at step 23400.


INFO:tensorflow:Restoring parameters from logdir/model_epoch_0049_step_23399


I0204 20:39:19.883621 139797565028224 saver.py:1284] Restoring parameters from logdir/model_epoch_0049_step_23399


INFO:tensorflow:Running local_init_op.


I0204 20:39:20.090693 139797565028224 session_manager.py:500] Running local_init_op.


INFO:tensorflow:Done running local_init_op.


I0204 20:39:20.117396 139797565028224 session_manager.py:502] Done running local_init_op.
