In [None]:
import os
# CHANGE: remove logger/logging

import numpy as np
from tqdm import trange
import tensorflow as tf

from utils import *
from network import Network
from statistic import Statistic

import network
import statistic
import ops
import utils

# network 
# CHANGE: replaced flags with variables
model = "pixel_rnn"
batch_size = 100
hidden_dims = 16
recurrent_length = 7
out_hidden_dims = 32
out_recurrent_length = 2
use_residual = False

# training
max_epoch = 2 #100000
test_step = 100
save_step = 1000
learning_rate = 1e-3
grad_clip = 1
use_gpu = True

# data
data = "mnist"
data_dir = "data"
sample_dir = "samples"

# Debug
is_train = True
display = False
log_level = "INFO"
random_seed = 123

# random seed
tf.set_random_seed(random_seed)
np.random.seed(random_seed)

def main(_):
    model_dir = "model"
    
    DATA_DIR = os.path.join(data_dir, data)
    SAMPLE_DIR = os.path.join(sample_dir, data, model_dir)

    check_and_create_dir(DATA_DIR)
    #check_and_create_dir(SAMPLE_DIR)
    SAMPLE_DIR = 'sample'

    # 0. prepare datasets
    if data == "mnist":
        from tensorflow.examples.tutorials.mnist import input_data
        mnist = input_data.read_data_sets(DATA_DIR, one_hot=True)

        next_train_batch = lambda x: mnist.train.next_batch(x)[0]
        next_test_batch = lambda x: mnist.test.next_batch(x)[0]

        height, width, channel = 28, 28, 1

        train_step_per_epoch = int(mnist.train.num_examples / batch_size)
        test_step_per_epoch = int(mnist.test.num_examples / batch_size)
    elif data == "cifar":
        from cifar10 import IMAGE_SIZE, inputs

        maybe_download_and_extract(DATA_DIR)
        images, labels = inputs(eval_data=False,
            data_dir=os.path.join(DATA_DIR, 'cifar-10-batches-bin'), batch_size=batch_size)

        height, width, channel = IMAGE_SIZE, IMAGE_SIZE, 3

    with tf.Session() as sess:
        print("SESSION")
        print(sess)
        print(type(sess))
        print()
        network = Network(sess, height, width, channel)

        stat = Statistic(sess, data, model_dir, tf.trainable_variables(), test_step)
        stat.load_model()

        if is_train:
            print("Training starts!")

            initial_step = stat.get_t() if stat else 0
            iterator = trange(max_epoch, ncols=70, initial=initial_step)

            for epoch in iterator:
                # 1. train
                total_train_costs = []
                for idx in range(train_step_per_epoch):
                    images = binarize(next_train_batch(batch_size)).reshape([batch_size, height, width, channel])

                    cost = network.test(images, with_update=True)
                    total_train_costs.append(cost)

            # 2. test
            total_test_costs = []
            for idx in range(test_step_per_epoch):
                images = binarize(next_test_batch(batch_size)).reshape([batch_size, height, width, channel])

                cost = network.test(images, with_update=False)
                total_test_costs.append(cost)

            avg_train_cost, avg_test_cost = np.mean(total_train_costs), np.mean(total_test_costs)

            stat.on_step(avg_train_cost, avg_test_cost)
            iterator.set_description("train l: %.3f, test l: %.3f" % (avg_train_cost, avg_test_cost))

            # 3. generate samples
            samples = network.generate()
            print("done")
            save_images(samples, height, width, 10, 10,
                directory=SAMPLE_DIR, prefix="epoch_%s" % epoch)

            iterator.set_description("train l: %.3f, test l: %.3f" % (avg_train_cost, avg_test_cost))
            print()
        else:
            print("Image generation starts!")

            samples = network.generate()
            save_images(samples, height, width, 10, 10, directory=SAMPLE_DIR)

if __name__ == "__main__":
    tf.app.run()

