<a href="https://colab.research.google.com/github/1kaiser/Media-Segment-Depth-MLP/blob/main/MLP_Image_Train_Inference_JAX.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

### **Downloading the dataset**

download [flower dataset](https://www.kaggle.com/datasets/alxmamaev/flowers-recognition?resource=download) from kaggle.

In [None]:
from google.colab import output
!wget https://github.com/1kaiser/Media-Segment-Depth-MLP/releases/download/v0.2/archive.zip -O archive.zip
!unzip /content/archive.zip #unzipping the flower images from archive..
output.clear()
##########################<< copying all varities into a single folder block >>################
!mkdir -p /content/flowers/all
!cp /content/flowers/daisy/* /content/flowers/all
!cp /content/flowers/dandelion/* /content/flowers/all
!cp /content/flowers/rose/* /content/flowers/all
!cp /content/flowers/sunflower/* /content/flowers/all
!cp /content/flowers/tulip/* /content/flowers/all
##########################<< end of block >>################
print("creating single image folder complete >>>")


## **RUN** 

**Model and training code**
Our model is a coordinate-based multilayer perceptron. In this example, for each input image coordinate $(x,y)$, the model predicts the associated color $(r,g,b)$ or any $(gray)$.

![Network diagram](https://user-images.githubusercontent.com/3310961/85066930-ad444580-b164-11ea-9cc0-17494679e71f.png)

**POSITIONAL ENCODING BLOCK** 

In [None]:
#✅
import jax
import jax.numpy as jnp

positional_encoding_dims = 6  # Number of positional encodings applied

def positional_encoding(args):
    image_height_x_image_width, cha = args.shape
    inputs_freq = jax.vmap(lambda x: args * 2.0 ** x)(jnp.arange(positional_encoding_dims))
    x = jnp.stack([jnp.sin(inputs_freq), jnp.cos(inputs_freq)])
    x = x.swapaxes(0, 2)
    x = x.reshape([image_height_x_image_width, -1])
    x = jnp.concatenate([args, x], axis=-1)
    return x



**MLP MODEL DEFINATION**
Basically, passing input points through a simple Fourier Feature Mapping enables an MLP to learn high-frequency functions (such as an RGB image) in low-dimensional problem domains (such as a 2D coordinate of pixels).

In [None]:
#✅
!python -m pip install -qq -U flax orbax
# Orbax needs to enable asyncio in a Colab environment.
!python -m pip install -qq nest_asyncio

import jax
import jax.numpy as jnp

import flax
import optax
from typing import Any

from jax import lax
import flax.linen as nn
from flax.training import train_state, common_utils

apply_positional_encoding = True # Apply posittional encoding to the input or not
ndl = 8 # num_dense_layers Number of dense layers in MLP
dlw = 256 # dense_layer_width Dimentionality of dense layers' output space 

##########################################<< MLP MODEL >>#########################################
class MLPModel(nn.Module):
    dtype: Any = jnp.float32
    precision: Any = lax.Precision.DEFAULT
    apply_positional_encoding: bool = apply_positional_encoding
    @nn.compact
    def __call__(self, input_points):
        x = positional_encoding(input_points) if self.apply_positional_encoding else input_points
        for i in range(ndl):
            x = nn.Dense(dlw,dtype=self.dtype,precision=self.precision)(x)
            x = nn.relu(x)
            x = jnp.concatenate([x, input_points], axis=-1) if i == 4 else x
  
        x = nn.Dense(1, dtype=self.dtype, precision=self.precision)(x)
        return x
##########################################<< MLP MODEL >>#########################################

**initialize the module**

In [None]:
#✅
!python -m pip install -q -U flax
import optax
from flax.training import train_state
import jax.numpy as jnp
import jax


def Create_train_state(r_key, model, shape, learning_rate ) -> train_state.TrainState:
    print(shape)
    variables = model.init(r_key, jnp.ones(shape)) 
    optimizer = optax.adam(learning_rate) 
    return train_state.TrainState.create(
        apply_fn = model.apply,
        tx=optimizer,
        params=variables['params']
    )

learning_rate = 1e-4
batch_size_no = 64

model = MLPModel() # Instantiate the Model

**defining loss function**

In [None]:
#serial
def image_difference_loss(logits, labels):
    loss = .5 * jnp.mean((logits - labels) ** 2) 
    return loss
def compute_metrics(*, logits, labels):
  loss = image_difference_loss(logits, labels)
  metrics = {
      'loss': loss,     #LOSS
      'logits': logits, #PREDICTED IMAGE
      'labels': labels  #ACTUAL IMAGE
  }
  return metrics

**train step defination**

In [None]:
#cpu serial
import jax

def train_step(state: train_state.TrainState, batch: jnp.asarray, rng):
    image, label = batch  
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, image);
        loss =  image_difference_loss(logits, label);
        return loss, logits

    gradient_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, logits), grads = gradient_fn(state.params)
    new_state = state.apply_gradients(grads=grads)
    logs = compute_metrics(logits=logits, labels=label)
    return new_state, logs

import jax
@jax.jit
def eval_step(state, image):
    logits = state.apply_fn({'params': state.params}, image)
    return compute_metrics(logits=logits, labels=image)


**image conversion fiunctions**

In [None]:
from PIL import Image
import jax.numpy as jnp
def imageGRAY(argv):
    im = Image.open(argv).convert('L')
    tvt, tvu = jnp.asarray(im.resize(newsize)),jnp.asarray(im.resize(newsize)).reshape(-1,1)
    return tvt, tvu
def imageRGB(argv):
    im = Image.open(argv)
    tvt, tvu = jnp.asarray(im.resize(newsize)),jnp.asarray(im.resize(newsize)).reshape(-1,3)
    return tvt, tvu

**image dataset, image size and batch size Setup**

In [None]:
########################################## to load the data in batches as mentioned single batch of images with already provided sizes 
import jax
from jax import random

newsize = (140,140) #(260, 260) # /.... 233 * 454
batch_size = 1

import os
image_dir = r'/content/flowers/tulip/'
#############################################################################
bandend = ["c",".jpg", "b02"]
expression_b2 = bandend[1]
total_images =  [f for f in os.listdir(image_dir) if f.__contains__(expression_b2)]
total_images.sort()
total_images_path = [os.path.join(image_dir, i) for i in total_images if i != 'outputs']
no_of_batches = int(len(total_images_path)/batch_size)

######################################## making 8 array of input for each device >>>
def batchedimages(image_locations):
  RGB8 = jnp.asarray((imageRGB(total_images_path[image_locations[0]])[1]))
  GRAY8 = jnp.asarray((imageGRAY(total_images_path[image_locations[0]])[1]))
  batch_ccc = RGB8, GRAY8 
  return batch_ccc

def data_stream():
  key = random.PRNGKey(0)
  perm = random.permutation(key, len(total_images_path))
  for i in range(no_of_batches):
    batch_idx = perm[i * batch_size : (i + 1) * batch_size]; #print(batch_idx)
    yield batchedimages(batch_idx)

batches = data_stream()  

In [None]:
#@title # **👠HIGH HEELS RUN >>>>>>>>>>>** { vertical-output: true }
newsize = (140,140) #(260, 260) # /.... 233 * 454

import jax
from jax import random
from tqdm import tqdm
import re
from google.colab import output
import orbax.checkpoint as orbax
from flax.training import checkpoints

import optax
import nest_asyncio
nest_asyncio.apply()

rng = jax.random.PRNGKey(0)
CKPT_DIR = 'ckpts'

######################<<<< initiating train state
count = 0
if count == 0 :
  HxW, Channels = next(batches)[0].shape
  state = Create_train_state( rng, model, (HxW, Channels), learning_rate ) 
  count = 1
#✅✅🔻 state = flax.jax_utils.replicate(state)  # FLAX will replicate the state to every device so that updating can be made easy

###################### 
checkpoint_available = 0
pattern = re.compile("checkpoint_\d+")   # to search for "checkpoint_*munerical value*" numerical value of any length is denoted by regular expression "\d+"
dir = "/content/ckpts/"
isFile = os.path.isdir(dir)
if isFile:
  for filepath in os.listdir(dir):
      if pattern.match(filepath):
          checkpoint_available = 1

total_epochs = 50
for epochs in tqdm(range(total_epochs)):  
  batches = data_stream() 

  if checkpoint_available:
    state = checkpoints.restore_checkpoint(ckpt_dir=CKPT_DIR, target=state)
    checkpoint_available = 0 # << Flag updated >>> to stop loading the same checkpoint in the next iteration then remove the checkpoint directory
    !rm -r {dir}

  for bbb in tqdm(range(no_of_batches-5)):
    state, metrics = train_step(state, next(batches), rng)
    output.clear()
    print("loss: ",metrics['loss']," <<< ") # naming of the checkpoint is "checkpoint_*"  where "*" => value of the steps variable, i.e. 'epochs'
  orbax_checkpointer = orbax.Checkpointer(orbax.PyTreeCheckpointHandler())
  checkpoints.save_checkpoint(ckpt_dir=CKPT_DIR, target=state, step=epochs, prefix='checkpoint_', keep=1, overwrite=False, orbax_checkpointer=orbax_checkpointer)
  # restored_state = checkpoints.restore_checkpoint(ckpt_dir=CKPT_DIR, target=state) # using to get the checkpoint loaded , it can be latest one , or if already available as checkpoint in the "CKPT_DIR" directory then take the file from directory then save in >> restored_checkpoints
  ##################################################



**inference engine**

In [None]:

# newsize = (140,140) #(260, 260) # /.... 233 * 454
from google.colab.patches import cv2_imshow
import numpy as np 
from google.colab import output
from flax.training import checkpoints

!wget https://live.staticflickr.com/7492/15677707699_d9d67acf9d_b.jpg -O a.jpg
image_in = '/content/a.jpg'

from PIL import Image
import jax.numpy as jnp
def imageRGB(argv):
    im = Image.open(argv)
    tvt, tvu = jnp.asarray(im.resize(newsize)),jnp.asarray(im.resize(newsize)).reshape(-1,3)
    return tvt, tvu
image = jnp.asarray((imageRGB(image_in)[1]))
#restored_state = checkpoints.restore_checkpoint(ckpt_dir=CKPT_DIR, target=state)
#state = restored_state
#initialize
HxW, Channels = image.shape
state = Create_train_state( model, rng, (HxW, Channels), learning_rate ) 

state = checkpoints.restore_checkpoint(ckpt_dir=CKPT_DIR, target=state)
prediction = eval_step(state, image)
prediction['loss']


predicted_image = np.array(prediction['logits'],  dtype=np.uint8).reshape(newsize) 
cv2_imshow(predicted_image)


In [None]:
# import cv2
# from google.colab.patches import cv2_imshow
# import numpy as np 
# def show_image(argu):
#   L1 = argu[0]
#   predicted_image = np.array(argu[0],  dtype=np.uint8).reshape(newsize) # This would be your image array
#   a = predicted_image
#   for i in range(0,argu.shape[0]):
#     predicted_image = np.array(argu[i],  dtype=np.uint8).reshape(newsize) 
#     a = cv2.hconcat([a, predicted_image])
#   cv2_imshow(a)

# show_image(metrics['logits'])

In [None]:
# from jax.tree_util import tree_structure
# print(tree_structure(state))

##**test dataset segmentation Creation download section**

In [None]:
##################################<<< MEDIAPIPE LIBRARY INSTALLATON >>>#############################
!python -m pip install mediapipe
##################################<<< FRAME EXTRACTION >>>#############################
video_location = '/content/drive/MyDrive/OUT/data/machine_learning_test_dataset/test.mp4'
import os

 
# Read images with OpenCV.
#images= None
image_dir = '/content/MEDIAPIPEinput/'
os.makedirs(image_dir, exist_ok=True)
image_dir_out = '/content/annotated_images'
os.makedirs(image_dir_out, exist_ok=True)
frame_rate = 25
!ffmpeg -y -hwaccel cuvid \
  -i {video_location} \
  -r {frame_rate} {image_dir}out_%09d.png

imgs_list = os.listdir(image_dir)
imgs_list.sort()
imgs_path = [os.path.join(image_dir, i) for i in imgs_list if i != 'outputs']
################################<<< SEGMENTATION USING MEDIAPIPE >>>###################################
import cv2
from google.colab.patches import cv2_imshow
import numpy as np
import mediapipe as mp
# mp_holistic = mp.solutions.holistic
mp_pose = mp.solutions.pose
!rm -r {image_dir}.ipynb_checkpoints

# Run MediaPipe Pose with `enable_segmentation=True` to get pose segmentation.
with mp_pose.Pose(static_image_mode=True, 
                          min_detection_confidence=0.2,
                          model_complexity=2, 
                          enable_segmentation=True,) as pose:
  temp_segmentation_mask =[]                        
  for name, image in enumerate(imgs_path):
    !rm -r {image_dir}.ipynb_checkpoints
    # Convert the BGR image to RGB and process it with MediaPipe Pose.
    image = cv2.imread(image)
    results = pose.process(image)

    # Draw pose segmentation.
    print(f'Pose segmentation of {name}:')
    annotated_image_pose = image.copy()
    red_img = np.zeros_like(annotated_image_pose, dtype=np.uint8)
    red_img[:, :] = (255,255,255)
    ###check if segmentation_mask exists or not ## if exists then ok Else use previous mask temporarily
    if results.segmentation_mask is None:
      print("true")
      results.segmentation_mask = temp_segmentation_mask[-1]
      temp_segmentation_mask.append(results.segmentation_mask)
    else:
      temp_segmentation_mask.append(results.segmentation_mask)
    ###End check if segmentation_mask exists or not ## if exists then ok Else use previous mask temporarily
    segm_2class = 0.0 + 1.0 * results.segmentation_mask
    segm_2class = np.repeat(segm_2class[..., np.newaxis], 3, axis=2)
    annotated_image_pose = annotated_image_pose * segm_2class + red_img * (1 - segm_2class)
    #resize_and_show(annotated_image)
    cv2.imwrite('%s/%s' %(image_dir_out, imgs_list[name]), annotated_image_pose)
    !rm -r {image_dir_out}.ipynb_checkpoints


## **RUN 2** 

**Model and training code**
Our model is a coordinate-based multilayer perceptron. In this example, for each input image coordinate $(x,y)$, the model predicts the associated color $(r,g,b)$ or any $(gray)$.

![Network diagram](https://user-images.githubusercontent.com/3310961/85066930-ad444580-b164-11ea-9cc0-17494679e71f.png)

**POSITIONAL ENCODING BLOCK** 

In [None]:
#✅
import jax
import jax.numpy as jnp

positional_encoding_dims = 6  # Number of positional encodings applied

def positional_encoding(args):
    image_height_x_image_width, cha = args.shape
    inputs_freq = jax.vmap(lambda x: args * 2.0 ** x)(jnp.arange(positional_encoding_dims))
    x = jnp.stack([jnp.sin(inputs_freq), jnp.cos(inputs_freq)])
    x = x.swapaxes(0, 2)
    x = x.reshape([image_height_x_image_width, -1])
    x = jnp.concatenate([args, x], axis=-1)
    return x



**MLP MODEL DEFINATION**
Basically, passing input points through a simple Fourier Feature Mapping enables an MLP to learn high-frequency functions (such as an RGB image) in low-dimensional problem domains (such as a 2D coordinate of pixels).

In [None]:
#✅
!python -m pip install -qq -U flax orbax
# Orbax needs to enable asyncio in a Colab environment.
!python -m pip install -qq nest_asyncio


import jax
import jax.numpy as jnp

import flax
import optax
from typing import Any

from jax import lax
import flax.linen as nn
from flax.training import train_state, common_utils

apply_positional_encoding = True # Apply posittional encoding to the input or not
ndl = 8 # num_dense_layers Number of dense layers in MLP
dlw = 256 # dense_layer_width Dimentionality of dense layers' output space 

##########################################<< MLP MODEL >>#########################################
class MLPModel(nn.Module):
    dtype: Any = jnp.float32
    precision: Any = lax.Precision.DEFAULT
    apply_positional_encoding: bool = apply_positional_encoding
    @nn.compact
    def __call__(self, input_points):
        x = positional_encoding(input_points) if self.apply_positional_encoding else input_points
        for i in range(ndl):
            x = nn.Dense(dlw,dtype=self.dtype,precision=self.precision)(x)
            x = nn.relu(x)
            x = jnp.concatenate([x, input_points], axis=-1) if i == 4 else x
  
        x = nn.Dense(1, dtype=self.dtype, precision=self.precision)(x)
        return x
##########################################<< MLP MODEL >>#########################################

**initialize the module**

In [None]:
#✅
!python -m pip install -q -U flax
import optax
from flax.training import train_state
import jax.numpy as jnp
import jax


def Create_train_state(r_key, model, shape, learning_rate ) -> train_state.TrainState:
    print(shape)
    variables = model.init(r_key, jnp.ones(shape)) 
    optimizer = optax.adam(learning_rate) 
    return train_state.TrainState.create(
        apply_fn = model.apply,
        tx=optimizer,
        params=variables['params']
    )

learning_rate = 1e-4
batch_size_no = 64

model = MLPModel() # Instantiate the Model

**defining loss function**

In [None]:
#serial
def image_difference_loss(logits, labels):
    loss = .5 * jnp.mean((logits - labels) ** 2) 
    return loss
def compute_metrics(*, logits, labels):
  loss = image_difference_loss(logits, labels)
  metrics = {
      'loss': loss,     #LOSS
      'logits': logits, #PREDICTED IMAGE
      'labels': labels  #ACTUAL IMAGE
  }
  return metrics

**train step defination**

In [None]:
#cpu serial
import jax

def train_step(state: train_state.TrainState, batch: jnp.asarray, rng):
    image, label = batch  
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, image);
        loss =  image_difference_loss(logits, label);
        return loss, logits

    gradient_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, logits), grads = gradient_fn(state.params)
    new_state = state.apply_gradients(grads=grads)
    logs = compute_metrics(logits=logits, labels=label)
    return new_state, logs

import jax
@jax.jit
def eval_step(state, image):
    logits = state.apply_fn({'params': state.params}, image)
    return compute_metrics(logits=logits, labels=image)


**image conversion fiunctions**

In [None]:
from PIL import Image
import jax.numpy as jnp
def imageGRAY(argv):
    im = Image.open(argv).convert('L')
    tvt, tvu = jnp.asarray(im.resize(newsize)),jnp.asarray(im.resize(newsize)).reshape(-1,1)
    return tvt, tvu
def imageRGB(argv):
    im = Image.open(argv)
    tvt, tvu = jnp.asarray(im.resize(newsize)),jnp.asarray(im.resize(newsize)).reshape(-1,3)
    return tvt, tvu

**image dataset, image size and batch size Setup**

In [None]:
########################################## to load the data in batches as mentioned single batch of images with already provided sizes 
import jax
from jax import random

newsize = (140,140) #(260, 260) # /.... 233 * 454
batch_size = 1

import os
image_dir = r'/content/MEDIAPIPEinput'
annotated_image_dir = r'/content/annotated_images'

#############################################################################
bandend = ["c",".png", "b02"]
expression_b2 = bandend[1]

total_images =  [f for f in os.listdir(image_dir) if f.__contains__(expression_b2)]
total_images.sort()
total_images_path = [os.path.join(image_dir, i) for i in total_images if i != 'outputs']

annotated_total_images =  [f for f in os.listdir(annotated_image_dir) if f.__contains__(expression_b2)]
annotated_total_images.sort()
annotated_total_images_path = [os.path.join(annotated_image_dir, i) for i in annotated_total_images if i != 'outputs']

no_of_batches = int(len(total_images_path)/batch_size)

######################################## making 8 array of input for each device >>>
def batchedimages(image_locations):
  RGB8 = jnp.asarray((imageRGB(total_images_path[image_locations[0]])[1]))
  ANNOTATED8 = jnp.asarray((imageRGB(annotated_total_images_path[image_locations[0]])[1]))
  batch_ccc = RGB8, ANNOTATED8 
  return batch_ccc

def data_stream():
  key = random.PRNGKey(0)
  perm = random.permutation(key, len(total_images_path))
  for i in range(no_of_batches):
    batch_idx = perm[i * batch_size : (i + 1) * batch_size]; #print(batch_idx)
    yield batchedimages(batch_idx)

batches = data_stream()  

In [None]:
next(batches)[0].shape

In [None]:
#@title # **👠HIGH HEELS RUN >>>>>>>>>>>** { vertical-output: true }
newsize = (140,140) #(260, 260) # /.... 233 * 454

import jax
from jax import random
from tqdm import tqdm
import re
from google.colab import output
import orbax.checkpoint as orbax
from flax.training import checkpoints

import optax
import nest_asyncio
nest_asyncio.apply()

rng = jax.random.PRNGKey(0)
CKPT_DIR = 'ckpts'

######################<<<< initiating train state
count = 0
if count == 0 :
  HxW, Channels = next(batches)[0].shape
  state = Create_train_state( rng, model, (HxW, Channels), learning_rate ) 
  count = 1
#✅✅🔻 state = flax.jax_utils.replicate(state)  # FLAX will replicate the state to every device so that updating can be made easy

###################### 
checkpoint_available = 0
pattern = re.compile("checkpoint_\d+")   # to search for "checkpoint_*munerical value*" numerical value of any length is denoted by regular expression "\d+"
dir = "/content/ckpts/"
isFile = os.path.isdir(dir)
if isFile:
  for filepath in os.listdir(dir):
      if pattern.match(filepath):
          checkpoint_available = 1

total_epochs = 50
for epochs in tqdm(range(total_epochs)):  
  batches = data_stream() 

  if checkpoint_available:
    state = checkpoints.restore_checkpoint(ckpt_dir=CKPT_DIR, target=state)
    checkpoint_available = 0 # << Flag updated >>> to stop loading the same checkpoint in the next iteration then remove the checkpoint directory
    !rm -r {dir}

  for bbb in tqdm(range(no_of_batches-5)):
    state, metrics = train_step(state, next(batches), rng)
    output.clear()
    print("loss: ",metrics['loss']," <<< ") # naming of the checkpoint is "checkpoint_*"  where "*" => value of the steps variable, i.e. 'epochs'
  orbax_checkpointer = orbax.Checkpointer(orbax.PyTreeCheckpointHandler())
  checkpoints.save_checkpoint(ckpt_dir=CKPT_DIR, target=state, step=epochs, prefix='checkpoint_', keep=1, overwrite=False, orbax_checkpointer=orbax_checkpointer)
  # restored_state = checkpoints.restore_checkpoint(ckpt_dir=CKPT_DIR, target=state) # using to get the checkpoint loaded , it can be latest one , or if already available as checkpoint in the "CKPT_DIR" directory then take the file from directory then save in >> restored_checkpoints
  ##################################################



**inference engine**

In [None]:

# newsize = (140,140) #(260, 260) # /.... 233 * 454
from google.colab.patches import cv2_imshow
import numpy as np 
from google.colab import output

!wget https://live.staticflickr.com/7492/15677707699_d9d67acf9d_b.jpg -O a.jpg
image_in = '/content/a.jpg'

from PIL import Image
import jax.numpy as jnp
def imageRGB(argv):
    im = Image.open(argv)
    tvt, tvu = jnp.asarray(im.resize(newsize)),jnp.asarray(im.resize(newsize)).reshape(-1,3)
    return tvt, tvu
image = jnp.asarray((imageRGB(image_in)[1]))
#restored_state = checkpoints.restore_checkpoint(ckpt_dir=CKPT_DIR, target=state)
#state = restored_state
prediction = eval_step(state, image)
prediction['loss']


predicted_image = np.array(prediction['logits'],  dtype=np.uint8).reshape(newsize) 
cv2_imshow(predicted_image)


In [None]:
# import cv2
# from google.colab.patches import cv2_imshow
# import numpy as np 
# def show_image(argu):
#   L1 = argu[0]
#   predicted_image = np.array(argu[0],  dtype=np.uint8).reshape(newsize) # This would be your image array
#   a = predicted_image
#   for i in range(0,argu.shape[0]):
#     predicted_image = np.array(argu[i],  dtype=np.uint8).reshape(newsize) 
#     a = cv2.hconcat([a, predicted_image])
#   cv2_imshow(a)

# show_image(metrics['logits'])

In [None]:
# from jax.tree_util import tree_structure
# print(tree_structure(state))

##**ensemble**

In [None]:
#✅
!python -m pip install -q -U flax
from typing import Any
import jax
from jax import lax
import jax.numpy as jnp
import optax
import flax
import flax.linen as nn
from flax.training import train_state, common_utils
import functools

positional_encoding_dims = 6  # Number of positional encodings applied

def positional_encoding(args):
    image_height_x_image_width, cha = args.shape
    inputs_freq = jax.vmap(lambda x: args * 2.0 ** x)(jnp.arange(positional_encoding_dims))
    x = jnp.stack([jnp.sin(inputs_freq), jnp.cos(inputs_freq)])
    x = x.swapaxes(0, 2)
    x = x.reshape([image_height_x_image_width, -1])
    x = jnp.concatenate([args, x], axis=-1)
    return x

apply_positional_encoding = True # Apply posittional encoding to the input or not
ndl = 8 # num_dense_layers Number of dense layers in MLP
dlw = 256 # dense_layer_width Dimentionality of dense layers' output space 

##########################################<< MLP MODEL >>#########################################
class MLPModel(nn.Module):
    dtype: Any = jnp.float32
    precision: Any = lax.Precision.DEFAULT
    apply_positional_encoding: bool = apply_positional_encoding
    @nn.compact
    def __call__(self, input_points):
        x = positional_encoding(input_points) if self.apply_positional_encoding else input_points
        for i in range(ndl):
            x = nn.Dense(dlw,dtype=self.dtype,precision=self.precision)(x)
            x = nn.relu(x)
            x = jnp.concatenate([x, input_points], axis=-1) if i == 4 else x
  
        x = nn.Dense(1, dtype=self.dtype, precision=self.precision)(x)
        return x
##########################################<< MLP MODEL >>#########################################

@functools.partial(jax.pmap, static_broadcasted_argnums=(1, 2))
def Create_train_state(r_key, shape, learning_rate ):
    print(shape)
    model = MLPModel()
    variables = model.init(r_key, jnp.ones(shape)) 
    optimizer = optax.adam(learning_rate) 
    return train_state.TrainState.create(
        apply_fn = model.apply,
        tx=optimizer,
        params=variables['params']
    )

# learning_rate = 1e-4
# batch_size_no = 64

# model = MLPModel() # Instantiate the Model

In [None]:
@functools.partial(jax.pmap, axis_name='ensemble')
def apply_model(state, batch: jnp.asarray):
  image, label = batch
  def loss_fn(params):
    logits = MLPModel().apply({'params': params}, image)
    loss =  image_difference_loss(logits, label);
    return loss, logits

  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(state.params)
  return grads, loss

@jax.pmap
def update_model(state, grads):
  return state.apply_gradients(grads=grads)

In [None]:
def train_epoch(state, train_ds, batch_size, rng):
  train_ds_size = len(train_ds['image'])
  steps_per_epoch = train_ds_size // batch_size

  perms = jax.random.permutation(rng, len(train_ds['image']))
  perms = perms[:steps_per_epoch * batch_size]
  perms = perms.reshape((steps_per_epoch, batch_size))

  epoch_loss = []

  for perm in perms:
    batch_images = flax.jax_utils.replicate(train_ds['image'][perm, ...])
    batch_labels = flax.jax_utils.replicate(train_ds['label'][perm, ...])
    grads, loss = apply_model(state, batch_images, batch_labels)
    state = update_model(state, grads)
    epoch_loss.append(flax.jax_utils.unreplicate(loss))
  train_loss = np.mean(epoch_loss)
  return state, train_loss

In [None]:
train_ds, test_ds = get_datasets()
test_ds = flax.jax_utils.replicate(test_ds)
rng = jax.random.PRNGKey(0)

rng, init_rng = jax.random.split(rng)

HxW, Channels = next(batches)[0].shape
state = create_train_state(jax.random.split(init_rng, jax.device_count()),(HxW, Channels),learning_rate)

for epoch in range(1, num_epochs + 1):
  rng, input_rng = jax.random.split(rng)
  state, train_loss = train_epoch(state, train_ds, batch_size, input_rng)

  # _, test_loss = flax.jax_utils.unreplicate(apply_model(state, test_ds['image'], test_ds['label']))

  logging.info('epoch:% 3d, train_loss: %.4f ' % (epoch, train_loss))

In [None]:
# same as before, but using @pad_shard_unshard decorator

# manually padding
# => precise & allows for data parallelism

@jax.pmap
def get_preds(variables, inputs):
  print('retrigger compilation', inputs.shape)
  return model.apply(variables, inputs)

ds = tfds.load(dataset_name, split=tfds.split_for_jax_process('test'))
per_host_batch_size = per_device_batch_size * jax.local_device_count()
ds = ds.batch(per_host_batch_size, drop_remainder=False)

correct = total = 0
for batch in ds.as_numpy_iterator():
  preds = flax.jax_utils.pad_shard_unpad(get_preds)(
      vs, batch['image'], min_device_batch=per_device_batch_size)
  total += len(batch['image'])
  correct += (batch['label'] == preds.argmax(axis=-1)).sum()

correct = correct.item()
correct, total, correct / total

In [None]:
def eval_step(metrics, variables, batch):
  print('retrigger compilation', {k: v.shape for k, v in batch.items()})
  preds = model.apply(variables, batch['image'])
  correct = (batch['mask'] & (batch['label'] == preds.argmax(axis=-1))).sum()
  total = batch['mask'].sum()
  return dict(
      correct=metrics['correct'] + jax.lax.psum(correct, axis_name='batch'),
      total=metrics['total'] + jax.lax.psum(total, axis_name='batch'),
  )

eval_step = jax.pmap(eval_step, axis_name='batch')
eval_step = flax.jax_utils.pad_shard_unpad(
    eval_step, static_argnums=(0, 1), static_return=True)

In [None]:
%cd {total_files}
%cd ..
!zip -r folder.zip {total_files}

In [None]:
!cp -r /content/folder.zip /content/drive/MyDrive/OUT/data/machine_learning_test_dataset

In [None]:
!wget https://github.com/1kaiser/Media-Segment-Depth-MLP/releases/download/v0.2/s.zip

In [None]:
!unzip /content/s.zip #unzipping the flower images from archive..

In [None]:
total_files= '/content/s'
input_images = '/content/MEDIAPIPEinput'
out_images = '/content/annotated_images'
!mkdir -p {total_files}
!cp -r {input_images} {total_files}
!cp -r {out_images} {total_files}

In [None]:
%cd {total_files}
!tfds new my_dataset

In [None]:
total_files= '/content/s'
%cd {total_files}/my_dataset/
!tfds build

In [None]:
!rm -r /content/t/my_dataset

In [None]:
import os
def _generate_examples(self, path):
  """Yields examples."""
  # TODO(my_dataset): Yields (key, example) tuples from the dataset
  for f in path.glob('*.png'):
    yield 'key', {
        'MEDIAPIPEinput': f,
        'annotated_images': f,
    }
    
os.path = r'/content/s'
_generate_examples(path / 'MEDIAPIPEinput')

In [None]:
import tensorflow_datasets as tfds
dl_manager = tfds.download.DownloadManager(download_dir='/content')
urls = 'https://github.com/1kaiser/Media-Segment-Depth-MLP/releases/download/v0.2/s.zip'
path = dl_manager.extract(dl_manager.download(urls))

In [None]:
def _generate_examples( img_path):
  # Read the input data out of the source files
  # with img_path.open() as f:
    yield {
        'image': img_path / '*.png',
    }

def _split_generators():
    """Download the data and define splits."""
    dl_manager = tfds.download.DownloadManager(download_dir='/content')
    urls = 'https://github.com/1kaiser/Media-Segment-Depth-MLP/releases/download/v0.2/s.zip'
    path = dl_manager.extract(dl_manager.download(urls))    # dl_manager returns pathlib-like objects with `path.read_text()`,
    # `path.iterdir()`,...
    return {
        'in_image': _generate_examples(path / 'MEDIAPIPEinput'),
        'out_image': _generate_examples(path / 'annotated_images'),
    }
# _generate_examples(path/'MEDIAPIPEinput')
print(_split_generators()['in_image'])

In [None]:
str(path)

In [None]:
next(_generate_examples(path / 'MEDIAPIPEinput'))['image']

In [None]:
class Builder(tfds.core.GeneratorBasedBuilder):
  """DatasetBuilder for my_dataset dataset."""

  VERSION = tfds.core.Version('1.0.0')
  RELEASE_NOTES = {
      '1.0.0': 'Initial release.',
  }

  def _info(self) -> tfds.core.DatasetInfo:
    """Dataset metadata (homepage, citation,...)."""
    return self.dataset_info_from_configs(
        features=tfds.features.FeaturesDict({
            'image': tfds.features.Image(shape=(256, 256, 3)),
            'label': tfds.features.Image(shape=(256, 256, 3)),
        }),
    )

  def _split_generators(self, dl_manager: tfds.download.DownloadManager):
    """Download the data and define splits."""
    urls = 'https://github.com/1kaiser/Media-Segment-Depth-MLP/releases/download/v0.2/s.zip'
    extracted_path = dl_manager.download_and_extract(urls)
    # dl_manager returns pathlib-like objects with `path.read_text()`,
    # `path.iterdir()`,...
    return {
        'train': self._generate_input_examples(path=extracted_path / 'MEDIAPIPEinput'),
        'test': self._generate_output_examples(path=extracted_path / 'annotated_images'),
    }

  def _generate_input_examples(self, path) -> Iterator[Tuple[Key, Example]]:
    """Generator of examples for each split."""
    for img_path in path.glob('*.png'):
      # Yields (key, example)
      yield img_path.name, {
          'image': img_path,
      }
  def _generate_output_examples(self, path) -> Iterator[Tuple[Key, Example]]:
    """Generator of examples for each split."""
    for img_path in path.glob('*.png'):
      # Yields (key, example)
      yield img_path.name, {
          'image': img_path,
      }

In [None]:
# import the modules
import os
from os import listdir
 
# get the path or directory
folder_dir = str(path)+'/MEDIAPIPEinput/'
for images in os.listdir(folder_dir):
 
    # check if the image ends with png or jpg or jpeg
    if (images.endswith(".png") or images.endswith(".jpg") or images.endswith(".jpeg")):
        # display
        print(images)

In [None]:
path

In [None]:
batch_size = 32
img_height = 180
img_width = 180
import tensorflow as tf

tf.keras.utils.image_dataset_from_directory(
  path,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)


In [None]:
train_ds

In [None]:
import pathlib
import tensorflow as tf

urls = 'https://github.com/1kaiser/Media-Segment-Depth-MLP/releases/download/v0.2/s.zip'
data_dir = tf.keras.utils.get_file(origin=urls,
                                   fname='s',
                                   cache_subdir='/content/biy',
                                   archive_format='auto',
                                   untar=False,
                                   extract=True)
data_dir = pathlib.Path(data_dir)

In [None]:
!rm -r /content/biy

In [None]:
data_dir

In [None]:
image_count = len(list(data_dir.glob('*/*.png')))
print(image_count)

In [None]:


import flax.linen as nn
import jax.numpy as jnp
from jax.random import PRNGKey

x = jnp.empty((4, 28, 28, 1)) 

x.reshape((x.shape[0], -1)).shape

class MLP(nn.Module):                              # create a Flax Module dataclass

  @nn.compact
  def __call__(self, input_points):
    # x = x.reshape((x.shape[0], -1))
    # x = nn.Dense(128)(x)                           # create inline Flax Module submodules
    # x = nn.relu(x)
    # x = nn.Dense(1)(x)                 # shape inference
    # return x
    for i in range(8):
      x = nn.Dense(256)(x)
      x = nn.relu(x)
      x = jnp.concatenate([x, input_points], axis=-1) if i == 4 else x
      x = nn.Dense(1)(x)
      return x
model = MLP()                           # instantiate the MLP model

x = jnp.empty((4, 28, 28, 1))                      # generate random data
params = model.init(PRNGKey(42), x)["params"]      # initialize the weights
y = model.apply({"params":params}, x)  

In [None]:
y.shape

In [None]:
positional_encoding_dims = 6  # Number of positional encodings applied
import jax
import jax.numpy as jnp

def positional_encoding(args):
    image_height_x_image_width, cha = args.shape
    inputs_freq = jax.vmap(lambda x: args * 2.0 ** x)(jnp.arange(positional_encoding_dims))
    x = jnp.stack([jnp.sin(inputs_freq), jnp.cos(inputs_freq)])
    x = x.swapaxes(0, 2);print(x.shape)
    x = x.reshape([image_height_x_image_width, -1])
    x = jnp.concatenate([args, x], axis=-1)
    return x

x = jnp.empty((4, 28, 28, 1))
positional_encoding(x.reshape(x.shape[0],-1)).shape

In [None]:
import glob, os
import tensorflow as tf
import numpy as np

PATH = '/content/biy/'
BATCH_SIZE = 12
IMAGE_SIZE = 140

def read_train_data():
    x_files = [f for f in glob.glob(PATH + "MEDIAPIPEinput/*.png", recursive=True)]
    y_files = [f for f in glob.glob(PATH + "annotated_images/*.png", recursive=True)]

    def read_image(x_filename, y_filename):
        x_image_string = tf.io.read_file(x_filename)
        y_image_string = tf.io.read_file(y_filename)

        x_image_decoded = tf.image.decode_jpeg(x_image_string, channels=3)
        y_image_decoded = tf.image.decode_jpeg(y_image_string, channels=3)

        x_image_resized = tf.image.resize(x_image_decoded, [IMAGE_SIZE, IMAGE_SIZE])
        y_image_resized = tf.image.resize(y_image_decoded, [IMAGE_SIZE, IMAGE_SIZE])

        x_image_norm = x_image_resized / 255
        y_image_norm = y_image_resized / 255

        return x_image_norm, y_image_norm

    dataset = tf.data.Dataset.from_tensor_slices((x_files, y_files))

    dataset = dataset.map(read_image).shuffle(1000).batch(BATCH_SIZE)

    return dataset

train_set = read_train_data()
for x, y in train_set.as_numpy_iterator():
  print(x.shape, y.shape)

In [None]:
!python -m pip install -q -U flax
import functools
from flax.training.train_state import TrainState

def dice_coef(y_true, y_pred):
    y_true = jnp.ravel(y_true)
    y_pred = jnp.ravel(y_pred)
    intersection = jnp.sum(y_true * y_pred)
    return 2.0 * intersection / (jnp.sum(y_true) + jnp.sum(y_pred) + 1)


def dice_coef_loss(y_true, y_pred):
    return 1.0 - dice_coef(y_true, y_pred)


class CustomTrainState(TrainState):
    def apply_fn_with_bn(self, *args, is_training, **nargs):
        output = self.apply_fn(*args, **nargs,rngs={'dropout': jax.random.PRNGKey(2)})
        return output

@functools.partial(jax.jit, static_argnums=(3,))
def train_step(x, y, train_state, is_training=True):
    def loss_fn(params, is_training):
        y_pred= train_state.apply_fn_with_bn({"params": params}, x, is_training=is_training)
        loss = dice_coef_loss(y, y_pred)

        return loss

    if is_training:
        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        (loss), grads = grad_fn(train_state.params, True)

        train_state = train_state.apply_gradients(grads=grads)
    else:
        loss = loss_fn(train_state.params, False)

    return loss, train_state

In [None]:
import optax
unet = MLP(out_dims=10)

init_rngs = {'params': jax.random.PRNGKey(0), 'dropout': jax.random.PRNGKey(1)}

unet_variables = unet.init(init_rngs, jnp.ones([1, IMAGE_SIZE, IMAGE_SIZE, 3]))

optimizer = optax.adam(learning_rate=0.001)

train_state = CustomTrainState.create(apply_fn=unet.apply, params=unet_variables["params"], tx=optimizer)


for e in range(20):
        loss_avg = 0
        for x, y in train_set.as_numpy_iterator():
            loss, train_state = train_step(x, y, train_state, True)
            print(f"epoch: {e}, loss: {loss:0.2f}")

In [None]:
x = jnp.empty((4, 28, 28, 3))
c = x[2]
c.shape
c = c.reshape(-1, c.shape[2])
p = positional_encoding(c)
print(p.shape)

In [None]:
import jax
import jax.numpy as jnp
x = jnp.empty((4, 28, 28, 1))
print(x.shape)

def positional_encoding(args):
    print(args.shape)
    image_height_x_image_width, cha = args.shape
    inputs_freq = jax.vmap(lambda x: args * 2.0 ** x)(jnp.arange(6))
    x = jnp.stack([jnp.sin(inputs_freq), jnp.cos(inputs_freq)])
    x = x.swapaxes(0, 2);print(x.shape)
    x = x.reshape([image_height_x_image_width, -1])
    x = jnp.concatenate([args, x], axis=-1)
    return x

    
img_list = []
for i in range(x.shape[0]):
  print(i)
  print(x[i].shape)
  c = x[i]
  c.shape
  c = c.reshape(-1, c.shape[2])
  p = positional_encoding(c)
  img_list.append(p)
  print(p.shape)

In [None]:
print(x.shape, jnp.array(img_list).shape)

##**starting here 🔻**

In [None]:
#✅
!python -m pip install -q -U flax
import flax.linen as nn
from typing import Any
import jax
from jax import lax
import jax.numpy as jnp

positional_encoding_dims = 6  # Number of positional encodings applied

def positional_encoding(args):
    image_height_x_image_width, cha = args.shape
    inputs_freq = jax.vmap(lambda x: args * 2.0 ** x)(jnp.arange(positional_encoding_dims))
    x = jnp.stack([jnp.sin(inputs_freq), jnp.cos(inputs_freq)])
    x = x.swapaxes(0, 2)
    x = x.reshape([image_height_x_image_width, -1])
    x = jnp.concatenate([args, x], axis=-1)
    return x

def batch_encoded(args):
    img_list = []
    for i in range(args.shape[0]):
        c = args[i]
        c = c.reshape(-1, c.shape[2])
        p = positional_encoding(c)
        img_list.append(p.reshape(args.shape[1],args.shape[2],p.shape[1]))
        x = jnp.array(img_list)
    return x

apply_positional_encoding = True # Apply posittional encoding to the input or not
ndl = 8 # num_dense_layers Number of dense layers in MLP
dlw = 256 # dense_layer_width Dimentionality of dense layers' output space 

##########################################<< MLP MODEL >>#########################################
class MLPModel(nn.Module):
    dtype: Any = jnp.float32
    precision: Any = lax.Precision.DEFAULT
    apply_positional_encoding: bool = apply_positional_encoding
    @nn.compact
    def __call__(self, input_points):
      x = batch_encoded(input_points) if self.apply_positional_encoding else input_points
      for i in range(ndl):
          x = nn.Dense(dlw,dtype=self.dtype,precision=self.precision)(x)
          x = nn.relu(x)
          x = jnp.concatenate([x, input_points], axis=-1) if i == 4 else x
      x = nn.Dense(3, dtype=self.dtype, precision=self.precision)(x)
      return x
##########################################<< MLP MODEL >>#########################################

In [None]:
import optax
unet = MLPModel()

init_rngs = {'params': jax.random.PRNGKey(0), 'dropout': jax.random.PRNGKey(1)}
IMAGE_SIZE = 140
unet_variables = unet.init(init_rngs, jnp.ones([7, IMAGE_SIZE, IMAGE_SIZE, 3]))

optimizer = optax.adam(learning_rate=0.001)

from flax.training.train_state import TrainState

class CustomTrainState(TrainState):
    def apply_fn_with_bn(self, *args, is_training, **nargs):
        output = self.apply_fn(*args, **nargs,rngs={'dropout': jax.random.PRNGKey(2)})
        return output

train_state = CustomTrainState.create(apply_fn=unet.apply, params=unet_variables["params"], tx=optimizer)

In [None]:
from jax.tree_util import tree_structure
print(tree_structure(train_state))

In [None]:
urls = 'https://github.com/1kaiser/Media-Segment-Depth-MLP/releases/download/v0.2/s.zip'
data_dir = tf.keras.utils.get_file(origin=urls,
                                   fname='s',
                                   cache_subdir='/content/biy',
                                   archive_format='auto',
                                   untar=False,
                                   extract=True)
data_dir = pathlib.Path(data_dir)

In [None]:
import functools
import glob
import tensorflow as tf

@functools.partial(jax.jit, static_argnums=(3,))
def train_step(x, y, train_state, is_training=True):
    def loss_fn(params, is_training):
        y_pred= train_state.apply_fn_with_bn({"params": params}, x, is_training=is_training)
        loss = dice_coef_loss(y, y_pred)

        return loss

    if is_training:
        grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
        loss, grads = grad_fn(train_state.params, True)

        train_state = train_state.apply_gradients(grads=grads)
    else:
        loss = loss_fn(train_state.params, False)

    return loss, train_state

PATH = '/content/biy/'
BATCH_SIZE = 12
IMAGE_SIZE = 140

def read_train_data():
    x_files = [f for f in glob.glob(PATH + "MEDIAPIPEinput/*.png", recursive=True)]
    y_files = [f for f in glob.glob(PATH + "annotated_images/*.png", recursive=True)]

    def read_image(x_filename, y_filename):
        x_image_string = tf.io.read_file(x_filename)
        y_image_string = tf.io.read_file(y_filename)

        x_image_decoded = tf.image.decode_png(x_image_string, channels=3)
        y_image_decoded = tf.image.decode_png(y_image_string, channels=3)

        x_image_resized = tf.image.resize(x_image_decoded, [IMAGE_SIZE, IMAGE_SIZE])
        y_image_resized = tf.image.resize(y_image_decoded, [IMAGE_SIZE, IMAGE_SIZE])

        x_image_norm = x_image_resized / 255
        y_image_norm = y_image_resized / 255

        return x_image_norm, y_image_norm

    dataset = tf.data.Dataset.from_tensor_slices((x_files, y_files))

    dataset = dataset.map(read_image).shuffle(1000).batch(BATCH_SIZE)

    return dataset

import pathlib
import tensorflow as tf





In [None]:
train_set = read_train_data()


def dice_coef(y_true, y_pred):
    y_true = jnp.ravel(y_true)
    y_pred = jnp.ravel(y_pred)
    intersection = jnp.sum(y_true * y_pred)
    return 2.0 * intersection / (jnp.sum(y_true) + jnp.sum(y_pred) + 1)


def dice_coef_loss(y_true, y_pred):
    return 1.0 - dice_coef(y_true, y_pred)

for e in range(20):
        loss_avg = 0
        for x, y in train_set.as_numpy_iterator():
            loss, train_state = train_step(x, y, train_state, True)
            print(f"epoch: {e}, loss: {loss:0.2f}")

##**ensemble test**

In [None]:
urls = 'https://github.com/1kaiser/Media-Segment-Depth-MLP/releases/download/v0.2/s.zip'
data_dir = tf.keras.utils.get_file(origin=urls,
                                   fname='s',
                                   cache_subdir='/content/biy',
                                   archive_format='auto',
                                   untar=False,
                                   extract=True)
data_dir = pathlib.Path(data_dir)

In [None]:
#✅
!python -m pip install -q -U flax
from typing import Any
import jax
from jax import lax
import jax.numpy as jnp
import optax
import flax
import flax.linen as nn
from flax.training import train_state, common_utils
import functools

positional_encoding_dims = 6  # Number of positional encodings applied

def positional_encoding(args):
    image_height_x_image_width, cha = args.shape
    inputs_freq = jax.vmap(lambda x: args * 2.0 ** x)(jnp.arange(positional_encoding_dims))
    x = jnp.stack([jnp.sin(inputs_freq), jnp.cos(inputs_freq)])
    x = x.swapaxes(0, 2)
    x = x.reshape([image_height_x_image_width, -1])
    x = jnp.concatenate([args, x], axis=-1)
    return x

def batch_encoded(args):
    img_list = []
    for i in range(args.shape[0]):
        c = args[i]
        c = c.reshape(-1, c.shape[2])
        p = positional_encoding(c)
        img_list.append(p.reshape(args.shape[1],args.shape[2],p.shape[1]))
        x = jnp.array(img_list)
    return x

apply_positional_encoding = True # Apply posittional encoding to the input or not
ndl = 8 # num_dense_layers Number of dense layers in MLP
dlw = 256 # dense_layer_width Dimentionality of dense layers' output space 

##########################################<< MLP MODEL >>#########################################
class MLPModel(nn.Module):
    dtype: Any = jnp.float32
    precision: Any = lax.Precision.DEFAULT
    apply_positional_encoding: bool = apply_positional_encoding
    @nn.compact
    def __call__(self, input_points):
      x = batch_encoded(input_points) if self.apply_positional_encoding else input_points
      for i in range(ndl):
          x = nn.Dense(dlw,dtype=self.dtype,precision=self.precision)(x)
          x = nn.relu(x)
          x = jnp.concatenate([x, input_points], axis=-1) if i == 4 else x
      x = nn.Dense(3, dtype=self.dtype, precision=self.precision)(x)
      return x
##########################################<< MLP MODEL >>#########################################

@functools.partial(jax.pmap, static_broadcasted_argnums=(1, 2))
def Create_train_state(r_key, shape, learning_rate ):
    print(shape)
    model = MLPModel()
    variables = model.init(r_key, jnp.ones(shape)) 
    optimizer = optax.adam(learning_rate) 
    return train_state.TrainState.create(
        apply_fn = model.apply,
        tx=optimizer,
        params=variables['params']
    )

# learning_rate = 1e-4
# batch_size_no = 64

# model = MLPModel() # Instantiate the Model

In [None]:
model = MLPModel()

rng = jax.random.PRNGKey(0)

rn_g, init_rng = jax.random.split(rng)
model.init(rng, jnp.ones(next(data_stream())[0].shape)) 

In [None]:
@functools.partial(jax.pmap, axis_name='batch')
def apply_model(state, batch: jnp.asarray):
  image, label = batch
  def loss_fn(params):
    logits = MLPModel().apply({'params': params}, image)
    loss =  image_difference_loss(logits, label);
    return loss, logits

  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(state.params)
  return grads, loss

@jax.pmap
def update_model(state, grads):
  return state.apply_gradients(grads=grads)

In [None]:
def train_epoch(state, train_ds, rng):
  train_ds_size = train_ds.shape[0]
  steps_per_epoch = train_ds_size // batch_size

  perms = jax.random.permutation(rng, train_ds.shape[0])
  perms = perms[:steps_per_epoch * batch_size]
  perms = perms.reshape((steps_per_epoch, batch_size))

  epoch_loss = []

  for perm in perms:
    x_images = flax.jax_utils.replicate(train_ds[perm, ...])
    grads, loss = apply_model(state, train_ds)
    state = update_model(state, grads)
    epoch_loss.append(flax.jax_utils.unreplicate(loss))
  train_loss = np.mean(epoch_loss)
  return state, train_loss

In [None]:
########################################## to load the data in batches as mentioned single batch of images with already provided sizes 
import jax
from jax import random

newsize = (140,140) #(260, 260) # /.... 233 * 454
batch_size = 10

import os
from PIL import Image
import jax.numpy as jnp

def imageGRAY(argv):
    im = Image.open(argv).convert('L')
    tvt, tvu = jnp.asarray(im.resize(newsize)),jnp.asarray(im.resize(newsize)).reshape(-1,1)
    return tvt, tvu
def imageRGB(argv):
    im = Image.open(argv)
    tvt, tvu = jnp.asarray(im.resize(newsize)),jnp.asarray(im.resize(newsize)).reshape(-1,3)
    return tvt, tvu

x_image_dir = r'/content/biy/MEDIAPIPEinput/'
y_image_dir = r'/content/biy/annotated_images/'

#############################################################################
bandend = ["c",".png", "b02"]
expression_b2 = bandend[1]
x_total_images =  [f for f in os.listdir(x_image_dir) if f.__contains__(expression_b2)]
x_total_images.sort()
x_total_images_path = [os.path.join(x_image_dir, i) for i in x_total_images if i != 'outputs']
no_of_batches = int(len(x_total_images_path)/batch_size)


y_total_images =  [f for f in os.listdir(y_image_dir) if f.__contains__(expression_b2)]
y_total_images.sort()
y_total_images_path = [os.path.join(y_image_dir, i) for i in y_total_images if i != 'outputs']


######################################## making 8 array of input for each device >>>
def batchedimages(total_images_path, image_locations):
  RGB8 = jnp.asarray((imageRGB(total_images_path[image_locations[0]])[0]))
  return RGB8


def data_stream():
  key = random.PRNGKey(0)
  perm = random.permutation(key, len(total_images_path))
  x_image_list = []
  y_image_list = []

  for i in range(no_of_batches):
    batch_idx = perm[i * batch_size : (i + 1) * batch_size]; #print(batch_idx)
    x_image_list.append(batchedimages(x_total_images_path, batch_idx))
    y_image_list.append(batchedimages(y_total_images_path, batch_idx))
  yield jnp.array(x_image_list),jnp.array(y_image_list)

batches = data_stream()  


In [None]:
train_ds, test_ds = next(data_stream())
# test_ds = jax_utils.replicate(test_ds)
rng = jax.random.PRNGKey(0)

rn_g, init_rng = jax.random.split(rng)

BATCH, H, W, Channels = next(data_stream())[0].shape
learning_rate = 1e-4

# state = Create_train_state(jax.random.split(init_rng, jax.device_count()),(BATCH, H, W, Channels),learning_rate)
state = Create_train_state(jax.random.split(init_rng, 1),(next(data_stream())[0].shape),learning_rate)
num_epochs = 3
for epoch in range(1, num_epochs + 1):
  rng, input_rng = jax.random.split(rng)
  state, train_loss = train_epoch(state, train_ds, input_rng)

  # _, test_loss = jax_utils.unreplicate(apply_model(state, test_ds['image'], test_ds['label']))

  logging.info('epoch:% 3d, train_loss: %.4f ' % (epoch, train_loss))

In [None]:
# same as before, but using @pad_shard_unshard decorator

# manually padding
# => precise & allows for data parallelism

@jax.pmap
def get_preds(variables, inputs):
  print('retrigger compilation', inputs.shape)
  return model.apply(variables, inputs)

ds = tfds.load(dataset_name, split=tfds.split_for_jax_process('test'))
per_host_batch_size = per_device_batch_size * jax.local_device_count()
ds = ds.batch(per_host_batch_size, drop_remainder=False)

correct = total = 0
for batch in ds.as_numpy_iterator():
  preds = flax.jax_utils.pad_shard_unpad(get_preds)(
      vs, batch['image'], min_device_batch=per_device_batch_size)
  total += len(batch['image'])
  correct += (batch['label'] == preds.argmax(axis=-1)).sum()

correct = correct.item()
correct, total, correct / total

In [None]:
x_train_ds, y_train_ds = next(data_stream())
print(x_train_ds.shape, y_train_ds.shape)

In [None]:
x_train_ds[0].shape


In [None]:
import cv2
from google.colab.patches import cv2_imshow
import numpy as onp
img = onp.array(x_train_ds[0])
cv2_imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))

