<a href="https://colab.research.google.com/github/1kaiser/Media-Segment-Depth-MLP/blob/main/MLP_Image_Train_Inference_JAX.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

### **Downloading the dataset**

download [flower dataset](https://www.kaggle.com/datasets/alxmamaev/flowers-recognition?resource=download) from kaggle.

In [None]:
from google.colab import output
!wget https://github.com/1kaiser/Media-Segment-Depth-MLP/releases/download/v0.2/archive.zip -O archive.zip
!unzip /content/archive.zip #unzipping the flower images from archive..
output.clear()
##########################<< copying all varities into a single folder block >>################
!mkdir -p /content/flowers/all
!cp /content/flowers/daisy/* /content/flowers/all
!cp /content/flowers/dandelion/* /content/flowers/all
!cp /content/flowers/rose/* /content/flowers/all
!cp /content/flowers/sunflower/* /content/flowers/all
!cp /content/flowers/tulip/* /content/flowers/all
##########################<< end of block >>################
print("creating single image folder complete >>>")


## **RUN** 

**Model and training code**
Our model is a coordinate-based multilayer perceptron. In this example, for each input image coordinate $(x,y)$, the model predicts the associated color $(r,g,b)$ or any $(gray)$.

![Network diagram](https://user-images.githubusercontent.com/3310961/85066930-ad444580-b164-11ea-9cc0-17494679e71f.png)

**POSITIONAL ENCODING BLOCK** 

In [None]:
#✅
import jax
import jax.numpy as jnp

positional_encoding_dims = 6  # Number of positional encodings applied

def positional_encoding(args):
    image_height_x_image_width, cha = args.shape
    inputs_freq = jax.vmap(lambda x: args * 2.0 ** x)(jnp.arange(positional_encoding_dims))
    x = jnp.stack([jnp.sin(inputs_freq), jnp.cos(inputs_freq)])
    x = x.swapaxes(0, 2)
    x = x.reshape([image_height_x_image_width, -1])
    x = jnp.concatenate([args, x], axis=-1)
    return x



**MLP MODEL DEFINATION**
Basically, passing input points through a simple Fourier Feature Mapping enables an MLP to learn high-frequency functions (such as an RGB image) in low-dimensional problem domains (such as a 2D coordinate of pixels).

In [None]:
#✅
!python -m pip install -qq -U flax orbax
# Orbax needs to enable asyncio in a Colab environment.
!python -m pip install -qq nest_asyncio

import jax
import jax.numpy as jnp

import flax
import optax
from typing import Any

from jax import lax
import flax.linen as nn
from flax.training import train_state, common_utils

apply_positional_encoding = True # Apply posittional encoding to the input or not
ndl = 8 # num_dense_layers Number of dense layers in MLP
dlw = 256 # dense_layer_width Dimentionality of dense layers' output space 

##########################################<< MLP MODEL >>#########################################
class MLPModel(nn.Module):
    dtype: Any = jnp.float32
    precision: Any = lax.Precision.DEFAULT
    apply_positional_encoding: bool = apply_positional_encoding
    @nn.compact
    def __call__(self, input_points):
        x = positional_encoding(input_points) if self.apply_positional_encoding else input_points
        for i in range(ndl):
            x = nn.Dense(dlw,dtype=self.dtype,precision=self.precision)(x)
            x = nn.relu(x)
            x = jnp.concatenate([x, input_points], axis=-1) if i == 4 else x
  
        x = nn.Dense(1, dtype=self.dtype, precision=self.precision)(x)
        return x
##########################################<< MLP MODEL >>#########################################

**initialize the module**

In [None]:
#✅
!python -m pip install -q -U flax
import optax
from flax.training import train_state
import jax.numpy as jnp
import jax


def Create_train_state(r_key, model, shape, learning_rate ) -> train_state.TrainState:
    print(shape)
    variables = model.init(r_key, jnp.ones(shape)) 
    optimizer = optax.adam(learning_rate) 
    return train_state.TrainState.create(
        apply_fn = model.apply,
        tx=optimizer,
        params=variables['params']
    )

learning_rate = 1e-4
batch_size_no = 64

model = MLPModel() # Instantiate the Model

**defining loss function**

In [None]:
#serial
def image_difference_loss(logits, labels):
    loss = .5 * jnp.mean((logits - labels) ** 2) 
    return loss
def compute_metrics(*, logits, labels):
  loss = image_difference_loss(logits, labels)
  metrics = {
      'loss': loss,     #LOSS
      'logits': logits, #PREDICTED IMAGE
      'labels': labels  #ACTUAL IMAGE
  }
  return metrics

**train step defination**

In [None]:
#cpu serial
import jax

def train_step(state: train_state.TrainState, batch: jnp.asarray, rng):
    image, label = batch  
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, image);
        loss =  image_difference_loss(logits, label);
        return loss, logits

    gradient_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, logits), grads = gradient_fn(state.params)
    new_state = state.apply_gradients(grads=grads)
    logs = compute_metrics(logits=logits, labels=label)
    return new_state, logs

import jax
@jax.jit
def eval_step(state, image):
    logits = state.apply_fn({'params': state.params}, image)
    return compute_metrics(logits=logits, labels=image)


**image conversion fiunctions**

In [None]:
from PIL import Image
import jax.numpy as jnp
def imageGRAY(argv):
    im = Image.open(argv).convert('L')
    tvt, tvu = jnp.asarray(im.resize(newsize)),jnp.asarray(im.resize(newsize)).reshape(-1,1)
    return tvt, tvu
def imageRGB(argv):
    im = Image.open(argv)
    tvt, tvu = jnp.asarray(im.resize(newsize)),jnp.asarray(im.resize(newsize)).reshape(-1,3)
    return tvt, tvu

**image dataset, image size and batch size Setup**

In [None]:
########################################## to load the data in batches as mentioned single batch of images with already provided sizes 
import jax
from jax import random

newsize = (140,140) #(260, 260) # /.... 233 * 454
batch_size = 1

import os
image_dir = r'/content/flowers/tulip/'
#############################################################################
bandend = ["c",".jpg", "b02"]
expression_b2 = bandend[1]
total_images =  [f for f in os.listdir(image_dir) if f.__contains__(expression_b2)]
total_images.sort()
total_images_path = [os.path.join(image_dir, i) for i in total_images if i != 'outputs']
no_of_batches = int(len(total_images_path)/batch_size)

######################################## making 8 array of input for each device >>>
def batchedimages(image_locations):
  RGB8 = jnp.asarray((imageRGB(total_images_path[image_locations[0]])[1]))
  GRAY8 = jnp.asarray((imageGRAY(total_images_path[image_locations[0]])[1]))
  batch_ccc = RGB8, GRAY8 
  return batch_ccc

def data_stream():
  key = random.PRNGKey(0)
  perm = random.permutation(key, len(total_images_path))
  for i in range(no_of_batches):
    batch_idx = perm[i * batch_size : (i + 1) * batch_size]; #print(batch_idx)
    yield batchedimages(batch_idx)

batches = data_stream()  

In [None]:
#@title # **👠HIGH HEELS RUN >>>>>>>>>>>** { vertical-output: true }
newsize = (140,140) #(260, 260) # /.... 233 * 454

import jax
from jax import random
from tqdm import tqdm
import re
from google.colab import output
import orbax.checkpoint as orbax
from flax.training import checkpoints

import optax
import nest_asyncio
nest_asyncio.apply()

rng = jax.random.PRNGKey(0)
CKPT_DIR = 'ckpts'

######################<<<< initiating train state
count = 0
if count == 0 :
  HxW, Channels = next(batches)[0].shape
  state = Create_train_state( rng, model, (HxW, Channels), learning_rate ) 
  count = 1
#✅✅🔻 state = flax.jax_utils.replicate(state)  # FLAX will replicate the state to every device so that updating can be made easy

###################### 
checkpoint_available = 0
pattern = re.compile("checkpoint_\d+")   # to search for "checkpoint_*munerical value*" numerical value of any length is denoted by regular expression "\d+"
dir = "/content/ckpts/"
isFile = os.path.isdir(dir)
if isFile:
  for filepath in os.listdir(dir):
      if pattern.match(filepath):
          checkpoint_available = 1

total_epochs = 50
for epochs in tqdm(range(total_epochs)):  
  batches = data_stream() 

  if checkpoint_available:
    state = checkpoints.restore_checkpoint(ckpt_dir=CKPT_DIR, target=state)
    checkpoint_available = 0 # << Flag updated >>> to stop loading the same checkpoint in the next iteration then remove the checkpoint directory
    !rm -r {dir}

  for bbb in tqdm(range(no_of_batches-5)):
    state, metrics = train_step(state, next(batches), rng)
    output.clear()
    print("loss: ",metrics['loss']," <<< ") # naming of the checkpoint is "checkpoint_*"  where "*" => value of the steps variable, i.e. 'epochs'
  orbax_checkpointer = orbax.Checkpointer(orbax.PyTreeCheckpointHandler())
  checkpoints.save_checkpoint(ckpt_dir=CKPT_DIR, target=state, step=epochs, prefix='checkpoint_', keep=1, overwrite=False, orbax_checkpointer=orbax_checkpointer)
  # restored_state = checkpoints.restore_checkpoint(ckpt_dir=CKPT_DIR, target=state) # using to get the checkpoint loaded , it can be latest one , or if already available as checkpoint in the "CKPT_DIR" directory then take the file from directory then save in >> restored_checkpoints
  ##################################################



**inference engine**

In [None]:

# newsize = (140,140) #(260, 260) # /.... 233 * 454
from google.colab.patches import cv2_imshow
import numpy as np 
from google.colab import output
from flax.training import checkpoints

!wget https://live.staticflickr.com/7492/15677707699_d9d67acf9d_b.jpg -O a.jpg
image_in = '/content/a.jpg'

from PIL import Image
import jax.numpy as jnp
def imageRGB(argv):
    im = Image.open(argv)
    tvt, tvu = jnp.asarray(im.resize(newsize)),jnp.asarray(im.resize(newsize)).reshape(-1,3)
    return tvt, tvu
image = jnp.asarray((imageRGB(image_in)[1]))
#restored_state = checkpoints.restore_checkpoint(ckpt_dir=CKPT_DIR, target=state)
#state = restored_state
#initialize
HxW, Channels = image.shape
state = Create_train_state( model, rng, (HxW, Channels), learning_rate ) 

state = checkpoints.restore_checkpoint(ckpt_dir=CKPT_DIR, target=state)
prediction = eval_step(state, image)
prediction['loss']


predicted_image = np.array(prediction['logits'],  dtype=np.uint8).reshape(newsize) 
cv2_imshow(predicted_image)


In [None]:
# import cv2
# from google.colab.patches import cv2_imshow
# import numpy as np 
# def show_image(argu):
#   L1 = argu[0]
#   predicted_image = np.array(argu[0],  dtype=np.uint8).reshape(newsize) # This would be your image array
#   a = predicted_image
#   for i in range(0,argu.shape[0]):
#     predicted_image = np.array(argu[i],  dtype=np.uint8).reshape(newsize) 
#     a = cv2.hconcat([a, predicted_image])
#   cv2_imshow(a)

# show_image(metrics['logits'])

In [None]:
# from jax.tree_util import tree_structure
# print(tree_structure(state))

##**test dataset segmentation Creation download section**

In [None]:
##################################<<< MEDIAPIPE LIBRARY INSTALLATON >>>#############################
!python -m pip install mediapipe
##################################<<< FRAME EXTRACTION >>>#############################
video_location = '/content/drive/MyDrive/OUT/data/machine_learning_test_dataset/test.mp4'
import os

 
# Read images with OpenCV.
#images= None
image_dir = '/content/MEDIAPIPEinput/'
os.makedirs(image_dir, exist_ok=True)
image_dir_out = '/content/annotated_images'
os.makedirs(image_dir_out, exist_ok=True)
frame_rate = 25
!ffmpeg -y -hwaccel cuvid \
  -i {video_location} \
  -r {frame_rate} {image_dir}out_%09d.png

imgs_list = os.listdir(image_dir)
imgs_list.sort()
imgs_path = [os.path.join(image_dir, i) for i in imgs_list if i != 'outputs']
################################<<< SEGMENTATION USING MEDIAPIPE >>>###################################
import cv2
from google.colab.patches import cv2_imshow
import numpy as np
import mediapipe as mp
# mp_holistic = mp.solutions.holistic
mp_pose = mp.solutions.pose
!rm -r {image_dir}.ipynb_checkpoints

# Run MediaPipe Pose with `enable_segmentation=True` to get pose segmentation.
with mp_pose.Pose(static_image_mode=True, 
                          min_detection_confidence=0.2,
                          model_complexity=2, 
                          enable_segmentation=True,) as pose:
  temp_segmentation_mask =[]                        
  for name, image in enumerate(imgs_path):
    !rm -r {image_dir}.ipynb_checkpoints
    # Convert the BGR image to RGB and process it with MediaPipe Pose.
    image = cv2.imread(image)
    results = pose.process(image)

    # Draw pose segmentation.
    print(f'Pose segmentation of {name}:')
    annotated_image_pose = image.copy()
    red_img = np.zeros_like(annotated_image_pose, dtype=np.uint8)
    red_img[:, :] = (255,255,255)
    ###check if segmentation_mask exists or not ## if exists then ok Else use previous mask temporarily
    if results.segmentation_mask is None:
      print("true")
      results.segmentation_mask = temp_segmentation_mask[-1]
      temp_segmentation_mask.append(results.segmentation_mask)
    else:
      temp_segmentation_mask.append(results.segmentation_mask)
    ###End check if segmentation_mask exists or not ## if exists then ok Else use previous mask temporarily
    segm_2class = 0.0 + 1.0 * results.segmentation_mask
    segm_2class = np.repeat(segm_2class[..., np.newaxis], 3, axis=2)
    annotated_image_pose = annotated_image_pose * segm_2class + red_img * (1 - segm_2class)
    #resize_and_show(annotated_image)
    cv2.imwrite('%s/%s' %(image_dir_out, imgs_list[name]), annotated_image_pose)
    !rm -r {image_dir_out}.ipynb_checkpoints


## **RUN 2** 

**Model and training code**
Our model is a coordinate-based multilayer perceptron. In this example, for each input image coordinate $(x,y)$, the model predicts the associated color $(r,g,b)$ or any $(gray)$.

![Network diagram](https://user-images.githubusercontent.com/3310961/85066930-ad444580-b164-11ea-9cc0-17494679e71f.png)

**POSITIONAL ENCODING BLOCK** 

In [None]:
#✅
import jax
import jax.numpy as jnp

positional_encoding_dims = 6  # Number of positional encodings applied

def positional_encoding(args):
    image_height_x_image_width, cha = args.shape
    inputs_freq = jax.vmap(lambda x: args * 2.0 ** x)(jnp.arange(positional_encoding_dims))
    x = jnp.stack([jnp.sin(inputs_freq), jnp.cos(inputs_freq)])
    x = x.swapaxes(0, 2)
    x = x.reshape([image_height_x_image_width, -1])
    x = jnp.concatenate([args, x], axis=-1)
    return x



**MLP MODEL DEFINATION**
Basically, passing input points through a simple Fourier Feature Mapping enables an MLP to learn high-frequency functions (such as an RGB image) in low-dimensional problem domains (such as a 2D coordinate of pixels).

In [None]:
#✅
!python -m pip install -qq -U flax orbax
# Orbax needs to enable asyncio in a Colab environment.
!python -m pip install -qq nest_asyncio


import jax
import jax.numpy as jnp

import flax
import optax
from typing import Any

from jax import lax
import flax.linen as nn
from flax.training import train_state, common_utils

apply_positional_encoding = True # Apply posittional encoding to the input or not
ndl = 8 # num_dense_layers Number of dense layers in MLP
dlw = 256 # dense_layer_width Dimentionality of dense layers' output space 

##########################################<< MLP MODEL >>#########################################
class MLPModel(nn.Module):
    dtype: Any = jnp.float32
    precision: Any = lax.Precision.DEFAULT
    apply_positional_encoding: bool = apply_positional_encoding
    @nn.compact
    def __call__(self, input_points):
        x = positional_encoding(input_points) if self.apply_positional_encoding else input_points
        for i in range(ndl):
            x = nn.Dense(dlw,dtype=self.dtype,precision=self.precision)(x)
            x = nn.relu(x)
            x = jnp.concatenate([x, input_points], axis=-1) if i == 4 else x
  
        x = nn.Dense(1, dtype=self.dtype, precision=self.precision)(x)
        return x
##########################################<< MLP MODEL >>#########################################

**initialize the module**

In [None]:
#✅
!python -m pip install -q -U flax
import optax
from flax.training import train_state
import jax.numpy as jnp
import jax


def Create_train_state(r_key, model, shape, learning_rate ) -> train_state.TrainState:
    print(shape)
    variables = model.init(r_key, jnp.ones(shape)) 
    optimizer = optax.adam(learning_rate) 
    return train_state.TrainState.create(
        apply_fn = model.apply,
        tx=optimizer,
        params=variables['params']
    )

learning_rate = 1e-4
batch_size_no = 64

model = MLPModel() # Instantiate the Model

**defining loss function**

In [None]:
#serial
def image_difference_loss(logits, labels):
    loss = .5 * jnp.mean((logits - labels) ** 2) 
    return loss
def compute_metrics(*, logits, labels):
  loss = image_difference_loss(logits, labels)
  metrics = {
      'loss': loss,     #LOSS
      'logits': logits, #PREDICTED IMAGE
      'labels': labels  #ACTUAL IMAGE
  }
  return metrics

**train step defination**

In [None]:
#cpu serial
import jax

def train_step(state: train_state.TrainState, batch: jnp.asarray, rng):
    image, label = batch  
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, image);
        loss =  image_difference_loss(logits, label);
        return loss, logits

    gradient_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, logits), grads = gradient_fn(state.params)
    new_state = state.apply_gradients(grads=grads)
    logs = compute_metrics(logits=logits, labels=label)
    return new_state, logs

import jax
@jax.jit
def eval_step(state, image):
    logits = state.apply_fn({'params': state.params}, image)
    return compute_metrics(logits=logits, labels=image)


**image conversion fiunctions**

In [None]:
from PIL import Image
import jax.numpy as jnp
def imageGRAY(argv):
    im = Image.open(argv).convert('L')
    tvt, tvu = jnp.asarray(im.resize(newsize)),jnp.asarray(im.resize(newsize)).reshape(-1,1)
    return tvt, tvu
def imageRGB(argv):
    im = Image.open(argv)
    tvt, tvu = jnp.asarray(im.resize(newsize)),jnp.asarray(im.resize(newsize)).reshape(-1,3)
    return tvt, tvu

**image dataset, image size and batch size Setup**

In [None]:
########################################## to load the data in batches as mentioned single batch of images with already provided sizes 
import jax
from jax import random

newsize = (140,140) #(260, 260) # /.... 233 * 454
batch_size = 1

import os
image_dir = r'/content/MEDIAPIPEinput'
annotated_image_dir = r'/content/annotated_images'

#############################################################################
bandend = ["c",".png", "b02"]
expression_b2 = bandend[1]

total_images =  [f for f in os.listdir(image_dir) if f.__contains__(expression_b2)]
total_images.sort()
total_images_path = [os.path.join(image_dir, i) for i in total_images if i != 'outputs']

annotated_total_images =  [f for f in os.listdir(annotated_image_dir) if f.__contains__(expression_b2)]
annotated_total_images.sort()
annotated_total_images_path = [os.path.join(annotated_image_dir, i) for i in annotated_total_images if i != 'outputs']

no_of_batches = int(len(total_images_path)/batch_size)

######################################## making 8 array of input for each device >>>
def batchedimages(image_locations):
  RGB8 = jnp.asarray((imageRGB(total_images_path[image_locations[0]])[1]))
  ANNOTATED8 = jnp.asarray((imageRGB(annotated_total_images_path[image_locations[0]])[1]))
  batch_ccc = RGB8, ANNOTATED8 
  return batch_ccc

def data_stream():
  key = random.PRNGKey(0)
  perm = random.permutation(key, len(total_images_path))
  for i in range(no_of_batches):
    batch_idx = perm[i * batch_size : (i + 1) * batch_size]; #print(batch_idx)
    yield batchedimages(batch_idx)

batches = data_stream()  

In [None]:
next(batches)[0].shape

In [None]:
#@title # **👠HIGH HEELS RUN >>>>>>>>>>>** { vertical-output: true }
newsize = (140,140) #(260, 260) # /.... 233 * 454

import jax
from jax import random
from tqdm import tqdm
import re
from google.colab import output
import orbax.checkpoint as orbax
from flax.training import checkpoints

import optax
import nest_asyncio
nest_asyncio.apply()

rng = jax.random.PRNGKey(0)
CKPT_DIR = 'ckpts'

######################<<<< initiating train state
count = 0
if count == 0 :
  HxW, Channels = next(batches)[0].shape
  state = Create_train_state( rng, model, (HxW, Channels), learning_rate ) 
  count = 1
#✅✅🔻 state = flax.jax_utils.replicate(state)  # FLAX will replicate the state to every device so that updating can be made easy

###################### 
checkpoint_available = 0
pattern = re.compile("checkpoint_\d+")   # to search for "checkpoint_*munerical value*" numerical value of any length is denoted by regular expression "\d+"
dir = "/content/ckpts/"
isFile = os.path.isdir(dir)
if isFile:
  for filepath in os.listdir(dir):
      if pattern.match(filepath):
          checkpoint_available = 1

total_epochs = 50
for epochs in tqdm(range(total_epochs)):  
  batches = data_stream() 

  if checkpoint_available:
    state = checkpoints.restore_checkpoint(ckpt_dir=CKPT_DIR, target=state)
    checkpoint_available = 0 # << Flag updated >>> to stop loading the same checkpoint in the next iteration then remove the checkpoint directory
    !rm -r {dir}

  for bbb in tqdm(range(no_of_batches-5)):
    state, metrics = train_step(state, next(batches), rng)
    output.clear()
    print("loss: ",metrics['loss']," <<< ") # naming of the checkpoint is "checkpoint_*"  where "*" => value of the steps variable, i.e. 'epochs'
  orbax_checkpointer = orbax.Checkpointer(orbax.PyTreeCheckpointHandler())
  checkpoints.save_checkpoint(ckpt_dir=CKPT_DIR, target=state, step=epochs, prefix='checkpoint_', keep=1, overwrite=False, orbax_checkpointer=orbax_checkpointer)
  # restored_state = checkpoints.restore_checkpoint(ckpt_dir=CKPT_DIR, target=state) # using to get the checkpoint loaded , it can be latest one , or if already available as checkpoint in the "CKPT_DIR" directory then take the file from directory then save in >> restored_checkpoints
  ##################################################



**inference engine**

In [None]:

# newsize = (140,140) #(260, 260) # /.... 233 * 454
from google.colab.patches import cv2_imshow
import numpy as np 
from google.colab import output

!wget https://live.staticflickr.com/7492/15677707699_d9d67acf9d_b.jpg -O a.jpg
image_in = '/content/a.jpg'

from PIL import Image
import jax.numpy as jnp
def imageRGB(argv):
    im = Image.open(argv)
    tvt, tvu = jnp.asarray(im.resize(newsize)),jnp.asarray(im.resize(newsize)).reshape(-1,3)
    return tvt, tvu
image = jnp.asarray((imageRGB(image_in)[1]))
#restored_state = checkpoints.restore_checkpoint(ckpt_dir=CKPT_DIR, target=state)
#state = restored_state
prediction = eval_step(state, image)
prediction['loss']


predicted_image = np.array(prediction['logits'],  dtype=np.uint8).reshape(newsize) 
cv2_imshow(predicted_image)


In [None]:
# import cv2
# from google.colab.patches import cv2_imshow
# import numpy as np 
# def show_image(argu):
#   L1 = argu[0]
#   predicted_image = np.array(argu[0],  dtype=np.uint8).reshape(newsize) # This would be your image array
#   a = predicted_image
#   for i in range(0,argu.shape[0]):
#     predicted_image = np.array(argu[i],  dtype=np.uint8).reshape(newsize) 
#     a = cv2.hconcat([a, predicted_image])
#   cv2_imshow(a)

# show_image(metrics['logits'])

In [None]:
# from jax.tree_util import tree_structure
# print(tree_structure(state))

##**ensemble**

In [None]:
#✅
!python -m pip install -q -U flax
from typing import Any
import jax
from jax import lax
import jax.numpy as jnp
import optax
import flax
import flax.linen as nn
from flax.training import train_state, common_utils
import functools

positional_encoding_dims = 6  # Number of positional encodings applied

def positional_encoding(args):
    image_height_x_image_width, cha = args.shape
    inputs_freq = jax.vmap(lambda x: args * 2.0 ** x)(jnp.arange(positional_encoding_dims))
    x = jnp.stack([jnp.sin(inputs_freq), jnp.cos(inputs_freq)])
    x = x.swapaxes(0, 2)
    x = x.reshape([image_height_x_image_width, -1])
    x = jnp.concatenate([args, x], axis=-1)
    return x

apply_positional_encoding = True # Apply posittional encoding to the input or not
ndl = 8 # num_dense_layers Number of dense layers in MLP
dlw = 256 # dense_layer_width Dimentionality of dense layers' output space 

##########################################<< MLP MODEL >>#########################################
class MLPModel(nn.Module):
    dtype: Any = jnp.float32
    precision: Any = lax.Precision.DEFAULT
    apply_positional_encoding: bool = apply_positional_encoding
    @nn.compact
    def __call__(self, input_points):
        x = positional_encoding(input_points) if self.apply_positional_encoding else input_points
        for i in range(ndl):
            x = nn.Dense(dlw,dtype=self.dtype,precision=self.precision)(x)
            x = nn.relu(x)
            x = jnp.concatenate([x, input_points], axis=-1) if i == 4 else x
  
        x = nn.Dense(1, dtype=self.dtype, precision=self.precision)(x)
        return x
##########################################<< MLP MODEL >>#########################################

@functools.partial(jax.pmap, static_broadcasted_argnums=(1, 2))
def Create_train_state(r_key, shape, learning_rate ):
    print(shape)
    model = MLPModel()
    variables = model.init(r_key, jnp.ones(shape)) 
    optimizer = optax.adam(learning_rate) 
    return train_state.TrainState.create(
        apply_fn = model.apply,
        tx=optimizer,
        params=variables['params']
    )

# learning_rate = 1e-4
# batch_size_no = 64

# model = MLPModel() # Instantiate the Model

In [None]:
@functools.partial(jax.pmap, axis_name='ensemble')
def apply_model(state, batch: jnp.asarray):
  image, label = batch
  def loss_fn(params):
    logits = MLPModel().apply({'params': params}, image)
    loss =  image_difference_loss(logits, label);
    return loss, logits

  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(state.params)
  return grads, loss

@jax.pmap
def update_model(state, grads):
  return state.apply_gradients(grads=grads)

In [None]:
def train_epoch(state, train_ds, batch_size, rng):
  train_ds_size = len(train_ds['image'])
  steps_per_epoch = train_ds_size // batch_size

  perms = jax.random.permutation(rng, len(train_ds['image']))
  perms = perms[:steps_per_epoch * batch_size]
  perms = perms.reshape((steps_per_epoch, batch_size))

  epoch_loss = []

  for perm in perms:
    batch_images = flax.jax_utils.replicate(train_ds['image'][perm, ...])
    batch_labels = flax.jax_utils.replicate(train_ds['label'][perm, ...])
    grads, loss = apply_model(state, batch_images, batch_labels)
    state = update_model(state, grads)
    epoch_loss.append(flax.jax_utils.unreplicate(loss))
  train_loss = np.mean(epoch_loss)
  return state, train_loss

In [None]:
train_ds, test_ds = get_datasets()
test_ds = flax.jax_utils.replicate(test_ds)
rng = jax.random.PRNGKey(0)

rng, init_rng = jax.random.split(rng)

HxW, Channels = next(batches)[0].shape
state = create_train_state(jax.random.split(init_rng, jax.device_count()),(HxW, Channels),learning_rate)

for epoch in range(1, num_epochs + 1):
  rng, input_rng = jax.random.split(rng)
  state, train_loss = train_epoch(state, train_ds, batch_size, input_rng)

  # _, test_loss = flax.jax_utils.unreplicate(apply_model(state, test_ds['image'], test_ds['label']))

  logging.info('epoch:% 3d, train_loss: %.4f ' % (epoch, train_loss))

In [None]:
# same as before, but using @pad_shard_unshard decorator

# manually padding
# => precise & allows for data parallelism

@jax.pmap
def get_preds(variables, inputs):
  print('retrigger compilation', inputs.shape)
  return model.apply(variables, inputs)

ds = tfds.load(dataset_name, split=tfds.split_for_jax_process('test'))
per_host_batch_size = per_device_batch_size * jax.local_device_count()
ds = ds.batch(per_host_batch_size, drop_remainder=False)

correct = total = 0
for batch in ds.as_numpy_iterator():
  preds = flax.jax_utils.pad_shard_unpad(get_preds)(
      vs, batch['image'], min_device_batch=per_device_batch_size)
  total += len(batch['image'])
  correct += (batch['label'] == preds.argmax(axis=-1)).sum()

correct = correct.item()
correct, total, correct / total

In [None]:
def eval_step(metrics, variables, batch):
  print('retrigger compilation', {k: v.shape for k, v in batch.items()})
  preds = model.apply(variables, batch['image'])
  correct = (batch['mask'] & (batch['label'] == preds.argmax(axis=-1))).sum()
  total = batch['mask'].sum()
  return dict(
      correct=metrics['correct'] + jax.lax.psum(correct, axis_name='batch'),
      total=metrics['total'] + jax.lax.psum(total, axis_name='batch'),
  )

eval_step = jax.pmap(eval_step, axis_name='batch')
eval_step = flax.jax_utils.pad_shard_unpad(
    eval_step, static_argnums=(0, 1), static_return=True)

In [None]:
%cd {total_files}
%cd ..
!zip -r folder.zip {total_files}

In [None]:
!cp -r /content/folder.zip /content/drive/MyDrive/OUT/data/machine_learning_test_dataset

In [None]:
!wget https://github.com/1kaiser/Media-Segment-Depth-MLP/releases/download/v0.2/s.zip

In [None]:
!unzip /content/s.zip #unzipping the flower images from archive..

In [None]:
total_files= '/content/s'
input_images = '/content/MEDIAPIPEinput'
out_images = '/content/annotated_images'
!mkdir -p {total_files}
!cp -r {input_images} {total_files}
!cp -r {out_images} {total_files}

In [None]:
%cd {total_files}
!tfds new my_dataset

In [None]:
total_files= '/content/s'
%cd {total_files}/my_dataset/
!tfds build

In [None]:
!rm -r /content/t/my_dataset

In [None]:
import os
def _generate_examples(self, path):
  """Yields examples."""
  # TODO(my_dataset): Yields (key, example) tuples from the dataset
  for f in path.glob('*.png'):
    yield 'key', {
        'MEDIAPIPEinput': f,
        'annotated_images': f,
    }
    
os.path = r'/content/s'
_generate_examples(path / 'MEDIAPIPEinput')

In [None]:
import tensorflow_datasets as tfds
dl_manager = tfds.download.DownloadManager(download_dir='/content')
urls = 'https://github.com/1kaiser/Media-Segment-Depth-MLP/releases/download/v0.2/s.zip'
path = dl_manager.extract(dl_manager.download(urls))

In [None]:
def _generate_examples( img_path):
  # Read the input data out of the source files
  # with img_path.open() as f:
    yield {
        'image': img_path / '*.png',
    }

def _split_generators():
    """Download the data and define splits."""
    dl_manager = tfds.download.DownloadManager(download_dir='/content')
    urls = 'https://github.com/1kaiser/Media-Segment-Depth-MLP/releases/download/v0.2/s.zip'
    path = dl_manager.extract(dl_manager.download(urls))    # dl_manager returns pathlib-like objects with `path.read_text()`,
    # `path.iterdir()`,...
    return {
        'in_image': _generate_examples(path / 'MEDIAPIPEinput'),
        'out_image': _generate_examples(path / 'annotated_images'),
    }
# _generate_examples(path/'MEDIAPIPEinput')
print(_split_generators()['in_image'])

In [None]:
str(path)

In [None]:
next(_generate_examples(path / 'MEDIAPIPEinput'))['image']

In [None]:
class Builder(tfds.core.GeneratorBasedBuilder):
  """DatasetBuilder for my_dataset dataset."""

  VERSION = tfds.core.Version('1.0.0')
  RELEASE_NOTES = {
      '1.0.0': 'Initial release.',
  }

  def _info(self) -> tfds.core.DatasetInfo:
    """Dataset metadata (homepage, citation,...)."""
    return self.dataset_info_from_configs(
        features=tfds.features.FeaturesDict({
            'image': tfds.features.Image(shape=(256, 256, 3)),
            'label': tfds.features.Image(shape=(256, 256, 3)),
        }),
    )

  def _split_generators(self, dl_manager: tfds.download.DownloadManager):
    """Download the data and define splits."""
    urls = 'https://github.com/1kaiser/Media-Segment-Depth-MLP/releases/download/v0.2/s.zip'
    extracted_path = dl_manager.download_and_extract(urls)
    # dl_manager returns pathlib-like objects with `path.read_text()`,
    # `path.iterdir()`,...
    return {
        'train': self._generate_input_examples(path=extracted_path / 'MEDIAPIPEinput'),
        'test': self._generate_output_examples(path=extracted_path / 'annotated_images'),
    }

  def _generate_input_examples(self, path) -> Iterator[Tuple[Key, Example]]:
    """Generator of examples for each split."""
    for img_path in path.glob('*.png'):
      # Yields (key, example)
      yield img_path.name, {
          'image': img_path,
      }
  def _generate_output_examples(self, path) -> Iterator[Tuple[Key, Example]]:
    """Generator of examples for each split."""
    for img_path in path.glob('*.png'):
      # Yields (key, example)
      yield img_path.name, {
          'image': img_path,
      }

In [None]:
# import the modules
import os
from os import listdir
 
# get the path or directory
folder_dir = str(path)+'/MEDIAPIPEinput/'
for images in os.listdir(folder_dir):
 
    # check if the image ends with png or jpg or jpeg
    if (images.endswith(".png") or images.endswith(".jpg") or images.endswith(".jpeg")):
        # display
        print(images)

In [None]:
path

In [None]:
batch_size = 32
img_height = 180
img_width = 180
import tensorflow as tf

tf.keras.utils.image_dataset_from_directory(
  path,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)


In [None]:
train_ds

In [None]:
import pathlib
import tensorflow as tf

urls = 'https://github.com/1kaiser/Media-Segment-Depth-MLP/releases/download/v0.2/s.zip'
data_dir = tf.keras.utils.get_file(origin=urls,
                                   fname='s',
                                   cache_subdir='/content/biy',
                                   archive_format='auto',
                                   untar=False,
                                   extract=True)
data_dir = pathlib.Path(data_dir)

In [None]:
!rm -r /content/biy

In [None]:
data_dir

In [None]:
image_count = len(list(data_dir.glob('*/*.png')))
print(image_count)

In [None]:


import flax.linen as nn
import jax.numpy as jnp
from jax.random import PRNGKey

x = jnp.empty((4, 28, 28, 1)) 

x.reshape((x.shape[0], -1)).shape

class MLP(nn.Module):                              # create a Flax Module dataclass

  @nn.compact
  def __call__(self, input_points):
    # x = x.reshape((x.shape[0], -1))
    # x = nn.Dense(128)(x)                           # create inline Flax Module submodules
    # x = nn.relu(x)
    # x = nn.Dense(1)(x)                 # shape inference
    # return x
    for i in range(8):
      x = nn.Dense(256)(x)
      x = nn.relu(x)
      x = jnp.concatenate([x, input_points], axis=-1) if i == 4 else x
      x = nn.Dense(1)(x)
      return x
model = MLP()                           # instantiate the MLP model

x = jnp.empty((4, 28, 28, 1))                      # generate random data
params = model.init(PRNGKey(42), x)["params"]      # initialize the weights
y = model.apply({"params":params}, x)  

In [None]:
y.shape

In [None]:
positional_encoding_dims = 6  # Number of positional encodings applied
import jax
import jax.numpy as jnp

def positional_encoding(args):
    image_height_x_image_width, cha = args.shape
    inputs_freq = jax.vmap(lambda x: args * 2.0 ** x)(jnp.arange(positional_encoding_dims))
    x = jnp.stack([jnp.sin(inputs_freq), jnp.cos(inputs_freq)])
    x = x.swapaxes(0, 2);print(x.shape)
    x = x.reshape([image_height_x_image_width, -1])
    x = jnp.concatenate([args, x], axis=-1)
    return x

x = jnp.empty((4, 28, 28, 1))
positional_encoding(x.reshape(x.shape[0],-1)).shape

In [None]:
import glob, os
import tensorflow as tf
import numpy as np

PATH = '/content/biy/'
BATCH_SIZE = 12
IMAGE_SIZE = 140

def read_train_data():
    x_files = [f for f in glob.glob(PATH + "MEDIAPIPEinput/*.png", recursive=True)]
    y_files = [f for f in glob.glob(PATH + "annotated_images/*.png", recursive=True)]

    def read_image(x_filename, y_filename):
        x_image_string = tf.io.read_file(x_filename)
        y_image_string = tf.io.read_file(y_filename)

        x_image_decoded = tf.image.decode_jpeg(x_image_string, channels=3)
        y_image_decoded = tf.image.decode_jpeg(y_image_string, channels=3)

        x_image_resized = tf.image.resize(x_image_decoded, [IMAGE_SIZE, IMAGE_SIZE])
        y_image_resized = tf.image.resize(y_image_decoded, [IMAGE_SIZE, IMAGE_SIZE])

        x_image_norm = x_image_resized / 255
        y_image_norm = y_image_resized / 255

        return x_image_norm, y_image_norm

    dataset = tf.data.Dataset.from_tensor_slices((x_files, y_files))

    dataset = dataset.map(read_image).shuffle(1000).batch(BATCH_SIZE)

    return dataset

train_set = read_train_data()
for x, y in train_set.as_numpy_iterator():
  print(x.shape, y.shape)

In [None]:
!python -m pip install -q -U flax
import functools
from flax.training.train_state import TrainState

def dice_coef(y_true, y_pred):
    y_true = jnp.ravel(y_true)
    y_pred = jnp.ravel(y_pred)
    intersection = jnp.sum(y_true * y_pred)
    return 2.0 * intersection / (jnp.sum(y_true) + jnp.sum(y_pred) + 1)


def dice_coef_loss(y_true, y_pred):
    return 1.0 - dice_coef(y_true, y_pred)


class CustomTrainState(TrainState):
    def apply_fn_with_bn(self, *args, is_training, **nargs):
        output = self.apply_fn(*args, **nargs,rngs={'dropout': jax.random.PRNGKey(2)})
        return output

@functools.partial(jax.jit, static_argnums=(3,))
def train_step(x, y, train_state, is_training=True):
    def loss_fn(params, is_training):
        y_pred= train_state.apply_fn_with_bn({"params": params}, x, is_training=is_training)
        loss = dice_coef_loss(y, y_pred)

        return loss

    if is_training:
        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (loss), grads = grad_fn(train_state.params, True)

        train_state = train_state.apply_gradients(grads=grads)
    else:
        loss = loss_fn(train_state.params, False)

    return loss, train_state

In [None]:
import optax
unet = MLP(out_dims=10)

init_rngs = {'params': jax.random.PRNGKey(0), 'dropout': jax.random.PRNGKey(1)}

unet_variables = unet.init(init_rngs, jnp.ones([1, IMAGE_SIZE, IMAGE_SIZE, 3]))

optimizer = optax.adam(learning_rate=0.001)

train_state = CustomTrainState.create(apply_fn=unet.apply, params=unet_variables["params"], tx=optimizer)


for e in range(20):
        loss_avg = 0
        for x, y in train_set.as_numpy_iterator():
            loss, train_state = train_step(x, y, train_state, True)
            print(f"epoch: {e}, loss: {loss:0.2f}")

In [None]:
x = jnp.empty((4, 28, 28, 3))
c = x[2]
c.shape
c = c.reshape(-1, c.shape[2])
p = positional_encoding(c)
print(p.shape)

In [None]:
import jax
import jax.numpy as jnp
x = jnp.empty((4, 28, 28, 1))
print(x.shape)

def positional_encoding(args):
    print(args.shape)
    image_height_x_image_width, cha = args.shape
    inputs_freq = jax.vmap(lambda x: args * 2.0 ** x)(jnp.arange(6))
    x = jnp.stack([jnp.sin(inputs_freq), jnp.cos(inputs_freq)])
    x = x.swapaxes(0, 2);print(x.shape)
    x = x.reshape([image_height_x_image_width, -1])
    x = jnp.concatenate([args, x], axis=-1)
    return x

    
img_list = []
for i in range(x.shape[0]):
  print(i)
  print(x[i].shape)
  c = x[i]
  c.shape
  c = c.reshape(-1, c.shape[2])
  p = positional_encoding(c)
  img_list.append(p)
  print(p.shape)

In [None]:
print(x.shape, jnp.array(img_list).shape)

##**starting here 🔻**

In [None]:
#✅
!python -m pip install -q -U flax
import flax.linen as nn
from typing import Any
import jax
from jax import lax
import jax.numpy as jnp

positional_encoding_dims = 6  # Number of positional encodings applied

def positional_encoding(args):
    image_height_x_image_width, cha = args.shape
    inputs_freq = jax.vmap(lambda x: args * 2.0 ** x)(jnp.arange(positional_encoding_dims))
    x = jnp.stack([jnp.sin(inputs_freq), jnp.cos(inputs_freq)])
    x = x.swapaxes(0, 2)
    x = x.reshape([image_height_x_image_width, -1])
    x = jnp.concatenate([args, x], axis=-1)
    return x

def batch_encoded(args):
    img_list = []
    for i in range(args.shape[0]):
        c = args[i]
        c = c.reshape(-1, c.shape[2])
        p = positional_encoding(c)
        img_list.append(p.reshape(args.shape[1],args.shape[2],p.shape[1]))
        x = jnp.array(img_list)
    return x

apply_positional_encoding = True # Apply posittional encoding to the input or not
ndl = 8 # num_dense_layers Number of dense layers in MLP
dlw = 256 # dense_layer_width Dimentionality of dense layers' output space 

##########################################<< MLP MODEL >>#########################################
class MLPModel(nn.Module):
    dtype: Any = jnp.float32
    precision: Any = lax.Precision.DEFAULT
    apply_positional_encoding: bool = apply_positional_encoding
    @nn.compact
    def __call__(self, input_points):
      x = batch_encoded(input_points) if self.apply_positional_encoding else input_points
      for i in range(ndl):
          x = nn.Dense(dlw,dtype=self.dtype,precision=self.precision)(x)
          x = nn.relu(x)
          x = jnp.concatenate([x, input_points], axis=-1) if i == 4 else x
      x = nn.Dense(3, dtype=self.dtype, precision=self.precision)(x)
      return x
##########################################<< MLP MODEL >>#########################################

In [None]:
import optax
unet = MLPModel()

init_rngs = {'params': jax.random.PRNGKey(0), 'dropout': jax.random.PRNGKey(1)}
IMAGE_SIZE = 140
unet_variables = unet.init(init_rngs, jnp.ones([7, IMAGE_SIZE, IMAGE_SIZE, 3]))

optimizer = optax.adam(learning_rate=0.001)

from flax.training.train_state import TrainState

class CustomTrainState(TrainState):
    def apply_fn_with_bn(self, *args, is_training, **nargs):
        output = self.apply_fn(*args, **nargs,rngs={'dropout': jax.random.PRNGKey(2)})
        return output

train_state = CustomTrainState.create(apply_fn=unet.apply, params=unet_variables["params"], tx=optimizer)

In [None]:
from jax.tree_util import tree_structure
print(tree_structure(train_state))

In [None]:
urls = 'https://github.com/1kaiser/Media-Segment-Depth-MLP/releases/download/v0.2/s.zip'
data_dir = tf.keras.utils.get_file(origin=urls,
                                   fname='s',
                                   cache_subdir='/content/biy',
                                   archive_format='auto',
                                   untar=False,
                                   extract=True)
data_dir = pathlib.Path(data_dir)

In [None]:
import functools
import glob
import tensorflow as tf

@functools.partial(jax.jit, static_argnums=(3,))
def train_step(x, y, train_state, is_training=True):
    def loss_fn(params, is_training):
        y_pred= train_state.apply_fn_with_bn({"params": params}, x, is_training=is_training)
        loss = dice_coef_loss(y, y_pred)

        return loss

    if is_training:
        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        loss, grads = grad_fn(train_state.params, True)

        train_state = train_state.apply_gradients(grads=grads)
    else:
        loss = loss_fn(train_state.params, False)

    return loss, train_state

PATH = '/content/biy/'
BATCH_SIZE = 12
IMAGE_SIZE = 140

def read_train_data():
    x_files = [f for f in glob.glob(PATH + "MEDIAPIPEinput/*.png", recursive=True)]
    y_files = [f for f in glob.glob(PATH + "annotated_images/*.png", recursive=True)]

    def read_image(x_filename, y_filename):
        x_image_string = tf.io.read_file(x_filename)
        y_image_string = tf.io.read_file(y_filename)

        x_image_decoded = tf.image.decode_png(x_image_string, channels=3)
        y_image_decoded = tf.image.decode_png(y_image_string, channels=3)

        x_image_resized = tf.image.resize(x_image_decoded, [IMAGE_SIZE, IMAGE_SIZE])
        y_image_resized = tf.image.resize(y_image_decoded, [IMAGE_SIZE, IMAGE_SIZE])

        x_image_norm = x_image_resized / 255
        y_image_norm = y_image_resized / 255

        return x_image_norm, y_image_norm

    dataset = tf.data.Dataset.from_tensor_slices((x_files, y_files))

    dataset = dataset.map(read_image).shuffle(1000).batch(BATCH_SIZE)

    return dataset

import pathlib
import tensorflow as tf





In [None]:
train_set = read_train_data()


def dice_coef(y_true, y_pred):
    y_true = jnp.ravel(y_true)
    y_pred = jnp.ravel(y_pred)
    intersection = jnp.sum(y_true * y_pred)
    return 2.0 * intersection / (jnp.sum(y_true) + jnp.sum(y_pred) + 1)


def dice_coef_loss(y_true, y_pred):
    return 1.0 - dice_coef(y_true, y_pred)

for e in range(20):
        loss_avg = 0
        for x, y in train_set.as_numpy_iterator():
            loss, train_state = train_step(x, y, train_state, True)
            print(f"epoch: {e}, loss: {loss:0.2f}")

##**ensemble test**

In [None]:
urls = 'https://github.com/1kaiser/Media-Segment-Depth-MLP/releases/download/v0.2/s.zip'
data_dir = tf.keras.utils.get_file(origin=urls,
                                   fname='s',
                                   cache_subdir='/content/biy',
                                   archive_format='auto',
                                   untar=False,
                                   extract=True)
data_dir = pathlib.Path(data_dir)

In [None]:
#✅
!python -m pip install -q -U flax
from typing import Any
import jax
from jax import lax
import jax.numpy as jnp
import optax
import flax
import flax.linen as nn
from flax.training import train_state, common_utils
import functools

positional_encoding_dims = 6  # Number of positional encodings applied

def positional_encoding(args):
    image_height_x_image_width, cha = args.shape
    inputs_freq = jax.vmap(lambda x: args * 2.0 ** x)(jnp.arange(positional_encoding_dims))
    x = jnp.stack([jnp.sin(inputs_freq), jnp.cos(inputs_freq)])
    x = x.swapaxes(0, 2)
    x = x.reshape([image_height_x_image_width, -1])
    x = jnp.concatenate([args, x], axis=-1)
    return x

def batch_encoded(args):
    img_list = []
    for i in range(args.shape[0]):
        c = args[i]
        c = c.reshape(-1, c.shape[2])
        p = positional_encoding(c)
        img_list.append(p.reshape(args.shape[1],args.shape[2],p.shape[1]))
        x = jnp.array(img_list)
    return x

apply_positional_encoding = True # Apply posittional encoding to the input or not
ndl = 8 # num_dense_layers Number of dense layers in MLP
dlw = 256 # dense_layer_width Dimentionality of dense layers' output space 

##########################################<< MLP MODEL >>#########################################
class MLPModel(nn.Module):
    dtype: Any = jnp.float32
    precision: Any = lax.Precision.DEFAULT
    apply_positional_encoding: bool = apply_positional_encoding
    @nn.compact
    def __call__(self, input_points):
      x = batch_encoded(input_points) if self.apply_positional_encoding else input_points
      for i in range(ndl):
          x = nn.Dense(dlw,dtype=self.dtype,precision=self.precision)(x)
          x = nn.relu(x)
          x = jnp.concatenate([x, input_points], axis=-1) if i == 4 else x
      x = nn.Dense(3, dtype=self.dtype, precision=self.precision)(x)
      return x
##########################################<< MLP MODEL >>#########################################

@functools.partial(jax.pmap, static_broadcasted_argnums=(1, 2))
def Create_train_state(r_key, shape, learning_rate ):
    print(shape)
    model = MLPModel()
    variables = model.init(r_key, jnp.ones(shape)) 
    optimizer = optax.adam(learning_rate) 
    return train_state.TrainState.create(
        apply_fn = model.apply,
        tx=optimizer,
        params=variables['params']
    )

# learning_rate = 1e-4
# batch_size_no = 64

# model = MLPModel() # Instantiate the Model

In [None]:
model = MLPModel()

rng = jax.random.PRNGKey(0)

rn_g, init_rng = jax.random.split(rng)
model.init(rng, jnp.ones(next(data_stream())[0].shape)) 

In [None]:
@functools.partial(jax.pmap, axis_name='batch')
def apply_model(state, batch: jnp.asarray):
  image, label = batch
  def loss_fn(params):
    logits = MLPModel().apply({'params': params}, image)
    loss =  image_difference_loss(logits, label);
    return loss, logits

  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(state.params)
  return grads, loss

@jax.pmap
def update_model(state, grads):
  return state.apply_gradients(grads=grads)

In [None]:
def train_epoch(state, train_ds, rng):
  train_ds_size = train_ds.shape[0]
  steps_per_epoch = train_ds_size // batch_size

  perms = jax.random.permutation(rng, train_ds.shape[0])
  perms = perms[:steps_per_epoch * batch_size]
  perms = perms.reshape((steps_per_epoch, batch_size))

  epoch_loss = []

  for perm in perms:
    x_images = flax.jax_utils.replicate(train_ds[perm, ...])
    grads, loss = apply_model(state, train_ds)
    state = update_model(state, grads)
    epoch_loss.append(flax.jax_utils.unreplicate(loss))
  train_loss = np.mean(epoch_loss)
  return state, train_loss

In [None]:
########################################## to load the data in batches as mentioned single batch of images with already provided sizes 
import jax
from jax import random

newsize = (140,140) #(260, 260) # /.... 233 * 454
batch_size = 10

import os
from PIL import Image
import jax.numpy as jnp

def imageGRAY(argv):
    im = Image.open(argv).convert('L')
    tvt, tvu = jnp.asarray(im.resize(newsize)),jnp.asarray(im.resize(newsize)).reshape(-1,1)
    return tvt, tvu
def imageRGB(argv):
    im = Image.open(argv)
    tvt, tvu = jnp.asarray(im.resize(newsize)),jnp.asarray(im.resize(newsize)).reshape(-1,3)
    return tvt, tvu

x_image_dir = r'/content/biy/MEDIAPIPEinput/'
y_image_dir = r'/content/biy/annotated_images/'

#############################################################################
bandend = ["c",".png", "b02"]
expression_b2 = bandend[1]
x_total_images =  [f for f in os.listdir(x_image_dir) if f.__contains__(expression_b2)]
x_total_images.sort()
x_total_images_path = [os.path.join(x_image_dir, i) for i in x_total_images if i != 'outputs']
no_of_batches = int(len(x_total_images_path)/batch_size)


y_total_images =  [f for f in os.listdir(y_image_dir) if f.__contains__(expression_b2)]
y_total_images.sort()
y_total_images_path = [os.path.join(y_image_dir, i) for i in y_total_images if i != 'outputs']


######################################## making 8 array of input for each device >>>
def batchedimages(total_images_path, image_locations):
  RGB8 = jnp.asarray((imageRGB(total_images_path[image_locations[0]])[0]))
  return RGB8


def data_stream():
  key = random.PRNGKey(0)
  perm = random.permutation(key, len(total_images_path))
  x_image_list = []
  y_image_list = []

  for i in range(no_of_batches):
    batch_idx = perm[i * batch_size : (i + 1) * batch_size]; #print(batch_idx)
    x_image_list.append(batchedimages(x_total_images_path, batch_idx))
    y_image_list.append(batchedimages(y_total_images_path, batch_idx))
  yield jnp.array(x_image_list),jnp.array(y_image_list)

batches = data_stream()  


In [None]:
train_ds, test_ds = next(data_stream())
# test_ds = jax_utils.replicate(test_ds)
rng = jax.random.PRNGKey(0)

rn_g, init_rng = jax.random.split(rng)

BATCH, H, W, Channels = next(data_stream())[0].shape
learning_rate = 1e-4

# state = Create_train_state(jax.random.split(init_rng, jax.device_count()),(BATCH, H, W, Channels),learning_rate)
state = Create_train_state(jax.random.split(init_rng, 1),(next(data_stream())[0].shape),learning_rate)
num_epochs = 3
for epoch in range(1, num_epochs + 1):
  rng, input_rng = jax.random.split(rng)
  state, train_loss = train_epoch(state, train_ds, input_rng)

  # _, test_loss = jax_utils.unreplicate(apply_model(state, test_ds['image'], test_ds['label']))

  logging.info('epoch:% 3d, train_loss: %.4f ' % (epoch, train_loss))

In [None]:
# same as before, but using @pad_shard_unshard decorator

# manually padding
# => precise & allows for data parallelism

@jax.pmap
def get_preds(variables, inputs):
  print('retrigger compilation', inputs.shape)
  return model.apply(variables, inputs)

ds = tfds.load(dataset_name, split=tfds.split_for_jax_process('test'))
per_host_batch_size = per_device_batch_size * jax.local_device_count()
ds = ds.batch(per_host_batch_size, drop_remainder=False)

correct = total = 0
for batch in ds.as_numpy_iterator():
  preds = flax.jax_utils.pad_shard_unpad(get_preds)(
      vs, batch['image'], min_device_batch=per_device_batch_size)
  total += len(batch['image'])
  correct += (batch['label'] == preds.argmax(axis=-1)).sum()

correct = correct.item()
correct, total, correct / total

In [None]:
x_train_ds, y_train_ds = next(data_stream())
print(x_train_ds.shape, y_train_ds.shape)

In [None]:
x_train_ds[0].shape


In [None]:
import cv2
from google.colab.patches import cv2_imshow
import numpy as onp
img = onp.array(x_train_ds[0])
cv2_imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))

