In [None]:
#@title LICENSE
# Licensed under the Apache License, Version 2.0

In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] ='0'

In [2]:
import torch.utils.data as data
from torch.utils.tensorboard import SummaryWriter
import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR10

2023-06-30 13:37:01.553089: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
import torch
torch.cuda.is_available()

True

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-research/jaxpruner/blob/main/colabs/mnist_pruning.ipynb)

## Imports / Helpers

In [4]:
from absl import logging
import flax
import jax.numpy as jnp
from matplotlib import pyplot as plt
import numpy as np
import tensorflow_datasets as tfds
from flax import linen as nn
from flax.metrics import tensorboard
from flax.training import train_state
import jax
import optax
import tensorflow as tf
from tqdm import tqdm
# Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
# it unavailable to JAX.
tf.config.experimental.set_visible_devices([], "GPU")


logging.set_verbosity(logging.INFO)

In [5]:
!pip3 install git+https://github.com/google-research/jaxpruner
import jaxpruner
import ml_collections

Collecting git+https://github.com/google-research/jaxpruner
  Cloning https://github.com/google-research/jaxpruner to /tmp/pip-req-build-fysmeqm7
  Running command git clone --filter=blob:none --quiet https://github.com/google-research/jaxpruner /tmp/pip-req-build-fysmeqm7
  Resolved https://github.com/google-research/jaxpruner to commit d92a781c7e7c55a06c8b88d5bcf22b51cf44a890
  Preparing metadata (setup.py) ... [?25ldone


## Dataset

In [6]:
DATASET_PATH = "./data"
train_dataset = CIFAR10(root=DATASET_PATH, train=True, download=True)
DATA_MEANS = (train_dataset.data / 255.0).mean(axis=(0,1,2))
DATA_STD = (train_dataset.data / 255.0).std(axis=(0,1,2))
print("Data mean", DATA_MEANS)
print("Data std", DATA_STD)

Files already downloaded and verified
Data mean [0.49139968 0.48215841 0.44653091]
Data std [0.24703223 0.24348513 0.26158784]


In [7]:
def image_to_numpy(img):
    img = np.array(img, dtype=np.float32)
    img = (img / 255. - DATA_MEANS) / DATA_STD
    return img

# We need to stack the batch elements
def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)


In [8]:
test_transform = image_to_numpy
# For training, we add some augmentation. Networks are too powerful and would overfit.
train_transform = transforms.Compose([transforms.RandomHorizontalFlip(),
                                      transforms.RandomResizedCrop((32,32),scale=(0.8,1.0),ratio=(0.9,1.1)),
                                      image_to_numpy
                                     ])
train_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=train_transform, download=True)
val_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=test_transform, download=True)
train_set, _ = torch.utils.data.random_split(train_dataset, [45000, 5000], generator=torch.Generator().manual_seed(42))
_, val_set = torch.utils.data.random_split(val_dataset, [45000, 5000], generator=torch.Generator().manual_seed(42))

Files already downloaded and verified
Files already downloaded and verified


In [9]:
test_set = CIFAR10(root=DATASET_PATH, train=False, transform=test_transform, download=True)
train_loader = data.DataLoader(train_set,
                               batch_size=128,
                               shuffle=True,
                               drop_last=True,
                               collate_fn=numpy_collate,
                               num_workers=8,
                               persistent_workers=True)
val_loader   = data.DataLoader(val_set,
                               batch_size=128,
                               shuffle=False,
                               drop_last=False,
                               collate_fn=numpy_collate,
                               num_workers=4,
                               persistent_workers=True)
test_loader  = data.DataLoader(test_set,
                               batch_size=128,
                               shuffle=False,
                               drop_last=False,
                               collate_fn=numpy_collate,
                               num_workers=4,
                               persistent_workers=True)

Files already downloaded and verified


In [10]:
# def get_datasets():
#   """Load MNIST train and test datasets into memory."""
#   ds_builder = tfds.builder('mnist')
#   ds_builder.download_and_prepare()
#   train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))
#   test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))
#   train_ds['image'] = jnp.float32(train_ds['image']) / 255.
#   test_ds['image'] = jnp.float32(test_ds['image']) / 255.
#   return train_ds, test_ds

# # Helper functions for images.

# def show_img(img, ax=None, title=None):
#   """Shows a single image."""
#   if ax is None:
#     ax = plt.gca()
#   ax.imshow(img[..., 0], cmap='gray')
#   ax.set_xticks([])
#   ax.set_yticks([])
#   if title:
#     ax.set_title(title)

# def show_img_grid(imgs, titles):
#   """Shows a grid of images."""
#   n = int(np.ceil(len(imgs)**.5))
#   _, axs = plt.subplots(n, n, figsize=(3 * n, 3 * n))
#   for i, (img, title) in enumerate(zip(imgs, titles)):
#     show_img(img, axs[i // n][i % n], title)

In [11]:
# Get datasets as dict of JAX arrays.
# train_ds, test_ds = get_datasets()

## Training

In [12]:
import ml_collections
from typing import Any
from collections import defaultdict
from flax.training import train_state, checkpoints
import time
class TrainState(train_state.TrainState):
    # A simple extension of TrainState to also include batch statistics
    batch_stats: Any
# class CNN(nn.Module):
#   """A simple CNN model."""

#   @nn.compact
#   def __call__(self, x):
#     x = nn.Conv(features=32, kernel_size=(3, 3))(x)
#     x = nn.relu(x)
#     x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
#     x = nn.Conv(features=64, kernel_size=(3, 3))(x)
#     x = nn.relu(x)
#     x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
#     x = x.reshape((x.shape[0], -1))  # flatten
#     x = nn.Dense(features=256)(x)
#     x = nn.relu(x)
#     x = nn.Dense(features=10)(x)
#     return x
from resnet import ResNet20 as res20


@jax.jit
def apply_model_train(state, images, labels):
  """Computes gradients, loss and accuracy for a single batch."""
  def loss_fn(params):
  
    logits,updates = state.apply_fn({'params': params, 'batch_stats': state.batch_stats}, images, train = True, mutable = ['batch_stats'])
    
    one_hot = jax.nn.one_hot(labels, 10)
    loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
    return loss, (logits, updates['batch_stats'])

  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, (logits, new_batch_stats)), grads = grad_fn(state.params)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  return grads, loss, accuracy, new_batch_stats
@jax.jit
def apply_model_test(state, images, labels,batch_stats):
  """Computes gradients, loss and accuracy for a single batch."""
  def loss_fn(params):
  
    logits = state.apply_fn({'params': params, 'batch_stats': batch_stats}, images, train = False)
    one_hot = jax.nn.one_hot(labels, 10)
    loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
    return loss, logits

  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(state.params)
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  return grads, loss, accuracy


@jax.jit
def update_model(state, grads, batch_stats):
  # updates, new_opt_state = state.tx.update(
  #     grads, state.opt_state, state.params)
  # new_params = optax.apply_updates(state.params, updates)
  state = state.apply_gradients(grads=grads, batch_stats=batch_stats)
  return state
def test_epoch(state,testloader, batch_stats):
  epoch_loss = []
  epoch_accuracy = []
  for batch in tqdm(testloader,desc = 'Testing' , leave = False):
    batch_images = batch[0]
    batch_labels = batch[1]
    _,loss,accuracy = apply_model_test(state,batch_images, batch_labels, batch_stats = batch_stats)
    epoch_loss.append(loss)
    epoch_accuracy.append(accuracy)
  return np.mean(epoch_loss), np.mean(epoch_accuracy)

def train_epoch(state, trainloader, sparsity_updater):
  """Train for a single epoch."""
  # train_ds_size = len(train_ds['image'])
  # steps_per_epoch = train_ds_size // batch_size

  # perms = jax.random.permutation(rng, len(train_ds['image']))
  # perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
  # perms = perms.reshape((steps_per_epoch, batch_size))

  epoch_loss = []
  epoch_accuracy = []
  is_ste = isinstance(sparsity_updater, (jaxpruner.SteMagnitudePruning,
                                         jaxpruner.SteRandomPruning))
  pre_op = jax.jit(sparsity_updater.pre_forward_update)
  # for i, perm in enumerate(perms):
  i = 1
  for batch in tqdm(trainloader, desc = 'Training' , leave = False):
    i+=1;
    # print(i)
    batch_images = batch[0]
    batch_labels = batch[1]
    # Following is only needed for STE.

    new_params = pre_op(state.params, state.opt_state[2])
    forward_state = state.replace(params=new_params)

    grads, loss, accuracy, batch_stats = apply_model_train(forward_state, batch_images,
                                        batch_labels)
    state = update_model(state, grads, batch_stats)
    # print(len(state.opt_state))
    post_params = sparsity_updater.post_gradient_update(
        state.params, state.opt_state[2])
    state = state.replace(params=post_params)
    epoch_loss.append(loss)
    epoch_accuracy.append(accuracy)
    if i % 100 == 0:
      if is_ste:
        print(jaxpruner.summarize_sparsity(
            new_params, only_total_sparsity=True))
      else:
        print(jaxpruner.summarize_sparsity(
            state.params, only_total_sparsity=True))
  train_loss = np.mean(epoch_loss)
  train_accuracy = np.mean(epoch_accuracy)
  return state, train_loss, train_accuracy