In [None]:
import jax
import jax.numpy as jnp

from flax import linen as nn
from flax import jax_utils
import optax
from flax.training.train_state import TrainState

model = MLPModel()
x = jnp.ones((2,100, 100, 3))
params = model.init(jax.random.PRNGKey(0), x)
tx = optax.adam(learning_rate=1e-3)
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx,)
state = jax_utils.replicate(state)

def loss_fn(state, x):
    return (model.apply(state.params, x) ** 2.0).mean()

jax.pmap(loss_fn)(state, x)

testing model pmap *bug

In [None]:

positional_encoding_dims = 6  # Number of positional encodings applied

def positional_encoding(args):
    image_height_x_image_width, cha = args.shape
    inputs_freq = jax.vmap(lambda x: args * 2.0 ** x)(jnp.arange(positional_encoding_dims))
    x = jnp.stack([jnp.sin(inputs_freq), jnp.cos(inputs_freq)])
    x = x.swapaxes(0, 2)
    x = x.reshape([image_height_x_image_width, -1])
    x = jnp.concatenate([args, x], axis=-1)
    return x

def batch_encoded(args):
    img_list = []
    for i in range(args.shape[0]):
        c = args[i]
        c = c.reshape(-1, c.shape[2])
        p = positional_encoding(c)
        img_list.append(p.reshape(args.shape[1],args.shape[2],p.shape[1]))
        x = jnp.array(img_list)
    return x