In [None]:
import jax
import jax.numpy as jnp

from flax import linen as nn
from flax import jax_utils
import optax
from flax.training.train_state import TrainState

model = MLPModel()
x = jnp.ones((2,100, 100, 3))
params = model.init(jax.random.PRNGKey(0), x)
tx = optax.adam(learning_rate=1e-3)
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx,)
state = jax_utils.replicate(state)

def loss_fn(state, x):
    return (model.apply(state.params, x) ** 2.0).mean()

jax.pmap(loss_fn)(state, x)

testing model pmap *bug

In [None]:

positional_encoding_dims = 6  # Number of positional encodings applied

def positional_encoding(args):
    image_height_x_image_width, cha = args.shape
    inputs_freq = jax.vmap(lambda x: args * 2.0 ** x)(jnp.arange(positional_encoding_dims))
    x = jnp.stack([jnp.sin(inputs_freq), jnp.cos(inputs_freq)])
    x = x.swapaxes(0, 2)
    x = x.reshape([image_height_x_image_width, -1])
    x = jnp.concatenate([args, x], axis=-1)
    return x

def batch_encoded(args):
    img_list = []
    for i in range(args.shape[0]):
        c = args[i]
        c = c.reshape(-1, c.shape[2])
        p = positional_encoding(c)
        img_list.append(p.reshape(args.shape[1],args.shape[2],p.shape[1]))
        x = jnp.array(img_list)
    return x

