# Installing packages

In [1]:
# !pip install porespy
# !pip install pypardiso
# !pip install ipympl

# Loading libraries

In [3]:
# GPU settings
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'

# Computation settings
import jax
from jax import lax, random, config, numpy as jnp
config.update('jax_enable_x64', True)   # for double precision, but cannot run convolution!
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental import mesh_utils
from jax.lax import with_sharding_constraint
import optax
import orbax
import porespy as ps
import numpy as np

# Dataset settings
import tensorflow as tf
import tensorflow_datasets as tfds

# Plotting settings
import matplotlib
# from google.colab import output
# output.enable_custom_widget_manager()
%matplotlib ipympl
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.animation as animation
from matplotlib import cm

# ETC
import pickle
import shutil
import time
import datetime
from functools import partial
from typing import Callable, Any
from pprint import pprint
from tqdm.notebook import tqdm

# Hyperparams

In [4]:
# DNN
width = 784
depth = 2
target_dim = 10
num_epochs = 100
nonlinearity = 'relu'

# Datset
batch_size = 100
minibatch_size = None
default_outer_batch_size = 100

# Plotting
phase_space = 'paraminit_vs_lr'
mnmx = [-3, 6, -3, 6]
default_resolution = 1024
dpi = 100
figsize = (8, 8)
interactive_gui = True

In [5]:
def canonical_name():
    """
    turn hyperparameters in the previous cell into a canonical base filename to
    use for this experimental condition
    """
    return f'zoom_sequence_width-{width}_depth-{depth}_datasetparamratio-{dataset_param_multiple}_minibatch-{minibatch_size}_nonlinearity-{nonlinearity}_phasespace-{phase_space}{"_readout-probe_point" if readout == "probe_point" else ""}'

In [7]:
if interactive_gui:
    ## interactive plotting
    # from google.colab import output
    # output.enable_custom_widget_manager()
    %matplotlib ipympl
else:
    matplotlib.use('Agg')

# Modeling

In [8]:
# Inerternal computation part
def net(theta, X):
    '''Compute the node as dot-cross.'''
    X = X.reshape((X.shape[0], -1))
    for W in theta:
        Z = jnp.dot(X, W) / jnp.sqrt(W.shape[0])
        if nonlinearity == 'tanh':
            X = jax.nn.tanh(Z * jnp.sqrt(2))
        elif nonlinearity == 'relu':
            X = jax.nn.relu(Z) * jnp.wqrt(2)
        elif nonlinearity == 'identity':
            X = Z
        else:
            assert False
    return Z / jnp.sqrt(W.shape[0])

def init(rng, width, depth):
    '''Initialize the DNN. Initial weights are defined as rng.'''
    rng = jax.random.split(rng, depth)
    theta = []
    out_width = width
    for i in range(depth):
        if i == depth - 1:
            out_width = target_dim
        W = jax.random.normal(rng[i], (width, out_width))
        theta.append(W)
    return theta

def loss(theta, X, Y):
    Z = net(theta, X)
    return jnp.mean((Z-Y) ** 2)


# Learning rate selection
def hparams_f(hparams, theta):
    lr = []
    for i, t in enumerate(theta):
        lr.append(hparams[i % len(hparams)])
    return lr

################################### CPU ########################################
# # Training step definition
# @partial(jax.jit, donate_argnums=(0, 1))
# @partial(jax.vmap, in_axes=(None, 0, 0, None, None), out_axes=(0, 0))
# def train_step(rng, theta, hparams, X, Y):
#     if phase_space == 'lr_vs_lr':
#         learning_rates = hparams_f(hparams, theta)
#     elif phase_space == 'paraminit_vs_lr':
#         new_theta = []
#         for i, t in enumerate(theta):
#             # First weight will added learning rate. It looks like ont-step updated.
#             # But this time, this learning rate is used once, new learning rate will be used other steps.
#             if i == 0:
#                 t += hparams[0]
#             new_theta += [t]
#         theta = new_theta
#         learning_rates = hparams_f([hparams[1]], theta)
#     else:
#         assert False, f"Invalid phase space '{phase_space}'. Use 'lr_vs_vs' or 'paramint_vs_lr'."

