<a href="https://colab.research.google.com/github/1kaiser/test2022/blob/main/MLPCopy_of_Untitled17.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#  run from here

### Model and training code

Our model is a coordinate-based multilayer perceptron. In this example, for each input image coordinate $(x,y)$, the model predicts the associated color $(r,g,b)$ or any $(gray)$.

![Network diagram](https://user-images.githubusercontent.com/3310961/85066930-ad444580-b164-11ea-9cc0-17494679e71f.png)

[--xla_force_host_platform_device_count](https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html#:~:text=When%20running%20on-,CPU,-you%20can%20always)

In [None]:
#✅
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
import jax
jax.devices()

In [None]:
#✅
import jax
import jax.numpy as jnp

positional_encoding_dims = 6  # Number of positional encodings applied

def positional_encoding(inputs):
    print("positional_encoding start")
    batch_size, _ = inputs.shape;print(inputs.shape)
    inputs_freq = jax.vmap(lambda x: inputs * 2.0 ** x)(jnp.arange(positional_encoding_dims));print(inputs_freq.shape)
    x = jnp.stack([jnp.sin(inputs_freq), jnp.cos(inputs_freq)]);print(x.shape)
    x = x.swapaxes(0, 2);print(x.shape)
    x = x.reshape([batch_size, -1]);print(x.shape)
    x = jnp.concatenate([inputs, x], axis=-1);print(x.shape)
    print("positional_encoding end")
    return x

# y = np.ones((256, 256, 3))
# print(y.shape)
# image_height, image_width, cha = y.shape
# size = image_height * image_width
# yt = np.ones((size, cha))
# print(yt.shape)
# positional_encoding(yt)


In [None]:
# positional_encoding_vmap = jax.vmap(positional_encoding)
# ######################################
# y = jnp.ones((8, 256, 256, 3))
# print(y.shape)
# batchsize, image_height, image_width, cha = y.shape
# size = image_height * image_width
# yt = jnp.ones((batchsize, size, cha))
# ######################################
# print("vmap >>>")
# print(positional_encoding_vmap(yt).shape)

# positional_encoding_pmap = jax.pmap(positional_encoding)
# print("pmap >>>")
# print(positional_encoding_pmap(yt).shape)

### MLP MODEL
Basically, passing input points through a simple Fourier Feature Mapping enables an MLP to learn high-frequency functions (such as an RGB image) in low-dimensional problem domains (such as a 2D coordinate of pixels).

In [None]:
#✅
!python -m pip install -q -U flax

import jax
import jax.numpy as jnp

import flax
import optax
from typing import Any

from jax import lax
import flax.linen as nn
from flax.training import train_state, common_utils
apply_positional_encoding = True # Apply posittional encoding to the input or not
num_dense_layers = 8 # Number of dense layers in MLP
dense_layer_width = 256 # Dimentionality of dense layers' output space 
##########################################<< MLP MODEL >>#########################################
class MLPModel(nn.Module):
    dtype: Any = jnp.float32
    precision: Any = lax.Precision.DEFAULT
    apply_positional_encoding: bool = apply_positional_encoding
    @nn.compact
    def __call__(self, input_points):
        x = positional_encoding(input_points) if self.apply_positional_encoding else input_points
        print("network model start")
        print(x.shape)
        for i in range(num_dense_layers):
            x = nn.Dense(
                dense_layer_width,
                dtype=self.dtype,
                precision=self.precision
                )(x)
            x = nn.relu(x)
            x = jnp.concatenate([x, input_points], axis=-1) if i == 4 else x
            print(x.shape)
  
        x = nn.Dense(1, dtype=self.dtype, precision=self.precision)(x)
        print(x.shape)
        print("network model end")
        return x



In [None]:
#✅
# !wget https://people.eecs.berkeley.edu/~bmild/nerf/tiny_nerf_data.npz
# from google.colab import output
# output.clear() #to_clear_the_output_console_everytime
# import jax.numpy as jnp

# data = jnp.load("tiny_nerf_data.npz")
# images = data["images"]
# poses = data["poses"]
# focal = float(data["focal"])
# _, image_height, image_width, _ = images.shape
# train_images, train_poses = images[:100], poses[:100]
# val_image, val_pose = images[101], poses[101]
# ############################################
# def initialize_model(key, input_pts_shape):
#     model = MLPModel()
#     initial_params = jax.jit(model.init)({"params": key},jnp.ones(input_pts_shape),)
#     return model, initial_params["params"]
# #############################################
# n_devices = jax.local_device_count()
# key, rng = jax.random.split(jax.random.PRNGKey(0))
# model, params = initialize_model(key, (image_height * image_width, 3))


In [None]:
# #✅
# import jax.numpy as jnp
# import jax
# key, rng = jax.random.split(jax.random.PRNGKey(0))
# batch_size_no = 64
# x = jnp.ones(shape=(batch_size_no, 32, 32, 3)) # Dummy Input
# BATCH, image_height, image_width, channel = x.shape
# size = image_height * image_width
# yt = jnp.ones((size, channel))
# model = MLPModel() # Instantiate the Model

# params = model.init(rng, yt) # Initialize the parameters
# print(type(params))

# params1 = model.apply # Initialize the parameters
# print(type(params1))

# jax.tree_map(lambda x: x.shape, params) # Check the parameters


In [None]:
#@title MODEL SUMMARY { vertical-output: true }
#✅
# import flax.linen as nn
# nn.tabulate(model, rng)(jnp.ones((image_height * image_width, channels)))

### initialize the module

In [None]:
#✅
!python -m pip install -q -U flax
import optax
from flax.training import train_state
import jax.numpy as jnp
import jax


def init_train_state(model, r_key, shape, learning_rate ) -> train_state.TrainState:
    print(shape)
    # BATCH, image_height, image_width, cha = shape
    # size = image_height * image_width
    # yt = jnp.ones((size, cha))
    init_variables = model.init(r_key, jnp.ones(shape))  # Initialize the Model
    optimizer = optax.adam(learning_rate) # Create the optimizer
    # Create a State
    return train_state.TrainState.create(
        apply_fn = model.apply,
        tx=optimizer,
        params=init_variables['params']
    )

learning_rate = 1e-4
batch_size_no = 64
model = MLPModel() # Instantiate the Model
key, rng = jax.random.split(jax.random.PRNGKey(0))
x = jnp.ones(shape=(batch_size_no, 28, 28, 1)) # Dummy Input
_, image_height, image_width, channels = x.shape
state = init_train_state( model, rng, (image_height * image_width, channels), learning_rate )


In [33]:
def cross_entropy_loss(*, logits, labels):
    one_hot_encoded_labels = jax.nn.one_hot(labels, num_classes=10)
    return optax.softmax_cross_entropy(
        logits=logits, labels=one_hot_encoded_labels
    ).mean()
def compute_metrics(*, logits, labels):
  loss = .5 * jnp.mean((logits - labels) ** 2)
  loss = lax.pmean(loss, axis_name="batch");print("ok4")
  metrics = {
      'loss': loss,
      'logits': logits,
      'labels': labels
  }
  return metrics

In [None]:
import jax
@jax.jit
def train_step(state: train_state.TrainState, batch: jnp.ndarray, rng):
    print(batch)
    image, label = batch
    print(image,"<<<image")
    print(label,"<<<label")    
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, image);print("done1",logits.shape)
        loss =  .5 * jnp.mean((logits - label) ** 2);print("done2",loss.shape)
        return loss, logits

    # def loss_fn(params):
    #     model_fn = lambda x: state.apply_fn({"params": params}, x)
    #     ray_origins, ray_directions = inputs
    #     print(ray_origins)
    #     print(ray_directions)
    #     rgb, *_ = perform_volume_rendering(
    #         model_fn, ray_origins, ray_directions, rng
    #     )
    #     return jnp.mean((rgb - targets) ** 2)  
    print("ok1really")
    gradient_fn = jax.value_and_grad(loss_fn, has_aux=True);print("ok1")
    (_, logits), grads = gradient_fn(state.params);print("ok2")
    #train_loss, gradients_each = jax.value_and_grad(loss_fn)(state.params);print("ok3")
    grads = lax.pmean(grads,"batch");print("ok4")
    # grads = jnp.mean(grads);print("ok4")
    state = state.apply_gradients(grads=grads);print("ok5")
    # train_loss = jnp.mean(train_loss);print("ok6")
    metrics = compute_metrics(logits=logits, labels=label);print("ok7")
    return state, metrics

parallel_train_step = jax.pmap(train_step, "batch")
# parallel_train_step = jax.pmap(train_step, axis_name="batch", in_axes = (0, 0, 0))

import jax
@jax.jit
def eval_step(state, batch):
    image, label = batch
    logits = state.apply_fn({'params': state.params}, image)
    return compute_metrics(logits=logits, labels=label)


### checkpoints management

In [None]:
# def save_checkpoint(ckpt_path, state):
#     with open(ckpt_path, "wb") as outfile:
#         outfile.write(msgpack_serialize(to_state_dict(state)))
    


# def load_checkpoint(ckpt_path, ckpt_file, state):
#     ckpt_path = os.path.join(ckpt_path, ckpt_file)
#     with open(ckpt_path, "rb") as data_file:
#         byte_data = data_file.read()
#     return from_bytes(state, byte_data)


# def accumulate_metrics(metrics):
#     metrics = jax.device_get(metrics)
#     return {
#         k: np.mean([metric[k] for metric in metrics])
#         for k in metrics[0]
#     }

### train & evaluation function

In [None]:
!gdown https://drive.google.com/uc?id=1UgWEotThxnP-Vh-h83-VcTPMkKWmgCDe #downloading MAP-DEM 

In [None]:
!python -m pip install -U tifffile imagecodecs matplotlib lxml zarr fsspec


In [None]:
!mv /content/HLSL30.020_B04_doy2021057_aid0001_43N.tif /content/a.tif #renaming file

In [None]:
import tifffile
import imagecodecs
from imagecodecs import imread, imwrite
fp = r'/content/a.tif'
image = imread("/content/a.tif")
b = image.reshape(-1,1)
b.shape
newsize = (28, 28)
c = jnp.asarray(image.resize(newsize, refcheck=False)).reshape(-1,1)
c.shape

In [None]:
!python -m pip install rasterio


In [None]:
import rasterio
from rasterio.plot import show
# fp = r'/content/a.tif'
# img = rasterio.open(fp)
# show(image)
# print(img.count) #to print number of bands
newsize = (28, 28)
def imageGRAY(argv):
    im = imread(argv)
    tvt, tvu = jnp.asarray(im.resize(newsize)),jnp.asarray(im.resize(newsize)).reshape(-1,1)
    return tvt, tvu
x, x1 = imageGRAY("/content/a.tif")
print(x.shape, x1.shape)

In [None]:
!wget https://live.staticflickr.com/7492/15677707699_d9d67acf9d_b.jpg -O a.jpg

newsize = (50, 50)
from PIL import Image, ImageFilter
import jax.numpy as jnp

def imageGRAY(argv):
    im = Image.open(argv).convert('L')
    tvt, tvu = jnp.asarray(im.resize(newsize)),jnp.asarray(im.resize(newsize)).reshape(-1,1)
    return tvt, tvu
def imageRGB(argv):
    im = Image.open(argv)
    tvt, tvu = jnp.asarray(im.resize(newsize)),jnp.asarray(im.resize(newsize)).reshape(-1,3)
    return tvt, tvu
x, x1 = imageGRAY("/content/a.jpg")
print(x.shape)# mnist IMAGES are 28x28=784 pixels
y, y1 = imageRGB("/content/a.jpg")
print(y.shape,y1.shape,"<< y shape",x.shape,x1.shape,"<< y shape")# mnist IMAGES are 28x28=784 pixels
# from matplotlib import pyplot as plt
# plt.imshow(x);plt.show()
# plt.imshow(y);plt.show()
# plt.imshow((y1[0]-x1).reshape(28,-1));plt.show()
# plt.imshow((y1[1]-x1).reshape(28,-1));plt.show()
# plt.imshow((y1[2]-x1).reshape(28,-1));plt.show()


# batch = y1, x1  # jnp.ones((28*28,1)),jnp.ones((28*28,1)) OR jnp.ones((2, 28*28, 1))
# shapea, channels = y1.shape
# state = init_train_state( model, rng, (shapea, channels), learning_rate )
# for e in range(150):
#   state, metrics = train_step(state, batch, rng)
#   print(metrics)

In [None]:
# #✅
# import os
# os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
# import jax
# jax.devices()
# import jax.numpy as jnp
# from jax import pmap
# a = jnp.arange(8*10).reshape((8, 2,5))
# b = 2
# print(type(a));print(a)
# def f(x,y):a = x**2+y**2;return a
# ff = pmap(f, in_axes=(0,None))
# result = ff(a,b)
# print(type(result));print(result)
# num_devices = jax.device_count()

# shape_prefix = (num_devices, 1);print(shape_prefix)

download [flower dataset](https://www.kaggle.com/datasets/alxmamaev/flowers-recognition?resource=download) from kaggle.

In [None]:
!gdown https://drive.google.com/uc?id=1SynswUkxdl6B3c6Uc7Q6MhLG3XirrHd9 # downloading from google drive saved location..
!unzip /content/archive.zip #unzipping the flower images from archive..

In [None]:
!mkdir -p /content/flowers/all
!cp /content/flowers/daisy/* /content/flowers/all
!cp /content/flowers/dandelion/* /content/flowers/all
!cp /content/flowers/rose/* /content/flowers/all
!cp /content/flowers/sunflower/* /content/flowers/all
!cp /content/flowers/tulip/* /content/flowers/all

In [None]:
# testing batching of the dataset frfom the total dataset, by assuming batch size as 8 then running 50 epochs over the batch , then moving processing onto next batch>>>
batch_size = 8
import os
image_dir = r'/content/flowers/rose/'
#############################################################################
prefix = "sur_refl_"
bandend = ["c",".jpg", "b02", "b03", "b04", "b05", "b06", "b07", "day_of_year", "qc_500m", "raz", "state_500m", "szen", "vzen"]
DayOY = "_doy\[0-9]+_aid0001"
fileExt = r'.jpg'
expression_b2 = bandend[1]
total_images =  [f for f in os.listdir(image_dir) if f.__contains__(expression_b2)]
total_images.sort()
total_images_path = [os.path.join(image_dir, i) for i in total_images if i != 'outputs']




In [None]:
#adding loop function to loop over the total images and batch 8 of them together , also print path of 8 images of the batch
print(len(total_images_path))
liss = len(total_images_path)
print(total_images_path[3])

no_of_batches = int(len(total_images_path)/batch_size)
print(no_of_batches)
for e in range(no_of_batches):
  print(e*batch_size,":",(e+1)*batch_size,"\n")
  for x in range(e*batch_size,(e+1)*batch_size,1):
    print(total_images_path[x])


In [None]:
# #checking image loading to RGB and GRAY scale
# x, x1 = imageGRAY(total_images_path[3])
# print(x.shape)# mnist IMAGES are 28x28=784 pixels
# y, y1 = imageRGB(total_images_path[3])
# print(y.shape,y1.shape,"<< y shape",x.shape,x1.shape,"<< x shape")# mnist IMAGES are 28x28=784 pixels
# from matplotlib import pyplot as plt
# plt.imshow(x);plt.show()
# plt.imshow(y);plt.show()

In [None]:
#combining both above looping over the dataset and image viewability TOGETHER
print(len(total_images_path))
liss = len(total_images_path)
print(total_images_path[3])

no_of_batches = int(len(total_images_path)/batch_size)
print(no_of_batches)
for e in range(no_of_batches):
  print(e*batch_size,":",(e+1)*batch_size,"\n")
  for x in range(batch_size):
    print(total_images_path[e*batch_size + x])
    #checking image loading to RGB and GRAY scale
    xGRAY, x1GRAY = imageGRAY(total_images_path[e*batch_size + x])
    print(xGRAY.shape)# mnist IMAGES are 28x28=784 pixels
    yRGB, y1RGB = imageRGB(total_images_path[e*batch_size + x])
    print(yRGB.shape, y1RGB.shape,"<< y shape", xGRAY.shape, x1GRAY.shape,"<< x shape")# mnist IMAGES are 28x28=784 pixels
    # from matplotlib import pyplot as plt
    # plt.imshow(xGRAY);plt.show()
    # plt.imshow(yRGB);plt.show()

    

In [None]:
jax.device_count()

✅ Runs correctly >>> but commenting out

In [None]:
# #combining both above looping over the dataset and image viewability TOGETHER, REcalibrating to list images as batch array>>>
# print(len(total_images_path))
# liss = len(total_images_path)
# print(total_images_path[3])
# ##############
# num_devices = jax.device_count()   # adding number of devices to process in parallel
# ##############
# no_of_batches = int(len(total_images_path)/batch_size)
# print(no_of_batches)
# for e in range(no_of_batches):
#   print(e*batch_size,":",(e+1)*batch_size,"\n")
#   xeGRAY = []
#   yeRGB = []
#   for x in range(batch_size):
#     print(total_images_path[e*batch_size + x])
#     #checking image loading to RGB and GRAY scale
#     xGRAY, x1GRAY = imageGRAY(total_images_path[e*batch_size + x])
#     xeGRAY.append(x1GRAY)
#     print(xGRAY.shape)# mnist IMAGES are 28x28=784 pixels
#     yRGB, y1RGB = imageRGB(total_images_path[e*batch_size + x])
#     yeRGB.append(y1RGB)
#     print(yRGB.shape, yeRGB[x].shape,"<< y shape", xGRAY.shape, xeGRAY[x].shape,"<< x shape")# mnist IMAGES are 28x28=784 pixels
#     # from matplotlib import pyplot as plt
#     # plt.imshow(xGRAY);plt.show()
#     # plt.imshow(yRGB);plt.show()
#   ##########################################<<< PARALLEL CALCULATION OF MODEL
#   #adding jnp.asarray to batch of images, 
#   ddyz = jnp.asarray((yeRGB[0],yeRGB[1],yeRGB[2],yeRGB[3],yeRGB[4],yeRGB[5],yeRGB[6],yeRGB[7]))       # rgb images (width * height, 3)
#   print(ddyz.shape,"<<< ddy shape")
#   ddxz = jnp.asarray((xeGRAY[0],xeGRAY[1],xeGRAY[2],xeGRAY[3],xeGRAY[4],xeGRAY[5],xeGRAY[6],xeGRAY[7]))    #gray images (width * height, 1)
#   shape_prefix = (num_devices, 1);print(shape_prefix);print(ddyz.shape,"<<<< ddyz.shape???")
#   batch_ccc = ddyz, ddxz  # jnp.ones((28*28,1)),jnp.ones((28*28,1)) OR jnp.ones((2, 28*28, 1))
#   print(len(batch_ccc),"<<< batch")
#   vv, shapea, channels = ddyz.shape
#   rng = jax.random.PRNGKey(0)
#   # dropout_rngs = jax.random.split(rng, jax.local_device_count())
#   ######################
#   count = 0
#   if count == 0 :
#     state = init_train_state( model, rng, (shapea, channels), learning_rate ) 
#     count = 1
#   state = flax.jax_utils.replicate(state)  # FLAX will replicate the state to every device so that updating can be made easy

#   for epochs in range(50):   # EPOCHS for training & updating the initiated state, metrics may show the loss in each epochs or iteration
#     dropout_rngs = jax.random.split(rng, jax.local_device_count())
#     state, metrics = parallel_train_step(state, batch_ccc, dropout_rngs)
#     print("<<✅✅✅epoc : ",epochs," complete✅✅✅>>\n",metrics)
#   ##########################################



    

### **tensorboard visualization of loss graph**

In [None]:
!rm -r /content/ckpts

In [None]:
%load_ext tensorboard

In [None]:
from torch.utils.tensorboard import SummaryWriter
logdir = "runs"

writer = SummaryWriter(logdir)

In [None]:
%tensorboard --logdir={logdir}

In [37]:
#@title ### **👠HIGH HEELS RUN >>>>>>>>>>>** { vertical-output: true }
import jax
from jax import random
def batchedimages(image_locations):
  ddyss = jnp.asarray((imageRGB(total_images_path[image_locations[0]])[1],
                      imageRGB(total_images_path[image_locations[1]])[1],
                      imageRGB(total_images_path[image_locations[2]])[1],
                      imageRGB(total_images_path[image_locations[3]])[1],
                      imageRGB(total_images_path[image_locations[4]])[1],
                      imageRGB(total_images_path[image_locations[5]])[1],
                      imageRGB(total_images_path[image_locations[6]])[1],
                      imageRGB(total_images_path[image_locations[7]])[1]))
  ddxss = jnp.asarray((imageGRAY(total_images_path[image_locations[0]])[1],
                      imageGRAY(total_images_path[image_locations[1]])[1],
                      imageGRAY(total_images_path[image_locations[2]])[1],
                      imageGRAY(total_images_path[image_locations[3]])[1],
                      imageGRAY(total_images_path[image_locations[4]])[1],
                      imageGRAY(total_images_path[image_locations[5]])[1],
                      imageGRAY(total_images_path[image_locations[6]])[1],
                      imageGRAY(total_images_path[image_locations[7]])[1]))
  #print(ddyss.shape,"<<<< ddyss.shape???",ddxss.shape,"<<<< ddxss.shape???") #to check shape 
  batch_ccc = ddyss, ddxss
  return batch_ccc
  
def data_stream():
  key = random.PRNGKey(0)
  perm = random.permutation(key, len(total_images_path))
  for i in range(no_of_batches):
    batch_idx = perm[i * batch_size : (i + 1) * batch_size]; #print(batch_idx)
    yield batchedimages(batch_idx)

batches = data_stream()  ### this stream will utilize the array of paths of images to a folder, then "generate" batches into the variable

next(batches)[0].shape ### this command starts initial 8 image  stream, if callled inside a iteration loop then it will get next images for calculations>>>
vv, shapea, channels = next(batches)[0].shape # seitting values >>> [0] 8 784 3  [1] 8 784 1 ; RGB & GRAYSCALE versions 8 images each converted to 1-D array
print(vv, shapea, channels)
######################<<< summary writer for tensor board
# from torch.utils.tensorboard import SummaryWriter
# logdir = "runs"
# writer = SummaryWriter(logdir)
######################
######################
rng = jax.random.PRNGKey(0)
# dropout_rngs = jax.random.split(rng, jax.local_device_count())
######################
#################################<<< checking if checkpoint already available
import os # importing os module
import re # to find file using regular expression
checkpoint_available = 0
pattern = re.compile("checkpoint_\d+")   # to search for "checkpoint_*munerical value*" numerical value of any length is denoted by regular expression "\d+"
dir = "/content/ckpts/"
isFile = os.path.isdir(dir)
if isFile:
  for filepath in os.listdir(dir):
      if pattern.match(filepath):
          checkpoint_available = 1
#################################
##########################################<<< loading checkpoint by checking the Flag available
from flax.training import checkpoints
if checkpoint_available:
  CKPT_DIR = 'ckpts'
  restored_state = checkpoints.restore_checkpoint(ckpt_dir=CKPT_DIR, target=state)
  #state = flax.jax_utils.replicate(restored_state)
  print("true <<< File loaded for and replicated to all devices")
##########################################
######################<<<< initiating train state
count = 0
if count == 0 :
  state = init_train_state( model, rng, (shapea, channels), learning_rate ) 
  count = 1
state = flax.jax_utils.replicate(state)  # FLAX will replicate the state to every device so that updating can be made easy
dropout_rngs = jax.random.split(rng, jax.local_device_count())
######################
from flax.training import checkpoints
import time
total_epochs = 50
for epochs in range(total_epochs):   # EPOCHS for training & updating the initiated state, metrics may show the loss in each epochs or iteration
  start_time = time.time()
  batches = data_stream()  ### this stream will utilize the array of paths of images to a folder, then "generate" batches into the variable
  if checkpoint_available:
    CKPT_DIR = 'ckpts'
    restored_state = checkpoints.restore_checkpoint(ckpt_dir=CKPT_DIR, target=state)
    state = restored_state
    checkpoint_available = 0 # << Flag updated >>> to stop loading the same checkpoint in the next iteration then remove the checkpoint directory
    !rm -r /content/ckpts
  for bbb in range(no_of_batches-5):
    print(bbb,"of total number of batches",no_of_batches)
    state, metrics = parallel_train_step(state, next(batches), dropout_rngs)
    print("<<✅✅✅epoc : ",epochs," complete✅✅✅>>\n",metrics['loss'][0]) #printing loss 1 out of 8 processed in 8 devices
    print("logits shape ⚡⚡⚡⚡", metrics['logits'][0].shape)
    
    L1 = metrics['logits'][0]
    print(L1)
    from matplotlib import pyplot as plt
    plt.imshow((L1).reshape(50,-1));plt.show()

    print("labels shape ⚡⚡⚡⚡", metrics['labels'][0].shape)
    L2 = metrics['labels'][0]
    plt.imshow((L2).reshape(50,-1));plt.show()

    writer.add_scalar('Loss', int(metrics['loss'][0]), epochs)
  epoch_time = time.time() - start_time
  print(f"Epoch {epochs} in {epoch_time:0.2f} sec")
  ##################################################<<< model saving mechanism for flax model state as checkpoints for each epochs,"checkpoint" is a terminology that means all the model weights and biases during the calculation till the completion of 1 epoch were being updated, then this final set of weights and biases including their placement inside the model will be saves as a (schema+weight values) saved as checkpoint in the mentioned <<CKPT_DIR = 'ckpts'>> mentioned folder.  
  CKPT_DIR = 'ckpts'
  checkpoints.save_checkpoint(ckpt_dir=CKPT_DIR, target=state, step= epochs)     # naming of the checkpoint is "checkpoint_*"  where "*" => value of the steps variable, i.e. 'epochs'
  restored_state = checkpoints.restore_checkpoint(ckpt_dir=CKPT_DIR, target=state) # using to get the checkpoint loaded , it can be latest one , or if already available as checkpoint in the "CKPT_DIR" directory then take the file from directory then save in >> restored_checkpoints
  ##################################################
  # images = total_images_path[batch_idx]
writer.flush()

8 2500 3
(2500, 3)
positional_encoding start
(2500, 3)
(6, 2500, 3)
(2, 6, 2500, 3)
(2500, 6, 2, 3)
(2500, 36)
(2500, 39)
positional_encoding end
network model start
(2500, 39)
(2500, 256)
(2500, 256)
(2500, 256)
(2500, 256)
(2500, 259)
(2500, 256)
(2500, 256)
(2500, 256)
(2500, 1)
network model end
0 of total number of batches 539
(Traced<ShapedArray(uint8[2500,3])>with<DynamicJaxprTrace(level=0/2)>, Traced<ShapedArray(uint8[2500,1])>with<DynamicJaxprTrace(level=0/2)>)
Traced<ShapedArray(uint8[2500,3])>with<DynamicJaxprTrace(level=0/2)> <<<image
Traced<ShapedArray(uint8[2500,1])>with<DynamicJaxprTrace(level=0/2)> <<<label
ok1really
ok1
positional_encoding start
(2500, 3)
(6, 2500, 3)
(2, 6, 2500, 3)
(2500, 6, 2, 3)
(2500, 36)
(2500, 39)
positional_encoding end
network model start
(2500, 39)
(2500, 256)
(2500, 256)
(2500, 256)
(2500, 256)
(2500, 259)
(2500, 256)
(2500, 256)
(2500, 256)
(2500, 1)
network model end
done1 (2500, 1)
done2 ()
ok2
ok4
ok5
ok4
ok7
<<✅✅✅epoc :  0  complete✅✅✅>

ImportError: ignored

<Figure size 432x288 with 1 Axes>

labels shape ⚡⚡⚡⚡ (2500, 1)


ImportError: ignored

<Figure size 432x288 with 1 Axes>

1 of total number of batches 539
<<✅✅✅epoc :  0  complete✅✅✅>>
 6174.291
logits shape ⚡⚡⚡⚡ (2500, 1)
[[0.17133436]
 [0.24452768]
 [0.2984809 ]
 ...
 [0.0348334 ]
 [0.01395282]
 [0.15540026]]


ImportError: ignored

<Figure size 432x288 with 1 Axes>

labels shape ⚡⚡⚡⚡ (2500, 1)


ImportError: ignored

<Figure size 432x288 with 1 Axes>

2 of total number of batches 539
<<✅✅✅epoc :  0  complete✅✅✅>>
 6493.8022
logits shape ⚡⚡⚡⚡ (2500, 1)
[[0.3588239]
 [1.1156261]
 [1.7224472]
 ...
 [3.0389073]
 [3.1026576]
 [2.9304407]]


ImportError: ignored

<Figure size 432x288 with 1 Axes>

labels shape ⚡⚡⚡⚡ (2500, 1)


ImportError: ignored

<Figure size 432x288 with 1 Axes>

3 of total number of batches 539
<<✅✅✅epoc :  0  complete✅✅✅>>
 6245.928
logits shape ⚡⚡⚡⚡ (2500, 1)
[[5.2705245]
 [5.639239 ]
 [5.8993278]
 ...
 [3.9876368]
 [4.557884 ]
 [4.701828 ]]


ImportError: ignored

<Figure size 432x288 with 1 Axes>

labels shape ⚡⚡⚡⚡ (2500, 1)


ImportError: ignored

<Figure size 432x288 with 1 Axes>

4 of total number of batches 539
<<✅✅✅epoc :  0  complete✅✅✅>>
 5569.6333
logits shape ⚡⚡⚡⚡ (2500, 1)
[[2.3868935]
 [3.1348755]
 [4.5863113]
 ...
 [0.8487298]
 [0.4793321]
 [1.1090873]]


ImportError: ignored

<Figure size 432x288 with 1 Axes>

labels shape ⚡⚡⚡⚡ (2500, 1)


ImportError: ignored

<Figure size 432x288 with 1 Axes>

5 of total number of batches 539
<<✅✅✅epoc :  0  complete✅✅✅>>
 6291.8667
logits shape ⚡⚡⚡⚡ (2500, 1)
[[3.1025002 ]
 [4.916146  ]
 [7.125583  ]
 ...
 [0.17531306]
 [0.22072546]
 [0.22072546]]


ImportError: ignored

<Figure size 432x288 with 1 Axes>

labels shape ⚡⚡⚡⚡ (2500, 1)


ImportError: ignored

<Figure size 432x288 with 1 Axes>

6 of total number of batches 539
<<✅✅✅epoc :  0  complete✅✅✅>>
 7862.379
logits shape ⚡⚡⚡⚡ (2500, 1)
[[6.6890883]
 [2.8260956]
 [3.3520365]
 ...
 [1.0963753]
 [2.5741947]
 [6.971418 ]]


ImportError: ignored

<Figure size 432x288 with 1 Axes>

labels shape ⚡⚡⚡⚡ (2500, 1)


ImportError: ignored

<Figure size 432x288 with 1 Axes>

7 of total number of batches 539
<<✅✅✅epoc :  0  complete✅✅✅>>
 8486.705
logits shape ⚡⚡⚡⚡ (2500, 1)
[[10.892286 ]
 [ 7.7266064]
 [ 6.198224 ]
 ...
 [17.752333 ]
 [17.365118 ]
 [17.441517 ]]


ImportError: ignored

<Figure size 432x288 with 1 Axes>

labels shape ⚡⚡⚡⚡ (2500, 1)


ImportError: ignored

<Figure size 432x288 with 1 Axes>

8 of total number of batches 539
<<✅✅✅epoc :  0  complete✅✅✅>>
 6475.593
logits shape ⚡⚡⚡⚡ (2500, 1)
[[24.788313]
 [22.460093]
 [ 9.97033 ]
 ...
 [13.966358]
 [12.176752]
 [12.938091]]


ImportError: ignored

<Figure size 432x288 with 1 Axes>

labels shape ⚡⚡⚡⚡ (2500, 1)


ImportError: ignored

<Figure size 432x288 with 1 Axes>

9 of total number of batches 539
<<✅✅✅epoc :  0  complete✅✅✅>>
 5412.152
logits shape ⚡⚡⚡⚡ (2500, 1)
[[23.917606]
 [22.165   ]
 [21.286896]
 ...
 [14.425289]
 [13.995896]
 [ 9.063686]]


ImportError: ignored

<Figure size 432x288 with 1 Axes>

labels shape ⚡⚡⚡⚡ (2500, 1)


ImportError: ignored

<Figure size 432x288 with 1 Axes>

10 of total number of batches 539
<<✅✅✅epoc :  0  complete✅✅✅>>
 7548.6206
logits shape ⚡⚡⚡⚡ (2500, 1)
[[32.654987]
 [32.68578 ]
 [32.838436]
 ...
 [32.909985]
 [32.909985]
 [32.909985]]


ImportError: ignored

<Figure size 432x288 with 1 Axes>

labels shape ⚡⚡⚡⚡ (2500, 1)


ImportError: ignored

<Figure size 432x288 with 1 Axes>

11 of total number of batches 539
<<✅✅✅epoc :  0  complete✅✅✅>>
 6506.0537
logits shape ⚡⚡⚡⚡ (2500, 1)
[[8.510901]
 [8.510901]
 [8.510901]
 ...
 [8.510901]
 [8.510901]
 [8.510901]]


ImportError: ignored

<Figure size 432x288 with 1 Axes>

labels shape ⚡⚡⚡⚡ (2500, 1)


ImportError: ignored

<Figure size 432x288 with 1 Axes>

12 of total number of batches 539
<<✅✅✅epoc :  0  complete✅✅✅>>
 6973.753
logits shape ⚡⚡⚡⚡ (2500, 1)
[[ 5.362956 ]
 [ 4.356979 ]
 [ 2.4005089]
 ...
 [12.853135 ]
 [13.422577 ]
 [13.823386 ]]


ImportError: ignored

<Figure size 432x288 with 1 Axes>

labels shape ⚡⚡⚡⚡ (2500, 1)


ImportError: ignored

<Figure size 432x288 with 1 Axes>

13 of total number of batches 539
<<✅✅✅epoc :  0  complete✅✅✅>>
 4691.961
logits shape ⚡⚡⚡⚡ (2500, 1)
[[ 1.487987 ]
 [ 1.487987 ]
 [ 1.487987 ]
 ...
 [16.78609  ]
 [ 2.3098562]
 [ 2.8033237]]


ImportError: ignored

<Figure size 432x288 with 1 Axes>

labels shape ⚡⚡⚡⚡ (2500, 1)


ImportError: ignored

<Figure size 432x288 with 1 Axes>

14 of total number of batches 539
<<✅✅✅epoc :  0  complete✅✅✅>>
 6286.885
logits shape ⚡⚡⚡⚡ (2500, 1)
[[11.582676 ]
 [10.4089365]
 [15.642454 ]
 ...
 [13.339455 ]
 [13.384721 ]
 [13.302502 ]]


ImportError: ignored

<Figure size 432x288 with 1 Axes>

labels shape ⚡⚡⚡⚡ (2500, 1)


ImportError: ignored

<Figure size 432x288 with 1 Axes>

15 of total number of batches 539
<<✅✅✅epoc :  0  complete✅✅✅>>
 4669.775
logits shape ⚡⚡⚡⚡ (2500, 1)
[[37.876434]
 [38.087914]
 [38.491024]
 ...
 [15.835135]
 [14.549774]
 [14.072151]]


ImportError: ignored

<Figure size 432x288 with 1 Axes>

labels shape ⚡⚡⚡⚡ (2500, 1)


ImportError: ignored

<Figure size 432x288 with 1 Axes>

16 of total number of batches 539
<<✅✅✅epoc :  0  complete✅✅✅>>
 

KeyboardInterrupt: ignored

In [106]:
print(L1.reshape(50,-1).shape)
print(L1.shape[0])
print(L1.shape)
print(int(L1[0]))
arrL1 = []
arrL2 = []
for b in range(0,2500):
  arrL1.append(int(L1[b]))
  arrL2.append(int(L2[b]))
print(arrL1)    #<<<<< is predicted value by model
print(arrL2)       #<<<< is real Gray image



(50, 50)
2500
(2500, 1)
37
[37, 38, 38, 38, 39, 39, 38, 38, 38, 38, 38, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 38, 38, 38, 38, 38, 38, 39, 40, 40, 40, 40, 40, 40, 40, 37, 38, 38, 38, 39, 38, 38, 38, 38, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 38, 38, 38, 38, 38, 39, 39, 40, 40, 39, 39, 39, 37, 37, 38, 38, 38, 38, 38, 37, 37, 37, 37, 37, 37, 37, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 38, 38, 38, 39, 39, 39, 39, 39, 39, 37, 37, 38, 38, 38, 38, 37, 37, 37, 37, 37, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 37, 37, 37, 37, 37, 38, 38, 38, 39, 39, 39, 39, 38, 37, 37, 38, 38, 38, 37, 37, 37, 37, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 37, 37, 37, 

ImportError: ignored

<Figure size 432x288 with 1 Axes>

In [118]:

import numpy as np 
import matplotlib.pyplot as plt

H = np.array(np.array(arrL1,  dtype=np.uint8).reshape(50,-1)) 
print(H.shape)
print(H)
plt.imshow(H)
plt.imshow(H, cmap='gray')
plt.show()
# H = np.array(arrL2).reshape(50,-1)
# plt.imshow(H, interpolation='none')
# plt.show()

(50, 50)
[[37 38 38 ... 40 40 40]
 [37 38 38 ... 39 39 39]
 [37 37 38 ... 39 39 39]
 ...
 [13 13 11 ... 12 17 15]
 [19 17 19 ...  9  9  8]
 [18 15 18 ... 15 14 14]]


ImportError: ignored

<Figure size 432x288 with 1 Axes>

In [104]:
print(arr)
print(type(arr))

iarrayL1 = np.array(arrL1)
iarrayL2 = np.array(arrL2)
print(image_array_list.reshape(50,-1).shape)


[37, 38, 38, 38, 39, 39, 38, 38, 38, 38, 38, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 38, 38, 38, 38, 38, 38, 39, 40, 40, 40, 40, 40, 40, 40, 37, 38, 38, 38, 39, 38, 38, 38, 38, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 38, 38, 38, 38, 38, 39, 39, 40, 40, 39, 39, 39, 37, 37, 38, 38, 38, 38, 38, 37, 37, 37, 37, 37, 37, 37, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 38, 38, 38, 39, 39, 39, 39, 39, 39, 37, 37, 38, 38, 38, 38, 37, 37, 37, 37, 37, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 37, 37, 37, 37, 37, 38, 38, 38, 39, 39, 39, 39, 38, 37, 37, 38, 38, 38, 37, 37, 37, 37, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 37, 37, 37, 37, 38, 38, 39, 39, 38, 38,

In [None]:
# !rm -r /content/runs
# !rm -r /content/ckpts

In [None]:
# ✅✅✅✅✅commented out but runs correctly
      # num_devices = jax.device_count()
      # ddy = jnp.asarray((y1,y1,y1,y1,y1,y1,y1,y1))       # rgb images (width * height, 3)
      # print(ddy.shape,"<<< ddy shape")
      # ddx = jnp.asarray((x1,x1,x1,x1,x1,x1,x1,x1))    #gray images (width * height, 1)
      # shape_prefix = (num_devices, 1);print(shape_prefix);print(ddy.shape,"<<<< ???")
      # # ddy = ddy.reshape(shape_prefix + ddy.shape[1:]);print(ddy.shape,"<< train_images_incorrect")
      # # ddx = ddx.reshape(shape_prefix + ddx.shape[1:]);print(ddx.shape,"<< train_images_incorrect")
      # batch = ddy, ddx  # jnp.ones((28*28,1)),jnp.ones((28*28,1)) OR jnp.ones((2, 28*28, 1))
      # print(len(batch),"<<< batch")
      # vv, shapea, channels = ddy.shape
      # ######################
      # rng = jax.random.PRNGKey(0)
      # # dropout_rngs = jax.random.split(rng, jax.local_device_count())
      # ######################
      # state = init_train_state( model, rng, (shapea, channels), learning_rate ) 
      # state = flax.jax_utils.replicate(state)  # FLAX will replicate the state to every device so that updating can be made easy

      # for e in range(50):   # EPOCHS for training & updating the initiated state, metrics may show the loss in each epochs or iteration
      #   dropout_rngs = jax.random.split(rng, jax.local_device_count())
      #   state, metrics = parallel_train_step(state, batch, dropout_rngs)
      #   print("<<✅✅✅epoc : ",e," complete✅✅✅>>\n",metrics)

# ❌❌❌❌❌❌❌doesnot work >>>

In [None]:
batch_size= 32
import datasets
train_images, train_labels, test_images, test_labels = datasets.mnist()
num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)
import jax
num_devices = jax.device_count()
def data_stream():
  key, rng = jax.random.split(jax.random.PRNGKey(0))
  while True:
    perm = jax.random.permutation(rng, num_train); print(perm.shape)
    for i in range(num_batches):
      batch_idx = perm[i * batch_size:(i + 1) * batch_size]; print(batch_idx)
      images, labels = train_images[batch_idx], train_labels[batch_idx]; print(images.shape,"<< train_images_incorrect");print(labels.shape,"<< train_images_correct")
      # For this SPMD example, we reshape the data batch dimension into two
      # batch dimensions, one of which is mapped over parallel devices.
      batch_size_per_device, ragged = divmod(images.shape[0], num_devices);print(batch_size_per_device,"<<< batch_size_per_device")
      if ragged:
        msg = "batch size must be divisible by device count, got {} and {}."
        raise ValueError(msg.format(batch_size, num_devices))
      shape_prefix = (num_devices, batch_size_per_device);print(shape_prefix)
      images = images.reshape(shape_prefix + images.shape[1:]);print(images.shape,"<< train_images_incorrect")
      labels = labels.reshape(shape_prefix + labels.shape[1:]);print(labels.shape,"<< train_images_correct")
      return images, labels
batches = data_stream()


In [None]:
num_train_batches = tf.data.experimental.cardinality(train_dataset)
train_datagen = iter(tfds.as_numpy(train_dataset))
batch = next(train_datagen)
batch = jnp.ones((28*28,1)),jnp.ones((28*28,1))  # jnp.ones((28*28,1)),jnp.ones((28*28,1)) OR jnp.ones((2, 28*28, 1))
# state = flax.jax_utils.replicate(state)
state = init_train_state( model, rng, (image_height * image_width, channels), learning_rate )
for e in range(50):
  state, metrics = train_step(state, batch, rng)
  print(metrics)

In [None]:

def train_and_evaluate(train_dataset, eval_dataset, test_dataset, state: train_state.TrainState, epochs: int,):
    num_train_batches = tf.data.experimental.cardinality(train_dataset)
    num_eval_batches = tf.data.experimental.cardinality(eval_dataset)
    num_test_batches = tf.data.experimental.cardinality(test_dataset)
   
    for epoch in tqdm(range(1, epochs + 1)):
        best_eval_loss = 1e6
        # ============== Training ============== #
        train_batch_metrics = []
        train_datagen = iter(tfds.as_numpy(train_dataset))
        for batch_idx in range(num_train_batches):
            batch = next(train_datagen)
            state, metrics = train_step(state, batch, rng)
            train_batch_metrics.append(metrics)
        train_batch_metrics = accumulate_metrics(train_batch_metrics)
        print('TRAIN (%d/%d): Loss: %.4f' % (
                epoch, epochs, train_batch_metrics['loss'],
              ))
        # ============== Validation ============= #
        eval_batch_metrics = []
        eval_datagen = iter(tfds.as_numpy(eval_dataset))
        for batch_idx in range(num_eval_batches):
            batch = next(eval_datagen)
            metrics = eval_step(state, batch)
            eval_batch_metrics.append(metrics)
        eval_batch_metrics = accumulate_metrics(eval_batch_metrics)
        print('EVAL (%d/%d):  Loss: %.4f\n' % (
                epoch, epochs, eval_batch_metrics['loss'],
              ))    

        if eval_batch_metrics['loss'] < best_eval_loss:
            save_checkpoint("checkpoint.msgpack", state)

    restored_state = load_checkpoint("checkpoint.msgpack", state)
    test_batch_metrics = []
    test_datagen = iter(tfds.as_numpy(test_dataset))
    for batch_idx in range(num_test_batches):
        batch = next(test_datagen)
        metrics = eval_step(restored_state, batch)
        test_batch_metrics.append(metrics)
    
    test_batch_metrics = accumulate_metrics(test_batch_metrics)
    print(
        'Test: Loss: %.4f,' % (
            test_batch_metrics['loss'],
        )
    )
    # Log Metrics to Weights & Biases
    history = {
        "Train Loss": train_batch_metrics['loss'],
        "Validation Loss": eval_batch_metrics['loss'],
    }
    return state, restored_state, history



### loading 'mnist' data

In [None]:
#@title @inproceedings{zhou2017scene, title={Scene Parsing through ADE20K Dataset}, author={Zhou, Bolei and Zhao, Hang and Puig, Xavier and Fidler, Sanja and Barriuso, Adela and Torralba, Antonio}, booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, year={2017} }
#✅✅
import tensorflow_datasets as tfds
(full_train_set, test_dataset), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)
validation_split = 0.2
def normalize_img(image, label):
    image = tf.cast(image, tf.float32) / 255.
    return image, label



In [None]:
#✅✅
import tensorflow as tf
batch_size = 64
full_train_set = full_train_set.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)

num_data = tf.data.experimental.cardinality(full_train_set).numpy()
print("Total number of data points:", num_data)
train_dataset = full_train_set.take(num_data * (1 - validation_split))
val_dataset = full_train_set.take(num_data * (validation_split))
print("Number of train data points:",tf.data.experimental.cardinality(train_dataset).numpy())
print("Number of val data points:",tf.data.experimental.cardinality(val_dataset).numpy())
#############TRAIN##################
train_dataset = train_dataset.cache();print(len(train_dataset))
train_dataset = train_dataset.shuffle(tf.data.experimental.cardinality(train_dataset).numpy());print(train_dataset)
train_dataset = train_dataset.batch(batch_size);print(train_dataset)
#############TRAIN(EVALUATE)##################
val_dataset = val_dataset.cache()
val_dataset = val_dataset.shuffle(tf.data.experimental.cardinality(val_dataset).numpy())
val_dataset = val_dataset.batch(batch_size)

#############TEST##################
test_dataset = test_dataset.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
print("Number of test data points:",tf.data.experimental.cardinality(test_dataset).numpy())
test_dataset = test_dataset.cache()
test_dataset = test_dataset.batch(batch_size)

In [None]:
num_train_batches = tf.data.experimental.cardinality(train_dataset)
num_train_batches

### run inferences

In [None]:
# state, inference_state, history = train_and_evaluate(state, parallelized_train_step, eval_step)


# train_dataset, eval_dataset, test_dataset, state, epochs
from tqdm.notebook import tqdm
epochs = 15
state, best_state, history = train_and_evaluate(
    train_dataset,
    val_dataset,
    test_dataset,
    state,
    epochs,
)

In [None]:
train_dataset

# go

In [None]:
import os
import time
import imageio
import requests
from typing import Any
import ipywidgets as widgets
from functools import partial
from tqdm.notebook import tqdm

!pip3 install -q -U flax

import jax
import flax
import optax
from jax import lax
import flax.linen as nn
from flax.training import train_state, common_utils


import numpy as np
import jax.numpy as jnp


from base64 import b64encode
from IPython.display import HTML
import plotly.express as px
from plotly.subplots import make_subplots
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid


In [None]:
# Reference: https://www.kaggle.com/code/odins0n/jax-flax-tf-data-vision-transformers-tutorial


# Detect if Kaggle Notebook has access to TPUs or not
if 'TPU_NAME' in os.environ:
    import requests
    if 'TPU_DRIVER_MODE' not in globals():
        url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
        resp = requests.post(url)
        TPU_DRIVER_MODE = 1
    from jax.config import config
    jax_xla_backend = "tpu_driver"
    jax_backend_target = os.environ['TPU_NAME']
    print("TPU DETECTED!")
    print('Registered TPU:', jax_backend_target)


# Detect if Google Colab Notebook has access to TPUs or not
elif "COLAB_TPU_ADDR" in os.environ:
    import jax.tools.colab_tpu
    jax.tools.colab_tpu.setup_tpu()


else:
    print('No TPU detected.')


DEVICE_COUNT = len(jax.local_devices())
TPU = DEVICE_COUNT==8


if TPU:
    print("8 cores of TPU ( Local devices in Jax ):")
    print('\n'.join(map(str,jax.local_devices())))


In [None]:


# sync all experiment configs with Weights and Biases

near_bound = 2. # Near Bound of sample space for 3d points
far_bound = 6.  # Far Bound of sample space for 3d points
batch_size = int(1e4) # Batch Size
num_sample_points = 256 # Number of points to be sampled across the volume
epsilon = 1e10 # Hyperparameter for volume rendering
apply_positional_encoding = True # Apply posittional encoding to the input or not
positional_encoding_dims = 3 # Number of positional encodings applied
num_dense_layers = 5 # Number of dense layers in MLP
dense_layer_width = 256 # Dimentionality of dense layers' output space 
learning_rate = 5e-4 # Learning Rate
train_epochs = 1000 # Number of training epochs
plot_interval = 100 # Epoch interval for plotting results during training


In [None]:
positional_encoding_dims = 3 # Number of positional encodings applied
def positional_encoding(inputs):
    batch_size, _ = inputs.shape;                                                   print(inputs.shape)
    # Applying vmap transform to vectorize the multiplication operation
    f = jax.vmap(lambda x: inputs * 2.0 ** x)(jnp.arange(positional_encoding_dims));print(f.shape)
    fy = jnp.stack([jnp.sin(f), jnp.cos(f)]);                                       print(fy)
    fy = fy.swapaxes(0, 2).reshape([batch_size, -1]);                               print(fy)
    fy = jnp.concatenate([inputs, fy], axis=-1);                                    print(fy)
    return fy


(10000, 3)


(3, 10000, 3)

Traced<ShapedArray(float32[2,3,10000,3])>with<DynamicJaxprTrace(level=0/1)>

Traced<ShapedArray(float32[10000,18])>with<DynamicJaxprTrace(level=0/1)>

Traced<ShapedArray(float32[10000,21])>with<DynamicJaxprTrace(level=0/1)>

# test mnist pmap gpu 1000 epocs runtime 9 minutes

In [None]:
%%writefile datasets.py
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Datasets used in examples."""


import array
import gzip
import os
from os import path
import struct
import urllib.request

import numpy as np


_DATA = "/tmp/jax_example_data/"


def _download(url, filename):
  """Download a url to a file in the JAX data temp directory."""
  if not path.exists(_DATA):
    os.makedirs(_DATA)
  out_file = path.join(_DATA, filename)
  if not path.isfile(out_file):
    urllib.request.urlretrieve(url, out_file)
    print(f"downloaded {url} to {_DATA}")


def _partial_flatten(x):
  """Flatten all but the first dimension of an ndarray."""
  return np.reshape(x, (x.shape[0], -1))


def _one_hot(x, k, dtype=np.float32):
  """Create a one-hot encoding of x of size k."""
  return np.array(x[:, None] == np.arange(k), dtype)


def mnist_raw():
  """Download and parse the raw MNIST dataset."""
  # CVDF mirror of http://yann.lecun.com/exdb/mnist/
  base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"

  def parse_labels(filename):
    with gzip.open(filename, "rb") as fh:
      _ = struct.unpack(">II", fh.read(8))
      return np.array(array.array("B", fh.read()), dtype=np.uint8)

  def parse_images(filename):
    with gzip.open(filename, "rb") as fh:
      _, num_data, rows, cols = struct.unpack(">IIII", fh.read(16))
      return np.array(array.array("B", fh.read()),
                      dtype=np.uint8).reshape(num_data, rows, cols)

  for filename in ["train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz",
                   "t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz"]:
    _download(base_url + filename, filename)

  train_images = parse_images(path.join(_DATA, "train-images-idx3-ubyte.gz"))
  train_labels = parse_labels(path.join(_DATA, "train-labels-idx1-ubyte.gz"))
  test_images = parse_images(path.join(_DATA, "t10k-images-idx3-ubyte.gz"))
  test_labels = parse_labels(path.join(_DATA, "t10k-labels-idx1-ubyte.gz"))

  return train_images, train_labels, test_images, test_labels


def mnist(permute_train=False):
  """Download, parse and process MNIST data to unit scale and one-hot labels."""
  train_images, train_labels, test_images, test_labels = mnist_raw()

  train_images = _partial_flatten(train_images) / np.float32(255.)
  test_images = _partial_flatten(test_images) / np.float32(255.)
  train_labels = _one_hot(train_labels, 10)
  test_labels = _one_hot(test_labels, 10)

  if permute_train:
    perm = np.random.RandomState(0).permutation(train_images.shape[0])
    train_images = train_images[perm]
    train_labels = train_labels[perm]

  return train_images, train_labels, test_images, test_labels

[datasets](https://github.com/google/jax/blob/main/examples/datasets.py) link

[spmd_mnist_classifier_fromscratch](https://github.com/google/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py) link

In [None]:
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""An MNIST example with single-program multiple-data (SPMD) data parallelism.

The aim here is to illustrate how to use JAX's `pmap` to express and execute
SPMD programs for data parallelism along a batch dimension, while also
minimizing dependencies by avoiding the use of higher-level layers and
optimizers libraries.
"""


from functools import partial
import time

import numpy as np
import numpy.random as npr

import jax
from jax import jit, grad, pmap
from jax.scipy.special import logsumexp
from jax.tree_util import tree_map
from jax import lax
import jax.numpy as jnp
import datasets
from google.colab import output


def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
  return [(scale * rng.randn(m, n), scale * rng.randn(n))
          for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])]

def predict(params, inputs):
  activations = inputs
  for w, b in params[:-1]:
    outputs = jnp.dot(activations, w) + b
    activations = jnp.tanh(outputs)

  final_w, final_b = params[-1]
  logits = jnp.dot(activations, final_w) + final_b
  return logits - logsumexp(logits, axis=1, keepdims=True)

def loss(params, batch):
  inputs, targets = batch
  preds = predict(params, inputs)
  return -jnp.mean(jnp.sum(preds * targets, axis=1))

@jit
def accuracy(params, batch):
  inputs, targets = batch
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(predict(params, inputs), axis=1)
  return jnp.mean(predicted_class == target_class)


if __name__ == "__main__":
  layer_sizes = [784, 1024, 1024, 10]
  param_scale = 0.1
  step_size = 0.001
  num_epochs = 1000
  batch_size = 128

  train_images, train_labels, test_images, test_labels = datasets.mnist()
  num_train = train_images.shape[0]
  num_complete_batches, leftover = divmod(num_train, batch_size)
  num_batches = num_complete_batches + bool(leftover)

  # For this manual SPMD example, we get the number of devices (e.g. GPUs or
  # TPU cores) that we're using, and use it to reshape data minibatches.
  num_devices = jax.device_count()
  def data_stream():
    rng = npr.RandomState(0)
    while True:
      perm = rng.permutation(num_train)
      for i in range(num_batches):
        batch_idx = perm[i * batch_size:(i + 1) * batch_size]
        images, labels = train_images[batch_idx], train_labels[batch_idx]
        # For this SPMD example, we reshape the data batch dimension into two
        # batch dimensions, one of which is mapped over parallel devices.
        batch_size_per_device, ragged = divmod(images.shape[0], num_devices)
        if ragged:
          msg = "batch size must be divisible by device count, got {} and {}."
          raise ValueError(msg.format(batch_size, num_devices))
        shape_prefix = (num_devices, batch_size_per_device)
        images = images.reshape(shape_prefix + images.shape[1:])
        labels = labels.reshape(shape_prefix + labels.shape[1:])
        yield images, labels
  batches = data_stream()

  @partial(pmap, axis_name='batch')
  def spmd_update(params, batch):
    grads = grad(loss)(params, batch)
    # We compute the total gradients, summing across the device-mapped axis,
    # using the `lax.psum` SPMD primitive, which does a fast all-reduce-sum.
    grads = [(lax.psum(dw, 'batch'), lax.psum(db, 'batch')) for dw, db in grads]
    return [(w - step_size * dw, b - step_size * db)
            for (w, b), (dw, db) in zip(params, grads)]

  # We replicate the parameters so that the constituent arrays have a leading
  # dimension of size equal to the number of devices we're pmapping over.
  init_params = init_random_params(param_scale, layer_sizes)
  replicate_array = lambda x: np.broadcast_to(x, (num_devices,) + x.shape)
  replicated_params = tree_map(replicate_array, init_params)

  for epoch in range(num_epochs):
    start_time = time.time()
    for _ in range(num_batches):
      replicated_params = spmd_update(replicated_params, next(batches))
    epoch_time = time.time() - start_time

    # We evaluate using the jitted `accuracy` function (not using pmap) by
    # grabbing just one of the replicated parameter values.
    params = tree_map(lambda x: x[0], replicated_params)
    train_acc = accuracy(params, (train_images, train_labels))
    test_acc = accuracy(params, (test_images, test_labels))
    output.clear() #to_clear_the_output_console_everytime
    print(f"Epoch {epoch} in {epoch_time:0.2f} sec")
    print(f"Training set accuracy {train_acc}")
    print(f"Test set accuracy {test_acc}")