apply_positional_encoding = True # Apply posittional encoding to the input or not
ndl = 8 # num_dense_layers Number of dense layers in MLP
dlw = 256 # dense_layer_width Dimentionality of dense layers' output space 

##########################################<< MLP MODEL >>#########################################
class MLPModel(nn.Module):
    dtype: Any = jnp.float32
    precision: Any = lax.Precision.DEFAULT
    apply_positional_encoding: bool = apply_positional_encoding
    @nn.compact
    def __call__(self, input_points):
      x = batch_encoded(input_points) if self.apply_positional_encoding else input_points
      for i in range(ndl):
          x = nn.Dense(dlw,dtype=self.dtype,precision=self.precision)(x)
          x = nn.relu(x)
          x = jnp.concatenate([x, input_points], axis=-1) if i == 4 else x
      x = nn.Dense(1, dtype=self.dtype, precision=self.precision)(x)
      return x
##########################################<< MLP MODEL >>#########################################

In [None]:
from jax.tree_util import tree_structure
print(tree_structure(state.params))

## **RUN 2 testing** 

In [1]:
import tensorflow as tf
import pathlib
urls = 'https://github.com/1kaiser/Media-Segment-Depth-MLP/releases/download/v0.2/s.zip'
data_dir = tf.keras.utils.get_file(origin=urls,
                                   fname='s',
                                   cache_subdir='/content/',
                                   archive_format='auto',
                                   untar=False,
                                   extract=True)