#     # Minibatch division (< batch_size)
#     # You can reduce computing cost if using minibatch.
#     # The loss will be computed using all-in batch, but the grad will be computed using some minibatch.
#     # So, updating can be weird when minibatch is very low.
#     if minibatch_size is None:
#         _loss, _grad = jax.value_and_grad(loss)(theta, X, Y)
#     else:
#         idx = jax.random.randint(rng, (minibatch_size,), 0, X.shape[0])
#         _X, _Y = X[idx], Y[idx]
#         _loss = loss(theta, X, Y)
#         _grad = jax.grad(loss)(theta, _X, _Y)

#     return jax.tree_map(lambda t, g, lr: t - lr * g, theta, _grad, learning_rates), _loss

################################################################################

# Training step definition
# @partial(jax.jit, donate_argnums=(0, 1))
@partial(jax.vmap, in_axes=(None, 0, 0, None, None), out_axes=(0, 0))
def train_step(rng, theta, hparams, X, Y):
    new_theta = []
    for i, t in enumerate(theta):
        # First weight will added learning rate. It looks like ont-step updated.
        # But this time, this learning rate is used once, new learning rate will be used other steps.
        if i == 0:
            t += hparams[0]
        new_theta += [t]
    theta = new_theta
    learning_rates = hparams_f([hparams[1]], theta)

    @partial(jax.pmap, axis_name='num_devices')
    def updates(theta, X, Y, lr):
        _loss, _grad = jax.value_and_grad(loss)(theta, X, Y)
        _grad = jax.lax.pmean(_grad, axis_name='num_devices')
        _loss = jax.lax.pmean(_loss, axis_name='num_devices')
        return jax.tree_map(lambda t, g, lr: t - lr * g, theta, _grad, learning_rates), _loss
    return updates(theta, X, Y, learning_rates)

def eval_step(theta, X, Y):
    _val = net(theta, X)
    _loss = loss(theta, X, Y)
    return _val, _loss

