# Loading libraries

In [1]:
# !pip install porespy
# !pip install pypardiso

In [1]:
import os
# os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,4' 
# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".XX"
# os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"

import jax
from jax import lax, random, config, numpy as jnp
config.update('jax_enable_x64', True)
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental import mesh_utils
from jax.lax import with_sharding_constraint

import flax
from flax import struct, traverse_util, linen as nn
from flax.core import freeze, unfreeze
from flax.training import train_state, checkpoints, orbax_utils
from flax.training.train_state import TrainState

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.animation as animation
from matplotlib import cm

import optax
import orbax
import porespy as ps
import numpy as np
import pickle
import shutil
import time
import datetime
import tensorflow as tf
import tensorflow_datasets as tfds
from functools import partial
from scipy import ndimage
from typing import Callable, Any
from pprint import pprint
from tqdm import tqdm

import telegram
import asyncio
from datetime import datetime
from tqdm.notebook import trange, tqdm


TOKEN = '6740952693:AAFOUwNFVu2O3Bpf7nlKwIlDzyNaarN7Fl8'
CHAT_ID = '5110804803'

bot = telegram.Bot(TOKEN)



# Hyperparams

In [2]:
# Model's hyperparams
num_epochs = 10 # 50
batch_size = 100
nonlinearity = nn.relu
lrbatch = 2500     # lrband 9500~10000 -> BooM!

# Plotting's hyperparams
mnmx = [-3, 6, -3, 6]
resolution = 128   # 1024
figsize = (8, 8)
dpi = 100

# Dataset loading

In [3]:
train_ds = tfds.load('mnist', split='train')
test_ds = tfds.load('mnist', split='test')

def data_normalize(ds):
    return ds.map(lambda sample: {
        'image': tf.cast(sample['image'], tf.float32) / 256.,
        'label': sample['label']
    })

train_ds = data_normalize(train_ds)
test_ds = data_normalize(test_ds)

# train_ds = train_ds.repeat(num_epochs).shuffle(1024).batch(batch_size).prefetch(1)
train_ds = train_ds.shuffle(1024).batch(batch_size).prefetch(1)
test_ds = test_ds.shuffle(1024).batch(batch_size).prefetch(1)


# train_ds = train_ds.repeat(100).shuffle(1024).batch(32).prefetch(1).take(1000)

2024-03-24 14:59:47.451083: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1960] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


# Define the Sharding

In [4]:
P = PartitionSpec
mesh = Mesh(mesh_utils.create_device_mesh((2, 2)), axis_names=('x', 'y'))

def with_mesh(f):
    def wrapper(*args, **kwargs):
        with mesh:
            return f(*args, **kwargs)
    return wrapper

# Build the model

In [5]:
# The grid of learning_rates
mn1, mx1, mn2, mx2 = mnmx
gg1 = jnp.logspace(mn1, mx1, resolution)
gg2 = jnp.logspace(mn2, mx2, resolution)
lr0, lr1 = jnp.meshgrid(gg2, gg1)
lrband = jnp.stack([lr0.ravel(), lr1.ravel()], axis=-1)

In [6]:
class DNN(nn.Module):

    width = 28*28
    use_bias = True
    act_fn: Callable

    @nn.compact
    def __call__(self, x):
        # print(x.shape)
        x = x.reshape((x.shape[0], -1))     # Normal mode
        # x = x.reshape((x.shape[0] * x.shape[1], ))     # Vmap mode

        x = nn.Dense(self.width, use_bias=False)(x)
        x = self.act_fn(x)
        x = nn.Dense(10, use_bias=False)(x)
        x = self.act_fn(x)
        # x = nn.BatchNorm(use_running_average=not train)(x)
        return x
        # return x / jnp.sqrt(self.width)


# 'act_fn' can be 'nn.relu'(ReLU), 'nn.activation.tanh'(tanh) or 'lambda x: x'(identity).
dnn = DNN(act_fn=nonlinearity)

x = jnp.zeros((batch_size, 28, 28, 1))
rng = jax.random.PRNGKey(42)
# rng = jax.random.split(rng, 2)
# variables = dnn.init(rng, x, train=False)
variables = dnn.init(rng, x)
pprint(jax.tree_map(jnp.shape, variables))



{'params': {'Dense_0': {'kernel': (784, 784)},
            'Dense_1': {'kernel': (784, 10)}}}