data_dir = pathlib.Path(data_dir)

Downloading data from https://github.com/1kaiser/Media-Segment-Depth-MLP/releases/download/v0.2/s.zip


**Model and training code**
Our model is a coordinate-based multilayer perceptron. In this example, for each input image coordinate $(x,y)$, the model predicts the associated color $(r,g,b)$ or any $(gray)$.

![Network diagram](https://user-images.githubusercontent.com/3310961/85066930-ad444580-b164-11ea-9cc0-17494679e71f.png)

**POSITIONAL ENCODING BLOCK** 

In [2]:
#✅
import jax
import jax.numpy as jnp


positional_encoding_dims = 6  # Number of positional encodings applied

def positional_encoding(args):
    image_height_x_image_width, cha = args.shape
    inputs_freq = jax.vmap(lambda x: args * 2.0 ** x)(jnp.arange(positional_encoding_dims))
    x = jnp.stack([jnp.sin(inputs_freq), jnp.cos(inputs_freq)])
    x = x.swapaxes(0, 2)
    x = x.reshape([image_height_x_image_width, -1])
    x = jnp.concatenate([args, x], axis=-1)
    return x

def batch_encoded(args):
    img_list = []
    for i in range(args.shape[0]):
        c = args[i]
        c = c.reshape(-1, c.shape[2])
        p = positional_encoding(c)
        img_list.append(p.reshape(args.shape[1],args.shape[2],p.shape[1]))
        x = jnp.array(img_list)
    return x



**MLP MODEL DEFINATION**
Basically, passing input points through a simple Fourier Feature Mapping enables an MLP to learn high-frequency functions (such as an RGB image) in low-dimensional problem domains (such as a 2D coordinate of pixels).

In [3]:
#✅
!python -m pip install -qq -U flax orbax
# Orbax needs to enable asyncio in a Colab environment.
!python -m pip install -qq nest_asyncio


import jax
import jax.numpy as jnp

import flax
import optax
from typing import Any

from jax import lax
import flax.linen as nn
from flax.training import train_state, common_utils

apply_positional_encoding = True # Apply posittional encoding to the input or not
ndl = 8 # num_dense_layers Number of dense layers in MLP
dlw = 256 # dense_layer_width Dimentionality of dense layers' output space 

##########################################<< MLP MODEL >>#########################################
class MLPModel(nn.Module):
    dtype: Any = jnp.float32
    precision: Any = lax.Precision.DEFAULT
    apply_positional_encoding: bool = apply_positional_encoding
    @nn.compact
    def __call__(self, input_points):
      x = batch_encoded(input_points) if self.apply_positional_encoding else input_points
      for i in range(ndl):
          x = nn.Dense(dlw,dtype=self.dtype,precision=self.precision)(x)
          x = nn.relu(x)
          x = jnp.concatenate([x, input_points], axis=-1) if i == 4 else x
      x = nn.Dense(1, dtype=self.dtype, precision=self.precision)(x)
      return x
##########################################<< MLP MODEL >>#########################################

[K     |████████████████████████████████| 197 kB 27.6 MB/s 
[K     |████████████████████████████████| 66 kB 5.7 MB/s 
[K     |████████████████████████████████| 154 kB 54.4 MB/s 
[K     |████████████████████████████████| 238 kB 72.7 MB/s 
[K     |████████████████████████████████| 8.3 MB 52.3 MB/s 
[K     |████████████████████████████████| 51 kB 5.0 MB/s 
[K     |████████████████████████████████| 85 kB 4.9 MB/s 
[?25h

**initialize the module**

In [4]:
#✅
!python -m pip install -q -U flax
import optax
from flax.training import train_state
import jax.numpy as jnp
import jax


def Create_train_state(r_key, model, shape, learning_rate ) -> train_state.TrainState:
    print(shape)
    variables = model.init(r_key, jnp.ones(shape)) 
    optimizer = optax.adam(learning_rate) 
    return train_state.TrainState.create(
        apply_fn = model.apply,
        tx=optimizer,
        params=variables['params']
    )

learning_rate = 1e-4
batch_size_no = 64

model = MLPModel() # Instantiate the Model

**defining loss function**

In [5]:
#serial
def image_difference_loss(logits, labels):
    loss = .5 * jnp.mean((logits - labels) ** 2) 
    return loss
def compute_metrics(*, logits, labels):
  loss = image_difference_loss(logits, labels)
  metrics = {
      'loss': loss,     #LOSS
      'logits': logits, #PREDICTED IMAGE
      'labels': labels  #ACTUAL IMAGE
  }
  return metrics

**train step defination**

In [6]:
#cpu serial
import jax

def train_step(state: train_state.TrainState, batch: jnp.asarray, rng):
    image, label = batch  
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, image);
        loss =  image_difference_loss(logits, label);
        return loss, logits

    gradient_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, logits), grads = gradient_fn(state.params)
    new_state = state.apply_gradients(grads=grads)
    logs = compute_metrics(logits=logits, labels=label)
    return new_state, logs

import jax
@jax.jit
def eval_step(state, image):
    logits = state.apply_fn({'params': state.params}, image)
    return compute_metrics(logits=logits, labels=image)


**image conversion fiunctions**

In [7]:
from PIL import Image
import jax.numpy as jnp
def imageGRAY(argv):
    im = Image.open(argv).convert('L')
    tvt, tvu = jnp.asarray(im.resize(newsize)),jnp.asarray(im.resize(newsize)).reshape(-1,1)
    return tvt, tvu
def imageRGB(argv):
    im = Image.open(argv)
    tvt, tvu = jnp.asarray(im.resize(newsize)),jnp.asarray(im.resize(newsize)).reshape(-1,3)
    return tvt, tvu

**image dataset, image size and batch size Setup**

In [8]:
########################################## to load the data in batches as mentioned single batch of images with already provided sizes 
import jax
from jax import random
import jax.numpy as jnp

newsize = (140,140) #(260, 260) # /.... 233 * 454
batch_size = 40

import os
x_image_dir = r'/content/MEDIAPIPEinput/'
y_image_dir = r'/content/annotated_images/'
#############################################################################
bandend = ["c",".png", "b02"]
expression_b2 = bandend[1]
x_total_images =  [f for f in os.listdir(x_image_dir) if f.__contains__(expression_b2)]
x_total_images.sort()
x_total_images_path = [os.path.join(x_image_dir, i) for i in x_total_images if i != 'outputs']
no_of_batches = int(len(x_total_images_path)/batch_size)


y_total_images =  [f for f in os.listdir(y_image_dir) if f.__contains__(expression_b2)]
y_total_images.sort()
y_total_images_path = [os.path.join(y_image_dir, i) for i in y_total_images if i != 'outputs']
######################################## making 8 array of input for each device >>>
def batchedimages(total_images_path, image_locations):
  RGB8 = jnp.asarray((imageRGB(total_images_path[image_locations[0]])[0]))
  return RGB8

def data_stream():
  key = random.PRNGKey(0)
  perm = random.permutation(key, len(x_total_images_path))
  x_img_list = []
  y_img_list = []
  for i in range(no_of_batches):
    batch_idx = perm[i * batch_size : (i + 1) * batch_size]; #print(batch_idx)
    x_img_list.append(batchedimages(x_total_images_path, batch_idx))
    y_img_list.append(batchedimages(y_total_images_path, batch_idx))
  yield jnp.array(x_img_list), jnp.array(y_img_list)



working here 🔻

In [None]:
batches = data_stream()  
next(batches)[0].shape

In [None]:
import numpy as onp
onp.array(jnp.array(x_img_list))


In [None]:
#@title # **👠HIGH HEELS RUN >>>>>>>>>>>** { vertical-output: true }
newsize = (140,140) #(260, 260) # /.... 233 * 454

import jax
from jax import random
from tqdm import tqdm
import re
from google.colab import output
import orbax.checkpoint as orbax
from flax.training import checkpoints

import optax
import nest_asyncio
nest_asyncio.apply()

rng = jax.random.PRNGKey(0)
CKPT_DIR = 'ckpts'

######################<<<< initiating train state
count = 0
if count == 0 :
  batches = data_stream()
  BATCH, H, W, Channels = next(batches)[0].shape
  state = Create_train_state( rng, model, (BATCH, H, W, Channels ), learning_rate ) 
  count = 1
#✅✅🔻 state = flax.jax_utils.replicate(state)  # FLAX will replicate the state to every device so that updating can be made easy

###################### 
checkpoint_available = 0
pattern = re.compile("checkpoint_\d+")   # to search for "checkpoint_*munerical value*" numerical value of any length is denoted by regular expression "\d+"
dir = "/content/ckpts/"
isFile = os.path.isdir(dir)
if isFile:
  for filepath in os.listdir(dir):
      if pattern.match(filepath):
          checkpoint_available = 1

total_epochs = 100
for epochs in tqdm(range(no_of_batches-5)):  
  batches = data_stream()

  if checkpoint_available:
    state = checkpoints.restore_checkpoint(ckpt_dir=CKPT_DIR, target=state)
    checkpoint_available = 0 # << Flag updated >>> to stop loading the same checkpoint in the next iteration then remove the checkpoint directory
    !rm -r {dir}
  input_data = next(batches)
  for bbb in tqdm(range(total_epochs)):
    state, metrics = train_step(state, input_data, rng)
    # output.clear()
    print("loss: ",metrics['loss']," <<< ") # naming of the checkpoint is "checkpoint_*"  where "*" => value of the steps variable, i.e. 'epochs'
  orbax_checkpointer = orbax.Checkpointer(orbax.PyTreeCheckpointHandler())
  checkpoints.save_checkpoint(ckpt_dir=CKPT_DIR, target=state, step=epochs, prefix='checkpoint_', keep=1, overwrite=False, orbax_checkpointer=orbax_checkpointer)
  # restored_state = checkpoints.restore_checkpoint(ckpt_dir=CKPT_DIR, target=state) # using to get the checkpoint loaded , it can be latest one , or if already available as checkpoint in the "CKPT_DIR" directory then take the file from directory then save in >> restored_checkpoints
  ##################################################



(25, 140, 140, 3)


  0%|          | 0/20 [00:00<?, ?it/s]
  0%|          | 0/100 [00:00<?, ?it/s][A
  1%|          | 1/100 [00:00<01:34,  1.04it/s][A

loss:  1259.5117  <<< 



  2%|▏         | 2/100 [00:01<01:35,  1.02it/s][A

loss:  1222.6796  <<< 



  3%|▎         | 3/100 [00:02<01:35,  1.02it/s][A

loss:  1238.2913  <<< 



  4%|▍         | 4/100 [00:03<01:34,  1.02it/s][A

loss:  1270.0698  <<< 



  5%|▌         | 5/100 [00:04<01:32,  1.02it/s][A

loss:  1295.5383  <<< 



  6%|▌         | 6/100 [00:05<01:31,  1.03it/s][A

loss:  1250.8555  <<< 



  7%|▋         | 7/100 [00:06<01:30,  1.03it/s][A

loss:  1223.6423  <<< 



  8%|▊         | 8/100 [00:07<01:32,  1.01s/it][A

loss:  1228.4275  <<< 



  9%|▉         | 9/100 [00:08<01:30,  1.00it/s][A

loss:  1249.0997  <<< 



 10%|█         | 10/100 [00:09<01:29,  1.00it/s][A

loss:  1265.1815  <<< 



 11%|█         | 11/100 [00:10<01:31,  1.03s/it][A

loss:  1241.9519  <<< 



 12%|█▏        | 12/100 [00:11<01:28,  1.01s/it][A

loss:  1223.9166  <<< 



 13%|█▎        | 13/100 [00:12<01:27,  1.00s/it][A

loss:  1221.0298  <<< 



 14%|█▍        | 14/100 [00:13<01:26,  1.00s/it][A

loss:  1232.133  <<< 



 15%|█▌        | 15/100 [00:14<01:24,  1.00it/s][A

loss:  1246.2306  <<< 



 16%|█▌        | 16/100 [00:15<01:23,  1.01it/s][A

loss:  1241.279  <<< 



 17%|█▋        | 17/100 [00:16<01:22,  1.01it/s][A

loss:  1233.249  <<< 



 18%|█▊        | 18/100 [00:17<01:21,  1.01it/s][A

loss:  1221.5654  <<< 



 19%|█▉        | 19/100 [00:18<01:20,  1.01it/s][A

loss:  1217.9565  <<< 



 20%|██        | 20/100 [00:19<01:19,  1.00it/s][A

loss:  1222.0264  <<< 



 21%|██        | 21/100 [00:20<01:18,  1.01it/s][A

loss:  1228.3373  <<< 



 22%|██▏       | 22/100 [00:21<01:17,  1.01it/s][A

loss:  1234.6516  <<< 



 23%|██▎       | 23/100 [00:22<01:16,  1.01it/s][A

loss:  1232.4802  <<< 



 24%|██▍       | 24/100 [00:23<01:15,  1.00it/s][A

loss:  1230.0938  <<< 



 25%|██▌       | 25/100 [00:24<01:15,  1.00s/it][A

loss:  1223.572  <<< 



 26%|██▌       | 26/100 [00:25<01:15,  1.01s/it][A

loss:  1219.1497  <<< 



 27%|██▋       | 27/100 [00:26<01:14,  1.02s/it][A

loss:  1216.4454  <<< 



 28%|██▊       | 28/100 [00:27<01:13,  1.02s/it][A

loss:  1215.9941  <<< 



 29%|██▉       | 29/100 [00:28<01:11,  1.01s/it][A

loss:  1217.2577  <<< 



 30%|███       | 30/100 [00:29<01:10,  1.01s/it][A

loss:  1219.3777  <<< 



 31%|███       | 31/100 [00:30<01:09,  1.01s/it][A

loss:  1222.4026  <<< 



 32%|███▏      | 32/100 [00:31<01:08,  1.00s/it][A

loss:  1224.4747  <<< 



 33%|███▎      | 33/100 [00:33<01:07,  1.01s/it][A

loss:  1228.1887  <<< 



 34%|███▍      | 34/100 [00:33<01:06,  1.00s/it][A

loss:  1229.2269  <<< 



 35%|███▌      | 35/100 [00:34<01:04,  1.00it/s][A

loss:  1233.3761  <<< 



 36%|███▌      | 36/100 [00:35<01:03,  1.01it/s][A

loss:  1233.0011  <<< 



 37%|███▋      | 37/100 [00:36<01:01,  1.02it/s][A

loss:  1237.0192  <<< 



 38%|███▊      | 38/100 [00:37<01:00,  1.02it/s][A

loss:  1234.7771  <<< 



 39%|███▉      | 39/100 [00:38<01:00,  1.02it/s][A

loss:  1237.32  <<< 



 40%|████      | 40/100 [00:39<00:58,  1.02it/s][A

loss:  1233.0375  <<< 



 41%|████      | 41/100 [00:40<00:57,  1.02it/s][A

loss:  1232.8326  <<< 



 42%|████▏     | 42/100 [00:41<00:56,  1.02it/s][A

loss:  1227.5156  <<< 



 43%|████▎     | 43/100 [00:42<00:56,  1.02it/s][A

loss:  1225.0082  <<< 



 44%|████▍     | 44/100 [00:43<00:55,  1.01it/s][A

loss:  1220.483  <<< 



 45%|████▌     | 45/100 [00:44<00:54,  1.02it/s][A

loss:  1217.6836  <<< 



 46%|████▌     | 46/100 [00:45<00:52,  1.02it/s][A

loss:  1215.0942  <<< 



 47%|████▋     | 47/100 [00:46<00:52,  1.01it/s][A

loss:  1213.5414  <<< 



 48%|████▊     | 48/100 [00:47<00:51,  1.01it/s][A

loss:  1212.6837  <<< 



 49%|████▉     | 49/100 [00:48<00:50,  1.02it/s][A

loss:  1212.4182  <<< 



 50%|█████     | 50/100 [00:49<00:49,  1.01it/s][A

loss:  1212.591  <<< 



 51%|█████     | 51/100 [00:50<00:48,  1.02it/s][A

loss:  1213.0947  <<< 



 52%|█████▏    | 52/100 [00:51<00:47,  1.01it/s][A

loss:  1214.0013  <<< 



 53%|█████▎    | 53/100 [00:52<00:46,  1.01it/s][A

loss:  1215.1724  <<< 



 54%|█████▍    | 54/100 [00:53<00:45,  1.01it/s][A

loss:  1217.3993  <<< 



 55%|█████▌    | 55/100 [00:54<00:44,  1.01it/s][A

loss:  1219.9618  <<< 



 56%|█████▌    | 56/100 [00:55<00:43,  1.01it/s][A

loss:  1225.8223  <<< 



 57%|█████▋    | 57/100 [00:56<00:42,  1.01it/s][A

loss:  1231.548  <<< 



 58%|█████▊    | 58/100 [00:57<00:41,  1.02it/s][A

loss:  1247.0977  <<< 



 59%|█████▉    | 59/100 [00:58<00:40,  1.01it/s][A

loss:  1256.8674  <<< 



 60%|██████    | 60/100 [00:59<00:39,  1.01it/s][A

loss:  1291.1401  <<< 



 61%|██████    | 61/100 [01:00<00:38,  1.01it/s][A

loss:  1290.6344  <<< 



 62%|██████▏   | 62/100 [01:01<00:37,  1.02it/s][A

loss:  1326.964  <<< 



 63%|██████▎   | 63/100 [01:02<00:36,  1.01it/s][A

loss:  1285.8124  <<< 



 64%|██████▍   | 64/100 [01:03<00:35,  1.01it/s][A

loss:  1271.7229  <<< 



 65%|██████▌   | 65/100 [01:04<00:34,  1.02it/s][A

loss:  1231.0425  <<< 



 66%|██████▌   | 66/100 [01:05<00:33,  1.02it/s][A

loss:  1212.848  <<< 



 67%|██████▋   | 67/100 [01:06<00:32,  1.02it/s][A

loss:  1214.9047  <<< 



 68%|██████▊   | 68/100 [01:07<00:31,  1.03it/s][A

loss:  1229.3835  <<< 



 69%|██████▉   | 69/100 [01:08<00:30,  1.02it/s][A

loss:  1247.2987  <<< 



 70%|███████   | 70/100 [01:09<00:29,  1.03it/s][A

loss:  1240.8728  <<< 



 71%|███████   | 71/100 [01:10<00:28,  1.03it/s][A

loss:  1232.5109  <<< 



 72%|███████▏  | 72/100 [01:11<00:27,  1.03it/s][A

loss:  1216.6952  <<< 



 73%|███████▎  | 73/100 [01:12<00:26,  1.03it/s][A

loss:  1210.3041  <<< 



 74%|███████▍  | 74/100 [01:13<00:25,  1.03it/s][A

loss:  1213.5193  <<< 



 75%|███████▌  | 75/100 [01:14<00:24,  1.03it/s][A

loss:  1221.1761  <<< 



 76%|███████▌  | 76/100 [01:15<00:23,  1.04it/s][A

loss:  1230.0607  <<< 



 77%|███████▋  | 77/100 [01:16<00:22,  1.04it/s][A

loss:  1228.7643  <<< 



 78%|███████▊  | 78/100 [01:17<00:21,  1.03it/s][A

loss:  1227.2289  <<< 



 79%|███████▉  | 79/100 [01:18<00:20,  1.04it/s][A

loss:  1218.7838  <<< 



 80%|████████  | 80/100 [01:19<00:19,  1.03it/s][A

loss:  1212.9109  <<< 



 81%|████████  | 81/100 [01:20<00:18,  1.03it/s][A

loss:  1209.032  <<< 



 82%|████████▏ | 82/100 [01:21<00:17,  1.02it/s][A

loss:  1208.3474  <<< 



 83%|████████▎ | 83/100 [01:21<00:16,  1.02it/s][A

loss:  1210.1405  <<< 



 84%|████████▍ | 84/100 [01:22<00:15,  1.02it/s][A

loss:  1213.0765  <<< 



 85%|████████▌ | 85/100 [01:23<00:14,  1.02it/s][A

loss:  1217.1366  <<< 



 86%|████████▌ | 86/100 [01:24<00:13,  1.02it/s][A

loss:  1219.4619  <<< 



 87%|████████▋ | 87/100 [01:25<00:12,  1.03it/s][A

loss:  1223.8307  <<< 



 88%|████████▊ | 88/100 [01:26<00:11,  1.03it/s][A

loss:  1224.352  <<< 



 89%|████████▉ | 89/100 [01:27<00:10,  1.03it/s][A

loss:  1228.8749  <<< 



 90%|█████████ | 90/100 [01:28<00:09,  1.02it/s][A

loss:  1227.9198  <<< 



 91%|█████████ | 91/100 [01:29<00:08,  1.03it/s][A

loss:  1232.6031  <<< 



 92%|█████████▏| 92/100 [01:30<00:07,  1.02it/s][A

loss:  1230.2312  <<< 



 93%|█████████▎| 93/100 [01:31<00:06,  1.02it/s][A

loss:  1234.398  <<< 



 94%|█████████▍| 94/100 [01:32<00:05,  1.02it/s][A

loss:  1230.2783  <<< 



 95%|█████████▌| 95/100 [01:33<00:04,  1.02it/s][A

loss:  1232.3729  <<< 



 96%|█████████▌| 96/100 [01:34<00:04,  1.00s/it][A

loss:  1226.5486  <<< 



 97%|█████████▋| 97/100 [01:35<00:02,  1.00it/s][A

loss:  1225.4133  <<< 



 98%|█████████▊| 98/100 [01:36<00:01,  1.01it/s][A

loss:  1219.1761  <<< 



 99%|█████████▉| 99/100 [01:37<00:00,  1.01it/s][A

loss:  1215.8358  <<< 



100%|██████████| 100/100 [01:38<00:00,  1.01it/s]
  5%|▌         | 1/20 [01:39<31:37, 99.85s/it]

loss:  1211.1918  <<< 



  0%|          | 0/100 [00:00<?, ?it/s][A
  1%|          | 1/100 [00:00<01:34,  1.04it/s][A

loss:  1208.2802  <<< 



  2%|▏         | 2/100 [00:01<01:35,  1.03it/s][A

loss:  1206.1539  <<< 



  3%|▎         | 3/100 [00:02<01:35,  1.02it/s][A

loss:  1205.1465  <<< 



  4%|▍         | 4/100 [00:03<01:33,  1.02it/s][A

loss:  1204.9873  <<< 



  5%|▌         | 5/100 [00:04<01:32,  1.03it/s][A

loss:  1205.4528  <<< 



  6%|▌         | 6/100 [00:05<01:32,  1.01it/s][A

loss:  1206.4296  <<< 



  7%|▋         | 7/100 [00:06<01:35,  1.02s/it][A

loss:  1207.6721  <<< 



  8%|▊         | 8/100 [00:08<01:47,  1.17s/it][A

loss:  1209.7592  <<< 



  9%|▉         | 9/100 [00:09<01:42,  1.13s/it][A

loss:  1211.8568  <<< 



 10%|█         | 10/100 [00:10<01:37,  1.08s/it][A

loss:  1216.3792  <<< 



 11%|█         | 11/100 [00:11<01:33,  1.05s/it][A

loss:  1220.3798  <<< 



 12%|█▏        | 12/100 [00:12<01:31,  1.04s/it][A

loss:  1231.156  <<< 



 13%|█▎        | 13/100 [00:13<01:29,  1.03s/it][A

loss:  1238.1494  <<< 



 14%|█▍        | 14/100 [00:14<01:27,  1.01s/it][A

loss:  1261.756  <<< 



 15%|█▌        | 15/100 [00:15<01:25,  1.00s/it][A

loss:  1266.032  <<< 



 16%|█▌        | 16/100 [00:16<01:23,  1.01it/s][A

loss:  1299.3289  <<< 



 17%|█▋        | 17/100 [00:17<01:22,  1.01it/s][A

loss:  1277.879  <<< 



 18%|█▊        | 18/100 [00:18<01:21,  1.01it/s][A

loss:  1282.9119  <<< 



 19%|█▉        | 19/100 [00:19<01:20,  1.01it/s][A

loss:  1242.0597  <<< 



 20%|██        | 20/100 [00:20<01:19,  1.01it/s][A

loss:  1219.2134  <<< 



 21%|██        | 21/100 [00:21<01:18,  1.01it/s][A

loss:  1204.7316  <<< 



 22%|██▏       | 22/100 [00:22<01:16,  1.01it/s][A

loss:  1206.9447  <<< 



 23%|██▎       | 23/100 [00:23<01:16,  1.01it/s][A

loss:  1220.2944  <<< 



 24%|██▍       | 24/100 [00:24<01:15,  1.01it/s][A

loss:  1228.9155  <<< 



 25%|██▌       | 25/100 [00:25<01:14,  1.01it/s][A

loss:  1235.9418  <<< 



 26%|██▌       | 26/100 [00:26<01:13,  1.01it/s][A

loss:  1223.4343  <<< 



 27%|██▋       | 27/100 [00:27<01:11,  1.02it/s][A

loss:  1213.2592  <<< 



 28%|██▊       | 28/100 [00:28<01:11,  1.01it/s][A

loss:  1204.3217  <<< 



 29%|██▉       | 29/100 [00:29<01:10,  1.01it/s][A

loss:  1202.5845  <<< 



 30%|███       | 30/100 [00:30<01:09,  1.01it/s][A

loss:  1206.7028  <<< 



 31%|███       | 31/100 [00:31<01:09,  1.00s/it][A

loss:  1212.4816  <<< 



 32%|███▏      | 32/100 [00:32<01:08,  1.00s/it][A

loss:  1219.2816  <<< 



 33%|███▎      | 33/100 [00:33<01:06,  1.00it/s][A

loss:  1218.9899  <<< 



 34%|███▍      | 34/100 [00:34<01:05,  1.01it/s][A

loss:  1219.6685  <<< 



 35%|███▌      | 35/100 [00:35<01:04,  1.01it/s][A

loss:  1213.7745  <<< 



 36%|███▌      | 36/100 [00:36<01:02,  1.02it/s][A

loss:  1209.8986  <<< 



 37%|███▋      | 37/100 [00:37<01:02,  1.01it/s][A

loss:  1205.0001  <<< 



 38%|███▊      | 38/100 [00:38<01:01,  1.01it/s][A

loss:  1202.0261  <<< 



 39%|███▉      | 39/100 [00:39<01:00,  1.01it/s][A

loss:  1200.3899  <<< 



 40%|████      | 40/100 [00:40<00:59,  1.00it/s][A

loss:  1200.0679  <<< 



 41%|████      | 41/100 [00:41<00:58,  1.00it/s][A

loss:  1200.7251  <<< 



 42%|████▏     | 42/100 [00:42<00:57,  1.01it/s][A

loss:  1202.0593  <<< 



 43%|████▎     | 43/100 [00:43<00:56,  1.01it/s][A

loss:  1204.3625  <<< 



 44%|████▍     | 44/100 [00:44<00:55,  1.01it/s][A

loss:  1206.994  <<< 



 45%|████▌     | 45/100 [00:45<00:54,  1.00it/s][A

loss:  1212.3267  <<< 



 46%|████▌     | 46/100 [00:46<00:53,  1.00it/s][A

loss:  1217.271  <<< 



 47%|████▋     | 47/100 [00:47<00:52,  1.00it/s][A

loss:  1230.265  <<< 



 48%|████▊     | 48/100 [00:48<00:51,  1.01it/s][A

loss:  1238.4591  <<< 



 49%|████▉     | 49/100 [00:49<00:50,  1.01it/s][A

loss:  1266.773  <<< 



 50%|█████     | 50/100 [00:50<00:49,  1.00it/s][A

loss:  1269.1364  <<< 



 51%|█████     | 51/100 [00:51<00:48,  1.02it/s][A

loss:  1303.995  <<< 



 52%|█████▏    | 52/100 [00:52<00:47,  1.02it/s][A

loss:  1273.2338  <<< 



 53%|█████▎    | 53/100 [00:53<00:45,  1.02it/s][A

loss:  1268.1294  <<< 



 54%|█████▍    | 54/100 [00:54<00:45,  1.02it/s][A

loss:  1226.4584  <<< 



 55%|█████▌    | 55/100 [00:55<00:44,  1.02it/s][A

loss:  1204.5785  <<< 



 56%|█████▌    | 56/100 [00:56<00:43,  1.02it/s][A

loss:  1199.4435  <<< 



 57%|█████▋    | 57/100 [00:57<00:42,  1.02it/s][A

loss:  1210.083  <<< 



 58%|█████▊    | 58/100 [00:58<00:41,  1.01it/s][A

loss:  1227.5376  <<< 



 59%|█████▉    | 59/100 [00:59<00:40,  1.01it/s][A

loss:  1227.9078  <<< 



 60%|██████    | 60/100 [01:00<00:39,  1.01it/s][A

loss:  1225.1124  <<< 



 61%|██████    | 61/100 [01:00<00:38,  1.01it/s][A

loss:  1208.9165  <<< 



 62%|██████▏   | 62/100 [01:01<00:37,  1.01it/s][A

loss:  1199.4159  <<< 



 63%|██████▎   | 63/100 [01:02<00:36,  1.01it/s][A

loss:  1198.1625  <<< 



 64%|██████▍   | 64/100 [01:03<00:35,  1.01it/s][A

loss:  1203.7073  <<< 



 65%|██████▌   | 65/100 [01:04<00:34,  1.02it/s][A

loss:  1212.3632  <<< 



 66%|██████▌   | 66/100 [01:05<00:33,  1.02it/s][A

loss:  1215.0311  <<< 



 67%|██████▋   | 67/100 [01:06<00:32,  1.02it/s][A

loss:  1217.4529  <<< 



 68%|██████▊   | 68/100 [01:07<00:31,  1.02it/s][A

loss:  1210.8027  <<< 



 69%|██████▉   | 69/100 [01:08<00:30,  1.02it/s][A

loss:  1205.8062  <<< 



 70%|███████   | 70/100 [01:09<00:29,  1.02it/s][A

loss:  1199.8112  <<< 



 71%|███████   | 71/100 [01:10<00:28,  1.01it/s][A

loss:  1196.539  <<< 



 72%|███████▏  | 72/100 [01:11<00:27,  1.02it/s][A

loss:  1195.6744  <<< 



 73%|███████▎  | 73/100 [01:12<00:27,  1.01s/it][A

loss:  1196.8218  <<< 



 74%|███████▍  | 74/100 [01:13<00:26,  1.01s/it][A

loss:  1199.3591  <<< 



 75%|███████▌  | 75/100 [01:14<00:25,  1.01s/it][A

loss:  1202.0593  <<< 



 76%|███████▌  | 76/100 [01:15<00:23,  1.00it/s][A

loss:  1206.1356  <<< 



 77%|███████▋  | 77/100 [01:16<00:22,  1.00it/s][A

loss:  1208.4631  <<< 



 78%|███████▊  | 78/100 [01:17<00:21,  1.01it/s][A

loss:  1214.1466  <<< 



 79%|███████▉  | 79/100 [01:18<00:20,  1.00it/s][A

loss:  1215.8539  <<< 



 80%|████████  | 80/100 [01:19<00:19,  1.01it/s][A

loss:  1224.0408  <<< 



 81%|████████  | 81/100 [01:20<00:18,  1.00it/s][A

loss:  1224.3243  <<< 



 82%|████████▏ | 82/100 [01:21<00:17,  1.01it/s][A

loss:  1234.5118  <<< 



 83%|████████▎ | 83/100 [01:22<00:16,  1.01it/s][A

loss:  1230.6742  <<< 



 84%|████████▍ | 84/100 [01:23<00:15,  1.01it/s][A

loss:  1238.0249  <<< 



 85%|████████▌ | 85/100 [01:24<00:14,  1.01it/s][A

loss:  1227.3223  <<< 



 86%|████████▌ | 86/100 [01:25<00:13,  1.01it/s][A

loss:  1225.1288  <<< 



 87%|████████▋ | 87/100 [01:26<00:12,  1.02it/s][A

loss:  1211.8519  <<< 



 88%|████████▊ | 88/100 [01:27<00:11,  1.02it/s][A

loss:  1203.8438  <<< 



 89%|████████▉ | 89/100 [01:28<00:10,  1.02it/s][A

loss:  1196.495  <<< 



 90%|█████████ | 90/100 [01:29<00:09,  1.02it/s][A

loss:  1193.4752  <<< 



 91%|█████████ | 91/100 [01:30<00:08,  1.02it/s][A

loss:  1194.11  <<< 



 92%|█████████▏| 92/100 [01:31<00:07,  1.02it/s][A

loss:  1197.1222  <<< 



 93%|█████████▎| 93/100 [01:32<00:06,  1.03it/s][A

loss:  1201.5854  <<< 



 94%|█████████▍| 94/100 [01:33<00:05,  1.03it/s][A

loss:  1204.2616  <<< 



 95%|█████████▌| 95/100 [01:34<00:04,  1.02it/s][A

loss:  1208.0488  <<< 



 96%|█████████▌| 96/100 [01:35<00:03,  1.02it/s][A

loss:  1207.2361  <<< 



 97%|█████████▋| 97/100 [01:36<00:02,  1.02it/s][A

loss:  1208.5171  <<< 



 98%|█████████▊| 98/100 [01:37<00:01,  1.02it/s][A

loss:  1205.4496  <<< 



 99%|█████████▉| 99/100 [01:38<00:00,  1.02it/s][A

loss:  1204.8044  <<< 



100%|██████████| 100/100 [01:39<00:00,  1.01it/s]
 10%|█         | 2/20 [03:20<30:03, 100.17s/it]

loss:  1201.576  <<< 



  0%|          | 0/100 [00:00<?, ?it/s][A
  1%|          | 1/100 [00:00<01:36,  1.02it/s][A

loss:  1200.2821  <<< 



  2%|▏         | 2/100 [00:01<01:36,  1.01it/s][A

loss:  1197.8958  <<< 



  3%|▎         | 3/100 [00:02<01:35,  1.02it/s][A

loss:  1196.8004  <<< 



  4%|▍         | 4/100 [00:03<01:34,  1.02it/s][A

loss:  1195.3585  <<< 



  5%|▌         | 5/100 [00:04<01:33,  1.02it/s][A

loss:  1194.7257  <<< 



  6%|▌         | 6/100 [00:05<01:32,  1.02it/s][A

loss:  1193.9963  <<< 



  7%|▋         | 7/100 [00:06<01:31,  1.01it/s][A

loss:  1193.8644  <<< 



  8%|▊         | 8/100 [00:07<01:30,  1.01it/s][A

loss:  1193.7473  <<< 



  9%|▉         | 9/100 [00:08<01:30,  1.01it/s][A

loss:  1194.3485  <<< 



 10%|█         | 10/100 [00:09<01:29,  1.01it/s][A

loss:  1195.1195  <<< 



 11%|█         | 11/100 [00:10<01:28,  1.00it/s][A

loss:  1197.4545  <<< 



 12%|█▏        | 12/100 [00:11<01:27,  1.01it/s][A

loss:  1200.3206  <<< 



 13%|█▎        | 13/100 [00:12<01:26,  1.00it/s][A

loss:  1207.8235  <<< 



 14%|█▍        | 14/100 [00:13<01:25,  1.00it/s][A

loss:  1215.7152  <<< 



 15%|█▌        | 15/100 [00:14<01:24,  1.01it/s][A

loss:  1238.3599  <<< 



 16%|█▌        | 16/100 [00:15<01:23,  1.01it/s][A

loss:  1252.3296  <<< 



 17%|█▋        | 17/100 [00:16<01:22,  1.01it/s][A

loss:  1304.0339  <<< 



 18%|█▊        | 18/100 [00:17<01:20,  1.01it/s][A

loss:  1293.7585  <<< 



 19%|█▉        | 19/100 [00:18<01:19,  1.02it/s][A

loss:  1331.0033  <<< 



 20%|██        | 20/100 [00:19<01:19,  1.01it/s][A

loss:  1262.1931  <<< 



 21%|██        | 21/100 [00:20<01:18,  1.01it/s][A

loss:  1226.0994  <<< 



 22%|██▏       | 22/100 [00:21<01:17,  1.00it/s][A

loss:  1194.3551  <<< 



 23%|██▎       | 23/100 [00:22<01:16,  1.00it/s][A

loss:  1196.4424  <<< 



 24%|██▍       | 24/100 [00:23<01:15,  1.01it/s][A

loss:  1221.2712  <<< 



 25%|██▌       | 25/100 [00:24<01:14,  1.01it/s][A

loss:  1230.5938  <<< 



 26%|██▌       | 26/100 [00:25<01:12,  1.02it/s][A

loss:  1231.7915  <<< 



 27%|██▋       | 27/100 [00:26<01:11,  1.02it/s][A

loss:  1205.1682  <<< 



 28%|██▊       | 28/100 [00:27<01:11,  1.01it/s][A

loss:  1191.0369  <<< 



 29%|██▉       | 29/100 [00:28<01:10,  1.01it/s][A

loss:  1193.888  <<< 



 30%|███       | 30/100 [00:29<01:08,  1.02it/s][A

loss:  1205.979  <<< 



 31%|███       | 31/100 [00:30<01:07,  1.02it/s][A

loss:  1218.7888  <<< 



 32%|███▏      | 32/100 [00:31<01:06,  1.02it/s][A

loss:  1212.1079  <<< 



 33%|███▎      | 33/100 [00:32<01:06,  1.01it/s][A

loss:  1203.9835  <<< 



 34%|███▍      | 34/100 [00:33<01:04,  1.02it/s][A

loss:  1192.5986  <<< 



 35%|███▌      | 35/100 [00:34<01:04,  1.01it/s][A

loss:  1188.3925  <<< 



 36%|███▌      | 36/100 [00:35<01:03,  1.01it/s][A

loss:  1191.1738  <<< 



 37%|███▋      | 37/100 [00:36<01:02,  1.01it/s][A

loss:  1197.0535  <<< 



 38%|███▊      | 38/100 [00:37<01:01,  1.01it/s][A

loss:  1203.8452  <<< 



 39%|███▉      | 39/100 [00:38<01:00,  1.00it/s][A

loss:  1203.731  <<< 



 40%|████      | 40/100 [00:39<00:59,  1.00it/s][A

loss:  1203.5743  <<< 



 41%|████      | 41/100 [00:40<00:59,  1.00s/it][A

loss:  1197.6559  <<< 



 42%|████▏     | 42/100 [00:41<00:58,  1.01s/it][A

loss:  1193.4406  <<< 



 43%|████▎     | 43/100 [00:42<00:57,  1.01s/it][A

loss:  1189.3341  <<< 



 44%|████▍     | 44/100 [00:43<00:56,  1.01s/it][A

loss:  1187.1621  <<< 



 45%|████▌     | 45/100 [00:44<00:55,  1.01s/it][A

loss:  1186.5352  <<< 



 46%|████▌     | 46/100 [00:45<00:54,  1.01s/it][A

loss:  1187.1404  <<< 



 47%|████▋     | 47/100 [00:46<00:53,  1.01s/it][A

loss:  1188.683  <<< 



 48%|████▊     | 48/100 [00:47<00:52,  1.01s/it][A

loss:  1190.6946  <<< 



 49%|████▉     | 49/100 [00:48<00:53,  1.06s/it][A

loss:  1194.0024  <<< 



 50%|█████     | 50/100 [00:49<00:51,  1.03s/it][A

loss:  1196.9375  <<< 



 51%|█████     | 51/100 [00:50<00:49,  1.02s/it][A

loss:  1203.3043  <<< 



 52%|█████▏    | 52/100 [00:51<00:48,  1.01s/it][A

loss:  1207.3333  <<< 



 53%|█████▎    | 53/100 [00:52<00:47,  1.01s/it][A

loss:  1219.2136  <<< 



 54%|█████▍    | 54/100 [00:53<00:46,  1.01s/it][A

loss:  1222.5676  <<< 



 55%|█████▌    | 55/100 [00:54<00:45,  1.00s/it][A

loss:  1239.8767  <<< 



 56%|█████▌    | 56/100 [00:55<00:44,  1.00s/it][A

loss:  1234.9868  <<< 



 57%|█████▋    | 57/100 [00:56<00:42,  1.00it/s][A

loss:  1246.1343  <<< 



 58%|█████▊    | 58/100 [00:57<00:41,  1.00it/s][A

loss:  1226.8088  <<< 



 59%|█████▉    | 59/100 [00:58<00:40,  1.00it/s][A

loss:  1218.6786  <<< 



 60%|██████    | 60/100 [00:59<00:40,  1.01s/it][A

loss:  1199.4408  <<< 



 61%|██████    | 61/100 [01:00<00:39,  1.00s/it][A

loss:  1188.7303  <<< 



 62%|██████▏   | 62/100 [01:01<00:38,  1.01s/it][A

loss:  1185.0266  <<< 



 63%|██████▎   | 63/100 [01:02<00:36,  1.01it/s][A

loss:  1188.3539  <<< 



 64%|██████▍   | 64/100 [01:03<00:35,  1.01it/s][A

loss:  1195.8784  <<< 



 65%|██████▌   | 65/100 [01:04<00:34,  1.01it/s][A

loss:  1200.7589  <<< 



 66%|██████▌   | 66/100 [01:05<00:33,  1.00it/s][A

loss:  1205.9158  <<< 



 67%|██████▋   | 67/100 [01:06<00:32,  1.01it/s][A

loss:  1201.6593  <<< 



 68%|██████▊   | 68/100 [01:07<00:31,  1.01it/s][A

loss:  1198.7328  <<< 



 69%|██████▉   | 69/100 [01:08<00:30,  1.01it/s][A

loss:  1191.8424  <<< 



 70%|███████   | 70/100 [01:09<00:29,  1.01it/s][A

loss:  1187.3224  <<< 



 71%|███████   | 71/100 [01:10<00:28,  1.01it/s][A

loss:  1184.2806  <<< 



 72%|███████▏  | 72/100 [01:11<00:27,  1.01it/s][A

loss:  1183.3938  <<< 



 73%|███████▎  | 73/100 [01:12<00:26,  1.00it/s][A

loss:  1184.1941  <<< 



 74%|███████▍  | 74/100 [01:13<00:25,  1.01it/s][A

loss:  1186.0201  <<< 



 75%|███████▌  | 75/100 [01:14<00:24,  1.01it/s][A

loss:  1188.7819  <<< 



 76%|███████▌  | 76/100 [01:15<00:23,  1.01it/s][A

loss:  1191.1146  <<< 



 77%|███████▋  | 77/100 [01:16<00:22,  1.01it/s][A

loss:  1195.1267  <<< 



 78%|███████▊  | 78/100 [01:17<00:21,  1.01it/s][A

loss:  1197.1488  <<< 



 79%|███████▉  | 79/100 [01:18<00:20,  1.01it/s][A

loss:  1203.0527  <<< 



 80%|████████  | 80/100 [01:19<00:19,  1.01it/s][A

loss:  1204.9407  <<< 



 81%|████████  | 81/100 [01:20<00:18,  1.01it/s][A

loss:  1214.366  <<< 



 82%|████████▏ | 82/100 [01:21<00:17,  1.02it/s][A

loss:  1215.5947  <<< 



 83%|████████▎ | 83/100 [01:22<00:16,  1.01it/s][A

loss:  1229.1053  <<< 



 84%|████████▍ | 84/100 [01:23<00:15,  1.01it/s][A

loss:  1225.8145  <<< 



 85%|████████▌ | 85/100 [01:24<00:14,  1.01it/s][A

loss:  1237.5568  <<< 



 86%|████████▌ | 86/100 [01:25<00:13,  1.01it/s][A

loss:  1224.3041  <<< 



 87%|████████▋ | 87/100 [01:26<00:12,  1.01it/s][A

loss:  1222.991  <<< 



 88%|████████▊ | 88/100 [01:27<00:11,  1.01it/s][A

loss:  1204.8734  <<< 



 89%|████████▉ | 89/100 [01:28<00:10,  1.01it/s][A

loss:  1194.1887  <<< 



 90%|█████████ | 90/100 [01:29<00:09,  1.01it/s][A

loss:  1184.7056  <<< 



 91%|█████████ | 91/100 [01:30<00:08,  1.01it/s][A

loss:  1181.5845  <<< 



 92%|█████████▏| 92/100 [01:31<00:07,  1.01it/s][A

loss:  1183.7726  <<< 



 93%|█████████▎| 93/100 [01:32<00:07,  1.00s/it][A

loss:  1188.6926  <<< 



 94%|█████████▍| 94/100 [01:33<00:06,  1.00s/it][A

loss:  1195.0818  <<< 



 95%|█████████▌| 95/100 [01:34<00:05,  1.00s/it][A

loss:  1196.8284  <<< 



 96%|█████████▌| 96/100 [01:35<00:03,  1.01it/s][A

loss:  1199.4001  <<< 



 97%|█████████▋| 97/100 [01:36<00:02,  1.01it/s][A

loss:  1195.2664  <<< 



 98%|█████████▊| 98/100 [01:37<00:01,  1.00it/s][A

loss:  1192.8927  <<< 



 99%|█████████▉| 99/100 [01:38<00:00,  1.01it/s][A

loss:  1187.7955  <<< 



100%|██████████| 100/100 [01:39<00:00,  1.00it/s]
 15%|█▌        | 3/20 [05:00<28:24, 100.28s/it]

loss:  1184.5109  <<< 



  0%|          | 0/100 [00:00<?, ?it/s][A
  1%|          | 1/100 [00:01<01:53,  1.14s/it][A

loss:  1181.7524  <<< 



  2%|▏         | 2/100 [00:02<01:55,  1.17s/it][A

loss:  1180.2932  <<< 



  3%|▎         | 3/100 [00:03<01:45,  1.09s/it][A

loss:  1179.8046  <<< 



  4%|▍         | 4/100 [00:04<01:41,  1.05s/it][A

loss:  1180.0704  <<< 



  5%|▌         | 5/100 [00:05<01:38,  1.04s/it][A

loss:  1180.9178  <<< 



  6%|▌         | 6/100 [00:06<01:35,  1.02s/it][A

loss:  1182.1697  <<< 



  7%|▋         | 7/100 [00:07<01:34,  1.01s/it][A

loss:  1184.3086  <<< 



  8%|▊         | 8/100 [00:08<01:32,  1.01s/it][A

loss:  1186.7692  <<< 



  9%|▉         | 9/100 [00:09<01:31,  1.01s/it][A

loss:  1192.0046  <<< 



 10%|█         | 10/100 [00:10<01:30,  1.00s/it][A

loss:  1197.2717  <<< 



 11%|█         | 11/100 [00:11<01:28,  1.00it/s][A

loss:  1211.2594  <<< 



 12%|█▏        | 12/100 [00:12<01:27,  1.01it/s][A

loss:  1221.4425  <<< 



 13%|█▎        | 13/100 [00:13<01:26,  1.01it/s][A

loss:  1255.4884  <<< 



 14%|█▍        | 14/100 [00:14<01:24,  1.01it/s][A

loss:  1260.6683  <<< 



 15%|█▌        | 15/100 [00:15<01:23,  1.01it/s][A

loss:  1307.6439  <<< 



 16%|█▌        | 16/100 [00:16<01:22,  1.02it/s][A

loss:  1268.2832  <<< 



 17%|█▋        | 17/100 [00:17<01:21,  1.02it/s][A

loss:  1261.5024  <<< 



 18%|█▊        | 18/100 [00:18<01:20,  1.02it/s][A

loss:  1208.6304  <<< 



 19%|█▉        | 19/100 [00:19<01:19,  1.02it/s][A

loss:  1183.1962  <<< 



 20%|██        | 20/100 [00:20<01:18,  1.02it/s][A

loss:  1182.8252  <<< 



 21%|██        | 21/100 [00:21<01:17,  1.02it/s][A

loss:  1200.1794  <<< 



 22%|██▏       | 22/100 [00:22<01:16,  1.01it/s][A

loss:  1221.7684  <<< 



 23%|██▎       | 23/100 [00:23<01:16,  1.01it/s][A

loss:  1211.745  <<< 



 24%|██▍       | 24/100 [00:24<01:15,  1.01it/s][A

loss:  1198.8702  <<< 



 25%|██▌       | 25/100 [00:25<01:13,  1.02it/s][A

loss:  1182.3934  <<< 



 26%|██▌       | 26/100 [00:26<01:13,  1.01it/s][A

loss:  1179.8271  <<< 



 27%|██▋       | 27/100 [00:27<01:14,  1.03s/it][A

loss:  1188.8467  <<< 



 28%|██▊       | 28/100 [00:28<01:12,  1.01s/it][A

loss:  1197.4222  <<< 



 29%|██▉       | 29/100 [00:29<01:10,  1.00it/s][A

loss:  1204.178  <<< 



 30%|███       | 30/100 [00:30<01:09,  1.01it/s][A

loss:  1195.6786  <<< 



 31%|███       | 31/100 [00:31<01:08,  1.01it/s][A

loss:  1187.8695  <<< 



 32%|███▏      | 32/100 [00:32<01:07,  1.01it/s][A

loss:  1179.8862  <<< 



 33%|███▎      | 33/100 [00:33<01:06,  1.01it/s][A

loss:  1177.3699  <<< 



 34%|███▍      | 34/100 [00:34<01:05,  1.00it/s][A

loss:  1179.7302  <<< 



 35%|███▌      | 35/100 [00:35<01:04,  1.00it/s][A

loss:  1184.2455  <<< 



 36%|███▌      | 36/100 [00:36<01:04,  1.00s/it][A

loss:  1189.7057  <<< 



 37%|███▋      | 37/100 [00:37<01:02,  1.00it/s][A

loss:  1190.7487  <<< 



 38%|███▊      | 38/100 [00:38<01:01,  1.01it/s][A

loss:  1192.3915  <<< 



 39%|███▉      | 39/100 [00:39<01:00,  1.01it/s][A

loss:  1188.8188  <<< 



 40%|████      | 40/100 [00:40<00:59,  1.02it/s][A

loss:  1186.8363  <<< 



 41%|████      | 41/100 [00:41<00:58,  1.01it/s][A

loss:  1182.8735  <<< 



 42%|████▏     | 42/100 [00:41<00:56,  1.02it/s][A

loss:  1180.4232  <<< 



 43%|████▎     | 43/100 [00:42<00:55,  1.02it/s][A

loss:  1178.1194  <<< 



 44%|████▍     | 44/100 [00:43<00:54,  1.02it/s][A

loss:  1176.7272  <<< 



 45%|████▌     | 45/100 [00:44<00:53,  1.02it/s][A

loss:  1175.8457  <<< 



 46%|████▌     | 46/100 [00:45<00:53,  1.02it/s][A

loss:  1175.4128  <<< 



 47%|████▋     | 47/100 [00:46<00:51,  1.02it/s][A

loss:  1175.284  <<< 



 48%|████▊     | 48/100 [00:47<00:51,  1.01it/s][A

loss:  1175.3777  <<< 



 49%|████▉     | 49/100 [00:48<00:50,  1.01it/s][A

loss:  1175.6791  <<< 



 50%|█████     | 50/100 [00:49<00:49,  1.01it/s][A

loss:  1176.1879  <<< 



 51%|█████     | 51/100 [00:50<00:48,  1.01it/s][A

loss:  1177.1339  <<< 



 52%|█████▏    | 52/100 [00:51<00:47,  1.01it/s][A

loss:  1178.4565  <<< 



 53%|█████▎    | 53/100 [00:52<00:46,  1.01it/s][A

loss:  1181.1925  <<< 



 54%|█████▍    | 54/100 [00:53<00:45,  1.01it/s][A

loss:  1184.6569  <<< 



 55%|█████▌    | 55/100 [00:54<00:44,  1.01it/s][A

loss:  1192.81  <<< 



 56%|█████▌    | 56/100 [00:55<00:43,  1.00it/s][A

loss:  1201.1271  <<< 



 57%|█████▋    | 57/100 [00:56<00:42,  1.00it/s][A

loss:  1224.1812  <<< 



 58%|█████▊    | 58/100 [00:57<00:41,  1.01it/s][A

loss:  1236.7388  <<< 



 59%|█████▉    | 59/100 [00:58<00:40,  1.01it/s][A

loss:  1284.4452  <<< 



 60%|██████    | 60/100 [00:59<00:39,  1.02it/s][A

loss:  1271.9031  <<< 



 61%|██████    | 61/100 [01:00<00:38,  1.02it/s][A

loss:  1300.6088  <<< 



 62%|██████▏   | 62/100 [01:01<00:37,  1.02it/s][A

loss:  1239.0281  <<< 



 63%|██████▎   | 63/100 [01:02<00:36,  1.01it/s][A

loss:  1205.7709  <<< 



 64%|██████▍   | 64/100 [01:03<00:35,  1.01it/s][A

loss:  1178.1425  <<< 



 65%|██████▌   | 65/100 [01:04<00:34,  1.01it/s][A

loss:  1180.1709  <<< 



 66%|██████▌   | 66/100 [01:05<00:33,  1.01it/s][A

loss:  1202.1329  <<< 



 67%|██████▋   | 67/100 [01:06<00:32,  1.01it/s][A

loss:  1211.4456  <<< 



 68%|██████▊   | 68/100 [01:07<00:32,  1.00s/it][A

loss:  1213.7003  <<< 



 69%|██████▉   | 69/100 [01:08<00:31,  1.00s/it][A

loss:  1190.1093  <<< 



 70%|███████   | 70/100 [01:09<00:29,  1.01it/s][A

loss:  1176.2664  <<< 



 71%|███████   | 71/100 [01:10<00:28,  1.01it/s][A

loss:  1176.6497  <<< 



 72%|███████▏  | 72/100 [01:11<00:27,  1.00it/s][A

loss:  1186.852  <<< 



 73%|███████▎  | 73/100 [01:12<00:26,  1.01it/s][A

loss:  1199.0261  <<< 



 74%|███████▍  | 74/100 [01:13<00:25,  1.01it/s][A

loss:  1195.7559  <<< 



 75%|███████▌  | 75/100 [01:14<00:24,  1.01it/s][A

loss:  1190.3698  <<< 



 76%|███████▌  | 76/100 [01:15<00:23,  1.01it/s][A

loss:  1179.2596  <<< 



 77%|███████▋  | 77/100 [01:16<00:22,  1.01it/s][A

loss:  1173.5747  <<< 



 78%|███████▊  | 78/100 [01:17<00:21,  1.01it/s][A

loss:  1173.9446  <<< 



 79%|███████▉  | 79/100 [01:18<00:20,  1.02it/s][A

loss:  1178.412  <<< 



 80%|████████  | 80/100 [01:19<00:19,  1.01it/s][A

loss:  1184.543  <<< 



 81%|████████  | 81/100 [01:20<00:18,  1.01it/s][A

loss:  1186.4557  <<< 



 82%|████████▏ | 82/100 [01:21<00:17,  1.01it/s][A

loss:  1188.2642  <<< 



 83%|████████▎ | 83/100 [01:22<00:16,  1.01it/s][A

loss:  1184.1521  <<< 



 84%|████████▍ | 84/100 [01:23<00:15,  1.01it/s][A

loss:  1181.2579  <<< 



 85%|████████▌ | 85/100 [01:24<00:14,  1.01it/s][A

loss:  1176.8475  <<< 



 86%|████████▌ | 86/100 [01:25<00:13,  1.01it/s][A

loss:  1174.025  <<< 



 87%|████████▋ | 87/100 [01:26<00:12,  1.01it/s][A

loss:  1172.1326  <<< 



 88%|████████▊ | 88/100 [01:27<00:11,  1.01it/s][A

loss:  1171.3705  <<< 



 89%|████████▉ | 89/100 [01:28<00:10,  1.02it/s][A

loss:  1171.4568  <<< 



 90%|█████████ | 90/100 [01:29<00:09,  1.01it/s][A

loss:  1172.1604  <<< 



 91%|█████████ | 91/100 [01:30<00:08,  1.01it/s][A

loss:  1173.4442  <<< 



 92%|█████████▏| 92/100 [01:31<00:07,  1.01it/s][A

loss:  1175.0471  <<< 



 93%|█████████▎| 93/100 [01:32<00:06,  1.01it/s][A

loss:  1177.8646  <<< 



 94%|█████████▍| 94/100 [01:33<00:06,  1.00s/it][A

loss:  1180.6914  <<< 



 95%|█████████▌| 95/100 [01:34<00:04,  1.00it/s][A

loss:  1186.879  <<< 



 96%|█████████▌| 96/100 [01:35<00:04,  1.00s/it][A

loss:  1191.6693  <<< 



 97%|█████████▋| 97/100 [01:36<00:02,  1.01it/s][A

loss:  1204.859  <<< 



 98%|█████████▊| 98/100 [01:37<00:01,  1.01it/s][A

loss:  1210.4349  <<< 



 99%|█████████▉| 99/100 [01:38<00:00,  1.00it/s][A

loss:  1232.7003  <<< 



100%|██████████| 100/100 [01:39<00:00,  1.01it/s]
 20%|██        | 4/20 [06:40<26:44, 100.30s/it]

loss:  1229.1755  <<< 



  0%|          | 0/100 [00:00<?, ?it/s][A
  1%|          | 1/100 [00:00<01:36,  1.03it/s][A

loss:  1246.4827  <<< 



  2%|▏         | 2/100 [00:01<01:37,  1.01it/s][A

loss:  1221.8217  <<< 



  3%|▎         | 3/100 [00:02<01:35,  1.02it/s][A

loss:  1211.7102  <<< 



  4%|▍         | 4/100 [00:04<01:37,  1.02s/it][A

loss:  1186.413  <<< 



  5%|▌         | 5/100 [00:04<01:35,  1.00s/it][A

loss:  1173.2086  <<< 



  6%|▌         | 6/100 [00:05<01:33,  1.00it/s][A

loss:  1170.9827  <<< 



  7%|▋         | 7/100 [00:06<01:32,  1.00it/s][A

loss:  1178.0165  <<< 



  8%|▊         | 8/100 [00:07<01:31,  1.01it/s][A

loss:  1189.3075  <<< 



  9%|▉         | 9/100 [00:08<01:30,  1.01it/s][A

loss:  1192.0076  <<< 



 10%|█         | 10/100 [00:09<01:28,  1.02it/s][A

loss:  1193.4591  <<< 



 11%|█         | 11/100 [00:10<01:27,  1.02it/s][A

loss:  1183.6714  <<< 



 12%|█▏        | 12/100 [00:11<01:27,  1.00it/s][A

loss:  1176.2485  <<< 



 13%|█▎        | 13/100 [00:12<01:26,  1.01it/s][A

loss:  1170.579  <<< 



 14%|█▍        | 14/100 [00:13<01:25,  1.01it/s][A

loss:  1169.45  <<< 



 15%|█▌        | 15/100 [00:14<01:24,  1.01it/s][A

loss:  1171.9589  <<< 



 16%|█▌        | 16/100 [00:15<01:23,  1.01it/s][A

loss:  1175.8604  <<< 



 17%|█▋        | 17/100 [00:16<01:23,  1.00s/it][A

loss:  1180.7416  <<< 



 18%|█▊        | 18/100 [00:17<01:21,  1.00it/s][A

loss:  1182.098  <<< 



 19%|█▉        | 19/100 [00:18<01:20,  1.00it/s][A

loss:  1184.587  <<< 



 20%|██        | 20/100 [00:19<01:19,  1.00it/s][A

loss:  1182.1143  <<< 



 21%|██        | 21/100 [00:20<01:18,  1.01it/s][A

loss:  1181.503  <<< 



 22%|██▏       | 22/100 [00:21<01:17,  1.01it/s][A

loss:  1177.8066  <<< 



 23%|██▎       | 23/100 [00:22<01:16,  1.00it/s][A

loss:  1175.8566  <<< 



 24%|██▍       | 24/100 [00:23<01:16,  1.01s/it][A

loss:  1173.0167  <<< 



 25%|██▌       | 25/100 [00:24<01:15,  1.01s/it][A

loss:  1171.3425  <<< 



 26%|██▌       | 26/100 [00:25<01:14,  1.01s/it][A

loss:  1169.7805  <<< 



 27%|██▋       | 27/100 [00:26<01:13,  1.01s/it][A

loss:  1168.8292  <<< 



 28%|██▊       | 28/100 [00:27<01:12,  1.01s/it][A

loss:  1168.1246  <<< 



 29%|██▉       | 29/100 [00:28<01:11,  1.01s/it][A

loss:  1167.6865  <<< 


**inference engine**

In [None]:

# newsize = (140,140) #(260, 260) # /.... 233 * 454
from google.colab.patches import cv2_imshow
import numpy as np 
from google.colab import output

!wget https://live.staticflickr.com/7492/15677707699_d9d67acf9d_b.jpg -O a.jpg
image_in = '/content/a.jpg'

from PIL import Image
import jax.numpy as jnp
def imageRGB(argv):
    im = Image.open(argv)
    tvt, tvu = jnp.asarray(im.resize(newsize)),jnp.asarray(im.resize(newsize)).reshape(-1,3)
    return tvt, tvu
image = jnp.asarray((imageRGB(image_in)[1]))
#restored_state = checkpoints.restore_checkpoint(ckpt_dir=CKPT_DIR, target=state)
#state = restored_state
prediction = eval_step(state, image)
prediction['loss']


predicted_image = np.array(prediction['logits'],  dtype=np.uint8).reshape(newsize) 
cv2_imshow(predicted_image)


In [None]:
# import cv2
# from google.colab.patches import cv2_imshow
# import numpy as np 
# def show_image(argu):
#   L1 = argu[0]
#   predicted_image = np.array(argu[0],  dtype=np.uint8).reshape(newsize) # This would be your image array
#   a = predicted_image
#   for i in range(0,argu.shape[0]):
#     predicted_image = np.array(argu[i],  dtype=np.uint8).reshape(newsize) 
#     a = cv2.hconcat([a, predicted_image])
#   cv2_imshow(a)

# show_image(metrics['logits'])

In [None]:
# from jax.tree_util import tree_structure
# print(tree_structure(state))