apply_positional_encoding = True # Apply posittional encoding to the input or not
ndl = 8 # num_dense_layers Number of dense layers in MLP
dlw = 256 # dense_layer_width Dimentionality of dense layers' output space 

##########################################<< MLP MODEL >>#########################################
class MLPModel(nn.Module):
    dtype: Any = jnp.float32
    precision: Any = lax.Precision.DEFAULT
    apply_positional_encoding: bool = apply_positional_encoding
    @nn.compact
    def __call__(self, input_points):
      x = batch_encoded(input_points) if self.apply_positional_encoding else input_points
      for i in range(ndl):
          x = nn.Dense(dlw,dtype=self.dtype,precision=self.precision)(x)
          x = nn.relu(x)
          x = jnp.concatenate([x, input_points], axis=-1) if i == 4 else x
      x = nn.Dense(1, dtype=self.dtype, precision=self.precision)(x)
      return x
##########################################<< MLP MODEL >>#########################################

In [None]:
from jax.tree_util import tree_structure
print(tree_structure(state.params))

## **RUN 2 testing** 

In [1]:
import tensorflow as tf
import pathlib
urls = 'https://github.com/1kaiser/Media-Segment-Depth-MLP/releases/download/v0.2/s.zip'
data_dir = tf.keras.utils.get_file(origin=urls,
                                   fname='s',
                                   cache_subdir='/content/',
                                   archive_format='auto',
                                   untar=False,
                                   extract=True)