In [7]:
# Just one img
@jax.jit
def train_step(state, lrband, batch):

    '''MiniBatch are needed soon.'''
    # 'Phase_space' will fixed as 'param_vs_lr'.
    state = state.replace(params={
        'Dense_0': {'kernel': jnp.array(state.params['Dense_0']['kernel']) + jnp.array(lrband[0])},
        'Dense_1': {'kernel': state.params['Dense_1']['kernel']}
    })  # new theta

    @jax.jit
    def loss_fn(params):

        logits = state.apply_fn(
            {'params': state.params},
            x=batch['image']
        )

        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits=logits, labels=batch['label']
        ).mean()    # mean for 10-label

        return loss, logits
    
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True, allow_int=True)
    (loss, logits), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    # metrics = {
    #     'loss': loss.mean(),
    #     'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label'])
    # }
    
    # loss = jnp.asarray(loss)
    return state, loss.mean()

@jax.jit
def eval_step(state, batch):
    logits = state.apply_fn(
        {'params': state.params, 'batch_stats': state.batch_stats},
        x=batch['image'],
        train=False
    )
    loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits, labels=batch['label']
    ).mean()
    return state, loss.mean()

@partial(jax.jit, static_argnums=(0,))
@partial(jax.vmap, in_axes=(None, 0))
def ready_to_train(num_epochs, lrband):

    # Initialize
    variables = dnn.init(jax.random.PRNGKey(int(time.time())), jnp.ones((batch_size, 28, 28, 1)))
    state = TrainState.create(      # The states are vmapped by lrband.
        apply_fn=dnn.apply,
        params=variables['params'],
        tx=optax.sgd(lrband[1])     # learning rates, applied 'hparams_f'
        )

    # Training session
    loss_archive = []
    # for epoch in trange(num_epochs, desc='Epochs'):
    #     loss_batch = 0.
    #     for batch in tqdm(train_ds.as_numpy_iterator(), desc='Iters', total=train_ds.cardinality().numpy()):
    #         state, loss = train_step(state, lrband, batch)
    #         loss_batch += loss
    #     loss_batch /= train_ds.cardinality().numpy()
    #     loss_archive.append(loss_batch)
    for epoch in range(num_epochs):
        loss_batch = 0.
        for batch in train_ds.as_numpy_iterator():
            state, loss = train_step(state, lrband, batch)
            loss_batch += loss
        loss_batch /= train_ds.cardinality().numpy()
        loss_archive.append(loss_batch)
        
    # Test session
    testloss_archive = []
    

    # return convergence_measure_vmap(jnp.stack(loss_archive, axis=-1))
    return jnp.array(loss_archive)


@jax.jit
@partial(jax.vmap, in_axes=(0,), out_axes=0)
def convergence_measure_vmap(v, max_val=1e6):
    '''v is loss.'''

    fin = jnp.isfinite(v)
    v = v*fin + max_val*(1-fin)

    v /= v[0]
    exceeds = (v > max_val)
    v = v*(1-exceeds) + max_val*exceeds

    converged = (jnp.mean(v[-20:]) < 1)
    return jnp.where(converged, -jnp.sum(v), jnp.sum(1/v))


