In [1]:
import os
import pickle
import numpy as np
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 
def cache(cache_path, fn, *args, **kwargs):
    """
    Cache-wrapper for a function or class. If the cache-file exists
    then the data is reloaded and returned, otherwise the function
    is called and the result is saved to cache. The fn-argument can
    also be a class instead, in which case an object-instance is
    created and saved to the cache-file.
    :param cache_path:
        File-path for the cache-file.
    :param fn:
        Function or class to be called.
    :param args:
        Arguments to the function or class-init.
    :param kwargs:
        Keyword arguments to the function or class-init.
    :return:
        The result of calling the function or creating the object-instance.
    """

    # If the cache-file exists.
    if os.path.exists(cache_path):
        # Load the cached data from the file.
        with open(cache_path, mode='rb') as file:
            obj = pickle.load(file)

        print("- Data loaded from cache-file: " + cache_path)
    else:
        # The cache-file does not exist.

        # Call the function / class-init with the supplied arguments.
        obj = fn(*args, **kwargs)

        # Save the data to a cache-file.
        with open(cache_path, mode='wb') as file:
            pickle.dump(obj, file)

        print("- Data saved to cache-file: " + cache_path)

    return obj
def convert_numpy2pickle(in_path, out_path):
    """
    Convert a numpy-file to pickle-file.
    The first version of the cache-function used numpy for saving the data.
    Instead of re-calculating all the data, you can just convert the
    cache-file using this function.
    :param in_path:
        Input file in numpy-format written using numpy.save().
    :param out_path:
        Output file written as a pickle-file.
    :return:
        Nothing.
    """
    # Load the data using numpy.
    data = np.load(in_path)

    # Save the data using pickle.
    with open(out_path, mode='wb') as file:
        pickle.dump(data, file)

if __name__ == '__main__':
    # This is a short example of using a cache-file.

    # This is the function that will only get called if the result
    # is not already saved in the cache-file. This would normally
    # be a function that takes a long time to compute, or if you
    # need persistent data for some other reason.
    def expensive_function(a, b):
        return a * b

    print('Computing expensive_function() ...')

    # Either load the result from a cache-file if it already exists,
    # otherwise calculate expensive_function(a=123, b=456) and
    # save the result to the cache-file for next time.
    result = cache(cache_path='cache_expensive_function.pkl',
                   fn=expensive_function, a=123, b=456)

    print('result =', result)

    # Newline.
    print()

    # This is another example which saves an object to a cache-file.

    # We want to cache an object-instance of this class.
    # The motivation is to do an expensive computation only once,
    # or if we need to persist the data for some other reason.
    class ExpensiveClass:
        def __init__(self, c, d):
            self.c = c
            self.d = d
            self.result = c * d

        def print_result(self):
            print('c =', self.c)
            print('d =', self.d)
            print('result = c * d =', self.result)

    print('Creating object from ExpensiveClass() ...')

    # Either load the object from a cache-file if it already exists,
    # otherwise make an object-instance ExpensiveClass(c=123, d=456)
    # and save the object to the cache-file for the next time.
    obj = cache(cache_path='cache_ExpensiveClass.pkl',
                fn=ExpensiveClass, c=123, d=456)

    obj.print_result()

Computing expensive_function() ...
- Data loaded from cache-file: cache_expensive_function.pkl
result = 56088

Creating object from ExpensiveClass() ...
- Data loaded from cache-file: cache_ExpensiveClass.pkl
c = 123
d = 456
result = c * d = 56088


In [4]:
!pip install sobol_seq
!pip install tensorflow==1.15.5

Collecting sobol_seq
  Downloading https://files.pythonhosted.org/packages/e4/df/6c4ad25c0b48545a537b631030f7de7e4abb939e6d2964ac2169d4379c85/sobol_seq-0.2.0-py3-none-any.whl