def train_n_test(theta, hparams, X, Y, tX, tY, num_steps, outer_batch_size=None, **kwargs):

    # Random seed and ETC
    rng = jax.random.PRNGKey(42)
    if outer_batch_size is None:
        outer_batch_size = default_outer_batch_size

    # Dataset control
    if (X is None) & (Y is None):
        if "train_ds" in kwargs.keys():
            dataset = kwargs['train_ds']
            dataset = dataset.shuffle(buffer_size=10, seed=42).batch(batch_size).prefetch(1).as_numpy_iterator()
            X = [d['image'] for d in dataset]
            Y = [d['label'] for d in dataset]
    if (tX is None) & (tY is None):
        if "test_ds" in kwargs.keys():
            dataset = kwargs['test_ds']
            dataset = dataset.shuffle(buffer_size=10, seed=42).batch(batch_size).prefetch(1).as_numpy_iterator()
            tX = [d['image'] for d in dataset]
            tY = [d['label'] for d in dataset]
    if (X is None) & (Y is None) & (tX is None) & (tY is None) & (len(kwargs.keys())==0):
        raise Exception("You should give kwargs 'train_ds=..., test_ds=...' if not giving X and Y.")

    # Theta initialization
    if theta is None:
        theta = init(rng, width, depth)

    # Cut off the learning rates as bite size
    bs = hparams.shape[0]
    if bs > outer_batch_size:
        train_loss1, test_loss1 = train_n_test(theta, hparams[:bs//2, ], X, Y, tX, tY, num_steps)
        train_loss2, test_loss2 = train_n_test(theta, hparams[bs//2:, ], X, Y, tX, tY, num_steps)
        return jnp.concatenate((train_loss1, train_loss2), axis=0), jnp.concatenate((test_loss1, test_loss2), axis=0)

    # Train session
    rng = jax.random.split(rng, num_steps)
    loss_archive = []
    _theta = jax.tree_map(lambda u: jnp.tile(u, (bs,) * len(u.shape)), theta)

    for epoch, _rng in tqdm(enumerate(rng), total=num_steps, desc='Epochs', leave=True):
        # Need to divide X and Y as each batch.
        loss_epoch = 0.
        for batch in tqdm(range(len(X)), desc='Batches', leave=False):
            _theta, _loss = train_step(_rng, _theta, hparams, X[epoch], Y[epoch])
            loss_epoch += _loss
        loss_epoch /= len(len(X))
        loss_archive.append(loss_epoch)

    # Test session
    tloss_archive = []
    _, _tloss = eval_step(_theta, tX, tY)
    tloss_archive.append(_tloss)

    return convergence_measure(jnp.stack(loss_archive, axis=-1)), convergence_measure(jnp.stack(tloss_archive, axis=-1))


# Compensator
@jax.jit
@partial(jax.vmap, in_axes=(0,), out_axes=0)
def convergence_measure(v, max_val=1e6):
    fin = jnp.isfinite(v)
    v = v * fin + max_val * (1-fin)
    v /= v[0]
    exceeds = (v > max_val)
    v = v * (1-exceeds) + max_val * exceeds
    converged = (jnp.mean(v[-20:]) < 1)
    return jnp.where(converged, -jnp.sum(v), jnp.sum(1/v))

# Plotting and Measure the Fractal Dimension

In [9]:
# Generate the lossmap
def gen_img(mnmx, resolution=None):
    """
    generate an image of the hyperparameter landscape,
    for a range of hyperparameter values specified by mnmx
    """

    if resolution is None:
        resolution = default_resolution

    mn1, mx1, mn2, mx2 = mnmx
    gg1 = jnp.logspace(mn1, mx1, resolution)
    gg2 = jnp.logspace(mn2, mx2, resolution)
    lr0, lr1 = jnp.meshgrid(gg2, gg1)
    lr = jnp.stack([lr0.ravel(), lr1.ravel()], axis=-1)

    V, tV = train_n_test(
        theta=None,
        hparams=lr,
        X=None, Y=None,
        tX=None, tY=None,
        num_steps=num_epochs,
        train_ds=train_ds,
        test_ds=test_ds
        )

    return V.reshape((resolution, resolution)), tV.reshape((resolution, resolution))


# Measure the fractal dim
def extract_edges(X):
    """
    define edges as sign changes in the scalar representing convergence or
    divergence rate -- on one side of the edge training converges,
    while on the other side of the edge training diverges
    """

    Y = jnp.stack((X[1:,1:], X[:-1,1:], X[1:,:-1], X[:-1,:-1]), axis=-1)
    Z = jnp.sign(jnp.max(Y, axis=-1)*jnp.min(Y, axis=-1))
    return Z<0

def estimate_fractal_dimension(hist_video, show_plot=True):
    edges = [extract_edges(U[0]) for U in hist_video]
    box_counts = [ps.metrics.boxcount(U) for U in edges]
    all_images = np.concatenate([bc.slope for bc in box_counts])

    if show_plot:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
        ax1.set_yscale('log')
        ax1.set_xscale('log')
        ax1.set_xlabel('box edge length')
        ax1.set_ylabel('number of boxes spanning phases')
        ax2.set_xlabel('box edge length')
        ax2.set_ylabel('image')
        ax2.set_xscale('log')

        for bc in box_counts:
            ax1.plot(bc.size, bc.count,'-o')
            ax2.plot(bc.size, bc.slope,'-o');

    mfd = np.median(all_images)
    print(f'median fractal dimension estimate {mfd}')

    return mfd


# Interploating
def cdf_img(x, x_ref, buffer=0.25):
    """
    rescale x, relative to x_ref (x_ref is often the same as x), to achieve a uniform
    distribution over values with positive and negative intensities, but also to
    preserve the sign of x. This makes for a visualization that shows more
    structure.
    """
    u = jnp.sort(x_ref.ravel())
    num_neg = jnp.sum(u<0)
    num_nonneg = u.shape[0] - num_neg
    v = jnp.concatenate((jnp.linspace(-1,-buffer,num_neg), jnp.linspace(buffer,1,num_nonneg)), axis=0)
    y = jnp.interp(x, u, v)
    return -y


# Notation
def truncate_sci_notation(numbers):
    """
    keeping enough significant digits that the
    numbers disagree in four digits
    """

    # Convert numbers to scientific notation
    n1_sci, n2_sci = "{:.15e}".format(numbers[0]), "{:.15e}".format(numbers[1])

    # Extract the significant parts and exponents
    sig_n1, exp_n1 = n1_sci.split('e')
    sig_n2, exp_n2 = n2_sci.split('e')

    # Find the first position at which they disagree
    min_len = min(len(sig_n1), len(sig_n2))
    truncate_index = min_len

    for i in range(min_len):
        if (sig_n1[i] != sig_n2[i]) or (exp_n1 != exp_n2):
            # +4 accounts for 4 digits after the first disagreement
            truncate_index = i + 4
            if i == 0:
                truncate_index += 1 # Account for decimal point
        break

    exp_n1 = exp_n1[0] + exp_n1[2]
    exp_n2 = exp_n2[0] + exp_n2[2]
    if (exp_n1 == "+00") and (exp_n2 == "+00"):
        # don't bother with scientific notation if exponent is 0
        return [sig_n1[:truncate_index], sig_n2[:truncate_index]]

    # Truncate and reconstruct the scientific notation
    truncated_n1 = "{}e{}".format(sig_n1[:truncate_index], exp_n1)
    truncated_n2 = "{}e{}".format(sig_n2[:truncate_index], exp_n2)

    return [truncated_n1, truncated_n2]

def tickslabels(mnmx):
    return mnmx, truncate_sci_notation(10.**np.array(mnmx))

In [10]:
# Interactive plotting
cids = []
click_event = [None]

def onclick(event):
  click_event[0] = (event.xdata, event.ydata)

def onrelease(event, fig, im, rect, mnmx, img, recalculate_image=True):
  if click_event[0] is None:
    return

  e0 = [click_event[0][0], event.xdata]
  e1 = [click_event[0][1], event.ydata]

  for v in e0+e1:
    if v is None:
      return

  newmnmx = [np.min(e1), np.max(e1), np.min(e0), np.max(e0)]

  min_w = (mnmx[1] - mnmx[0])/20
  if newmnmx[1] - newmnmx[0] < min_w:
    c = (newmnmx[1] + newmnmx[0])/2.
    newmnmx[0] = c - min_w/2
    newmnmx[1] = c + min_w/2
  min_w = (mnmx[3] - mnmx[2])/20
  if newmnmx[1] - newmnmx[0] < min_w:
    c = (newmnmx[3] + newmnmx[2])/2.
    newmnmx[2] = c - min_w/2
    newmnmx[3] = c + min_w/2

  for v in newmnmx:
    if v is None:
      return
  plot_img(img, mnmx, newmnmx, fig=fig, im=im, rect=rect)
  plt.draw()

  if recalculate_image:
    click_event[0] = None
    mnmx = newmnmx
    img = gen_img(mnmx)
    plot_img(img, mnmx, None, fig=fig, im=im, rect=rect)

def plot_img(image, mnmx, newmnmx=None, fig=None, im=None, rect=None,
             handler=True, savename=None,
             reference_scale=None,
             cmap='Spectral',
             title=""
             ):
  mn1, mx1, mn2, mx2 = mnmx

  if reference_scale is None:
    reference_scale = image

  image = cdf_img(image, reference_scale)

  ax1 = None
  if fig is None:
    fig, (ax1) = plt.subplots(figsize=figsize, dpi=dpi)
    im = ax1.imshow(image,
                    extent=[mn2, mx2, mn1, mx1],
                    origin='lower',
                    vmin=-1, vmax=1,
                    cmap=cmap,
                    aspect='auto',
                    interpolation='nearest'
                    )
    batch_text = f"{batch_size}" + "" if minibatch_size is None else f"(mini-{minibatch_size})"
    title = f'Trainability dependence on parameter initialization and learning rate\n1 hidden layer, {nonlinearity}, {batch_text}'
    if not title == "":
      plt.title(title)
    if phase_space == 'lr_vs_lr':
      ax1.set_ylabel('Output layer learning rate')
      ax1.set_xlabel('Input layer learning rate')
    elif phase_space == 'paraminit_vs_lr':
      ax1.set_ylabel('Learning rate')
      ax1.set_xlabel('Input layer weight offset')

    rect = patches.Rectangle((mn2, mn1), mx2-mn2, mx1-mn1, linewidth=1, edgecolor='r', facecolor='none')
    ax1.add_patch(rect)

  im.set_extent([mn2, mx2, mn1, mx1])
  im.set_data(image)

  # Set the new tick positions on the x-axis
  aaxx = plt.gca()
  aaxx.set_xticks(*tickslabels([mn2, mx2]))
  aaxx.set_yticks(*tickslabels([mn1, mx1]), rotation=90)

  labels = aaxx.get_xticklabels()
  labels[0].set_horizontalalignment('left')
  labels[1].set_horizontalalignment('right')
  labels = aaxx.get_yticklabels()
  labels[0].set_verticalalignment('bottom')
  labels[1].set_verticalalignment('top')

  if handler and (newmnmx is None):
    image_history.append((image, mnmx))

  if newmnmx:
    mn1, mx1, mn2, mx2 = newmnmx
  rect.set_xy((mn2, mn1))
  rect.set_width(mx2-mn2)
  rect.set_height(mx1-mn1)

  if handler:
    while len(cids) > 0:
      fig.canvas.mpl_disconnect(cids.pop())

    def onrelease_partial(event):
      return onrelease(event, fig, im, rect, mnmx, img)
    def onmotion_partial(event):
      return onrelease(event, fig, im, rect, mnmx, img, recalculate_image=False)

    cids.append(fig.canvas.mpl_connect('button_press_event', onclick))
    cids.append(fig.canvas.mpl_connect('button_release_event', onrelease_partial))
    # cids.append(fig.canvas.mpl_connect('motion_notify_event', onmotion_partial))

  plt.tight_layout()

  plt.draw()

  if savename:
    plt.savefig(savename)

  return fig, ax1, im


# Animating zoom sequences and interpolating between frames
def zoom_out_sequence(hist_final, growth_factor=2., max_scale=6):
  """
  generate a sequence of (image, bounds) zooming out from the (image, bounds) in hist_final
  """

  image, mnmx = hist_final

  cT = np.array([(mnmx[0] + mnmx[1])/2., (mnmx[2] + mnmx[3])/2.])
  wT = np.array([mnmx[1] - mnmx[0], mnmx[3] - mnmx[2]])

  hist = [(image, mnmx)]
  w_scale = 1.
  while np.min(wT * w_scale) < max_scale:
    w_scale *= 2
    mnmx = [
        cT[0] - w_scale * wT[0]/2.,
        cT[0] + w_scale * wT[0]/2.,
        cT[1] - w_scale * wT[1]/2.,
        cT[1] + w_scale * wT[1]/2.,
    ]
    hist.insert(0, (np.zeros((2,2)), mnmx))

  return hist

def increase_resolution(history, target_res):
  """
  Increase the resolution of images of a fractal landscape that we've already
  generated.

  Find the first entry in history with resolution below target_res, and increase
  its resolution. If all images are already at least the target resolution,
  return False.
  """

  new_h = []
  for ii in range(len(history)):
    h = history[ii]
    image, mnmx = h
    if image.shape[0] < target_res:
      current_time = datetime.datetime.now()
      print( f"increasing resolution of {ii} / {len(history)} at {current_time}, current resolution is {image.shape}")
      image = gen_img(mnmx, resolution=target_res)
      history[ii] = (image, mnmx)
      return True
  return False

def interpolate_history(hist1, hist2, alpha):
  """
  get the mnmx (hyperparameter bounding box) value for a fraction alpha between
  two images
  """

  _, mnmx1 = hist1
  _, mnmx2 = hist2

  if alpha == 0:
    # avoid NaNs on very last frame
    return mnmx1

  w1 = np.array([mnmx1[1] - mnmx1[0], mnmx1[3] - mnmx1[2]])
  w2 = np.array([mnmx2[1] - mnmx2[0], mnmx2[3] - mnmx2[2]])
  c1 = np.array([(mnmx1[0] + mnmx1[1])/2, (mnmx1[2] + mnmx1[3])/2])
  c2 = np.array([(mnmx2[0] + mnmx2[1])/2, (mnmx2[2] + mnmx2[3])/2])

  gamma = np.exp((1-alpha)*0 + alpha*np.log(w2/w1))

  # ct = cstar + (c1 - cstar)*gamma
  # c1 = cstar + (c1 - cstar)*1
  # c2 = cstar + (c1 - cstar)*w2/w1
  cstar = (c2 - c1*w2/w1) / (1 - w2 / w1)

  ct = cstar + (c1 - cstar)*gamma
  hwt = gamma*w1

  return [ct[0] - hwt[0]/2, ct[0] + hwt[0]/2, ct[1] - hwt[1]/2, ct[1] + hwt[1]/2]


def em(extent_rev):
  return [extent_rev[2], extent_rev[3], extent_rev[0], extent_rev[1]]

def make_animator(history, timesteps_per_transition=60, reference_scale=None, cmap='Spectral'):

  fig, ax, im1 = plot_img(history[0][0], history[0][1], newmnmx=None,
                          handler=False, reference_scale=reference_scale, cmap=cmap)

  im2 = ax.imshow(
      jnp.zeros_like(history[1][0]), extent=em(history[1][1]), origin='lower',
      vmin = -1, vmax = 1,
      cmap=cmap,
      aspect='auto',
      interpolation='nearest'
      )

  im3 = ax.imshow(
      jnp.zeros_like(history[1][0]), extent=em(history[1][1]), origin='lower',
      vmin = -1, vmax = 1,
      cmap=cmap,
      aspect='auto',
      interpolation='nearest'
      )

  def animate(n):
    hist_index = n // timesteps_per_transition
    alpha = (n % timesteps_per_transition) / timesteps_per_transition

    hist1 = history[hist_index]
    if hist_index >= len(history)-1:
      hist2 = hist1 # very last frame
    else:
      hist2 = history[hist_index+1]
    if hist_index >= len(history)-2:
      hist3 = hist2 # very last frame
    else:
      hist3 = history[hist_index+2]

    lims = interpolate_history(hist1, hist2, alpha)

    # interpolation scheme for image restretch / colormap
    alpha_area = jnp.sin(alpha*np.pi/2)**2

    print(f'frame {n} / {timesteps_per_transition*len(history)}, zoom step {hist_index} / {len(history)}', end='\r', flush=True)

    img_1 = (1-alpha_area)*cdf_img(hist1[0], hist1[0]) + alpha_area*cdf_img(hist1[0], hist2[0])
    img_2 = (1-alpha_area)*cdf_img(hist2[0], hist1[0]) + alpha_area*cdf_img(hist2[0], hist2[0])
    img_3 = (1-alpha_area)*cdf_img(hist3[0], hist1[0]) + alpha_area*cdf_img(hist3[0], hist2[0])

    im1.set_data(img_1)
    im1.set_extent(em(hist1[1]))
    im2.set_data(img_2)
    im2.set_extent(em(hist2[1]))
    im3.set_data(img_3)
    im3.set_extent(em(hist3[1]))
    im3.set_alpha(alpha)

    ax.set_ylim(lims[0], lims[1])
    ax.set_xlim(lims[2], lims[3])

    # Set the new tick positions
    ax.set_xticks(*tickslabels([lims[2], lims[3]]))
    ax.set_yticks(*tickslabels([lims[0], lims[1]]), rotation=90)

    labels = ax.get_xticklabels()
    labels[0].set_horizontalalignment('left')
    labels[1].set_horizontalalignment('right')
    labels = ax.get_yticklabels()
    labels[0].set_verticalalignment('bottom')
    labels[1].set_verticalalignment('top')

    return fig,

  anim = animation.FuncAnimation(fig,animate,frames=timesteps_per_transition*(len(history)-1)+1, repeat=False)
  return anim

# Generating the lossmap

In [11]:
train_ds = tfds.load('mnist', split='train')
test_ds = tfds.load('mnist', split='test')

def data_normalize(ds):
    return ds.map(lambda sample: {
        'image': tf.cast(sample['image'], tf.float32) / 256.,
        'label': sample['label']
    })

train_ds = data_normalize(train_ds)
test_ds = data_normalize(test_ds)

2024-03-26 08:41:31.781176: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1960] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


In [12]:
# Get the lossmap
train_img, test_img = gen_img(mnmx)
jnp.save('./output/npy/train_lossmap.npy', train_img)
jnp.save('./output/npy/test_lossmap.npy', test_img)


Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Batches:   0%|          | 0/600 [00:00<?, ?it/s]

IndexError: list index out of range

## Train lossmap

In [None]:
# Train lossmap
image_history = []
plt.close('all')
plt.ion()
plot_img(train_img, mnmx, None)
plt.show()

## WARNING!!
## After finishing to interact the plot,
## - Should do next-step immediately!!!
## - No plotting before making current frames,

In [None]:
# Test lossmap
plt.close('all')
plt.ion()
plot_img(test_img, mnmx, None)
plt.show()

---

* 내일 할거

telegram mesg

hyperparam rearr.

[4, 2, 3] [1, 2, 3]