data_dir = pathlib.Path(data_dir)

Downloading data from https://github.com/1kaiser/Media-Segment-Depth-MLP/releases/download/v0.2/s.zip


**Model and training code**
Our model is a coordinate-based multilayer perceptron. In this example, for each input image coordinate $(x,y)$, the model predicts the associated color $(r,g,b)$ or any $(gray)$.

![Network diagram](https://user-images.githubusercontent.com/3310961/85066930-ad444580-b164-11ea-9cc0-17494679e71f.png)

**POSITIONAL ENCODING BLOCK** 

In [2]:
#✅
import jax
import jax.numpy as jnp


positional_encoding_dims = 6  # Number of positional encodings applied

def positional_encoding(args):
    image_height_x_image_width, cha = args.shape
    inputs_freq = jax.vmap(lambda x: args * 2.0 ** x)(jnp.arange(positional_encoding_dims))
    x = jnp.stack([jnp.sin(inputs_freq), jnp.cos(inputs_freq)])
    x = x.swapaxes(0, 2)
    x = x.reshape([image_height_x_image_width, -1])
    x = jnp.concatenate([args, x], axis=-1)
    return x

def batch_encoded(args):
    img_list = []
    for i in range(args.shape[0]):
        c = args[i]
        c = c.reshape(-1, c.shape[2])
        p = positional_encoding(c)
        img_list.append(p.reshape(args.shape[1],args.shape[2],p.shape[1]))
        x = jnp.array(img_list)
    return x



**MLP MODEL DEFINATION**
Basically, passing input points through a simple Fourier Feature Mapping enables an MLP to learn high-frequency functions (such as an RGB image) in low-dimensional problem domains (such as a 2D coordinate of pixels).

In [3]:
#✅
!python -m pip install -qq -U flax orbax
# Orbax needs to enable asyncio in a Colab environment.
!python -m pip install -qq nest_asyncio


import jax
import jax.numpy as jnp

import flax
import optax
from typing import Any

from jax import lax
import flax.linen as nn
from flax.training import train_state, common_utils

apply_positional_encoding = True # Apply posittional encoding to the input or not
ndl = 8 # num_dense_layers Number of dense layers in MLP
dlw = 256 # dense_layer_width Dimentionality of dense layers' output space 

##########################################<< MLP MODEL >>#########################################
class MLPModel(nn.Module):
    dtype: Any = jnp.float32
    precision: Any = lax.Precision.DEFAULT
    apply_positional_encoding: bool = apply_positional_encoding
    @nn.compact
    def __call__(self, input_points):
      x = batch_encoded(input_points) if self.apply_positional_encoding else input_points
      for i in range(ndl):
          x = nn.Dense(dlw,dtype=self.dtype,precision=self.precision)(x)
          x = nn.relu(x)
          x = jnp.concatenate([x, input_points], axis=-1) if i == 4 else x
      x = nn.Dense(1, dtype=self.dtype, precision=self.precision)(x)
      return x
##########################################<< MLP MODEL >>#########################################

[K     |████████████████████████████████| 197 kB 27.6 MB/s 
[K     |████████████████████████████████| 66 kB 5.7 MB/s 
[K     |████████████████████████████████| 154 kB 54.4 MB/s 
[K     |████████████████████████████████| 238 kB 72.7 MB/s 
[K     |████████████████████████████████| 8.3 MB 52.3 MB/s 
[K     |████████████████████████████████| 51 kB 5.0 MB/s 
[K     |████████████████████████████████| 85 kB 4.9 MB/s 
[?25h

**initialize the module**

In [4]:
#✅
!python -m pip install -q -U flax
import optax
from flax.training import train_state
import jax.numpy as jnp
import jax


def Create_train_state(r_key, model, shape, learning_rate ) -> train_state.TrainState:
    print(shape)
    variables = model.init(r_key, jnp.ones(shape)) 
    optimizer = optax.adam(learning_rate) 
    return train_state.TrainState.create(
        apply_fn = model.apply,
        tx=optimizer,
        params=variables['params']
    )

learning_rate = 1e-4
batch_size_no = 64

model = MLPModel() # Instantiate the Model

**defining loss function**

In [5]:
#serial
def image_difference_loss(logits, labels):
    loss = .5 * jnp.mean((logits - labels) ** 2) 
    return loss
def compute_metrics(*, logits, labels):
  loss = image_difference_loss(logits, labels)
  metrics = {
      'loss': loss,     #LOSS
      'logits': logits, #PREDICTED IMAGE
      'labels': labels  #ACTUAL IMAGE
  }
  return metrics

**train step defination**

In [6]:
#cpu serial
import jax

def train_step(state: train_state.TrainState, batch: jnp.asarray, rng):
    image, label = batch  
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, image);
        loss =  image_difference_loss(logits, label);
        return loss, logits

    gradient_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, logits), grads = gradient_fn(state.params)
    new_state = state.apply_gradients(grads=grads)
    logs = compute_metrics(logits=logits, labels=label)
    return new_state, logs

import jax
@jax.jit
def eval_step(state, image):
    logits = state.apply_fn({'params': state.params}, image)
    return compute_metrics(logits=logits, labels=image)


**image conversion fiunctions**

In [7]:
from PIL import Image
import jax.numpy as jnp
def imageGRAY(argv):
    im = Image.open(argv).convert('L')
    tvt, tvu = jnp.asarray(im.resize(newsize)),jnp.asarray(im.resize(newsize)).reshape(-1,1)
    return tvt, tvu
def imageRGB(argv):
    im = Image.open(argv)
    tvt, tvu = jnp.asarray(im.resize(newsize)),jnp.asarray(im.resize(newsize)).reshape(-1,3)
    return tvt, tvu

**image dataset, image size and batch size Setup**

In [8]:
########################################## to load the data in batches as mentioned single batch of images with already provided sizes 
import jax
from jax import random
import jax.numpy as jnp

newsize = (140,140) #(260, 260) # /.... 233 * 454
batch_size = 40

import os
x_image_dir = r'/content/MEDIAPIPEinput/'
y_image_dir = r'/content/annotated_images/'
#############################################################################
bandend = ["c",".png", "b02"]
expression_b2 = bandend[1]
x_total_images =  [f for f in os.listdir(x_image_dir) if f.__contains__(expression_b2)]
x_total_images.sort()
x_total_images_path = [os.path.join(x_image_dir, i) for i in x_total_images if i != 'outputs']
no_of_batches = int(len(x_total_images_path)/batch_size)


y_total_images =  [f for f in os.listdir(y_image_dir) if f.__contains__(expression_b2)]
y_total_images.sort()
y_total_images_path = [os.path.join(y_image_dir, i) for i in y_total_images if i != 'outputs']
######################################## making 8 array of input for each device >>>
def batchedimages(total_images_path, image_locations):
  RGB8 = jnp.asarray((imageRGB(total_images_path[image_locations[0]])[0]))
  return RGB8

def data_stream():
  key = random.PRNGKey(0)
  perm = random.permutation(key, len(x_total_images_path))
  x_img_list = []
  y_img_list = []
  for i in range(no_of_batches):
    batch_idx = perm[i * batch_size : (i + 1) * batch_size]; #print(batch_idx)
    x_img_list.append(batchedimages(x_total_images_path, batch_idx))
    y_img_list.append(batchedimages(y_total_images_path, batch_idx))
  yield jnp.array(x_img_list), jnp.array(y_img_list)



working here 🔻

In [None]:
batches = data_stream()  
next(batches)[0].shape

In [None]:
import numpy as onp
onp.array(jnp.array(x_img_list))


In [None]:
#@title # **👠HIGH HEELS RUN >>>>>>>>>>>** { vertical-output: true }
newsize = (140,140) #(260, 260) # /.... 233 * 454

import jax
from jax import random
from tqdm import tqdm
import re
from google.colab import output
import orbax.checkpoint as orbax
from flax.training import checkpoints

import optax
import nest_asyncio
nest_asyncio.apply()

rng = jax.random.PRNGKey(0)
CKPT_DIR = 'ckpts'

######################<<<< initiating train state
count = 0
if count == 0 :
  batches = data_stream()
  BATCH, H, W, Channels = next(batches)[0].shape
  state = Create_train_state( rng, model, (BATCH, H, W, Channels ), learning_rate ) 
  count = 1
#✅✅🔻 state = flax.jax_utils.replicate(state)  # FLAX will replicate the state to every device so that updating can be made easy

###################### 
checkpoint_available = 0
pattern = re.compile("checkpoint_\d+")   # to search for "checkpoint_*munerical value*" numerical value of any length is denoted by regular expression "\d+"
dir = "/content/ckpts/"
isFile = os.path.isdir(dir)
if isFile:
  for filepath in os.listdir(dir):
      if pattern.match(filepath):
          checkpoint_available = 1

total_epochs = 50
for epochs in tqdm(range(no_of_batches-5)):  
  batches = data_stream()

  if checkpoint_available:
    state = checkpoints.restore_checkpoint(ckpt_dir=CKPT_DIR, target=state)
    checkpoint_available = 0 # << Flag updated >>> to stop loading the same checkpoint in the next iteration then remove the checkpoint directory
    !rm -r {dir}
  input_data = next(batches)
  for bbb in tqdm(range(total_epochs)):
    state, metrics = train_step(state, input_data, rng)
    # output.clear()
    print("loss: ",metrics['loss']," <<< ") # naming of the checkpoint is "checkpoint_*"  where "*" => value of the steps variable, i.e. 'epochs'
  orbax_checkpointer = orbax.Checkpointer(orbax.PyTreeCheckpointHandler())
  checkpoints.save_checkpoint(ckpt_dir=CKPT_DIR, target=state, step=epochs, prefix='checkpoint_', keep=1, overwrite=False, orbax_checkpointer=orbax_checkpointer)
  # restored_state = checkpoints.restore_checkpoint(ckpt_dir=CKPT_DIR, target=state) # using to get the checkpoint loaded , it can be latest one , or if already available as checkpoint in the "CKPT_DIR" directory then take the file from directory then save in >> restored_checkpoints
  ##################################################



(25, 140, 140, 3)


  0%|          | 0/20 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s][A
  2%|▏         | 1/50 [00:13<11:18, 13.85s/it][A

loss:  28754.871  <<< 



  4%|▍         | 2/50 [00:14<05:02,  6.30s/it][A

loss:  28537.637  <<< 



  6%|▌         | 3/50 [00:15<03:01,  3.87s/it][A

loss:  28328.125  <<< 



  8%|▊         | 4/50 [00:16<02:05,  2.72s/it][A

loss:  28122.154  <<< 



 10%|█         | 5/50 [00:17<01:33,  2.09s/it][A

loss:  27919.215  <<< 



 12%|█▏        | 6/50 [00:18<01:14,  1.70s/it][A

loss:  27715.63  <<< 



 14%|█▍        | 7/50 [00:19<01:02,  1.46s/it][A

loss:  27506.05  <<< 



 16%|█▌        | 8/50 [00:20<00:54,  1.31s/it][A

loss:  27288.76  <<< 



 18%|█▊        | 9/50 [00:21<00:49,  1.20s/it][A

loss:  27062.285  <<< 



 20%|██        | 10/50 [00:22<00:45,  1.13s/it][A

loss:  26827.566  <<< 



 22%|██▏       | 11/50 [00:23<00:41,  1.07s/it][A

loss:  26585.457  <<< 



 24%|██▍       | 12/50 [00:24<00:39,  1.03s/it][A

loss:  26336.986  <<< 



 26%|██▌       | 13/50 [00:25<00:37,  1.01s/it][A

loss:  26079.576  <<< 



 28%|██▊       | 14/50 [00:26<00:36,  1.00s/it][A

loss:  25807.354  <<< 



 30%|███       | 15/50 [00:27<00:34,  1.00it/s][A

loss:  25522.826  <<< 



 32%|███▏      | 16/50 [00:28<00:33,  1.01it/s][A

loss:  25220.564  <<< 



 34%|███▍      | 17/50 [00:29<00:32,  1.01it/s][A

loss:  24896.145  <<< 



 36%|███▌      | 18/50 [00:30<00:31,  1.01it/s][A

loss:  24546.928  <<< 



 38%|███▊      | 19/50 [00:31<00:30,  1.02it/s][A

loss:  24172.424  <<< 



 40%|████      | 20/50 [00:32<00:29,  1.00it/s][A

loss:  23771.852  <<< 



 42%|████▏     | 21/50 [00:33<00:28,  1.01it/s][A

loss:  23342.11  <<< 



 44%|████▍     | 22/50 [00:34<00:27,  1.01it/s][A

loss:  22878.37  <<< 



 46%|████▌     | 23/50 [00:35<00:29,  1.10s/it][A

loss:  22372.623  <<< 



 48%|████▊     | 24/50 [00:36<00:27,  1.07s/it][A

loss:  21823.19  <<< 



 50%|█████     | 25/50 [00:37<00:25,  1.04s/it][A

loss:  21235.846  <<< 



 52%|█████▏    | 26/50 [00:38<00:24,  1.02s/it][A

loss:  20607.326  <<< 



 54%|█████▍    | 27/50 [00:39<00:23,  1.00s/it][A

loss:  19936.164  <<< 



 56%|█████▌    | 28/50 [00:40<00:21,  1.00it/s][A

loss:  19220.477  <<< 



 58%|█████▊    | 29/50 [00:41<00:20,  1.01it/s][A

loss:  18459.773  <<< 



 60%|██████    | 30/50 [00:42<00:19,  1.02it/s][A

loss:  17655.35  <<< 



 62%|██████▏   | 31/50 [00:43<00:18,  1.02it/s][A

loss:  16814.012  <<< 



 64%|██████▍   | 32/50 [00:44<00:17,  1.01it/s][A

loss:  15943.295  <<< 



 66%|██████▌   | 33/50 [00:45<00:16,  1.02it/s][A

loss:  15046.98  <<< 



 68%|██████▊   | 34/50 [00:46<00:15,  1.02it/s][A

loss:  14145.505  <<< 



 70%|███████   | 35/50 [00:47<00:14,  1.02it/s][A

loss:  13263.837  <<< 



 72%|███████▏  | 36/50 [00:48<00:13,  1.02it/s][A

loss:  12431.478  <<< 



 74%|███████▍  | 37/50 [00:49<00:12,  1.02it/s][A

loss:  11685.08  <<< 



 76%|███████▌  | 38/50 [00:50<00:11,  1.02it/s][A

loss:  11067.611  <<< 



 78%|███████▊  | 39/50 [00:51<00:10,  1.02it/s][A

loss:  10627.331  <<< 



 80%|████████  | 40/50 [00:52<00:09,  1.00it/s][A

loss:  10407.995  <<< 



 82%|████████▏ | 41/50 [00:53<00:08,  1.01it/s][A

loss:  10430.742  <<< 



 84%|████████▍ | 42/50 [00:54<00:07,  1.01it/s][A

loss:  10666.375  <<< 



 86%|████████▌ | 43/50 [00:55<00:06,  1.01it/s][A

loss:  11010.676  <<< 



 88%|████████▊ | 44/50 [00:56<00:05,  1.01it/s][A

loss:  11323.494  <<< 



 90%|█████████ | 45/50 [00:57<00:04,  1.01it/s][A

loss:  11497.032  <<< 



 92%|█████████▏| 46/50 [00:58<00:03,  1.01it/s][A

loss:  11497.214  <<< 



 94%|█████████▍| 47/50 [00:59<00:02,  1.01it/s][A

loss:  11352.579  <<< 



 96%|█████████▌| 48/50 [01:00<00:01,  1.01it/s][A

loss:  11120.849  <<< 



 98%|█████████▊| 49/50 [01:01<00:01,  1.00s/it][A

loss:  10865.399  <<< 



100%|██████████| 50/50 [01:02<00:00,  1.25s/it]
  5%|▌         | 1/20 [01:03<20:11, 63.78s/it]

loss:  10632.769  <<< 



  0%|          | 0/50 [00:00<?, ?it/s][A
  2%|▏         | 1/50 [00:00<00:48,  1.01it/s][A

loss:  10450.304  <<< 



  4%|▍         | 2/50 [00:01<00:47,  1.01it/s][A

loss:  10329.35  <<< 



  6%|▌         | 3/50 [00:02<00:46,  1.00it/s][A

loss:  10265.895  <<< 



  8%|▊         | 4/50 [00:04<00:46,  1.01s/it][A

loss:  10247.979  <<< 



 10%|█         | 5/50 [00:05<00:45,  1.01s/it][A

loss:  10258.464  <<< 



 12%|█▏        | 6/50 [00:06<00:44,  1.02s/it][A

loss:  10280.949  <<< 



 14%|█▍        | 7/50 [00:07<00:43,  1.01s/it][A

loss:  10301.476  <<< 



 16%|█▌        | 8/50 [00:08<00:42,  1.01s/it][A

loss:  10309.936  <<< 



 18%|█▊        | 9/50 [00:09<00:41,  1.01s/it][A

loss:  10300.92  <<< 



 20%|██        | 10/50 [00:10<00:41,  1.05s/it][A

loss:  10273.928  <<< 



 22%|██▏       | 11/50 [00:11<00:40,  1.04s/it][A

loss:  10231.542  <<< 



 24%|██▍       | 12/50 [00:12<00:38,  1.02s/it][A

loss:  10179.993  <<< 



 26%|██▌       | 13/50 [00:13<00:37,  1.01s/it][A

loss:  10127.582  <<< 



 28%|██▊       | 14/50 [00:14<00:36,  1.00s/it][A

loss:  10084.3125  <<< 



 30%|███       | 15/50 [00:15<00:35,  1.00s/it][A

loss:  10057.676  <<< 



 32%|███▏      | 16/50 [00:16<00:34,  1.01s/it][A

loss:  10049.546  <<< 



 34%|███▍      | 17/50 [00:17<00:33,  1.01s/it][A

loss:  10050.646  <<< 



 36%|███▌      | 18/50 [00:18<00:31,  1.00it/s][A

loss:  10046.13  <<< 



 38%|███▊      | 19/50 [00:19<00:30,  1.01it/s][A

loss:  10027.194  <<< 



 40%|████      | 20/50 [00:20<00:30,  1.00s/it][A

loss:  9994.737  <<< 



 42%|████▏     | 21/50 [00:21<00:33,  1.16s/it][A

loss:  9955.605  <<< 



 44%|████▍     | 22/50 [00:22<00:32,  1.18s/it][A

loss:  9917.939  <<< 



 46%|████▌     | 23/50 [00:24<00:31,  1.17s/it][A

loss:  9886.802  <<< 



 48%|████▊     | 24/50 [00:25<00:28,  1.10s/it][A

loss:  9861.686  <<< 



 50%|█████     | 25/50 [00:25<00:26,  1.07s/it][A

loss:  9840.782  <<< 



 52%|█████▏    | 26/50 [00:26<00:24,  1.04s/it][A

loss:  9820.378  <<< 



 54%|█████▍    | 27/50 [00:27<00:23,  1.01s/it][A

loss:  9797.54  <<< 



 56%|█████▌    | 28/50 [00:28<00:22,  1.01s/it][A

loss:  9772.052  <<< 



 58%|█████▊    | 29/50 [00:29<00:20,  1.00it/s][A

loss:  9744.899  <<< 



 60%|██████    | 30/50 [00:30<00:19,  1.01it/s][A

loss:  9717.051  <<< 



 62%|██████▏   | 31/50 [00:31<00:18,  1.01it/s][A

loss:  9690.079  <<< 



 64%|██████▍   | 32/50 [00:32<00:17,  1.02it/s][A

loss:  9665.462  <<< 



 66%|██████▌   | 33/50 [00:33<00:16,  1.02it/s][A

loss:  9639.849  <<< 



 68%|██████▊   | 34/50 [00:34<00:15,  1.03it/s][A

loss:  9611.937  <<< 



 70%|███████   | 35/50 [00:35<00:14,  1.02it/s][A

loss:  9582.74  <<< 



 72%|███████▏  | 36/50 [00:36<00:13,  1.02it/s][A

loss:  9551.5205  <<< 



 74%|███████▍  | 37/50 [00:37<00:12,  1.02it/s][A

loss:  9519.931  <<< 



 76%|███████▌  | 38/50 [00:38<00:11,  1.02it/s][A

loss:  9489.602  <<< 



 78%|███████▊  | 39/50 [00:39<00:10,  1.02it/s][A

loss:  9458.983  <<< 



 80%|████████  | 40/50 [00:40<00:09,  1.02it/s][A

loss:  9426.339  <<< 



 82%|████████▏ | 41/50 [00:41<00:08,  1.02it/s][A

loss:  9392.011  <<< 



 84%|████████▍ | 42/50 [00:42<00:07,  1.02it/s][A

loss:  9357.237  <<< 



 86%|████████▌ | 43/50 [00:43<00:06,  1.02it/s][A

loss:  9322.126  <<< 



 88%|████████▊ | 44/50 [00:44<00:05,  1.02it/s][A

loss:  9285.103  <<< 



 90%|█████████ | 45/50 [00:45<00:04,  1.02it/s][A

loss:  9247.359  <<< 



 92%|█████████▏| 46/50 [00:46<00:03,  1.01it/s][A

loss:  9210.663  <<< 



 94%|█████████▍| 47/50 [00:47<00:02,  1.01it/s][A

loss:  9172.862  <<< 



 96%|█████████▌| 48/50 [00:48<00:01,  1.02it/s][A

loss:  9132.825  <<< 



 98%|█████████▊| 49/50 [00:49<00:00,  1.01it/s][A

loss:  9091.436  <<< 



100%|██████████| 50/50 [00:50<00:00,  1.01s/it]
 10%|█         | 2/20 [01:55<16:56, 56.48s/it]

loss:  9050.038  <<< 



  0%|          | 0/50 [00:00<?, ?it/s][A
  2%|▏         | 1/50 [00:00<00:46,  1.05it/s][A

loss:  9010.274  <<< 



  4%|▍         | 2/50 [00:01<00:46,  1.04it/s][A

loss:  8971.456  <<< 



  6%|▌         | 3/50 [00:02<00:45,  1.04it/s][A

loss:  8931.461  <<< 



  8%|▊         | 4/50 [00:03<00:44,  1.03it/s][A

loss:  8890.312  <<< 



 10%|█         | 5/50 [00:04<00:43,  1.03it/s][A

loss:  8848.832  <<< 



 12%|█▏        | 6/50 [00:05<00:42,  1.04it/s][A

loss:  8807.726  <<< 



 14%|█▍        | 7/50 [00:06<00:41,  1.03it/s][A

loss:  8766.142  <<< 



 16%|█▌        | 8/50 [00:07<00:40,  1.03it/s][A

loss:  8724.468  <<< 



 18%|█▊        | 9/50 [00:08<00:39,  1.04it/s][A

loss:  8682.933  <<< 



 20%|██        | 10/50 [00:09<00:38,  1.03it/s][A

loss:  8641.572  <<< 



 22%|██▏       | 11/50 [00:10<00:38,  1.03it/s][A

loss:  8600.185  <<< 



 24%|██▍       | 12/50 [00:11<00:37,  1.02it/s][A

loss:  8559.633  <<< 



 26%|██▌       | 13/50 [00:12<00:36,  1.02it/s][A

loss:  8519.321  <<< 



 28%|██▊       | 14/50 [00:13<00:35,  1.02it/s][A

loss:  8479.625  <<< 



 30%|███       | 15/50 [00:14<00:34,  1.02it/s][A

loss:  8441.1045  <<< 



 32%|███▏      | 16/50 [00:15<00:33,  1.02it/s][A

loss:  8403.968  <<< 



 34%|███▍      | 17/50 [00:16<00:32,  1.02it/s][A

loss:  8368.898  <<< 



 36%|███▌      | 18/50 [00:17<00:31,  1.01it/s][A

loss:  8332.963  <<< 



 38%|███▊      | 19/50 [00:18<00:30,  1.01it/s][A

loss:  8296.384  <<< 



 40%|████      | 20/50 [00:19<00:29,  1.01it/s][A

loss:  8259.176  <<< 



 42%|████▏     | 21/50 [00:20<00:28,  1.01it/s][A

loss:  8222.014  <<< 



 44%|████▍     | 22/50 [00:21<00:27,  1.01it/s][A

loss:  8185.247  <<< 



 46%|████▌     | 23/50 [00:22<00:26,  1.01it/s][A

loss:  8147.802  <<< 



 48%|████▊     | 24/50 [00:23<00:25,  1.01it/s][A

loss:  8109.408  <<< 



 50%|█████     | 25/50 [00:24<00:24,  1.01it/s][A

loss:  8070.344  <<< 



 52%|█████▏    | 26/50 [00:25<00:23,  1.01it/s][A

loss:  8030.9424  <<< 



 54%|█████▍    | 27/50 [00:26<00:22,  1.01it/s][A

loss:  7991.081  <<< 



 56%|█████▌    | 28/50 [00:27<00:21,  1.02it/s][A

loss:  7950.897  <<< 



 58%|█████▊    | 29/50 [00:28<00:20,  1.01it/s][A

loss:  7910.7666  <<< 



 60%|██████    | 30/50 [00:29<00:19,  1.01it/s][A

loss:  7870.08  <<< 



 62%|██████▏   | 31/50 [00:30<00:18,  1.01it/s][A

loss:  7828.41  <<< 



 64%|██████▍   | 32/50 [00:31<00:17,  1.00it/s][A

loss:  7785.983  <<< 



 66%|██████▌   | 33/50 [00:32<00:17,  1.01s/it][A

loss:  7743.4116  <<< 



 68%|██████▊   | 34/50 [00:33<00:16,  1.00s/it][A

loss:  7701.199  <<< 



 70%|███████   | 35/50 [00:34<00:14,  1.00it/s][A

loss:  7658.193  <<< 



 72%|███████▏  | 36/50 [00:35<00:13,  1.01it/s][A

loss:  7613.5107  <<< 



 74%|███████▍  | 37/50 [00:36<00:12,  1.01it/s][A

loss:  7568.5454  <<< 



 76%|███████▌  | 38/50 [00:37<00:12,  1.00s/it][A

loss:  7524.7764  <<< 



 78%|███████▊  | 39/50 [00:38<00:11,  1.00s/it][A

loss:  7482.6294  <<< 



 80%|████████  | 40/50 [00:39<00:09,  1.00it/s][A

loss:  7443.8965  <<< 



 82%|████████▏ | 41/50 [00:40<00:08,  1.00it/s][A

loss:  7410.9355  <<< 



 84%|████████▍ | 42/50 [00:41<00:07,  1.00it/s][A

loss:  7365.278  <<< 



 86%|████████▌ | 43/50 [00:42<00:06,  1.00it/s][A

loss:  7299.0063  <<< 



 88%|████████▊ | 44/50 [00:43<00:05,  1.00it/s][A

loss:  7260.2065  <<< 



 90%|█████████ | 45/50 [00:44<00:04,  1.01it/s][A

loss:  7227.212  <<< 



 92%|█████████▏| 46/50 [00:45<00:03,  1.00it/s][A

loss:  7163.0938  <<< 



 94%|█████████▍| 47/50 [00:46<00:03,  1.04s/it][A

loss:  7110.6885  <<< 



 96%|█████████▌| 48/50 [00:47<00:02,  1.06s/it][A

loss:  7076.2026  <<< 



 98%|█████████▊| 49/50 [00:48<00:01,  1.04s/it][A

loss:  7024.577  <<< 



100%|██████████| 50/50 [00:49<00:00,  1.01it/s]
 15%|█▌        | 3/20 [02:45<15:14, 53.78s/it]

loss:  6960.582  <<< 



  0%|          | 0/50 [00:00<?, ?it/s][A
  2%|▏         | 1/50 [00:00<00:46,  1.05it/s][A

loss:  6907.1675  <<< 



  4%|▍         | 2/50 [00:01<00:46,  1.03it/s][A

loss:  6864.3955  <<< 



  6%|▌         | 3/50 [00:02<00:46,  1.02it/s][A

loss:  6814.68  <<< 



  8%|▊         | 4/50 [00:03<00:45,  1.02it/s][A

loss:  6751.002  <<< 



 10%|█         | 5/50 [00:04<00:44,  1.01it/s][A

loss:  6690.4854  <<< 



 12%|█▏        | 6/50 [00:05<00:43,  1.01it/s][A

loss:  6639.379  <<< 



 14%|█▍        | 7/50 [00:06<00:42,  1.01it/s][A

loss:  6592.505  <<< 



 16%|█▌        | 8/50 [00:07<00:41,  1.01it/s][A

loss:  6545.811  <<< 



 18%|█▊        | 9/50 [00:08<00:40,  1.01it/s][A

loss:  6485.8145  <<< 



 20%|██        | 10/50 [00:09<00:39,  1.01it/s][A

loss:  6419.796  <<< 



 22%|██▏       | 11/50 [00:10<00:39,  1.00s/it][A

loss:  6350.5176  <<< 



 24%|██▍       | 12/50 [00:12<00:45,  1.20s/it][A

loss:  6288.3726  <<< 



 26%|██▌       | 13/50 [00:14<00:48,  1.31s/it][A

loss:  6232.1113  <<< 



 28%|██▊       | 14/50 [00:15<00:48,  1.34s/it][A

loss:  6185.197  <<< 



 30%|███       | 15/50 [00:16<00:48,  1.38s/it][A

loss:  6164.734  <<< 



 32%|███▏      | 16/50 [00:17<00:42,  1.25s/it][A

loss:  6127.0674  <<< 



 34%|███▍      | 17/50 [00:18<00:38,  1.17s/it][A

loss:  6056.4023  <<< 



 36%|███▌      | 18/50 [00:19<00:35,  1.12s/it][A

loss:  5925.1157  <<< 



 38%|███▊      | 19/50 [00:20<00:33,  1.08s/it][A

loss:  5904.472  <<< 



 40%|████      | 20/50 [00:21<00:31,  1.05s/it][A

loss:  5899.2817  <<< 



 42%|████▏     | 21/50 [00:22<00:30,  1.04s/it][A

loss:  5748.6724  <<< 



 44%|████▍     | 22/50 [00:23<00:28,  1.02s/it][A

loss:  5698.7544  <<< 



 46%|████▌     | 23/50 [00:24<00:27,  1.01s/it][A

loss:  5701.4307  <<< 



 48%|████▊     | 24/50 [00:25<00:26,  1.00s/it][A

loss:  5568.737  <<< 



 50%|█████     | 25/50 [00:26<00:24,  1.00it/s][A

loss:  5489.5156  <<< 



 52%|█████▏    | 26/50 [00:27<00:23,  1.01it/s][A

loss:  5482.0767  <<< 



 54%|█████▍    | 27/50 [00:28<00:22,  1.01it/s][A

loss:  5402.6167  <<< 



 56%|█████▌    | 28/50 [00:29<00:21,  1.01it/s][A

loss:  5295.4844  <<< 



 58%|█████▊    | 29/50 [00:30<00:20,  1.01it/s][A

loss:  5221.9673  <<< 



 60%|██████    | 30/50 [00:31<00:19,  1.01it/s][A

loss:  5195.3657  <<< 



 62%|██████▏   | 31/50 [00:32<00:18,  1.01it/s][A

loss:  5184.3135  <<< 



 64%|██████▍   | 32/50 [00:33<00:17,  1.01it/s][A

loss:  5083.0835  <<< 



 66%|██████▌   | 33/50 [00:34<00:16,  1.02it/s][A

loss:  4974.543  <<< 



 68%|██████▊   | 34/50 [00:35<00:15,  1.02it/s][A

loss:  4887.237  <<< 



 70%|███████   | 35/50 [00:36<00:14,  1.02it/s][A

loss:  4841.6016  <<< 



 72%|███████▏  | 36/50 [00:37<00:13,  1.02it/s][A

loss:  4827.4736  <<< 



 74%|███████▍  | 37/50 [00:38<00:12,  1.01it/s][A

loss:  4782.0117  <<< 



 76%|███████▌  | 38/50 [00:39<00:11,  1.01it/s][A

loss:  4735.196  <<< 



 78%|███████▊  | 39/50 [00:40<00:10,  1.01it/s][A

loss:  4577.356  <<< 



 80%|████████  | 40/50 [00:41<00:10,  1.00s/it][A

loss:  4471.9688  <<< 



 82%|████████▏ | 41/50 [00:42<00:08,  1.01it/s][A

loss:  4430.9414  <<< 



 84%|████████▍ | 42/50 [00:43<00:07,  1.01it/s][A

loss:  4423.3174  <<< 



 86%|████████▌ | 43/50 [00:44<00:06,  1.02it/s][A

loss:  4458.381  <<< 



 88%|████████▊ | 44/50 [00:45<00:05,  1.01it/s][A

loss:  4278.5806  <<< 



 90%|█████████ | 45/50 [00:46<00:04,  1.01it/s][A

loss:  4138.3345  <<< 



 92%|█████████▏| 46/50 [00:47<00:03,  1.00it/s][A

loss:  4049.3872  <<< 



 94%|█████████▍| 47/50 [00:48<00:02,  1.01it/s][A

loss:  4024.031  <<< 



 96%|█████████▌| 48/50 [00:49<00:02,  1.00s/it][A

loss:  4069.0017  <<< 



 98%|█████████▊| 49/50 [00:50<00:00,  1.00it/s][A

loss:  4002.7722  <<< 



100%|██████████| 50/50 [00:51<00:00,  1.03s/it]
 20%|██        | 4/20 [03:38<14:12, 53.28s/it]

loss:  3986.744  <<< 



  0%|          | 0/50 [00:00<?, ?it/s][A
  2%|▏         | 1/50 [00:00<00:47,  1.02it/s][A

loss:  3734.394  <<< 



  4%|▍         | 2/50 [00:01<00:48,  1.00s/it][A

loss:  3649.366  <<< 



  6%|▌         | 3/50 [00:03<00:47,  1.01s/it][A

loss:  3710.48  <<< 



  8%|▊         | 4/50 [00:04<00:46,  1.01s/it][A

loss:  3688.6672  <<< 



 10%|█         | 5/50 [00:05<00:44,  1.00it/s][A

loss:  3689.4858  <<< 



 12%|█▏        | 6/50 [00:06<00:43,  1.00it/s][A

loss:  3404.986  <<< 



 14%|█▍        | 7/50 [00:07<00:43,  1.00s/it][A

loss:  3378.1208  <<< 



 16%|█▌        | 8/50 [00:08<00:42,  1.01s/it][A

loss:  3549.7349  <<< 



 18%|█▊        | 9/50 [00:09<00:41,  1.01s/it][A

loss:  3319.4146  <<< 



 20%|██        | 10/50 [00:10<00:40,  1.00s/it][A

loss:  3163.1206  <<< 



 22%|██▏       | 11/50 [00:11<00:39,  1.00s/it][A

loss:  3099.8948  <<< 



 24%|██▍       | 12/50 [00:12<00:37,  1.00it/s][A

loss:  3126.727  <<< 



 26%|██▌       | 13/50 [00:13<00:36,  1.00it/s][A

loss:  3253.546  <<< 



 28%|██▊       | 14/50 [00:14<00:37,  1.04s/it][A

loss:  3042.1667  <<< 



 30%|███       | 15/50 [00:15<00:35,  1.02s/it][A

loss:  2915.2554  <<< 



 32%|███▏      | 16/50 [00:16<00:34,  1.02s/it][A

loss:  2823.1272  <<< 



 34%|███▍      | 17/50 [00:17<00:33,  1.01s/it][A

loss:  2794.6272  <<< 



 36%|███▌      | 18/50 [00:18<00:32,  1.00s/it][A

loss:  2831.4563  <<< 



 38%|███▊      | 19/50 [00:19<00:31,  1.00s/it][A

loss:  2856.0276  <<< 



 40%|████      | 20/50 [00:20<00:30,  1.00s/it][A

loss:  3068.698  <<< 



 42%|████▏     | 21/50 [00:21<00:29,  1.01s/it][A

loss:  2649.2012  <<< 



 44%|████▍     | 22/50 [00:22<00:28,  1.00s/it][A

loss:  2556.617  <<< 



 46%|████▌     | 23/50 [00:23<00:27,  1.00s/it][A

loss:  2704.0242  <<< 



 48%|████▊     | 24/50 [00:24<00:26,  1.00s/it][A

loss:  2668.3057  <<< 



 50%|█████     | 25/50 [00:25<00:25,  1.00s/it][A

loss:  2688.6885  <<< 



 52%|█████▏    | 26/50 [00:26<00:23,  1.00it/s][A

loss:  2400.0771  <<< 



 54%|█████▍    | 27/50 [00:27<00:22,  1.01it/s][A

loss:  2445.0557  <<< 



 56%|█████▌    | 28/50 [00:28<00:21,  1.01it/s][A

loss:  2733.8455  <<< 



 58%|█████▊    | 29/50 [00:29<00:20,  1.01it/s][A

loss:  2332.9575  <<< 



 60%|██████    | 30/50 [00:30<00:19,  1.02it/s][A

loss:  2314.4983  <<< 



 62%|██████▏   | 31/50 [00:31<00:18,  1.02it/s][A

loss:  2584.1904  <<< 



 64%|██████▍   | 32/50 [00:32<00:17,  1.02it/s][A

loss:  2253.763  <<< 



 66%|██████▌   | 33/50 [00:33<00:16,  1.01it/s][A

loss:  2188.7021  <<< 



 68%|██████▊   | 34/50 [00:33<00:15,  1.01it/s][A

loss:  2344.9114  <<< 



 70%|███████   | 35/50 [00:34<00:14,  1.01it/s][A

loss:  2209.1753  <<< 



 72%|███████▏  | 36/50 [00:35<00:13,  1.01it/s][A

loss:  2107.9197  <<< 



 74%|███████▍  | 37/50 [00:36<00:12,  1.02it/s][A

loss:  2062.5122  <<< 



 76%|███████▌  | 38/50 [00:37<00:11,  1.01it/s][A

loss:  2095.726  <<< 



 78%|███████▊  | 39/50 [00:38<00:10,  1.02it/s][A

loss:  2182.333  <<< 



 80%|████████  | 40/50 [00:39<00:09,  1.01it/s][A

loss:  2062.568  <<< 



 82%|████████▏ | 41/50 [00:40<00:08,  1.01it/s][A

loss:  1992.3486  <<< 



 84%|████████▍ | 42/50 [00:41<00:07,  1.02it/s][A

loss:  1946.2565  <<< 



 86%|████████▌ | 43/50 [00:42<00:06,  1.02it/s][A

loss:  1943.0908  <<< 



 88%|████████▊ | 44/50 [00:43<00:05,  1.01it/s][A

loss:  1982.1017  <<< 



 90%|█████████ | 45/50 [00:44<00:04,  1.01it/s][A

loss:  2000.2202  <<< 



 92%|█████████▏| 46/50 [00:45<00:03,  1.01it/s][A

loss:  2117.794  <<< 



 94%|█████████▍| 47/50 [00:46<00:02,  1.00it/s][A

loss:  1963.9856  <<< 



 96%|█████████▌| 48/50 [00:47<00:01,  1.00it/s][A

loss:  1906.8054  <<< 



 98%|█████████▊| 49/50 [00:48<00:00,  1.01it/s][A

loss:  1840.1284  <<< 



100%|██████████| 50/50 [00:49<00:00,  1.00it/s]
 25%|██▌       | 5/20 [04:28<13:05, 52.36s/it]

loss:  1817.4385  <<< 



  0%|          | 0/50 [00:00<?, ?it/s][A
  2%|▏         | 1/50 [00:00<00:47,  1.04it/s][A

loss:  1826.3861  <<< 



  4%|▍         | 2/50 [00:01<00:46,  1.03it/s][A

loss:  1850.9569  <<< 



  6%|▌         | 3/50 [00:02<00:45,  1.03it/s][A

loss:  1939.1897  <<< 



  8%|▊         | 4/50 [00:03<00:45,  1.01it/s][A

loss:  1896.0958  <<< 



 10%|█         | 5/50 [00:04<00:44,  1.01it/s][A

loss:  1948.5027  <<< 



 12%|█▏        | 6/50 [00:05<00:43,  1.00it/s][A

loss:  1808.1699  <<< 



 14%|█▍        | 7/50 [00:06<00:42,  1.01it/s][A

loss:  1749.178  <<< 



 16%|█▌        | 8/50 [00:07<00:41,  1.01it/s][A

loss:  1729.9084  <<< 



 18%|█▊        | 9/50 [00:08<00:40,  1.00it/s][A

loss:  1747.6375  <<< 



 20%|██        | 10/50 [00:09<00:39,  1.00it/s][A

loss:  1804.9558  <<< 



 22%|██▏       | 11/50 [00:10<00:38,  1.00it/s][A

loss:  1802.5931  <<< 



 24%|██▍       | 12/50 [00:11<00:37,  1.00it/s][A

loss:  1859.5706  <<< 



 26%|██▌       | 13/50 [00:12<00:37,  1.00s/it][A

loss:  1754.2666  <<< 



 28%|██▊       | 14/50 [00:13<00:36,  1.00s/it][A

loss:  1707.2109  <<< 



 30%|███       | 15/50 [00:14<00:34,  1.00it/s][A

loss:  1675.412  <<< 



 32%|███▏      | 16/50 [00:15<00:33,  1.00it/s][A

loss:  1674.6133  <<< 



 34%|███▍      | 17/50 [00:16<00:32,  1.00it/s][A

loss:  1698.05  <<< 



 36%|███▌      | 18/50 [00:17<00:31,  1.01it/s][A

loss:  1714.7374  <<< 



 38%|███▊      | 19/50 [00:18<00:30,  1.01it/s][A

loss:  1767.2449  <<< 



 40%|████      | 20/50 [00:19<00:29,  1.01it/s][A

loss:  1723.7827  <<< 



 42%|████▏     | 21/50 [00:20<00:28,  1.02it/s][A

loss:  1724.6458  <<< 



 44%|████▍     | 22/50 [00:21<00:27,  1.01it/s][A

loss:  1668.2032  <<< 



 46%|████▌     | 23/50 [00:22<00:26,  1.01it/s][A

loss:  1641.1711  <<< 



 48%|████▊     | 24/50 [00:23<00:25,  1.01it/s][A

loss:  1623.8866  <<< 



 50%|█████     | 25/50 [00:24<00:24,  1.01it/s][A

loss:  1619.9192  <<< 



 52%|█████▏    | 26/50 [00:25<00:23,  1.01it/s][A

loss:  1626.6194  <<< 



 54%|█████▍    | 27/50 [00:26<00:22,  1.01it/s][A

loss:  1639.0599  <<< 



 56%|█████▌    | 28/50 [00:27<00:21,  1.01it/s][A

loss:  1676.9126  <<< 



 58%|█████▊    | 29/50 [00:28<00:20,  1.01it/s][A

loss:  1687.8793  <<< 



 60%|██████    | 30/50 [00:29<00:19,  1.00it/s][A

loss:  1769.3943  <<< 



 62%|██████▏   | 31/50 [00:30<00:18,  1.01it/s][A

loss:  1697.0363  <<< 



 64%|██████▍   | 32/50 [00:31<00:17,  1.01it/s][A

loss:  1691.7555  <<< 



 66%|██████▌   | 33/50 [00:32<00:16,  1.01it/s][A

loss:  1618.358  <<< 



 68%|██████▊   | 34/50 [00:33<00:15,  1.01it/s][A

loss:  1587.1947  <<< 



 70%|███████   | 35/50 [00:34<00:14,  1.01it/s][A

loss:  1583.5266  <<< 



 72%|███████▏  | 36/50 [00:35<00:13,  1.01it/s][A

loss:  1600.878  <<< 



 74%|███████▍  | 37/50 [00:36<00:12,  1.01it/s][A

loss:  1639.2325  <<< 



 76%|███████▌  | 38/50 [00:37<00:11,  1.01it/s][A

loss:  1639.813  <<< 



 78%|███████▊  | 39/50 [00:38<00:10,  1.01it/s][A

loss:  1666.7235  <<< 



 80%|████████  | 40/50 [00:39<00:09,  1.01it/s][A

loss:  1615.7023  <<< 



 82%|████████▏ | 41/50 [00:40<00:08,  1.01it/s][A

loss:  1591.0166  <<< 



 84%|████████▍ | 42/50 [00:41<00:07,  1.01it/s][A

loss:  1565.2914  <<< 



 86%|████████▌ | 43/50 [00:42<00:06,  1.01it/s][A

loss:  1557.6125  <<< 



 88%|████████▊ | 44/50 [00:43<00:05,  1.01it/s][A

loss:  1563.8755  <<< 



 90%|█████████ | 45/50 [00:44<00:05,  1.02s/it][A

loss:  1576.2045  <<< 



 92%|█████████▏| 46/50 [00:45<00:04,  1.01s/it][A

loss:  1602.3782  <<< 



 94%|█████████▍| 47/50 [00:46<00:02,  1.00it/s][A

loss:  1603.644  <<< 



 96%|█████████▌| 48/50 [00:47<00:01,  1.01it/s][A

loss:  1632.9375  <<< 



 98%|█████████▊| 49/50 [00:48<00:00,  1.01it/s][A

loss:  1602.9225  <<< 



100%|██████████| 50/50 [00:49<00:00,  1.01it/s]
 30%|███       | 6/20 [05:19<12:04, 51.75s/it]

loss:  1602.3812  <<< 



  0%|          | 0/50 [00:00<?, ?it/s][A
  2%|▏         | 1/50 [00:00<00:47,  1.03it/s][A

loss:  1567.2552  <<< 



  4%|▍         | 2/50 [00:01<00:47,  1.02it/s][A

**inference engine**

In [None]:

# newsize = (140,140) #(260, 260) # /.... 233 * 454
from google.colab.patches import cv2_imshow
import numpy as np 
from google.colab import output

!wget https://live.staticflickr.com/7492/15677707699_d9d67acf9d_b.jpg -O a.jpg
image_in = '/content/a.jpg'

from PIL import Image
import jax.numpy as jnp
def imageRGB(argv):
    im = Image.open(argv)
    tvt, tvu = jnp.asarray(im.resize(newsize)),jnp.asarray(im.resize(newsize)).reshape(-1,3)
    return tvt, tvu
image = jnp.asarray((imageRGB(image_in)[1]))
#restored_state = checkpoints.restore_checkpoint(ckpt_dir=CKPT_DIR, target=state)
#state = restored_state
prediction = eval_step(state, image)
prediction['loss']


predicted_image = np.array(prediction['logits'],  dtype=np.uint8).reshape(newsize) 
cv2_imshow(predicted_image)


In [None]:
# import cv2
# from google.colab.patches import cv2_imshow
# import numpy as np 
# def show_image(argu):
#   L1 = argu[0]
#   predicted_image = np.array(argu[0],  dtype=np.uint8).reshape(newsize) # This would be your image array
#   a = predicted_image
#   for i in range(0,argu.shape[0]):
#     predicted_image = np.array(argu[i],  dtype=np.uint8).reshape(newsize) 
#     a = cv2.hconcat([a, predicted_image])
#   cv2_imshow(a)

# show_image(metrics['logits'])

In [None]:
# from jax.tree_util import tree_structure
# print(tree_structure(state))