def create_train_state(rng, config, num_epochs, num_steps_per_epoch):
  """Creates initial `TrainState`."""
  cnn = res20()
  variables = cnn.init(rng, jnp.ones([1, 32,32, 3]),train = False)
  params = variables['params']
  batch_stats = variables['batch_stats']
  sparsity_updater = jaxpruner.create_updater_from_config(config.sparsity_config)
  # tx = optax.adam(config.learning_rate, config.momentum)
  # tx = sparsity_updater.wrap_optax(tx)
  opt_class = optax.sgd
  lr_schedule = optax.piecewise_constant_schedule(
            init_value=1e-1,
            boundaries_and_scales=
                {int(num_steps_per_epoch*num_epochs*0.5): 0.1,
                 int(num_steps_per_epoch*num_epochs*0.85): 0.1}
        )
        # Clip gradients at max value, and evt. apply weight decay
  transf = [optax.clip(1.0)]
  hparam_dict = { "momentum" : config.momentum, "weight_decay" : 10}
  if opt_class == optax.sgd :  # wd is integrated in adamw
            transf.append(optax.add_decayed_weights(2e-4))
  optimizer = optax.chain(
            *transf,
            sparsity_updater.wrap_optax(opt_class(lr_schedule))
        )
        # Initialize training state
  # self.state = TrainState.create(apply_fn=self.model.apply,
  #                                      params=self.init_params if self.state is None else self.state.params,
  #                                      batch_stats=self.init_batch_stats if self.state is None else self.state.batch_stats,
  #                                      tx=optimizer)


  return TrainState.create(
      apply_fn=cnn.apply, params=params, tx=optimizer, batch_stats = batch_stats), sparsity_updater


import time

def train_and_evaluate(trainloader,testloader,config: ml_collections.ConfigDict
                       ) -> train_state.TrainState:
  """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    workdir: Directory where the tensorboard summaries are written to.

  Returns:
    The train state (which includes the `.params`).
  """
  # train_ds, test_ds = get_datasets()
  rng = jax.random.PRNGKey(0)

  rng, init_rng = jax.random.split(rng)

  state, sparsity_updater= create_train_state(init_rng, config, config.num_epochs, len(trainloader))

  # for epoch in range(1, config.num_epochs + 1):
  for epoch in tqdm(range(1,config.num_epochs+1)):
    s_time = time.time()
    rng, input_rng = jax.random.split(rng)
    state, train_loss, train_accuracy = train_epoch(state, trainloader,
                                                    sparsity_updater)
    # Following is only needed for STE.
    new_params = sparsity_updater.pre_forward_update(
        state.params, state.opt_state[2])
    forward_state = state.replace(params=new_params)
      
    test_loss, test_accuracy = test_epoch(forward_state, testloader,forward_state.batch_stats)

    print(
        'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f, time: %.2f'
        % (epoch, train_loss, train_accuracy * 100, test_loss,
           test_accuracy * 100, time.time() - s_time))

  return state



# Jaxpruner API

In [13]:
jaxpruner.ALGORITHMS

('no_prune',
 'magnitude',
 'random',
 'saliency',
 'magnitude_ste',
 'random_ste',
 'global_magnitude',
 'global_saliency',
 'static_sparse',
 'rigl',
 'set')

In [14]:
config = ml_collections.ConfigDict()

config.learning_rate = 0.01
config.momentum = 0.9
config.batch_size = 256
config.num_epochs = 200 # 1 epoch is 468 steps for bs=128

config.sparsity_config = ml_collections.ConfigDict()
config.sparsity_config.algorithm = 'rigl'
config.sparsity_config.update_freq = 10
config.sparsity_config.update_end_step = 1000
config.sparsity_config.update_start_step = 200
config.sparsity_config.sparsity = 0.5
config.sparsity_config.dist_type = 'erk'

In [15]:
jaxpruner.create_updater_from_config(config.sparsity_config)

RigL(scheduler=PolynomialSchedule(update_freq=10, update_start_step=200, update_end_step=1000, power=3), skip_gradients=True, is_sparse_gradients=True, sparsity_type=Unstructured(), sparsity_distribution_fn=functools.partial(<function erk at 0x7efe119ccaf0>, sparsity=0.5), rng_seed=Array([0, 8], dtype=uint32), use_packed_masks=False, eps=1e-05, drop_fraction_fn=<function cosine_decay_schedule.<locals>.schedule at 0x7effb1720040>, is_debug=False)

In [16]:
state = train_and_evaluate(train_loader,test_loader,config)

  0%|          | 0/200 [00:00<?, ?it/s]