Skip creating directory: data\mnist
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Instructions for updating:
Please write your own downloading logic.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting data\mnist\train-images-idx3-ubyte.gz
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting data\mnist\train-labels-idx1-ubyte.gz
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting data\mnist\t10k-images-idx3-ubyte.gz
Extracting data\mnist\t10k-labels-idx1-ubyte.gz
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
SESSION
<tensorflow.python.client.session.Session object at 0x0000018C81251630>
<class 'tensorflow.python.client.session.Session'>

Building pixel_rnn starts!
Building conv_inputs
[conv2d_a] conv_inputs : Placeholder:0 (?, 28, 28, 1) -> conv_inputs/outputs_plus_b:0

[skew] skewed_i : LSTM3/ReverseV2:0 (?, 28, 28, 16) -> LSTM3/output_state_bw/skewed_i/output:0 (?, 28, 55, 16)
[conv2d_b] i_to_s : LSTM3/output_state_bw/skewed_i/output:0 (?, 28, 55, 16) -> LSTM3/output_state_bw/i_to_s/outputs_plus_b:0 (?, 28, 55, 64)
[conv1d] s_to_s : LSTM3/output_state_bw/rnn/while/DiagonalBiLSTMCell/conv1d_inputs:0 (?, 28, 1, 16) -> LSTM3/output_state_bw/rnn/while/DiagonalBiLSTMCell/s_to_s/outputs_plus_b:0 (?, 28, 1, 64)
[DiagonalLSTMCell] DiagonalBiLSTMCell : LSTM3/output_state_bw/rnn/while/TensorArrayReadV3:0 (?, 1792) -> LSTM3/output_state_bw/rnn/while/DiagonalBiLSTMCell/hid:0 (?, 448)
[unskew] unskew : LSTM3/output_state_bw/transpose_1:0 (?, 28, 55, 16) -> LSTM3/output_state_bw/unskew/output:0 (?, 28, 28, 16)
Building LSTM3
[skew] skewed_i : LSTM3/add:0 (?, 28, 28, 16) -> LSTM4/output_state_fw/skewed_i/output:0 (?, 28, 55, 16)
[conv2d_b] i_to_s : LSTM4/output_state_fw/skewed_i/output:0 (?, 28, 55, 16) -> LSTM4/output_state_fw/i_to_s/outputs_plus_b:0 (?, 28, 55, 

[37] LSTM4/output_state_fw/rnn/DiagonalBiLSTMCell/s_to_s/biases:0 (64,) = 64
[38] LSTM4/output_state_bw/i_to_s/weights:0 (1, 1, 16, 64) = 1024
[39] LSTM4/output_state_bw/i_to_s/biases:0 (64,) = 64
[40] LSTM4/output_state_bw/rnn/DiagonalBiLSTMCell/s_to_s/weights:0 (2, 1, 16, 64) = 2048
[41] LSTM4/output_state_bw/rnn/DiagonalBiLSTMCell/s_to_s/biases:0 (64,) = 64
[42] LSTM5/output_state_fw/i_to_s/weights:0 (1, 1, 16, 64) = 1024
[43] LSTM5/output_state_fw/i_to_s/biases:0 (64,) = 64
[44] LSTM5/output_state_fw/rnn/DiagonalBiLSTMCell/s_to_s/weights:0 (2, 1, 16, 64) = 2048
[45] LSTM5/output_state_fw/rnn/DiagonalBiLSTMCell/s_to_s/biases:0 (64,) = 64
[46] LSTM5/output_state_bw/i_to_s/weights:0 (1, 1, 16, 64) = 1024
[47] LSTM5/output_state_bw/i_to_s/biases:0 (64,) = 64
[48] LSTM5/output_state_bw/rnn/DiagonalBiLSTMCell/s_to_s/weights:0 (2, 1, 16, 64) = 2048
[49] LSTM5/output_state_bw/rnn/DiagonalBiLSTMCell/s_to_s/biases:0 (64,) = 64
[50] LSTM6/output_state_fw/i_to_s/weights:0 (1, 1, 16, 64) = 1024

 50%|█████████████████████▌                     | 1/2 [00:00<?, ?it/s]

In [None]:
stat