Installing collected packages: sobol-seq
Successfully installed sobol-seq-0.2.0
Collecting tensorflow==1.15.5
[?25l  Downloading https://files.pythonhosted.org/packages/9a/51/99abd43185d94adaaaddf8f44a80c418a91977924a7bc39b8dacd0c495b0/tensorflow-1.15.5-cp37-cp37m-manylinux2010_x86_64.whl (110.5MB)
[K     |████████████████████████████████| 110.5MB 94kB/s 
Collecting tensorflow-estimator==1.15.1
[?25l  Downloading https://files.pythonhosted.org/packages/de/62/2ee9cd74c9fa2fa450877847ba560b260f5d0fb70ee0595203082dafcc9d/tensorflow_estimator-1.15.1-py2.py3-none-any.whl (503kB)
[K     |████████████████████████████████| 512kB 35.7MB/s 
Collecting keras-applications>=1.0.8
[?25l  Downloading https://files.pythonhosted.org/packages/71/e3/19762fdfc62877ae9102edf6342d71b28fbfd9dea3d2f96a882ce099b03f/Keras_Applications

In [2]:
import tensorflow as tf
from scipy.stats import norm
import sobol_seq

DTYPE = tf.float32

class Compressible(object):
    def __init__(self, name, message_freq=1000):
        self.name = name
        self.message_freq = message_freq
        self.message_counter = 0

    def get_feed_dict(self, validation=False):
        raise NotImplementedError

    def get_train_op(self, training):
        raise NotImplementedError

    def training_step(self, training, extra_ops):
        self.message_counter += 1
        if self.message_counter % self.message_freq == 0 or self.message_counter == 1:
            loss, training_acc, kl, kl_loss = self.sess.run([self.loss, self.accuracy, self.mean_kl, self.kl_loss],
                                                   feed_dict=self.get_feed_dict())

            mean_validation_acc = 0.0
            for i in range(10):
                validation_acc = self.sess.run(self.accuracy,
                                               feed_dict=self.get_feed_dict(validation=True))
                mean_validation_acc += validation_acc
            mean_validation_acc /= 10.
            print("Iteration {}, Validation score = {}, Training score = {}, Loss = {}, KL-Loss = {}, KL_2 = {}".format(
                self.message_counter,
                mean_validation_acc,
                training_acc,
                loss,
                kl_loss,
                kl / np.log(2.)))
            path = '/scratch/mh740/compression_models/{}/{}/{}.ckpt'.format(self.name, training,
                                                                         self.message_counter)
            if not os.path.exists(path):
                os.makedirs(path)
            self.saver.save(self.sess, path)

        self.sess.run((self.get_train_op(training=training), extra_ops),
                      feed_dict=self.get_feed_dict())

    def train(self, iterations, enforce_kl):
        if enforce_kl:
            self.sess.run(self.enable_kl_loss.assign(1.))
            with tf.control_dependencies([self.get_train_op(training='training')]):
                extra_ops = [tf.identity(self.kl_penalty_update)]
            for i in range(iterations):
                self.training_step(training='training', extra_ops=extra_ops)
        else:
            self.sess.run(self.enable_kl_loss.assign(0.))
            for i in range(iterations):
                self.training_step(training='pretrain', extra_ops=[])
        mean_validation_acc = 0.0
        for i in range(20):
            validation_acc = self.sess.run(self.accuracy,
                                           feed_dict=self.get_feed_dict(validation=True))
        mean_validation_acc += validation_acc
        return mean_validation_acc / 20.

    def compress(self, retrain_iter, kl_penalty_step=1.0005):
        self.sess.run(self.kl_penalty_step.assign(kl_penalty_step))
        n_blocks = self.fixed_weights.get_shape().as_list()[0]
        self.sess.run(self.enable_kl_loss.assign(1.))
        for i in range(n_blocks):
            self.sess.run(self.comp_ops, feed_dict={self.block_to_comp: i})
            print('Block {} of {} compressed'.format(i, n_blocks))
            for j in range(retrain_iter):
                self.training_step(training='compression', extra_ops=self.kl_penalty_update)

        mean_validation_acc = 0.0
        for i in range(100):
            validation_acc = self.sess.run(self.accuracy,
                                           feed_dict=self.get_feed_dict(validation=True))
            mean_validation_acc += validation_acc
        return mean_validation_acc / 100.

    def initialize_variables(self,
                             dimensions,
                             initializers,
                             hash_group_sizes,
                             block_size,
                             bits_per_block,
                             weight_decay=5e-4,
                             kl_penalty_step=1.00005):
        assert len(initializers) == len(dimensions)
        num_vars = 0
        for dim, group_size in zip(dimensions, hash_group_sizes):
            assert np.prod(dim) % group_size == 0
            num_vars += np.prod(dim) / group_size
        n_blocks = np.int64(1 + (num_vars - 1) / block_size)
        shape = [n_blocks, block_size]
        print('Number of blocks: {}, Block size: {}, Bits per block: {}, Target KL: {}, Overall bits {}, Ratio: {}'.format(
            n_blocks, block_size, bits_per_block, bits_per_block, bits_per_block*n_blocks, np.sum([np.prod(dim) for dim in dimensions])*32. / (bits_per_block * n_blocks)
        ))
        num_vars_ub = np.prod(shape)

        np.random.seed(420)
        num_vars_ub = np.int64(num_vars_ub)
        num_vars = np.int64(num_vars)
        permutation = np.random.permutation(num_vars_ub)
        permutation_inv = np.argsort(permutation)
        var_sizes = [np.prod(dim)/group_size for dim, group_size in zip(dimensions, hash_group_sizes)]
        var_sizes = np.int64(var_sizes)
        # print(var_sizes)
        # print(initializers)

        self.p_scale_vars = tf.Variable(tf.fill([len(dimensions) + 1], -2.), dtype=DTYPE)
        print(range(len(dimensions) + 1))
        #p_perm_inv = np.repeat(range(len(dimensions) + 1), var_sizes + [num_vars_ub - num_vars])[permutation_inv]
        p_perm_inv = np.repeat(range(len(dimensions)), var_sizes + [num_vars_ub - num_vars])[permutation_inv]
        #self.p_scale = tf.reshape(tf.gather(tf.exp(self.p_scale_vars), p_perm_inv), (shape))
        shape = np.int64(shape)
        self.p_scale = tf.reshape(tf.gather(tf.exp(self.p_scale_vars), p_perm_inv), (shape))
        p = tf.contrib.distributions.Normal(loc=0., scale=self.p_scale)
        mu_init_list = []
        for (type, val), size in zip(initializers, var_sizes):
            if type == 'normal':
                mu_init_list.append(np.random.normal(size=size, loc=0., scale=val))
            elif type == 'uni':
                mu_init_list.append(np.random.uniform(-val, val, size=size))
            elif type == 'zero':
                mu_init_list.append(np.zeros(size))
            else:
                assert False

        mu_init = np.concatenate(mu_init_list)
        # print(num_vars, mu_init.shape)
        # print(var_sizes, [init.shape for init in mu_init_list])
        init_inv_permuted = np.concatenate((mu_init,
                                            np.zeros(num_vars_ub - num_vars)),
                                           axis=0)[permutation_inv]

        mu = tf.Variable(init_inv_permuted.reshape(shape), dtype=DTYPE, name='mu')
        self.mu = mu
        self.weight_decay_loss = tf.reduce_sum(tf.square(mu)) * weight_decay
        self.sigma_var = tf.Variable(tf.fill(shape, tf.cast(-10., dtype=DTYPE, name='sigma')))
        sigma = tf.exp(self.sigma_var)
        self.sigma = sigma
        epsilon = tf.random_normal(shape)
        self.w_dist = tf.contrib.distributions.Normal(loc=mu, scale=sigma)
        variational_weights = mu + epsilon * sigma
        self.fixed_weights = tf.Variable(tf.zeros_like(variational_weights), trainable=False)
        self.mask = tf.Variable(tf.ones([n_blocks]), trainable=False)
        kl_penalties = tf.Variable(tf.fill([n_blocks], tf.cast(1e-8, dtype=DTYPE)), trainable=False)
        self.kl_penalties = kl_penalties

        kl_target = tf.Variable(bits_per_block * np.log(2.), dtype=tf.float32, trainable=False)
        block_kl = tf.reduce_sum(tf.distributions.kl_divergence(self.w_dist, p), axis=1)
        self.mean_kl = tf.reduce_mean(block_kl)

        self.enable_kl_loss = tf.Variable(1., dtype=DTYPE, trainable=False)
        self.kl_loss = tf.reduce_sum(block_kl * self.mask * kl_penalties) * self.enable_kl_loss
        self.kl_penalty_step = tf.Variable(kl_penalty_step, trainable=False)
        self.kl_penalty_update = [kl_penalties.assign(tf.where(tf.logical_and(tf.cast(self.mask, tf.bool),
                                                                              tf.greater(block_kl, kl_target)),
                                                               kl_penalties * self.kl_penalty_step,
                                                               kl_penalties / self.kl_penalty_step))]

        mask_expanded = tf.expand_dims(self.mask, 1)
        combined_weights = tf.reshape(mask_expanded * variational_weights
                                      + (1. - mask_expanded) * self.fixed_weights,
                                      [-1])

        permuted_weights = tf.gather(combined_weights, permutation)
        split_weights = tf.split(permuted_weights, var_sizes + [num_vars_ub - num_vars])

        result = []
        i = 0
        for dim in dimensions:
            split = tf.expand_dims(split_weights[i], axis=1) * np.random.choice([-1., 1.], size=hash_group_sizes[i])
            # print(split.get_shape().as_list())
            result.append(tf.reshape(split, dim))
            i += 1

        self.initialize_compressor(bits_per_block)
        return result

    def initialize_compressor(self, bits_per_block):
        with tf.variable_scope('compressor'):
            self.block_to_comp = tf.placeholder(tf.int32)
            shape = self.fixed_weights.get_shape().as_list()
            # block_ind = tf.expand_dims(block, 1)
            # mask = tf.scatter_nd(block_ind, tf.ones([1], dtype=tf.float32), shape)
            # sample_shape = tf.concat(([tries], shape), axis=0)

            # sequencer = ghalton.Halton(shape[1])
            n_blocks = shape[0]
            sobol_dim = shape[1]
            assert sobol_dim <= 40
            uni_quasi = np.array(sobol_seq.i4_sobol_generate(sobol_dim, np.power(2, bits_per_block), skip=1)).transpose()
            normal_quasi = norm.ppf(uni_quasi).transpose()
            #normal_quasi = np.tile(normal_quasi[:, None, :], [1, n_blocks, 1])
            # This line helps but not exactly sure why
            # normal_quasi /= np.sqrt(np.mean(np.square(normal_quasi), axis=1))[:, None]
            sample_block = tf.constant(normal_quasi, dtype=DTYPE)

            # normal = np.random.normal(size=(tries, shape[0], shape[1]))
            # normal /= np.sqrt(np.mean(np.square(normal), axis=2))[:, :, None]
            # sample_block = tf.constant(normal, dtype=tf.float32) * p_scale

            block_p = self.p_scale[self.block_to_comp, :]
            block_mu = self.mu[self.block_to_comp, :]
            block_sigma = self.sigma[self.block_to_comp, :]
            nll_q = tf.reduce_sum(tf.square(block_mu - sample_block * block_p) / (2*tf.square(block_sigma)), axis=1)
            nll_p = tf.reduce_sum(tf.square(sample_block), axis=1)
            #prob = tf.Print(tf.exp(-nll), [nll_q, nll_p], summarize=100)
            #norm_prob = tf.Print(prob / tf.reduce_sum(prob), [prob], summarize=100)
            dist = tf.distributions.Categorical(probs=tf.nn.softmax(nll_p - nll_q)) # , validate_args=True) #Risky
            index = dist.sample([])

            # This line makes the algorithm objectively better. But we cannot prove it theoretically.
            # min_index = tf.argmin(nll_q, axis=0)

            best_sample = sample_block[index, :] * block_p
            self.comp_ops = []
            self.comp_ops.append(tf.scatter_update(self.fixed_weights,
                                                   [self.block_to_comp],
                                                   [best_sample]))
            self.comp_ops.append(tf.scatter_update(self.mask, [self.block_to_comp], [0.]))

    def initialize_session(self, load_name=None):
        # Initialize the variables (i.e. assign their default value)
        self.saver = tf.train.Saver(max_to_keep=None)
        self.loader = tf.train.Saver(var_list=[v for v in tf.all_variables() if v not in []])
        init = tf.global_variables_initializer()
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        self.sess = tf.Session(config=config)

        # Run the initializer
        self.sess.run(init)

        if load_name is not None:
            # tf.reset_default_graph()
            path = '/scratch/mh740/compression_models/{}'.format(load_name)
            self.loader.restore(self.sess, path)

In [3]:
########################################################################
#
# Functions for downloading and extracting data-files from the internet.
#
# Implemented in Python 3.5
#
########################################################################
#
# This code is part of the TensorFlow Tutorials available at:
#
# https://github.com/Hvass-Labs/TensorFlow-Tutorials
#
# Published under the MIT License. See the file LICENSE for details.
#
# Copyright 2016 by Magnus Erik Hvass Pedersen
#
########################################################################

import sys
import os
import urllib
import tarfile
import zipfile

########################################################################


def _print_download_progress(count, block_size, total_size):
    """
    Function used for printing the download progress.
    Used as a call-back function in maybe_download_and_extract().
    """

    # Percentage completion.
    pct_complete = float(count * block_size) / total_size

    # Limit it because rounding errors may cause it to exceed 100%.
    pct_complete = min(1.0, pct_complete)

    # Status-message. Note the \r which means the line should overwrite itself.
    msg = "\r- Download progress: {0:.1%}".format(pct_complete)

    # Print it.
    sys.stdout.write(msg)
    sys.stdout.flush()


########################################################################

def download(base_url, filename, download_dir):
    """
    Download the given file if it does not already exist in the download_dir.
    :param base_url: The internet URL without the filename.
    :param filename: The filename that will be added to the base_url.
    :param download_dir: Local directory for storing the file.
    :return: Nothing.
    """

    # Path for local file.
    save_path = os.path.join(download_dir, filename)

    # Check if the file already exists, otherwise we need to download it now.
    if not os.path.exists(save_path):
        # Check if the download directory exists, otherwise create it.
        if not os.path.exists(download_dir):
            os.makedirs(download_dir)

        print("Downloading", filename, "...")

        # Download the file from the internet.
        url = base_url + filename
        file_path, _ = urllib.urlretrieve(url=url,
                                          filename=save_path,
                                          reporthook=_print_download_progress)

        print(" Done!")


def maybe_download_and_extract(url, download_dir):
    """
    Download and extract the data if it doesn't already exist.
    Assumes the url is a tar-ball file.
    :param url:
        Internet URL for the tar-file to download.
        Example: "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
    :param download_dir:
        Directory where the downloaded file is saved.
        Example: "data/CIFAR-10/"
    :return:
        Nothing.
    """

    # Filename for saving the file downloaded from the internet.
    # Use the filename from the URL and add it to the download_dir.
    filename = url.split('/')[-1]
    file_path = os.path.join(download_dir, filename)

    # Check if the file already exists.
    # If it exists then we assume it has also been extracted,
    # otherwise we need to download and extract it now.
    if not os.path.exists(file_path):
        # Check if the download directory exists, otherwise create it.
        if not os.path.exists(download_dir):
            os.makedirs(download_dir)

        # Download the file from the internet.
        file_path, _ = urllib.urlretrieve(url=url,
                                          filename=file_path,
                                          reporthook=_print_download_progress)

        print()
        print("Download finished. Extracting files.")

        if file_path.endswith(".zip"):
            # Unpack the zip-file.
            zipfile.ZipFile(file=file_path, mode="r").extractall(download_dir)
        elif file_path.endswith((".tar.gz", ".tgz")):
            # Unpack the tar-ball.
            tarfile.open(name=file_path, mode="r:gz").extractall(download_dir)

        print("Done.")
    else:
        print("Data has apparently already been downloaded and unpacked.")


In [4]:
#!pip install tensorflow==1.15.5
import tensorflow as tf
print(tf.__version__) #must be 1.15.5!

1.15.5


In [6]:
from tensorflow.examples.tutorials.mnist import input_data
class Lenet5(Compressible):
    def conv2d(self, x, W, b, padding='SAME', strides=1):
        x = tf.nn.conv2d(x, W, strides=[1, strides, strides, 1], padding=padding)
        x = tf.nn.bias_add(x, b)
        return tf.nn.relu(x)

    def maxpool2d(self, x, k=2, padding='SAME'):
        return tf.nn.max_pool(x, ksize=[1, k, k, 1], strides=[1, k, k, 1],
                              padding=padding)

    # Создаем модель
    def conv_net(self, x, weights):
        # MNIST подается на вход как вектор (1, 784)
        # Делаем RESHAPE для соответствия формату изображения [Height, Width, Channel]
        # Входной тензор становится следующим: [Batch Size, Height, Width, Channel]
        x = tf.reshape(x, shape=[-1, 28, 28, 1])

        # Сверточный слой
        conv1 = self.conv2d(x, weights['wc1'], weights['bc1'], padding='VALID')
        print(conv1.shape)
        # Maxpool
        conv1 = self.maxpool2d(conv1, k=2, padding='SAME')
        print(conv1.shape)
        # Сверточный слой
        conv2 = self.conv2d(conv1, weights['wc2'], weights['bc2'], padding='VALID')
        print(conv2.shape)
        # Maxpool
        conv2 = self.maxpool2d(conv2, k=2, padding='SAME')
        print(conv2.shape)

        # Полносвязный слой
        # Изменение выхода conv2, чтобы соответствовать входу полносвязаного
        fc1 = tf.reshape(conv2, [-1, weights['wd1'].get_shape().as_list()[0]])
        fc1 = tf.add(tf.matmul(fc1, weights['wd1']), weights['bd1'])
        fc1 = tf.nn.relu(fc1)

        # Выход, предсказание класса
        out = tf.add(tf.matmul(fc1, weights['out']), weights['bout'])
        return out

    def __init__(self, bpb, load_name=None):
        super(Lenet5, self).__init__('Lenet5')
        self.mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)

        # Обучающие параметры
        self.batch_size = 256

        # Параметры сети
        num_input = 784  # MNIST data input (img shape: 28*28)
        num_classes = 10  # MNIST total classes (0-9 digits)

        # tf Ввод графика
        self.X = tf.placeholder(tf.float32, [None, num_input]) - 0.5
        self.Y = tf.placeholder(tf.float32, [None, num_classes])

        # Веса
        weight_names = ['wc1', 'wc2', 'wd1', 'out', 'bc1', 'bc2', 'bd1', 'bout']
        weight_dims = [[5, 5, 1, 20], [5, 5, 20, 50], [4 * 4 * 50, 500],
                       [500, num_classes], [20], [50], [500], [num_classes]]
        weight_hash_groups = [1, 2, 50, 1, 1, 1, 1, 1]
        weight_initializers = []
        for d in weight_dims:
            if len(d) == 4:
                weight_initializers.append(('normal', np.sqrt(1. / (d[0] * d[1] * d[2]))))
            else:
                weight_initializers.append(('normal', np.sqrt(1. / d[0])))

        weights = {}
        weights.update(zip(weight_names, self.initialize_variables(weight_dims,
                                                                   weight_initializers,
                                                                   weight_hash_groups,
                                                                   30, bpb,
                                                                   kl_penalty_step=1.0001)))
        # Строим модель
        logits = self.conv_net(self.X, weights)

        # Оцениваем модель
        prediction = tf.nn.softmax(logits)
        correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(self.Y, 1))
        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, DTYPE))

        # Определяем loss и оптимизатор
        self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
            logits=logits, labels=self.Y)) + self.kl_loss

        global_step = tf.Variable(initial_value=0,
                                  name='global_step', trainable=False)
        learning_rate = tf.train.exponential_decay(
            0.001,  # Базовый learning rate.
            global_step,  # Текущий индекс в датасете
            30 * self.mnist.train.images.shape[0] / self.batch_size,  # Шаг
            1.,  # Скорость
            staircase=True)
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
        self.train_op = optimizer.minimize(self.loss)
        no_scales_list = [v for v in tf.trainable_variables() if v is not self.p_scale_vars]
        assert len(no_scales_list) < len(tf.trainable_variables())
        self.train_op_no_scales = optimizer.minimize(self.loss, var_list=no_scales_list)

        self.initialize_session(load_name)

    def get_feed_dict(self, validation=False):
        if validation:
            batch_x, batch_y = self.mnist.validation.images, self.mnist.validation.labels
        else:
            batch_x, batch_y = self.mnist.train.next_batch(self.batch_size)
        return {self.X: batch_x, self.Y: batch_y}

    def get_train_op(self, training=True):
        if training:
            return self.train_op
        else:
            return self.train_op_no_scales

In [None]:
# Обучаем LeNet-5 на MNIST
model = Lenet5(bpb=10)

model.train(200000, False)
model.train(200000, True)
# переобучаем,
print(model.compress(100))

Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Instructions for updating:
Please write your own downloading logic.
Instructions for updating:
Please use urllib or similar directly.
Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting /tmp/data/train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Instructions for updating:
Please use tf.one_hot on tensors.
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from t