{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


  0%|          | 1/200 [00:38<2:07:29, 38.44s/it]

epoch:  1, train_loss: 1.6303, train_accuracy: 39.52, test_loss: 1.3553, test_accuracy: 50.27, time: 38.44




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


  1%|          | 2/200 [00:52<1:18:56, 23.92s/it]

epoch:  2, train_loss: 1.2468, train_accuracy: 54.73, test_loss: 1.2243, test_accuracy: 56.66, time: 13.76




{'_total_sparsity': Array(0.5014719, dtype=float32)}




{'_total_sparsity': Array(0.49858725, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


  2%|▏         | 3/200 [01:05<1:02:36, 19.07s/it]

epoch:  3, train_loss: 1.0654, train_accuracy: 61.81, test_loss: 1.1924, test_accuracy: 58.39, time: 13.29




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


  2%|▏         | 4/200 [01:19<55:17, 16.93s/it]  

epoch:  4, train_loss: 0.9435, train_accuracy: 66.73, test_loss: 0.9951, test_accuracy: 65.56, time: 13.64




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


  2%|▎         | 5/200 [01:32<50:27, 15.53s/it]

epoch:  5, train_loss: 0.8445, train_accuracy: 69.92, test_loss: 0.8453, test_accuracy: 69.75, time: 13.04




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


  3%|▎         | 6/200 [01:45<47:18, 14.63s/it]

epoch:  6, train_loss: 0.7814, train_accuracy: 72.50, test_loss: 0.8081, test_accuracy: 72.44, time: 12.89




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


  4%|▎         | 7/200 [01:57<45:02, 14.00s/it]

epoch:  7, train_loss: 0.7297, train_accuracy: 74.19, test_loss: 0.7650, test_accuracy: 73.85, time: 12.70




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


  4%|▍         | 8/200 [02:10<43:34, 13.62s/it]

epoch:  8, train_loss: 0.6859, train_accuracy: 76.18, test_loss: 0.7690, test_accuracy: 73.31, time: 12.80




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


  4%|▍         | 9/200 [02:23<42:42, 13.41s/it]

epoch:  9, train_loss: 0.6530, train_accuracy: 77.23, test_loss: 0.7751, test_accuracy: 73.90, time: 12.97




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


  5%|▌         | 10/200 [02:36<42:15, 13.34s/it]

epoch: 10, train_loss: 0.6210, train_accuracy: 78.30, test_loss: 0.8970, test_accuracy: 71.42, time: 13.19




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


  6%|▌         | 11/200 [02:49<41:41, 13.24s/it]

epoch: 11, train_loss: 0.5980, train_accuracy: 79.30, test_loss: 0.7326, test_accuracy: 75.12, time: 12.99




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


  6%|▌         | 12/200 [03:02<41:19, 13.19s/it]

epoch: 12, train_loss: 0.5753, train_accuracy: 80.02, test_loss: 0.6859, test_accuracy: 76.66, time: 13.07




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


  6%|▋         | 13/200 [03:15<41:02, 13.17s/it]

epoch: 13, train_loss: 0.5508, train_accuracy: 80.80, test_loss: 0.7220, test_accuracy: 75.43, time: 13.12




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


  7%|▋         | 14/200 [03:28<40:42, 13.13s/it]

epoch: 14, train_loss: 0.5319, train_accuracy: 81.59, test_loss: 0.7728, test_accuracy: 74.53, time: 13.06




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


  8%|▊         | 15/200 [03:42<40:32, 13.15s/it]

epoch: 15, train_loss: 0.5201, train_accuracy: 81.92, test_loss: 0.5823, test_accuracy: 80.21, time: 13.18




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


  8%|▊         | 16/200 [03:54<39:57, 13.03s/it]

epoch: 16, train_loss: 0.4991, train_accuracy: 82.59, test_loss: 0.7240, test_accuracy: 76.61, time: 12.76




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


  8%|▊         | 17/200 [04:07<39:47, 13.04s/it]

epoch: 17, train_loss: 0.4894, train_accuracy: 83.05, test_loss: 0.7303, test_accuracy: 76.09, time: 13.07




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


  9%|▉         | 18/200 [04:21<39:55, 13.16s/it]

epoch: 18, train_loss: 0.4741, train_accuracy: 83.52, test_loss: 0.5792, test_accuracy: 80.49, time: 13.43




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 10%|▉         | 19/200 [04:34<39:37, 13.13s/it]

epoch: 19, train_loss: 0.4662, train_accuracy: 83.85, test_loss: 0.6755, test_accuracy: 77.02, time: 13.07




{'_total_sparsity': Array(0.49755895, dtype=float32)}



Training:  58%|█████▊    | 204/351 [00:07<00:05, 26.63it/s]

{'_total_sparsity': Array(0.49755895, dtype=float32)}


[A

{'_total_sparsity': Array(0.49755895, dtype=float32)}


 10%|█         | 20/200 [04:47<39:27, 13.15s/it]

epoch: 20, train_loss: 0.4488, train_accuracy: 84.42, test_loss: 0.5358, test_accuracy: 81.93, time: 13.20




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 10%|█         | 21/200 [05:00<38:54, 13.04s/it]

epoch: 21, train_loss: 0.4355, train_accuracy: 84.67, test_loss: 0.6203, test_accuracy: 79.60, time: 12.79




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 11%|█         | 22/200 [05:13<38:28, 12.97s/it]

epoch: 22, train_loss: 0.4303, train_accuracy: 85.18, test_loss: 0.5436, test_accuracy: 81.69, time: 12.80




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 12%|█▏        | 23/200 [05:26<38:05, 12.91s/it]

epoch: 23, train_loss: 0.4179, train_accuracy: 85.42, test_loss: 0.5846, test_accuracy: 80.89, time: 12.77




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 12%|█▏        | 24/200 [05:39<38:04, 12.98s/it]

epoch: 24, train_loss: 0.4068, train_accuracy: 85.82, test_loss: 0.5420, test_accuracy: 81.74, time: 13.14




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 12%|█▎        | 25/200 [05:52<37:45, 12.95s/it]

epoch: 25, train_loss: 0.3938, train_accuracy: 86.13, test_loss: 0.6154, test_accuracy: 80.18, time: 12.87




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 13%|█▎        | 26/200 [06:04<37:26, 12.91s/it]

epoch: 26, train_loss: 0.3933, train_accuracy: 86.20, test_loss: 0.5355, test_accuracy: 82.42, time: 12.83




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 14%|█▎        | 27/200 [06:18<37:25, 12.98s/it]

epoch: 27, train_loss: 0.3840, train_accuracy: 86.51, test_loss: 0.5140, test_accuracy: 82.67, time: 13.14




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 14%|█▍        | 28/200 [06:30<36:57, 12.89s/it]

epoch: 28, train_loss: 0.3777, train_accuracy: 86.86, test_loss: 0.4961, test_accuracy: 83.16, time: 12.68




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 14%|█▍        | 29/200 [06:43<36:29, 12.80s/it]

epoch: 29, train_loss: 0.3701, train_accuracy: 87.07, test_loss: 0.5817, test_accuracy: 80.59, time: 12.60




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 15%|█▌        | 30/200 [06:56<36:13, 12.79s/it]

epoch: 30, train_loss: 0.3600, train_accuracy: 87.26, test_loss: 0.5379, test_accuracy: 81.96, time: 12.74




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 16%|█▌        | 31/200 [07:09<36:12, 12.86s/it]

epoch: 31, train_loss: 0.3540, train_accuracy: 87.54, test_loss: 0.5217, test_accuracy: 82.64, time: 13.02




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 16%|█▌        | 32/200 [07:22<36:30, 13.04s/it]

epoch: 32, train_loss: 0.3489, train_accuracy: 87.89, test_loss: 0.5444, test_accuracy: 82.79, time: 13.46




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 16%|█▋        | 33/200 [07:35<36:27, 13.10s/it]

epoch: 33, train_loss: 0.3401, train_accuracy: 88.06, test_loss: 0.5341, test_accuracy: 82.59, time: 13.24




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 17%|█▋        | 34/200 [07:49<36:24, 13.16s/it]

epoch: 34, train_loss: 0.3399, train_accuracy: 88.09, test_loss: 0.6228, test_accuracy: 79.99, time: 13.31




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 18%|█▊        | 35/200 [08:02<36:30, 13.27s/it]

epoch: 35, train_loss: 0.3316, train_accuracy: 88.48, test_loss: 0.4764, test_accuracy: 84.48, time: 13.54




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 18%|█▊        | 36/200 [08:16<36:39, 13.41s/it]

epoch: 36, train_loss: 0.3309, train_accuracy: 88.44, test_loss: 0.6057, test_accuracy: 80.70, time: 13.72




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 18%|█▊        | 37/200 [08:29<36:17, 13.36s/it]

epoch: 37, train_loss: 0.3212, train_accuracy: 88.70, test_loss: 0.5388, test_accuracy: 82.87, time: 13.25




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 19%|█▉        | 38/200 [08:42<35:46, 13.25s/it]

epoch: 38, train_loss: 0.3161, train_accuracy: 88.88, test_loss: 0.4731, test_accuracy: 84.26, time: 12.99




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 20%|█▉        | 39/200 [08:55<35:23, 13.19s/it]

epoch: 39, train_loss: 0.3136, train_accuracy: 88.99, test_loss: 0.5016, test_accuracy: 84.12, time: 13.06




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 20%|██        | 40/200 [09:08<34:43, 13.02s/it]

epoch: 40, train_loss: 0.3071, train_accuracy: 89.26, test_loss: 0.5324, test_accuracy: 82.89, time: 12.62




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 20%|██        | 41/200 [09:21<34:39, 13.08s/it]

epoch: 41, train_loss: 0.3031, train_accuracy: 89.38, test_loss: 0.5322, test_accuracy: 83.30, time: 13.21




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 21%|██        | 42/200 [09:34<34:05, 12.95s/it]

epoch: 42, train_loss: 0.2978, train_accuracy: 89.60, test_loss: 0.5152, test_accuracy: 83.66, time: 12.64




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 22%|██▏       | 43/200 [09:47<34:24, 13.15s/it]

epoch: 43, train_loss: 0.2936, train_accuracy: 89.64, test_loss: 0.5773, test_accuracy: 82.50, time: 13.62




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 22%|██▏       | 44/200 [10:00<34:13, 13.16s/it]

epoch: 44, train_loss: 0.2947, train_accuracy: 89.75, test_loss: 0.6714, test_accuracy: 79.72, time: 13.20




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 22%|██▎       | 45/200 [10:13<33:52, 13.11s/it]

epoch: 45, train_loss: 0.2837, train_accuracy: 90.18, test_loss: 0.4974, test_accuracy: 84.31, time: 12.99




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 23%|██▎       | 46/200 [10:26<33:31, 13.06s/it]

epoch: 46, train_loss: 0.2826, train_accuracy: 89.85, test_loss: 0.4681, test_accuracy: 84.92, time: 12.94




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 24%|██▎       | 47/200 [10:40<33:31, 13.14s/it]

epoch: 47, train_loss: 0.2818, train_accuracy: 90.28, test_loss: 0.4859, test_accuracy: 84.40, time: 13.34




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 24%|██▍       | 48/200 [10:53<33:17, 13.14s/it]

epoch: 48, train_loss: 0.2796, train_accuracy: 90.05, test_loss: 0.5325, test_accuracy: 82.68, time: 13.14




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 24%|██▍       | 49/200 [11:05<32:41, 12.99s/it]

epoch: 49, train_loss: 0.2746, train_accuracy: 90.41, test_loss: 0.5176, test_accuracy: 83.47, time: 12.63




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 25%|██▌       | 50/200 [11:18<32:17, 12.91s/it]

epoch: 50, train_loss: 0.2699, train_accuracy: 90.39, test_loss: 0.4681, test_accuracy: 84.55, time: 12.73




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 26%|██▌       | 51/200 [11:31<32:06, 12.93s/it]

epoch: 51, train_loss: 0.2648, train_accuracy: 90.71, test_loss: 0.5017, test_accuracy: 84.35, time: 12.96




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 26%|██▌       | 52/200 [11:44<31:45, 12.88s/it]

epoch: 52, train_loss: 0.2626, train_accuracy: 90.78, test_loss: 0.4460, test_accuracy: 85.47, time: 12.76




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 26%|██▋       | 53/200 [11:58<32:03, 13.09s/it]

epoch: 53, train_loss: 0.2603, train_accuracy: 90.91, test_loss: 0.5099, test_accuracy: 84.27, time: 13.57




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 27%|██▋       | 54/200 [12:10<31:29, 12.94s/it]

epoch: 54, train_loss: 0.2577, train_accuracy: 90.97, test_loss: 0.5809, test_accuracy: 82.08, time: 12.61




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 28%|██▊       | 55/200 [12:23<31:14, 12.93s/it]

epoch: 55, train_loss: 0.2550, train_accuracy: 91.03, test_loss: 0.4603, test_accuracy: 85.27, time: 12.89




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 28%|██▊       | 56/200 [12:36<31:01, 12.93s/it]

epoch: 56, train_loss: 0.2512, train_accuracy: 91.23, test_loss: 0.5022, test_accuracy: 84.32, time: 12.93



Training:  30%|██▉       | 105/351 [00:03<00:08, 29.60it/s]

{'_total_sparsity': Array(0.49755895, dtype=float32)}


[A

{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 28%|██▊       | 57/200 [12:49<30:43, 12.89s/it]

epoch: 57, train_loss: 0.2476, train_accuracy: 91.23, test_loss: 0.5147, test_accuracy: 83.88, time: 12.79




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 29%|██▉       | 58/200 [13:01<30:20, 12.82s/it]

epoch: 58, train_loss: 0.2476, train_accuracy: 91.27, test_loss: 0.4775, test_accuracy: 85.33, time: 12.65




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 30%|██▉       | 59/200 [13:14<30:16, 12.88s/it]

epoch: 59, train_loss: 0.2460, train_accuracy: 91.18, test_loss: 0.5130, test_accuracy: 84.52, time: 13.03




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 30%|███       | 60/200 [13:27<29:54, 12.82s/it]

epoch: 60, train_loss: 0.2488, train_accuracy: 91.16, test_loss: 0.5315, test_accuracy: 83.75, time: 12.67




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 30%|███       | 61/200 [13:40<29:53, 12.90s/it]

epoch: 61, train_loss: 0.2426, train_accuracy: 91.48, test_loss: 0.5167, test_accuracy: 84.08, time: 13.10




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 31%|███       | 62/200 [13:53<29:53, 12.99s/it]

epoch: 62, train_loss: 0.2394, train_accuracy: 91.57, test_loss: 0.5307, test_accuracy: 83.76, time: 13.21




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 32%|███▏      | 63/200 [14:06<29:31, 12.93s/it]

epoch: 63, train_loss: 0.2352, train_accuracy: 91.66, test_loss: 0.4890, test_accuracy: 84.66, time: 12.79




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 32%|███▏      | 64/200 [14:19<29:07, 12.85s/it]

epoch: 64, train_loss: 0.2351, train_accuracy: 91.66, test_loss: 0.4625, test_accuracy: 85.85, time: 12.67




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 32%|███▎      | 65/200 [14:32<28:49, 12.81s/it]

epoch: 65, train_loss: 0.2305, train_accuracy: 91.84, test_loss: 0.4919, test_accuracy: 84.75, time: 12.72




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 33%|███▎      | 66/200 [14:45<28:43, 12.86s/it]

epoch: 66, train_loss: 0.2274, train_accuracy: 92.04, test_loss: 0.4797, test_accuracy: 85.04, time: 12.97




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 34%|███▎      | 67/200 [14:57<28:20, 12.78s/it]

epoch: 67, train_loss: 0.2237, train_accuracy: 92.03, test_loss: 0.4931, test_accuracy: 84.49, time: 12.60




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 34%|███▍      | 68/200 [15:10<28:08, 12.79s/it]

epoch: 68, train_loss: 0.2264, train_accuracy: 91.94, test_loss: 0.4596, test_accuracy: 85.20, time: 12.81




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 34%|███▍      | 69/200 [15:23<27:46, 12.72s/it]

epoch: 69, train_loss: 0.2205, train_accuracy: 92.30, test_loss: 0.5329, test_accuracy: 84.06, time: 12.57




{'_total_sparsity': Array(0.49755895, dtype=float32)}





{'_total_sparsity': Array(0.49755895, dtype=float32)}


Training:  58%|█████▊    | 204/351 [00:07<00:05, 25.98it/s][A

{'_total_sparsity': Array(0.49755895, dtype=float32)}


 35%|███▌      | 70/200 [15:36<28:04, 12.96s/it]

epoch: 70, train_loss: 0.2248, train_accuracy: 91.93, test_loss: 0.4885, test_accuracy: 85.19, time: 13.50




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 36%|███▌      | 71/200 [15:48<27:31, 12.80s/it]

epoch: 71, train_loss: 0.2238, train_accuracy: 92.02, test_loss: 0.4893, test_accuracy: 84.53, time: 12.45




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 36%|███▌      | 72/200 [16:02<27:36, 12.94s/it]

epoch: 72, train_loss: 0.2188, train_accuracy: 92.23, test_loss: 0.5182, test_accuracy: 83.56, time: 13.27




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 36%|███▋      | 73/200 [16:14<27:14, 12.87s/it]

epoch: 73, train_loss: 0.2110, train_accuracy: 92.64, test_loss: 0.4800, test_accuracy: 85.14, time: 12.69




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 37%|███▋      | 74/200 [16:27<26:55, 12.82s/it]

epoch: 74, train_loss: 0.2161, train_accuracy: 92.23, test_loss: 0.4823, test_accuracy: 85.22, time: 12.71




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 38%|███▊      | 75/200 [16:40<26:38, 12.79s/it]

epoch: 75, train_loss: 0.2082, train_accuracy: 92.55, test_loss: 0.4861, test_accuracy: 84.97, time: 12.71




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 38%|███▊      | 76/200 [16:53<26:49, 12.98s/it]

epoch: 76, train_loss: 0.2150, train_accuracy: 92.24, test_loss: 0.4293, test_accuracy: 86.26, time: 13.42




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 38%|███▊      | 77/200 [17:06<26:30, 12.93s/it]

epoch: 77, train_loss: 0.2079, train_accuracy: 92.68, test_loss: 0.4657, test_accuracy: 85.85, time: 12.82




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 39%|███▉      | 78/200 [17:19<26:14, 12.90s/it]

epoch: 78, train_loss: 0.2077, train_accuracy: 92.65, test_loss: 0.4416, test_accuracy: 86.52, time: 12.84




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 40%|███▉      | 79/200 [17:32<25:50, 12.81s/it]

epoch: 79, train_loss: 0.2059, train_accuracy: 92.87, test_loss: 0.5042, test_accuracy: 84.58, time: 12.60




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 40%|████      | 80/200 [17:44<25:33, 12.78s/it]

epoch: 80, train_loss: 0.2051, train_accuracy: 92.70, test_loss: 0.4404, test_accuracy: 86.53, time: 12.70




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 40%|████      | 81/200 [17:57<25:19, 12.77s/it]

epoch: 81, train_loss: 0.1992, train_accuracy: 92.73, test_loss: 0.4775, test_accuracy: 85.30, time: 12.74




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 41%|████      | 82/200 [18:10<25:15, 12.84s/it]

epoch: 82, train_loss: 0.1995, train_accuracy: 93.00, test_loss: 0.4793, test_accuracy: 85.57, time: 13.01




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 42%|████▏     | 83/200 [18:23<25:10, 12.91s/it]

epoch: 83, train_loss: 0.2023, train_accuracy: 92.81, test_loss: 0.4575, test_accuracy: 86.03, time: 13.09




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 42%|████▏     | 84/200 [18:36<24:51, 12.86s/it]

epoch: 84, train_loss: 0.2017, train_accuracy: 92.77, test_loss: 0.4485, test_accuracy: 86.16, time: 12.73




{'_total_sparsity': Array(0.49755895, dtype=float32)}



Training:  58%|█████▊    | 205/351 [00:07<00:04, 29.46it/s]

{'_total_sparsity': Array(0.49755895, dtype=float32)}


[A

{'_total_sparsity': Array(0.49755895, dtype=float32)}


 42%|████▎     | 85/200 [18:49<24:53, 12.99s/it]

epoch: 85, train_loss: 0.2037, train_accuracy: 92.63, test_loss: 0.4777, test_accuracy: 85.35, time: 13.28




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 43%|████▎     | 86/200 [19:03<25:04, 13.20s/it]

epoch: 86, train_loss: 0.1959, train_accuracy: 93.05, test_loss: 0.4602, test_accuracy: 85.93, time: 13.70




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 44%|████▎     | 87/200 [19:16<24:49, 13.18s/it]

epoch: 87, train_loss: 0.1974, train_accuracy: 92.96, test_loss: 0.4779, test_accuracy: 85.22, time: 13.15




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 44%|████▍     | 88/200 [19:29<24:35, 13.18s/it]

epoch: 88, train_loss: 0.1916, train_accuracy: 93.25, test_loss: 0.4608, test_accuracy: 86.20, time: 13.16




{'_total_sparsity': Array(0.49755895, dtype=float32)}



Training:  58%|█████▊    | 205/351 [00:07<00:04, 29.38it/s]

{'_total_sparsity': Array(0.49755895, dtype=float32)}


[A

{'_total_sparsity': Array(0.49755895, dtype=float32)}


 44%|████▍     | 89/200 [19:42<24:16, 13.12s/it]

epoch: 89, train_loss: 0.1944, train_accuracy: 93.09, test_loss: 0.4954, test_accuracy: 85.12, time: 13.00




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 45%|████▌     | 90/200 [19:55<23:52, 13.02s/it]

epoch: 90, train_loss: 0.1932, train_accuracy: 93.01, test_loss: 0.4925, test_accuracy: 85.33, time: 12.78




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 46%|████▌     | 91/200 [20:07<23:24, 12.89s/it]

epoch: 91, train_loss: 0.1934, train_accuracy: 93.16, test_loss: 0.4578, test_accuracy: 85.99, time: 12.57




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 46%|████▌     | 92/200 [20:21<23:17, 12.94s/it]

epoch: 92, train_loss: 0.1921, train_accuracy: 93.18, test_loss: 0.4619, test_accuracy: 85.81, time: 13.05




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 46%|████▋     | 93/200 [20:33<23:00, 12.90s/it]

epoch: 93, train_loss: 0.1923, train_accuracy: 93.22, test_loss: 0.4399, test_accuracy: 86.28, time: 12.81




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 47%|████▋     | 94/200 [20:46<22:37, 12.80s/it]

epoch: 94, train_loss: 0.1842, train_accuracy: 93.40, test_loss: 0.5282, test_accuracy: 84.97, time: 12.59




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 48%|████▊     | 95/200 [20:59<22:34, 12.90s/it]

epoch: 95, train_loss: 0.1872, train_accuracy: 93.25, test_loss: 0.4861, test_accuracy: 85.39, time: 13.13




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 48%|████▊     | 96/200 [21:12<22:23, 12.92s/it]

epoch: 96, train_loss: 0.1874, train_accuracy: 93.32, test_loss: 0.4810, test_accuracy: 85.84, time: 12.96




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 48%|████▊     | 97/200 [21:25<22:18, 13.00s/it]

epoch: 97, train_loss: 0.1846, train_accuracy: 93.34, test_loss: 0.4495, test_accuracy: 86.16, time: 13.17




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 49%|████▉     | 98/200 [21:38<22:09, 13.03s/it]

epoch: 98, train_loss: 0.1892, train_accuracy: 93.27, test_loss: 0.4813, test_accuracy: 85.22, time: 13.12




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 50%|████▉     | 99/200 [21:51<22:01, 13.09s/it]

epoch: 99, train_loss: 0.1764, train_accuracy: 93.77, test_loss: 0.4568, test_accuracy: 86.39, time: 13.21




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 50%|█████     | 100/200 [22:04<21:36, 12.97s/it]

epoch: 100, train_loss: 0.1784, train_accuracy: 93.63, test_loss: 0.4705, test_accuracy: 85.85, time: 12.69




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 50%|█████     | 101/200 [22:17<21:14, 12.88s/it]

epoch: 101, train_loss: 0.1485, train_accuracy: 94.95, test_loss: 0.4002, test_accuracy: 87.67, time: 12.67




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 51%|█████     | 102/200 [22:30<21:19, 13.06s/it]

epoch: 102, train_loss: 0.1136, train_accuracy: 96.18, test_loss: 0.3859, test_accuracy: 88.14, time: 13.48




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 52%|█████▏    | 103/200 [22:44<21:13, 13.12s/it]

epoch: 103, train_loss: 0.1084, train_accuracy: 96.48, test_loss: 0.4042, test_accuracy: 88.04, time: 13.28




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 52%|█████▏    | 104/200 [22:57<20:59, 13.12s/it]

epoch: 104, train_loss: 0.1047, train_accuracy: 96.50, test_loss: 0.3828, test_accuracy: 88.32, time: 13.09




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 52%|█████▎    | 105/200 [23:11<21:22, 13.50s/it]

epoch: 105, train_loss: 0.1002, train_accuracy: 96.82, test_loss: 0.3950, test_accuracy: 88.17, time: 14.38




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 53%|█████▎    | 106/200 [23:24<21:06, 13.47s/it]

epoch: 106, train_loss: 0.0971, train_accuracy: 96.82, test_loss: 0.4007, test_accuracy: 88.05, time: 13.41




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 54%|█████▎    | 107/200 [23:37<20:31, 13.24s/it]

epoch: 107, train_loss: 0.0950, train_accuracy: 96.86, test_loss: 0.4082, test_accuracy: 87.94, time: 12.70




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 54%|█████▍    | 108/200 [23:51<20:29, 13.37s/it]

epoch: 108, train_loss: 0.0915, train_accuracy: 97.11, test_loss: 0.4078, test_accuracy: 87.88, time: 13.67




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 55%|█████▍    | 109/200 [24:04<20:18, 13.39s/it]

epoch: 109, train_loss: 0.0924, train_accuracy: 96.96, test_loss: 0.4024, test_accuracy: 88.34, time: 13.45




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 55%|█████▌    | 110/200 [24:17<19:49, 13.22s/it]

epoch: 110, train_loss: 0.0884, train_accuracy: 97.14, test_loss: 0.3953, test_accuracy: 88.38, time: 12.82




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 56%|█████▌    | 111/200 [24:31<19:40, 13.27s/it]

epoch: 111, train_loss: 0.0891, train_accuracy: 97.13, test_loss: 0.3941, test_accuracy: 88.35, time: 13.38




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 56%|█████▌    | 112/200 [24:43<19:18, 13.17s/it]

epoch: 112, train_loss: 0.0896, train_accuracy: 97.10, test_loss: 0.4159, test_accuracy: 87.90, time: 12.93




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 56%|█████▋    | 113/200 [24:57<19:09, 13.21s/it]

epoch: 113, train_loss: 0.0900, train_accuracy: 97.21, test_loss: 0.4192, test_accuracy: 87.99, time: 13.31




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 57%|█████▋    | 114/200 [25:10<18:52, 13.17s/it]

epoch: 114, train_loss: 0.0886, train_accuracy: 97.12, test_loss: 0.4204, test_accuracy: 87.60, time: 13.06




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 57%|█████▊    | 115/200 [25:23<18:49, 13.29s/it]

epoch: 115, train_loss: 0.0859, train_accuracy: 97.18, test_loss: 0.3900, test_accuracy: 88.46, time: 13.57




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 58%|█████▊    | 116/200 [25:37<18:34, 13.27s/it]

epoch: 116, train_loss: 0.0834, train_accuracy: 97.26, test_loss: 0.3895, test_accuracy: 88.52, time: 13.24




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 58%|█████▊    | 117/200 [25:50<18:15, 13.20s/it]

epoch: 117, train_loss: 0.0834, train_accuracy: 97.34, test_loss: 0.3995, test_accuracy: 88.40, time: 13.02




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 59%|█████▉    | 118/200 [26:03<18:01, 13.19s/it]

epoch: 118, train_loss: 0.0812, train_accuracy: 97.46, test_loss: 0.3939, test_accuracy: 88.54, time: 13.18




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 60%|█████▉    | 119/200 [26:16<17:46, 13.16s/it]

epoch: 119, train_loss: 0.0832, train_accuracy: 97.38, test_loss: 0.4070, test_accuracy: 87.92, time: 13.09




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 60%|██████    | 120/200 [26:29<17:24, 13.06s/it]

epoch: 120, train_loss: 0.0808, train_accuracy: 97.39, test_loss: 0.4088, test_accuracy: 88.11, time: 12.82




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 60%|██████    | 121/200 [26:43<17:30, 13.29s/it]

epoch: 121, train_loss: 0.0809, train_accuracy: 97.38, test_loss: 0.4056, test_accuracy: 88.26, time: 13.84




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 61%|██████    | 122/200 [26:56<17:12, 13.23s/it]

epoch: 122, train_loss: 0.0769, train_accuracy: 97.63, test_loss: 0.3937, test_accuracy: 88.37, time: 13.09




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 62%|██████▏   | 123/200 [27:09<16:56, 13.20s/it]

epoch: 123, train_loss: 0.0801, train_accuracy: 97.45, test_loss: 0.4062, test_accuracy: 88.28, time: 13.13




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 62%|██████▏   | 124/200 [27:22<16:43, 13.21s/it]

epoch: 124, train_loss: 0.0778, train_accuracy: 97.57, test_loss: 0.4069, test_accuracy: 88.24, time: 13.22




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 62%|██████▎   | 125/200 [27:36<16:50, 13.47s/it]

epoch: 125, train_loss: 0.0777, train_accuracy: 97.48, test_loss: 0.4097, test_accuracy: 88.06, time: 14.08




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 63%|██████▎   | 126/200 [27:49<16:33, 13.42s/it]

epoch: 126, train_loss: 0.0764, train_accuracy: 97.56, test_loss: 0.3924, test_accuracy: 88.42, time: 13.32




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 64%|██████▎   | 127/200 [28:03<16:14, 13.35s/it]

epoch: 127, train_loss: 0.0744, train_accuracy: 97.65, test_loss: 0.4059, test_accuracy: 88.32, time: 13.19




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 64%|██████▍   | 128/200 [28:15<15:44, 13.12s/it]

epoch: 128, train_loss: 0.0776, train_accuracy: 97.44, test_loss: 0.4016, test_accuracy: 88.45, time: 12.58




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 64%|██████▍   | 129/200 [28:28<15:35, 13.17s/it]

epoch: 129, train_loss: 0.0749, train_accuracy: 97.68, test_loss: 0.4095, test_accuracy: 88.26, time: 13.29




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 65%|██████▌   | 130/200 [28:42<15:24, 13.21s/it]

epoch: 130, train_loss: 0.0738, train_accuracy: 97.69, test_loss: 0.4057, test_accuracy: 88.35, time: 13.29




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 66%|██████▌   | 131/200 [28:54<14:57, 13.01s/it]

epoch: 131, train_loss: 0.0734, train_accuracy: 97.77, test_loss: 0.4027, test_accuracy: 88.43, time: 12.54




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 66%|██████▌   | 132/200 [29:07<14:36, 12.90s/it]

epoch: 132, train_loss: 0.0747, train_accuracy: 97.64, test_loss: 0.3990, test_accuracy: 88.46, time: 12.64




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 66%|██████▋   | 133/200 [29:20<14:17, 12.80s/it]

epoch: 133, train_loss: 0.0738, train_accuracy: 97.64, test_loss: 0.4144, test_accuracy: 88.28, time: 12.57




{'_total_sparsity': Array(0.49755895, dtype=float32)}



Training:  58%|█████▊    | 205/351 [00:06<00:04, 30.00it/s]

{'_total_sparsity': Array(0.49755895, dtype=float32)}


[A

{'_total_sparsity': Array(0.49755895, dtype=float32)}


 67%|██████▋   | 134/200 [29:32<14:03, 12.77s/it]

epoch: 134, train_loss: 0.0708, train_accuracy: 97.71, test_loss: 0.4177, test_accuracy: 88.11, time: 12.72




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 68%|██████▊   | 135/200 [29:45<13:53, 12.82s/it]

epoch: 135, train_loss: 0.0707, train_accuracy: 97.73, test_loss: 0.4137, test_accuracy: 88.09, time: 12.91




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 68%|██████▊   | 136/200 [29:58<13:36, 12.76s/it]

epoch: 136, train_loss: 0.0709, train_accuracy: 97.71, test_loss: 0.4155, test_accuracy: 88.31, time: 12.64




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 68%|██████▊   | 137/200 [30:11<13:29, 12.84s/it]

epoch: 137, train_loss: 0.0691, train_accuracy: 97.81, test_loss: 0.4178, test_accuracy: 88.29, time: 13.03




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 69%|██████▉   | 138/200 [30:24<13:17, 12.87s/it]

epoch: 138, train_loss: 0.0701, train_accuracy: 97.79, test_loss: 0.4269, test_accuracy: 88.01, time: 12.92




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 70%|██████▉   | 139/200 [30:37<13:12, 12.99s/it]

epoch: 139, train_loss: 0.0697, train_accuracy: 97.76, test_loss: 0.4084, test_accuracy: 88.44, time: 13.26




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 70%|███████   | 140/200 [30:50<12:55, 12.92s/it]

epoch: 140, train_loss: 0.0700, train_accuracy: 97.78, test_loss: 0.4215, test_accuracy: 88.25, time: 12.76




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 70%|███████   | 141/200 [31:02<12:37, 12.84s/it]

epoch: 141, train_loss: 0.0684, train_accuracy: 97.78, test_loss: 0.4222, test_accuracy: 88.33, time: 12.66




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 71%|███████   | 142/200 [31:15<12:19, 12.74s/it]

epoch: 142, train_loss: 0.0671, train_accuracy: 97.89, test_loss: 0.4019, test_accuracy: 88.64, time: 12.52




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 72%|███████▏  | 143/200 [31:29<12:20, 12.99s/it]

epoch: 143, train_loss: 0.0681, train_accuracy: 97.81, test_loss: 0.4146, test_accuracy: 88.28, time: 13.56




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 72%|███████▏  | 144/200 [31:41<12:02, 12.90s/it]

epoch: 144, train_loss: 0.0692, train_accuracy: 97.69, test_loss: 0.4262, test_accuracy: 88.00, time: 12.69




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 72%|███████▎  | 145/200 [31:54<11:45, 12.83s/it]

epoch: 145, train_loss: 0.0660, train_accuracy: 97.86, test_loss: 0.4373, test_accuracy: 87.68, time: 12.66




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 73%|███████▎  | 146/200 [32:06<11:29, 12.76s/it]

epoch: 146, train_loss: 0.0668, train_accuracy: 97.86, test_loss: 0.4157, test_accuracy: 88.34, time: 12.61




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 74%|███████▎  | 147/200 [32:20<11:27, 12.97s/it]

epoch: 147, train_loss: 0.0656, train_accuracy: 97.89, test_loss: 0.4219, test_accuracy: 88.37, time: 13.44




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 74%|███████▍  | 148/200 [32:33<11:13, 12.96s/it]

epoch: 148, train_loss: 0.0690, train_accuracy: 97.82, test_loss: 0.4677, test_accuracy: 87.10, time: 12.95




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 74%|███████▍  | 149/200 [32:47<11:13, 13.21s/it]

epoch: 149, train_loss: 0.0671, train_accuracy: 97.90, test_loss: 0.4300, test_accuracy: 88.20, time: 13.79




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 75%|███████▌  | 150/200 [33:00<11:03, 13.28s/it]

epoch: 150, train_loss: 0.0671, train_accuracy: 97.89, test_loss: 0.4232, test_accuracy: 88.41, time: 13.44




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 76%|███████▌  | 151/200 [33:13<10:39, 13.06s/it]

epoch: 151, train_loss: 0.0636, train_accuracy: 97.91, test_loss: 0.4230, test_accuracy: 88.22, time: 12.54




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 76%|███████▌  | 152/200 [33:25<10:23, 12.99s/it]

epoch: 152, train_loss: 0.0642, train_accuracy: 97.92, test_loss: 0.4122, test_accuracy: 88.38, time: 12.82




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 76%|███████▋  | 153/200 [33:38<10:09, 12.96s/it]

epoch: 153, train_loss: 0.0650, train_accuracy: 97.90, test_loss: 0.4209, test_accuracy: 88.31, time: 12.89




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 77%|███████▋  | 154/200 [33:51<09:54, 12.92s/it]

epoch: 154, train_loss: 0.0661, train_accuracy: 97.92, test_loss: 0.4259, test_accuracy: 87.98, time: 12.84




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 78%|███████▊  | 155/200 [34:04<09:38, 12.85s/it]

epoch: 155, train_loss: 0.0623, train_accuracy: 98.13, test_loss: 0.4336, test_accuracy: 87.92, time: 12.69




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 78%|███████▊  | 156/200 [34:17<09:24, 12.82s/it]

epoch: 156, train_loss: 0.0646, train_accuracy: 97.99, test_loss: 0.4110, test_accuracy: 88.57, time: 12.73




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 78%|███████▊  | 157/200 [34:30<09:18, 12.98s/it]

epoch: 157, train_loss: 0.0638, train_accuracy: 97.94, test_loss: 0.4677, test_accuracy: 87.18, time: 13.35




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 79%|███████▉  | 158/200 [34:43<09:07, 13.02s/it]

epoch: 158, train_loss: 0.0624, train_accuracy: 98.05, test_loss: 0.4282, test_accuracy: 87.86, time: 13.13




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 80%|███████▉  | 159/200 [34:56<08:56, 13.08s/it]

epoch: 159, train_loss: 0.0642, train_accuracy: 98.06, test_loss: 0.4125, test_accuracy: 88.54, time: 13.21




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 80%|████████  | 160/200 [35:10<08:44, 13.12s/it]

epoch: 160, train_loss: 0.0623, train_accuracy: 98.03, test_loss: 0.4332, test_accuracy: 88.02, time: 13.20




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 80%|████████  | 161/200 [35:23<08:32, 13.15s/it]

epoch: 161, train_loss: 0.0645, train_accuracy: 97.94, test_loss: 0.4297, test_accuracy: 88.34, time: 13.23




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 81%|████████  | 162/200 [35:35<08:15, 13.03s/it]

epoch: 162, train_loss: 0.0610, train_accuracy: 98.08, test_loss: 0.4666, test_accuracy: 87.66, time: 12.74




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 82%|████████▏ | 163/200 [35:49<08:05, 13.12s/it]

epoch: 163, train_loss: 0.0612, train_accuracy: 98.04, test_loss: 0.4404, test_accuracy: 88.07, time: 13.32




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 82%|████████▏ | 164/200 [36:02<07:54, 13.19s/it]

epoch: 164, train_loss: 0.0642, train_accuracy: 97.92, test_loss: 0.4295, test_accuracy: 87.97, time: 13.36




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 82%|████████▎ | 165/200 [36:15<07:38, 13.09s/it]

epoch: 165, train_loss: 0.0611, train_accuracy: 98.07, test_loss: 0.4219, test_accuracy: 88.27, time: 12.86




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}



Training:  87%|████████▋ | 305/351 [00:10<00:01, 29.12it/s]

{'_total_sparsity': Array(0.49755895, dtype=float32)}


 83%|████████▎ | 166/200 [36:28<07:21, 12.99s/it]

epoch: 166, train_loss: 0.0625, train_accuracy: 98.10, test_loss: 0.4366, test_accuracy: 88.06, time: 12.74




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 84%|████████▎ | 167/200 [36:41<07:08, 12.97s/it]

epoch: 167, train_loss: 0.0606, train_accuracy: 98.13, test_loss: 0.4278, test_accuracy: 88.19, time: 12.94




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 84%|████████▍ | 168/200 [36:54<06:56, 13.02s/it]

epoch: 168, train_loss: 0.0598, train_accuracy: 98.18, test_loss: 0.4510, test_accuracy: 87.66, time: 13.14




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 84%|████████▍ | 169/200 [37:07<06:46, 13.13s/it]

epoch: 169, train_loss: 0.0608, train_accuracy: 98.04, test_loss: 0.4413, test_accuracy: 87.95, time: 13.37




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 85%|████████▌ | 170/200 [37:20<06:29, 12.99s/it]

epoch: 170, train_loss: 0.0594, train_accuracy: 98.10, test_loss: 0.4526, test_accuracy: 87.64, time: 12.67




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 86%|████████▌ | 171/200 [37:33<06:14, 12.93s/it]

epoch: 171, train_loss: 0.0586, train_accuracy: 98.20, test_loss: 0.4230, test_accuracy: 88.38, time: 12.77




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 86%|████████▌ | 172/200 [37:45<05:59, 12.82s/it]

epoch: 172, train_loss: 0.0559, train_accuracy: 98.26, test_loss: 0.4354, test_accuracy: 88.26, time: 12.58




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 86%|████████▋ | 173/200 [37:58<05:44, 12.75s/it]

epoch: 173, train_loss: 0.0540, train_accuracy: 98.35, test_loss: 0.4546, test_accuracy: 87.67, time: 12.59




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 87%|████████▋ | 174/200 [38:11<05:34, 12.86s/it]

epoch: 174, train_loss: 0.0565, train_accuracy: 98.25, test_loss: 0.4260, test_accuracy: 88.44, time: 13.09




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 88%|████████▊ | 175/200 [38:25<05:31, 13.25s/it]

epoch: 175, train_loss: 0.0548, train_accuracy: 98.40, test_loss: 0.4299, test_accuracy: 88.40, time: 14.19




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 88%|████████▊ | 176/200 [38:39<05:23, 13.47s/it]

epoch: 176, train_loss: 0.0574, train_accuracy: 98.22, test_loss: 0.4301, test_accuracy: 88.26, time: 13.97




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 88%|████████▊ | 177/200 [38:52<05:04, 13.22s/it]

epoch: 177, train_loss: 0.0549, train_accuracy: 98.30, test_loss: 0.4255, test_accuracy: 88.77, time: 12.65




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 89%|████████▉ | 178/200 [39:05<04:48, 13.10s/it]

epoch: 178, train_loss: 0.0565, train_accuracy: 98.27, test_loss: 0.4328, test_accuracy: 88.23, time: 12.80




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 90%|████████▉ | 179/200 [39:17<04:33, 13.02s/it]

epoch: 179, train_loss: 0.0562, train_accuracy: 98.30, test_loss: 0.4431, test_accuracy: 87.69, time: 12.86




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 90%|█████████ | 180/200 [39:31<04:24, 13.22s/it]

epoch: 180, train_loss: 0.0545, train_accuracy: 98.37, test_loss: 0.4329, test_accuracy: 88.40, time: 13.66




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 90%|█████████ | 181/200 [39:44<04:08, 13.09s/it]

epoch: 181, train_loss: 0.0549, train_accuracy: 98.30, test_loss: 0.4499, test_accuracy: 87.76, time: 12.78




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 91%|█████████ | 182/200 [39:57<03:57, 13.22s/it]

epoch: 182, train_loss: 0.0542, train_accuracy: 98.33, test_loss: 0.4402, test_accuracy: 87.95, time: 13.53




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 92%|█████████▏| 183/200 [40:11<03:45, 13.25s/it]

epoch: 183, train_loss: 0.0551, train_accuracy: 98.34, test_loss: 0.4408, test_accuracy: 88.00, time: 13.32




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 92%|█████████▏| 184/200 [40:23<03:29, 13.09s/it]

epoch: 184, train_loss: 0.0534, train_accuracy: 98.42, test_loss: 0.4368, test_accuracy: 88.05, time: 12.71




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 92%|█████████▎| 185/200 [40:36<03:15, 13.01s/it]

epoch: 185, train_loss: 0.0522, train_accuracy: 98.42, test_loss: 0.4394, test_accuracy: 87.96, time: 12.83




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 93%|█████████▎| 186/200 [40:49<03:02, 13.00s/it]

epoch: 186, train_loss: 0.0542, train_accuracy: 98.37, test_loss: 0.4323, test_accuracy: 88.42, time: 12.98




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 94%|█████████▎| 187/200 [41:02<02:48, 12.93s/it]

epoch: 187, train_loss: 0.0524, train_accuracy: 98.41, test_loss: 0.4551, test_accuracy: 87.83, time: 12.76




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 94%|█████████▍| 188/200 [41:15<02:34, 12.88s/it]

epoch: 188, train_loss: 0.0537, train_accuracy: 98.37, test_loss: 0.4202, test_accuracy: 88.63, time: 12.76




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 94%|█████████▍| 189/200 [41:28<02:21, 12.86s/it]

epoch: 189, train_loss: 0.0549, train_accuracy: 98.30, test_loss: 0.4348, test_accuracy: 88.03, time: 12.82




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 95%|█████████▌| 190/200 [41:41<02:09, 12.99s/it]

epoch: 190, train_loss: 0.0515, train_accuracy: 98.49, test_loss: 0.4343, test_accuracy: 88.39, time: 13.28




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 96%|█████████▌| 191/200 [41:54<01:56, 13.00s/it]

epoch: 191, train_loss: 0.0538, train_accuracy: 98.40, test_loss: 0.4181, test_accuracy: 88.19, time: 13.03




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 96%|█████████▌| 192/200 [42:07<01:44, 13.05s/it]

epoch: 192, train_loss: 0.0553, train_accuracy: 98.38, test_loss: 0.4348, test_accuracy: 88.36, time: 13.16




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 96%|█████████▋| 193/200 [42:20<01:30, 12.93s/it]

epoch: 193, train_loss: 0.0516, train_accuracy: 98.50, test_loss: 0.4280, test_accuracy: 88.40, time: 12.67




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 97%|█████████▋| 194/200 [42:32<01:17, 12.88s/it]

epoch: 194, train_loss: 0.0537, train_accuracy: 98.37, test_loss: 0.4363, test_accuracy: 88.37, time: 12.74




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 98%|█████████▊| 195/200 [42:45<01:03, 12.79s/it]

epoch: 195, train_loss: 0.0542, train_accuracy: 98.38, test_loss: 0.4189, test_accuracy: 88.40, time: 12.60




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 98%|█████████▊| 196/200 [42:58<00:50, 12.75s/it]

epoch: 196, train_loss: 0.0519, train_accuracy: 98.49, test_loss: 0.4394, test_accuracy: 87.95, time: 12.65




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 98%|█████████▊| 197/200 [43:10<00:38, 12.75s/it]

epoch: 197, train_loss: 0.0546, train_accuracy: 98.28, test_loss: 0.4375, test_accuracy: 87.97, time: 12.74




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


 99%|█████████▉| 198/200 [43:24<00:26, 13.10s/it]

epoch: 198, train_loss: 0.0527, train_accuracy: 98.43, test_loss: 0.4436, test_accuracy: 88.00, time: 13.93




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


100%|█████████▉| 199/200 [43:37<00:13, 13.01s/it]

epoch: 199, train_loss: 0.0549, train_accuracy: 98.32, test_loss: 0.4570, test_accuracy: 87.67, time: 12.78




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}




{'_total_sparsity': Array(0.49755895, dtype=float32)}


100%|██████████| 200/200 [43:50<00:00, 13.15s/it]

epoch: 200, train_loss: 0.0536, train_accuracy: 98.37, test_loss: 0.4279, test_accuracy: 88.48, time: 12.65





# One Shot pruning

In [None]:
config.sparsity_config.update_start_step = 0
config.sparsity_config.update_end_step = 0
config.sparsity_config.skip_gradients=True
state = train_and_evaluate(config)

# STE

In [None]:
# STE also supports gradual pruning schedules. 
# Here we train weights with sparse forward pass from the start.
config.sparsity_config.algorithm = 'magnitude_ste'
config.sparsity_config.sparsity = 0.95
config.sparsity_config.update_end_step = 0
config.sparsity_config.update_start_step = 0
config.sparsity_config.dist_type = 'erk'
state = train_and_evaluate(config)

# Global Pruning

In [None]:
config.sparsity_config.algorithm = 'global_magnitude'
config.sparsity_config.update_freq = 10
config.sparsity_config.update_end_step = 1000
config.sparsity_config.update_start_step = 200
config.sparsity_config.sparsity = 0.95
config.sparsity_config.dist_type = 'erk'

state = train_and_evaluate(config)

In [None]:
jaxpruner.summarize_sparsity(state.opt_state.masks)

# Dynamic Sparse Training

In [None]:
config.sparsity_config.algorithm = 'rigl'

config.sparsity_config.update_freq = 10
config.sparsity_config.update_end_step = 1000
config.sparsity_config.update_start_step = 1
config.sparsity_config.sparsity = 0.95
config.sparsity_config.dist_type = 'erk'

state = train_and_evaluate(config)

In [None]:
config.sparsity_config.algorithm = 'set'
state = train_and_evaluate(config)

In [None]:
config.sparsity_config.algorithm = 'static_sparse'

state = train_and_evaluate(config)

# Pruning After Training

In [None]:
config.sparsity_config.algorithm = 'no_prune'
state = train_and_evaluate(config)

In [None]:
config.sparsity_config.algorithm = 'magnitude'
config.sparsity_config.sparsity = 0.9
sparsity_updater = jaxpruner.create_updater_from_config(config.sparsity_config)
pruned_params, _ = sparsity_updater.instant_sparsify(state.params)
print(jaxpruner.summarize_sparsity(pruned_params, only_total_sparsity=True))

In [None]:
_, test_ds = get_datasets()
pruned_state = state.replace(params=pruned_params)
_, _, test_accuracy  = apply_model(pruned_state, test_ds['image'], test_ds['label'])
print(test_accuracy*100)

# N:M sparsity

In [None]:
config.sparsity_config.algorithm = 'magnitude'
config.sparsity_config.sparsity_type = 'nm_1,4'
sparsity_updater = jaxpruner.create_updater_from_config(config.sparsity_config)
pruned_params, masks = sparsity_updater.instant_sparsify(state.params)
print(jaxpruner.summarize_sparsity(pruned_params, only_total_sparsity=True))

In [None]:
masks['Dense_0']['kernel'][0][:16]

In [None]:
_, test_ds = get_datasets()
pruned_state = state.replace(params=pruned_params)
_, _, test_accuracy  = apply_model(pruned_state, test_ds['image'], test_ds['label'])
print(test_accuracy*100)

# Block Sparsity

In [None]:
config.sparsity_config.algorithm = 'magnitude'
config.sparsity_config.sparsity = 0.7
config.sparsity_config.sparsity_type = 'block_2,2'
sparsity_updater = jaxpruner.create_updater_from_config(config.sparsity_config)
pruned_params, masks = sparsity_updater.instant_sparsify(state.params)
print(jaxpruner.summarize_sparsity(pruned_params, only_total_sparsity=True))

In [None]:
masks['Dense_0']['kernel'][:4, :16]

In [None]:
_, test_ds = get_datasets()
pruned_state = state.replace(params=pruned_params)
_, _, test_accuracy  = apply_model(pruned_state, test_ds['image'], test_ds['label'])
print(test_accuracy*100)

In [11]:
from typing import Any
from collections import defaultdict
from flax.training import train_state, checkpoints
import time



In [None]:
class TrainerModule:

    def __init__(self,
                 model_name : str,
                 model_class : nn.Module,
                 model_hparams : dict,
                 optimizer_name : str,
                 optimizer_hparams : dict,
                 exmp_imgs : Any,
                 seed=42):
        """
        Module for summarizing all training functionalities for classification on CIFAR10.

        Inputs:
            model_name - String of the class name, used for logging and saving
            model_class - Class implementing the neural network
            model_hparams - Hyperparameters of the model, used as input to model constructor
            optimizer_name - String of the optimizer name, supporting ['sgd', 'adam', 'adamw']
            optimizer_hparams - Hyperparameters of the optimizer, including learning rate as 'lr'
            exmp_imgs - Example imgs, used as input to initialize the model
            seed - Seed to use in the model initialization
        """
        super().__init__()
        self.model_name = model_name
        self.model_class = model_class
        self.model_hparams = model_hparams
        self.optimizer_name = optimizer_name
        self.optimizer_hparams = optimizer_hparams
        self.seed = seed
        # Create empty model. Note: no parameters yet
        self.model = res20()
        # Prepare logging
        self.log_dir = os.path.join('./', self.model_name)
        self.logger = SummaryWriter(log_dir=self.log_dir)
        # Create jitted training and eval functions
        self.create_functions()
        # Initialize model
        self.init_model(exmp_imgs)

    def create_functions(self):
        # Function to calculate the classification loss and accuracy for a model
        def calculate_loss(params, batch_stats, batch, train):
            imgs, labels = batch
            # Run model. During training, we need to update the BatchNorm statistics.
            outs = self.model.apply({'params': params, 'batch_stats': batch_stats},
                                    imgs,
                                    train=train,
                                    mutable=['batch_stats'] if train else False)
            logits, new_model_state = outs if train else (outs, None)
            loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
            acc = (logits.argmax(axis=-1) == labels).mean()
            return loss, (acc, new_model_state)
        # Training function
        def train_step(state, batch):
            loss_fn = lambda params: calculate_loss(params, state.batch_stats, batch, train=True)
            # Get loss, gradients for loss, and other outputs of loss function
            ret, grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
            loss, acc, new_model_state = ret[0], *ret[1]
            # Update parameters and batch statistics
            state = state.apply_gradients(grads=grads, batch_stats=new_model_state['batch_stats'])
            return state, loss, acc
        # Eval function
        def eval_step(state, batch):
            # Return the accuracy for a single batch
            _, (acc, _) = calculate_loss(state.params, state.batch_stats, batch, train=False)
            return acc
        # jit for efficiency
        self.train_step = jax.jit(train_step)
        self.eval_step = jax.jit(eval_step)

    def init_model(self, exmp_imgs):
        # Initialize model
        init_rng = jax.random.PRNGKey(self.seed)
        variables = self.model.init(init_rng, exmp_imgs, train=True)
        self.init_params, self.init_batch_stats = variables['params'], variables['batch_stats']
        self.state = None

    def init_optimizer(self, num_epochs, num_steps_per_epoch):
        # Initialize learning rate schedule and optimizer
        if self.optimizer_name.lower() == 'adam':
            opt_class = optax.adam
        elif self.optimizer_name.lower() == 'adamw':
            opt_class = optax.adamw
        elif self.optimizer_name.lower() == 'sgd':
            opt_class = optax.sgd
        else:
            assert False, f'Unknown optimizer "{opt_class}"'
        # We decrease the learning rate by a factor of 0.1 after 60% and 85% of the training
        lr_schedule = optax.piecewise_constant_schedule(
            init_value=self.optimizer_hparams.pop('lr'),
            boundaries_and_scales=
                {int(num_steps_per_epoch*num_epochs*0.6): 0.1,
                 int(num_steps_per_epoch*num_epochs*0.85): 0.1}
        )
        # Clip gradients at max value, and evt. apply weight decay
        # transf = [optax.clip(1.0)]
        if opt_class == optax.sgd and 'weight_decay' in self.optimizer_hparams:  # wd is integrated in adamw
            transf.append(optax.add_decayed_weights(self.optimizer_hparams.pop('weight_decay')))
        optimizer = optax.chain(
            *transf,
            opt_class(lr_schedule, **self.optimizer_hparams)
        )
        # Initialize training state
        self.state = TrainState.create(apply_fn=self.model.apply,
                                       params=self.init_params if self.state is None else self.state.params,
                                       batch_stats=self.init_batch_stats if self.state is None else self.state.batch_stats,
                                       tx=optimizer)

    def train_model(self, train_loader, val_loader, num_epochs=200):
        # Train model for defined number of epochs
        # We first need to create optimizer and the scheduler for the given number of epochs
        self.init_optimizer(num_epochs, len(train_loader))
        # Track best eval accuracy
        best_eval = 0.0
        for epoch_idx in tqdm(range(1, num_epochs+1)):
            self.train_epoch(train_loader, epoch=epoch_idx)
            if epoch_idx % 2 == 0:
                eval_acc = self.eval_model(val_loader)
                self.logger.add_scalar('val/acc', eval_acc, global_step=epoch_idx)
                if eval_acc >= best_eval:
                    best_eval = eval_acc
                    self.save_model(step=epoch_idx)
                self.logger.flush()

    def train_epoch(self, train_loader, epoch):
        # Train model for one epoch, and log avg loss and accuracy
        metrics = defaultdict(list)
        for batch in tqdm(train_loader, desc='Training', leave=False):
            self.state, loss, acc = self.train_step(self.state, batch)
            metrics['loss'].append(loss)
            metrics['acc'].append(acc)
        for key in metrics:
            avg_val = np.stack(jax.device_get(metrics[key])).mean()
            self.logger.add_scalar('train/'+key, avg_val, global_step=epoch)

    def eval_model(self, data_loader):
        # Test model on all images of a data loader and return avg loss
        correct_class, count = 0, 0
        for batch in data_loader:
            acc = self.eval_step(self.state, batch)
            correct_class += acc * batch[0].shape[0]
            count += batch[0].shape[0]
        eval_acc = (correct_class / count).item()
        return eval_acc

    def save_model(self, step=0):
        # Save current model at certain training iteration
        checkpoints.save_checkpoint(ckpt_dir=self.log_dir,
                                    target={'params': self.state.params,
                                            'batch_stats': self.state.batch_stats},
                                    step=step,
                                   overwrite=True)

    def load_model(self, pretrained=False):
        # Load model. We use different checkpoint for pretrained models
        if not pretrained:
            state_dict = checkpoints.restore_checkpoint(ckpt_dir=self.log_dir, target=None)
        else:
            state_dict = checkpoints.restore_checkpoint(ckpt_dir=os.path.join('./', f'{self.model_name}.ckpt'), target=None)
        self.state = TrainState.create(apply_fn=self.model.apply,
                                       params=state_dict['params'],
                                       batch_stats=state_dict['batch_stats'],
                                       tx=self.state.tx if self.state else optax.sgd(0.1)   # Default optimizer
                                      )

    def checkpoint_exists(self):
        # Check whether a pretrained model exist for this autoencoder
        return os.path.isfile(os.path.join('./', f'{self.model_name}.ckpt'))

In [None]:
class TrainState(train_state.TrainState):
    # A simple extension of TrainState to also include batch statistics
    batch_stats: Any

In [None]:
def train_classifier(*args, num_epochs=200, **kwargs):
    # Create a trainer module with specified hyperparameters
    trainer = TrainerModule(*args, **kwargs)
    if not trainer.checkpoint_exists():  # Skip training if pretrained model exists
        trainer.train_model(train_loader, val_loader, num_epochs=num_epochs)
        trainer.load_model()
    else:
        trainer.load_model(pretrained=True)
    # Test trained model
    val_acc = trainer.eval_model(val_loader)
    test_acc = trainer.eval_model(test_loader)
    return trainer, {'val': val_acc, 'test': test_acc}

In [None]:
resnet_kernel_init = nn.initializers.variance_scaling(2.0, mode='fan_out', distribution='normal')

class ResNetBlock(nn.Module):
    act_fn : callable  # Activation function
    c_out : int   # Output feature size
    subsample : bool = False  # If True, we apply a stride inside F

    @nn.compact
    def __call__(self, x, train=True):
        # Network representing F
        z = nn.Conv(self.c_out, kernel_size=(3, 3),
                    strides=(1, 1) if not self.subsample else (2, 2),
                    kernel_init=resnet_kernel_init,
                    use_bias=False)(x)
        z = nn.BatchNorm()(z, use_running_average=not train)
        z = self.act_fn(z)
        z = nn.Conv(self.c_out, kernel_size=(3, 3),
                    kernel_init=resnet_kernel_init,
                    use_bias=False)(z)
        z = nn.BatchNorm()(z, use_running_average=not train)

        if self.subsample:
            x = nn.Conv(self.c_out, kernel_size=(1, 1), strides=(2, 2), kernel_init=resnet_kernel_init)(x)

        x_out = self.act_fn(z + x)
        return x_out

In [None]:
class ResNet(nn.Module):
    num_classes : int
    act_fn : callable
    block_class : nn.Module
    num_blocks : tuple = (3, 3, 3)
    c_hidden : tuple = (16, 32, 64)

    @nn.compact
    def __call__(self, x, train=True):
        # A first convolution on the original image to scale up the channel size
        x = nn.Conv(self.c_hidden[0], kernel_size=(3, 3), kernel_init=resnet_kernel_init, use_bias=False)(x)
        if self.block_class == ResNetBlock:  # If pre-activation block, we do not apply non-linearities yet
            x = nn.BatchNorm()(x, use_running_average=not train)
            x = self.act_fn(x)

        # Creating the ResNet blocks
        for block_idx, block_count in enumerate(self.num_blocks):
            for bc in range(block_count):
                # Subsample the first block of each group, except the very first one.
                subsample = (bc == 0 and block_idx > 0)
                # ResNet block
                x = self.block_class(c_out=self.c_hidden[block_idx],
                                     act_fn=self.act_fn,
                                     subsample=subsample)(x, train=train)

        # Mapping to classification output
        x = x.mean(axis=(1, 2))
        x = nn.Dense(self.num_classes)(x)
        return x

In [None]:
resnet_trainer, resnet_results = train_classifier(model_name="ResNet",
                                                  model_class=ResNet,
                                                  model_hparams={"num_classes": 10,
                                                                 "c_hidden": (16, 32, 64),
                                                                 "num_blocks": (3, 3, 3),
                                                                 "act_fn": nn.relu,
                                                                 "block_class": ResNetBlock},
                                                  optimizer_name="SGD",
                                                  optimizer_hparams={"lr": 0.1,
                                                                     "momentum": 0.9,
                                                                     "weight_decay": 1e-4},
                                                  exmp_imgs=jax.device_put(
                                                      next(iter(train_loader))[0]),
                                                  num_epochs=200)


In [None]:
resnet_results