def train(num_epochs, lrband, lrbatch=100):
    bs = lrband.shape[0]

    if bs > lrbatch:
        print(f"[WARNING!!] The band of learning rate is too long. It will be cut. Current length: {bs} -> {bs//2}")
        return jnp.concatenate(
            (jnp.array(train(num_epochs, lrband[:bs//2], lrbatch)),
            jnp.array(train(num_epochs, lrband[bs//2:], lrbatch))),
        axis=0)
    return ready_to_train(num_epochs, lrband)


def gen_img(saveas=None):
    losses = train(num_epochs, lrband, lrbatch)
    V = convergence_measure_vmap(jnp.stack(losses, axis=0))
    V = V.reshape((resolution, resolution))
    if saveas != None:
        jnp.save(saveas, V)
    return V

In [8]:
# losses = train(num_epochs, lrband, lrbatch)
# V = convergence_measure_vmap(jnp.stack(losses, axis=-1))
# V.reshape((resolution, resolution))


# train_step --> no problem
# variables = dnn.init(jax.random.PRNGKey(int(time.time())), jnp.ones((1, 28, 28, 1)))
# state = TrainState.create(      # The states are vmapped by lrband.
#     apply_fn=dnn.apply,
#     params=variables['params'],
#     tx=optax.sgd(lrband[0, 1])     # learning rates, applied 'hparams_f'
#     )

# for batch in train_ds.as_numpy_iterator():
#     break
# state, loss1 = train_step(state, lrband[0, :], batch)
# state, loss2 = train_step(state, lrband[0, :], batch)

# print((loss1 + loss2) / 2)

# ready_to_train --> no problem, but need to shrink 'lrbatch' threshold
# v = ready_to_train(2, lrband[:4, :])
# jnp.stack(v, axis=-1)

# train --> lrbatch: 2500 is safe! But over 5000, BooM!
# lrbatch = 2500
# v = train(2, lrband, lrbatch=lrbatch)

recoding_time = datetime.now().strftime("%Y%m%d_%H%M%S")
saveas = f'./output/npy/lossmap_DNN_epoch{num_epochs}_batch{batch_size}_actFN{nonlinearity.__name__}_resolution{resolution}_dpi{dpi}_time.npy'
V = gen_img(saveas=saveas)

from datetime import datetime
recoding_time = datetime.now().strftime("%Y.%m.%d %H:%M:%S")
msg = f"[{recoding_time}]\nLossmap was generated and saved.\nSaved as: {saveas}"
await bot.sendMessage(chat_id=CHAT_ID, text=msg)



2024-03-24 19:45:57.066085: E external/xla/xla/service/slow_operation_alarm.cc:65] 
********************************
[Compiling module jit_ready_to_train] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************


## Calculate the Fractal Dimension

In [8]:
def extract_edges(X):
    """
    define edges as sign changes in the scalar representing convergence or
    divergence rate -- on one side of the edge training converges,
    while on the other side of the edge training diverges
    """

    Y = jnp.stack((X[1:,1:], X[:-1,1:], X[1:,:-1], X[:-1,:-1]), axis=-1)
    Z = jnp.sign(jnp.max(Y, axis=-1)*jnp.min(Y, axis=-1))
    return Z<0

def estimate_fractal_dimension(hist_video, show_plot=True):
    edges = [extract_edges(U[0]) for U in hist_video]
    box_counts = [ps.metrics.boxcount(U) for U in edges]
    all_images = np.concatenate([bc.slope for bc in box_counts])

    if show_plot:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
        ax1.set_yscale('log')
        ax1.set_xscale('log')
        ax1.set_xlabel('box edge length')
        ax1.set_ylabel('number of boxes spanning phases')
        ax2.set_xlabel('box edge length')
        ax2.set_ylabel('image')
        ax2.set_xscale('log')

        for bc in box_counts:
            ax1.plot(bc.size, bc.count,'-o')
            ax2.plot(bc.size, bc.slope,'-o');

    mfd = np.median(all_images)
    print(f'median fractal dimension estimate {mfd}')

    return mfd

## Image restretch

In [9]:
def cdf_img(x, x_ref, buffer=0.25):
    """
    rescale x, relative to x_ref (x_ref is often the same as x), to achieve a uniform
    distribution over values with positive and negative intensities, but also to
    preserve the sign of x. This makes for a visualization that shows more
    structure.
    """

    u = jnp.sort(x_ref.ravel())

    num_neg = jnp.sum(u<0)
    num_nonneg = u.shape[0] - num_neg
    v = jnp.concatenate((jnp.linspace(-1,-buffer,num_neg), jnp.linspace(buffer,1,num_nonneg)), axis=0)

    y = jnp.interp(x, u, v)
    return -y

# Interactive Figure

In [10]:
def truncate_sci_notation(numbers):
    """
    keeping enough significant digits that the
    numbers disagree in four digits
    """

    # Convert numbers to scientific notation
    n1_sci, n2_sci = "{:.15e}".format(numbers[0]), "{:.15e}".format(numbers[1])

    # Extract the significant parts and exponents
    sig_n1, exp_n1 = n1_sci.split('e')
    sig_n2, exp_n2 = n2_sci.split('e')

    # Find the first position at which they disagree
    min_len = min(len(sig_n1), len(sig_n2))
    truncate_index = min_len

    for i in range(min_len):
        if (sig_n1[i] != sig_n2[i]) or (exp_n1 != exp_n2):
            # +4 accounts for 4 digits after the first disagreement
            truncate_index = i + 4
            if i == 0:
                truncate_index += 1 # Account for decimal point
            break

    exp_n1 = exp_n1[0] + exp_n1[2]
    exp_n2 = exp_n2[0] + exp_n2[2]
    if (exp_n1 == "+00") and (exp_n2 == "+00"):
        # don't bother with scientific notation if exponent is 0
        return [sig_n1[:truncate_index], sig_n2[:truncate_index]]

    # Truncate and reconstruct the scientific notation
    truncated_n1 = "{}e{}".format(sig_n1[:truncate_index], exp_n1)
    truncated_n2 = "{}e{}".format(sig_n2[:truncate_index], exp_n2)

    return [truncated_n1, truncated_n2]

def tickslabels(mnmx):
    return mnmx, truncate_sci_notation(10.**np.array(mnmx))

In [11]:
cids = []
click_event = [None]

def onclick(event):
    click_event[0] = (event.xdata, event.ydata)

def onrelease(event, fig, im, rect, mnmx, img, recalculate_image=True):
    if click_event[0] is None:
        return

    e0 = [click_event[0][0], event.xdata]
    e1 = [click_event[0][1], event.ydata]

    for v in e0+e1:
        if v is None:
            return

    newmnmx = [np.min(e1), np.max(e1), np.min(e0), np.max(e0)]

    min_w = (mnmx[1] - mnmx[0])/20
    if newmnmx[1] - newmnmx[0] < min_w:
        c = (newmnmx[1] + newmnmx[0])/2.
        newmnmx[0] = c - min_w/2
        newmnmx[1] = c + min_w/2
    min_w = (mnmx[3] - mnmx[2])/20
    if newmnmx[1] - newmnmx[0] < min_w:
        c = (newmnmx[3] + newmnmx[2])/2.
        newmnmx[2] = c - min_w/2
        newmnmx[3] = c + min_w/2

    for v in newmnmx:
        if v is None:
            return
    plot_img(img, mnmx, newmnmx, fig=fig, im=im, rect=rect)
    plt.draw()

    if recalculate_image:
        click_event[0] = None
        mnmx = newmnmx
        img = gen_img(mnmx)
        plot_img(img, mnmx, None, fig=fig, im=im, rect=rect)

def plot_img(image, mnmx, newmnmx=None, fig=None, im=None, rect=None,
            handler=True, savename=None,
            reference_scale=None,
            cmap='Spectral',
            title=""
            ):
    mn1, mx1, mn2, mx2 = mnmx

    if reference_scale is None:
        reference_scale = image

    image = cdf_img(image, reference_scale)

    ax1 = None
    if fig is None:
        fig, (ax1) = plt.subplots(figsize=figsize, dpi=dpi)
        im = ax1.imshow(image,
                    extent=[mn2, mx2, mn1, mx1],
                    origin='lower',
                    vmin=-1, vmax=1,
                    cmap=cmap,
                    aspect='auto',
                    interpolation='nearest'
                    )
    if title is None:
        title = f'Trainability dependence on parameter initialization and learning rate \n1 hidden layer, {nonlinearity}, batch size={batch_size}'
    if not title == "":
        plt.title(title)
    ax1.set_ylabel('Learning rate')
    ax1.set_xlabel('Input layer weight offset')

    rect = patches.Rectangle((mn2, mn1), mx2-mn2, mx1-mn1, linewidth=1, edgecolor='r', facecolor='none')
    ax1.add_patch(rect)

    im.set_extent([mn2, mx2, mn1, mx1])
    im.set_data(image)

    # Set the new tick positions on the x-axis
    aaxx = plt.gca()
    aaxx.set_xticks(*tickslabels([mn2, mx2]))
    aaxx.set_yticks(*tickslabels([mn1, mx1]), rotation=90)

    labels = aaxx.get_xticklabels()
    labels[0].set_horizontalalignment('left')
    labels[1].set_horizontalalignment('right')
    labels = aaxx.get_yticklabels()
    labels[0].set_verticalalignment('bottom')
    labels[1].set_verticalalignment('top')

    if handler and (newmnmx is None):
        image_history.append((image, mnmx))

    if newmnmx:
        mn1, mx1, mn2, mx2 = newmnmx
    rect.set_xy((mn2, mn1))
    rect.set_width(mx2-mn2)
    rect.set_height(mx1-mn1)

    if handler:
        while len(cids) > 0:
            fig.canvas.mpl_disconnect(cids.pop())

        def onrelease_partial(event):
            return onrelease(event, fig, im, rect, mnmx, img)
        def onmotion_partial(event):
            return onrelease(event, fig, im, rect, mnmx, img, recalculate_image=False)

        cids.append(fig.canvas.mpl_connect('button_press_event', onclick))
        cids.append(fig.canvas.mpl_connect('button_release_event', onrelease_partial))
        # cids.append(fig.canvas.mpl_connect('motion_notify_event', onmotion_partial))

    plt.tight_layout()

    plt.draw()

    if savename:
        plt.savefig(savename)

    return fig, ax1, im

# Fractal Zoom

In [12]:
def zoom_out_sequence(hist_final, growth_factor=2., max_scale=6):
    """
    generate a sequence of (image, bounds) zooming out from the (image, bounds) in hist_final
    """

    image, mnmx = hist_final

    cT = np.array([(mnmx[0] + mnmx[1])/2., (mnmx[2] + mnmx[3])/2.])
    wT = np.array([mnmx[1] - mnmx[0], mnmx[3] - mnmx[2]])

    hist = [(image, mnmx)]
    w_scale = 1.
    while np.min(wT * w_scale) < max_scale:
        w_scale *= 2
        mnmx = [
            cT[0] - w_scale * wT[0]/2.,
            cT[0] + w_scale * wT[0]/2.,
            cT[1] - w_scale * wT[1]/2.,
            cT[1] + w_scale * wT[1]/2.,
        ]
        hist.insert(0, (np.zeros((2,2)), mnmx))

    return hist

def increase_resolution(history, target_res):
    """
    Increase the resolution of images of a fractal landscape that we've already
    generated.

    Find the first entry in history with resolution below target_res, and increase
    its resolution. If all images are already at least the target resolution,
    return False.
    """

    new_h = []
    for ii in range(len(history)):
        h = history[ii]
        image, mnmx = h
        if image.shape[0] < target_res:
            current_time = datetime.datetime.now()
            print( f"increasing resolution of {ii} / {len(history)} at {current_time}, current resolution is {image.shape}")
            image = gen_img(mnmx, resolution=target_res)
            history[ii] = (image, mnmx)
            return True
    return False


In [13]:
def interpolate_history(hist1, hist2, alpha):
    """
    get the mnmx (hyperparameter bounding box) value for a fraction alpha between
    two images
    """

    _, mnmx1 = hist1
    _, mnmx2 = hist2

    if alpha == 0:
        # avoid NaNs on very last frame
        return mnmx1

    w1 = np.array([mnmx1[1] - mnmx1[0], mnmx1[3] - mnmx1[2]])
    w2 = np.array([mnmx2[1] - mnmx2[0], mnmx2[3] - mnmx2[2]])
    c1 = np.array([(mnmx1[0] + mnmx1[1])/2, (mnmx1[2] + mnmx1[3])/2])
    c2 = np.array([(mnmx2[0] + mnmx2[1])/2, (mnmx2[2] + mnmx2[3])/2])

    gamma = np.exp((1-alpha)*0 + alpha*np.log(w2/w1))

    # ct = cstar + (c1 - cstar)*gamma
    # c1 = cstar + (c1 - cstar)*1
    # c2 = cstar + (c1 - cstar)*w2/w1
    cstar = (c2 - c1*w2/w1) / (1 - w2 / w1)

    ct = cstar + (c1 - cstar)*gamma
    hwt = gamma*w1

    return [ct[0] - hwt[0]/2, ct[0] + hwt[0]/2, ct[1] - hwt[1]/2, ct[1] + hwt[1]/2]


def em(extent_rev):
    return [extent_rev[2], extent_rev[3], extent_rev[0], extent_rev[1]]

def make_animator(history, timesteps_per_transition=60, reference_scale=None, cmap='Spectral'):

    fig, ax, im1 = plot_img(history[0][0], history[0][1], newmnmx=None,
                            handler=False, reference_scale=reference_scale, cmap=cmap)

    im2 = ax.imshow(
        jnp.zeros_like(history[1][0]), extent=em(history[1][1]), origin='lower',
        vmin = -1, vmax = 1,
        cmap=cmap,
        aspect='auto',
        interpolation='nearest'
        )

    im3 = ax.imshow(
        jnp.zeros_like(history[1][0]), extent=em(history[1][1]), origin='lower',
        vmin = -1, vmax = 1,
        cmap=cmap,
        aspect='auto',
        interpolation='nearest'
        )

    def animate(n):
        hist_index = n // timesteps_per_transition
        alpha = (n % timesteps_per_transition) / timesteps_per_transition

        hist1 = history[hist_index]
        if hist_index >= len(history)-1:
            hist2 = hist1 # very last frame
        else:
            hist2 = history[hist_index+1]
        if hist_index >= len(history)-2:
            hist3 = hist2 # very last frame
        else:
            hist3 = history[hist_index+2]

        lims = interpolate_history(hist1, hist2, alpha)

        # interpolation scheme for image restretch / colormap
        alpha_area = jnp.sin(alpha*np.pi/2)**2

        print(f'frame {n} / {timesteps_per_transition*len(history)}, zoom step {hist_index} / {len(history)}', end='\r', flush=True)

        img_1 = (1-alpha_area)*cdf_img(hist1[0], hist1[0]) + alpha_area*cdf_img(hist1[0], hist2[0])
        img_2 = (1-alpha_area)*cdf_img(hist2[0], hist1[0]) + alpha_area*cdf_img(hist2[0], hist2[0])
        img_3 = (1-alpha_area)*cdf_img(hist3[0], hist1[0]) + alpha_area*cdf_img(hist3[0], hist2[0])

        im1.set_data(img_1)
        im1.set_extent(em(hist1[1]))
        im2.set_data(img_2)
        im2.set_extent(em(hist2[1]))
        im3.set_data(img_3)
        im3.set_extent(em(hist3[1]))
        im3.set_alpha(alpha)

        ax.set_ylim(lims[0], lims[1])
        ax.set_xlim(lims[2], lims[3])

        # Set the new tick positions
        ax.set_xticks(*tickslabels([lims[2], lims[3]]))
        ax.set_yticks(*tickslabels([lims[0], lims[1]]), rotation=90)

        labels = ax.get_xticklabels()
        labels[0].set_horizontalalignment('left')
        labels[1].set_horizontalalignment('right')
        labels = ax.get_yticklabels()
        labels[0].set_verticalalignment('bottom')
        labels[1].set_verticalalignment('top')

        return fig,

    anim = animation.FuncAnimation(fig,animate,frames=timesteps_per_transition*(len(history)-1)+1, repeat=False)
    return anim

# Generate the Images & Movies

In [14]:
image_history = []

In [15]:
plt.close('all')
plt.ion()

img = gen_img()
plot_img(img, mnmx, None)
plt.show()



  1%|          | 12/1000 [10:06<13:52:23, 50.55s/it]


KeyboardInterrupt: 

# Build the Model (Batch inputs)

In [68]:
class DNN(nn.Module):

    width = 28*28
    use_bias = True
    act_fn: Callable

    @nn.compact
    def __call__(self, x, train=True):
        x = x.reshape((x.shape[0], -1))     # Normal mode
        x = nn.Dense(self.width, use_bias=False)(x)
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = self.act_fn(x)
        x = nn.Dense(10, use_bias=False)(x)
        return x
        # return x / jnp.sqrt(self.width)


# 'act_fn' can be 'nn.relu'(ReLU), 'nn.activation.tanh'(tanh) or 'lambda x: x'(identity).
dnn = DNN(act_fn=nn.relu)

x = jnp.zeros((batch_size, 28, 28, 1))
rng = jax.random.PRNGKey(42)
variables = dnn.init(rng, x, train=False)
pprint(jax.tree_map(jnp.shape, variables))



{'batch_stats': {'BatchNorm_0': {'mean': (784,), 'var': (784,)}},
 'params': {'BatchNorm_0': {'bias': (784,), 'scale': (784,)},
            'Dense_0': {'kernel': (784, 784)},
            'Dense_1': {'kernel': (784, 10)}}}


In [76]:
# Just one batch
class TrainState(train_state.TrainState):
    batch_stats: Any

def train_step(state, lrband, batch):

    '''MiniBatch are needed soon.'''
    # 'Phase_space' will fixed as 'param_vs_lr'.
    state = state.replace(params={
        'BatchNorm_0': {
            'bias': state.params['BatchNorm_0']['bias'],
            'scale': state.params['BatchNorm_0']['scale']
            },
        'Dense_0': {'kernel': jnp.array(state.params['Dense_0']['kernel']) + jnp.array(lrband[0])},
        'Dense_1': {'kernel': state.params['Dense_1']['kernel']}
    })  # new theta

    def loss_fn(params):

        logits, updates = state.apply_fn(
            {'params': params, 'batch_stats': state.batch_stats},
            x=batch['image'], train=True, mutable=['batch_stats']
        )

        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits=logits, labels=batch['label']
        ).mean()    # mean for 10-label

        return loss, (logits, updates)

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True, allow_int=True)
    (loss, (logits, updates)), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    state = state.replace(batch_stats=updates['batch_stats'])
    # metrics = {
    #     'loss': loss.mean(),
    #     'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label'])
    # }
    return state, loss.mean()   # mean for tiled batch


@partial(jax.vmap, in_axes=(None, 0, None))
def ready_to_train(num_epochs, lrband, batch):
    # Initialize
    variables = dnn.init(jax.random.PRNGKey(42), jnp.ones((batch_size, 28, 28, 1)))
    # The states are vmapped by lrband.
    state = TrainState.create(
        apply_fn=dnn.apply,
        params=variables['params'],
        batch_stats=variables['batch_stats'],     ###############################
        tx=optax.sgd(lrband[1])     # learning rates, applied 'hparams_f'
        )

    loss_archive = []
    for _ in range(num_epochs):
        state, loss = train_step(state, lrband, batch)
        loss_archive.append(loss)

    # return convergence_measure_vmap(jnp.stack(loss_archive, axis=-1))
    return loss_archive


@jax.jit
@partial(jax.vmap, in_axes=(0,), out_axes=0)
def convergence_measure_vmap(v, max_val=1e6):
    '''v is loss.'''

    fin = jnp.isfinite(v)
    v = v*fin + max_val*(1-fin)

    v /= v[0]
    exceeds = (v > max_val)
    v = v*(1-exceeds) + max_val*exceeds

    converged = (jnp.mean(v[-20:]) < 1)
    return jnp.where(converged, -jnp.sum(v), jnp.sum(1/v))


def train(num_epochs, lrband, batch, lrbatch=100):
    bs = lrband.shape[0]
    if bs > lrbatch:
        return jnp.concatenate(
            (train(num_epochs, lrband[:bs//2], batch, lrbatch),
            train(num_epochs, lrband[bs//2:], batch, lrbatch)),
        axis=0)
    return ready_to_train(num_epochs, lrband, batch)




losses = train(100, lrband[:4, :], batch)
V = convergence_measure_vmap(jnp.stack(losses, axis=-1))
# V = V.reshape((resolution, resolution))

In [1]:
# Just one batch
class TrainState(train_state.TrainState):
    batch_stats: Any

@with_mesh
@jax.jit
def train_step(state, lrband, batch):

    '''MiniBatch are needed soon.'''
    # 'Phase_space' will fixed as 'param_vs_lr'.
    state = state.replace(params={
        'BatchNorm_0': {
            'bias': state.params['BatchNorm_0']['bias'],
            'scale': state.params['BatchNorm_0']['scale']
            },
        'Dense_0': {'kernel': jnp.array(state.params['Dense_0']['kernel']) + jnp.array(lrband[0])},
        'Dense_1': {'kernel': state.params['Dense_1']['kernel']}
    })  # new theta

    @with_mesh
    @jax.jit
    def loss_fn(params):

        logits, updates = state.apply_fn(
            {'params': params, 'batch_stats': state.batch_stats},
            x=batch['image'], train=True, mutable=['batch_stats']
        )

        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits=logits, labels=batch['label']
        ).mean()    # mean for 10-label

        return loss, (logits, updates)

    state_spec = nn.get_partition_spec(state)

    batch = jax.lax.with_sharding_constraint(batch, P(('x', 'y')))
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True, allow_int=True)
    (loss, (logits, updates)), grads = grad_fn(state.params)
    grads = jax.lax.with_sharding_constraint(grads, state_spec.params)

    state = state.apply_gradients(grads=grads)
    state = state.replace(batch_stats=updates['batch_stats'])
    state = jax.lax.with_sharding_constraint(state, state_spec)
    # metrics = {
    #     'loss': loss.mean(),
    #     'accuracy': jnp.mean(jnp.argmax(logits, -1) == batch['label'])
    # }
    return state, loss.mean()   # mean for tiled batch


@with_mesh
@partial(jax.jit, static_argnums=(0, ))
def create_state(module, lr):
    variables = module.init(jax.random.PRNGKey(1), jnp.ones((batch_size, 28, 28, 1)), train=False)
    state = TrainState.create(
        apply_fn=module.apply,
        params=variables['params'],
        batch_stats=variables['batch_stats'],
        tx=optax.sgd(lr)
    )
    state = jax.tree_map(jnp.asarray, state)
    state_spec = nn.get_partition_spec(state)
    state = with_sharding_constraint(state, state_spec)
    return state

@with_mesh
@partial(jax.vmap, in_axes=(None, 0, None))
def ready_to_train(num_epochs, lrband, batch):
    # Initialize
    variables = dnn.init(jax.random.PRNGKey(42), jnp.ones((batch_size, 28, 28, 1)))
    # The states are vmapped by lrband.
    # state = TrainState.create(
    #     apply_fn=dnn.apply,
    #     params=variables['params'],
    #     batch_stats=variables['batch_stats'],     ###############################
    #     tx=optax.sgd(lrband[1])     # learning rates, applied 'hparams_f'
    #     )
    state = create_state(dnn, lrband[1])


    loss_archive = []
    for _ in range(num_epochs):
        state, loss = train_step(state, lrband, batch)
        loss_archive.append(loss)

    # return convergence_measure_vmap(jnp.stack(loss_archive, axis=-1))
    return loss_archive

@with_mesh
@jax.jit
@partial(jax.vmap, in_axes=(0,), out_axes=0)
def convergence_measure_vmap(v, max_val=1e6):
    '''v is loss.'''

    fin = jnp.isfinite(v)
    v = v*fin + max_val*(1-fin)

    v /= v[0]
    exceeds = (v > max_val)
    v = v*(1-exceeds) + max_val*exceeds

    converged = (jnp.mean(v[-20:]) < 1)
    return jnp.where(converged, -jnp.sum(v), jnp.sum(1/v))


def train(num_epochs, lrband, batch, lrbatch=100):
    bs = lrband.shape[0]
    if bs > lrbatch:
        return jnp.concatenate(
            (train(num_epochs, lrband[:bs//2], batch, lrbatch),
            train(num_epochs, lrband[bs//2:], batch, lrbatch)),
        axis=0)
    return ready_to_train(num_epochs, lrband, batch)




losses = train(100, lrband[:4, :], batch)
V = convergence_measure_vmap(jnp.stack(losses, axis=-1))
# V = V.reshape((resolution, resolution))

NameError: name 'train_state' is not defined

In [97]:
# https://www.kaggle.com/code/aakashnain/tf-jax-tutorials-part-8-vmap-pmap
create_state(dnn, lrband[0][0])

TrainState(step=Array(0, dtype=int64, weak_type=True), apply_fn=<bound method Module.apply of DNN(
    # attributes
    act_fn = relu
)>, params={'BatchNorm_0': {'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
     

# Initialize the Model

In [23]:
params = variables['params']
batch_stats = variables['batch_stats']
y, updates = DNN.apply(
    {'params': params, 'batch_stats': batch_stats},
    x, mutable=['batch_stats']
)
batch_stats = updates['batch_stats']

# class TrainState(TrainState):
    # batch_stats: Any

@with_mesh
@partial(jax.jit, static_argnums=(0,))
def create_state(module):
    variables = module.init(jax.random.PRNGKey(1), jnp.ones((1, 28, 28, 1)), train=True)
    state = TrainState.create(
        apply_fn=module.apply,
        params=params,
        batch_stats=batch_stats,
        tx=optax.sgd(1e-3)
    )
    state = jax.tree_map(jnp.asarray, state)
    state_spec = nn.get_partition_spec(state)
    state = with_sharding_constraint(state, state_spec)
    return state

state = create_state(DNN)
pprint(jax.tree_map(jnp.shape, state))

TrainState(step=(),
           apply_fn=<bound method Module.apply of DNN(
    # attributes
    act_fn = relu
)>,
           params={'BatchNorm_0': {'bias': (10,), 'scale': (10,)},
                   'Dense_0': {'bias': (784,), 'kernel': (784, 784)},
                   'Dense_1': {'bias': (10,), 'kernel': (784, 10)}},
           tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x79c624207250>, update=<function chain.<locals>.update_fn at 0x79c624206b90>),
           opt_state=(EmptyState(), EmptyState()),
           batch_stats={'BatchNorm_0': {'mean': (10,), 'var': (10,)}})


# Parallelize

In [24]:
# Train step
@with_mesh
@jax.jit
def train_step(state, batch):
    def loss_fn(params):
        logits, updates = state.apply_fn(
            {'params': params, 'batch_stats': batch_stats},
            x=batch['image'], train=True, mutable=['batch_stats']
        )
        loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits, labels=batch['label']
        ).mean()
        return loss, (logits, updates)

    state_spec = nn.get_partition_spec(state)

    batch = jax.lax.with_sharding_constraint(batch, P(('x', 'y')))
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, (logits, updates)), grads = grad_fn(state.params)
    grads = jax.lax.with_sharding_constraint(grads, state_spec.params)

    state = state.apply_gradients(grads=grads)
    state = state.replace(batch_stats=updates['batch_stats'])

    logs = {
        'loss': loss,
        'accuracy': (jnp.argmax(logits, axis=-1) == batch['label']).mean()
    }
    state = jax.lax.with_sharding_constraint(state, state_spec)

    return state, logs

In [25]:
from tqdm import tqdm

state = create_state(DNN)
history = []
total_steps = train_ds.cardinality().numpy()

mn1, mx1, mn2, mx2 = mnmx
gg1 = jnp.logspace(mn1, mx1, resolution)
gg2 = jnp.logspace(mn2, mx2, resolution)
lr0, lr1 = jnp.meshgrid(gg2, gg1)
lr_band = jnp.stack([lr0,ravel, lr1.ravel()], axis=-1)


for epoch in range(num_epochs):
    history_epoch = []
    for epoch, batch in enumerate(tqdm(train_ds.as_numpy_iterator(), total=total_steps)):
        if epoch == 0:
            ...
        state, logs = train_step(state, batch)
        # pprint(state['params'].keys())
        logs = jax.tree_map(np.asarray, logs)
        history_epoch.append(logs)
    hd = {'accuracy': [], 'loss': []}
    hd['accuracy'].append(np.mean([he['accuracy'] for he in history_epoch]))
    hd['loss'].append(np.mean([he['loss'] for he in history_epoch]))
    history.append(hd)

    break

100%|██████████| 3750/3750 [01:25<00:00, 43.77it/s]


In [42]:
state.params['Dense_0']['kernel'].shape

(784, 784)