<a href="https://colab.research.google.com/github/ShaliniAnandaPhD/PIXEL-PIONEERS-TUTORIALS/blob/main/Pixel_Pioneer_Tutorial_2_From_NumPy_to_JAX_using_Caloric_Counting.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


#High-Performance Computing with JAX for Smart Calorie Counting

## Introduction

This comprehensive guide delves into the world of JAX, a powerful Python library for accelerated numerical computing and machine learning. By using the example of smart calorie counting, we'll demonstrate why even simple tasks can greatly benefit from JAX's advanced capabilities, such as automatic differentiation, JIT compilation, vectorization, and parallelization.

## What is JAX?

JAX is a Python library designed to supercharge numerical computations and machine learning research. It extends NumPy and integrates acceleration libraries like XLA, offering support for GPUs and TPUs.

### Key Capabilities

- **Automatic Differentiation**: Simplifies model optimization.
- **Just-in-Time (JIT) Compilation**: Speeds up code execution.
- **Vectorization**: Efficient processing across data samples.
- **Parallelization**: Utilizes multiple devices (GPUs/TPUs).

In [47]:
# Setting up the Python Environment

### Essential Library Imports

import jax
import jax.numpy as jnp
import numpy as np

In [4]:
### Installing JAX

##!pip install jax jaxlib  # Standard installation

# For TPU support:
##!pip install jax[tpu] --upgrade
##import jax.tools.colab_tpu
##jax.tools.colab_tpu.setup_tpu()

In [48]:
## Basic JAX Operations Compared with NumPy

### Comparing Arrays

arr = np.arange(5)
print("NumPy array:", arr)

arr_jax = jnp.arange(5)
print("JAX array:", arr_jax)

NumPy array: [0 1 2 3 4]
JAX array: [0 1 2 3 4]


In [6]:
### Array Transformations
squared_arr = arr ** 2
squared_arr_jax = arr_jax ** 2

print("Squared NumPy array:", squared_arr)
print("Squared JAX array:", squared_arr_jax)


Squared NumPy array: [ 0  1  4  9 16]
Squared JAX array: [ 0  1  4  9 16]


Scenario: Caloric Adjustment Model
Suppose you're building a model to adjust the estimated calorie count of a food item based on some error metric. You have an initial calorie estimate and want to minimize the error between this estimate and the actual calorie count through gradient descent.

Using JAX for Automatic Differentiation
JAX excels in scenarios like this due to its automatic differentiation capability. You can compute gradients effortlessly, which are essential for optimization algorithms like gradient descent.

In [49]:
## Automatic Differentiation with `grad`

### Simplifying Gradient Computation
from jax import grad
import jax.numpy as jnp

# Function to compute the squared error
def squared_error(estimated_calories, actual_calories):
    return (estimated_calories - actual_calories) ** 2

# Automatic differentiation to find the gradient
grad_error = grad(squared_error)

# Example data
actual_calories = 250.0
estimated_calories = 230.0

# Compute the gradient
gradient = grad_error(estimated_calories, actual_calories)
print("JAX - Gradient of Error:", gradient)





JAX - Gradient of Error: -40.0


Simplification: JAX's grad function automatically calculates the gradient, simplifying the code and reducing the potential for manual errors.

In [50]:
import numpy as np

# Function to compute the squared error
def squared_error(estimated_calories, actual_calories):
    return (estimated_calories - actual_calories) ** 2

# Manually compute the gradient
def grad_error(estimated_calories, actual_calories):
    return 2 * (estimated_calories - actual_calories)

# Example data
actual_calories = 250.0
estimated_calories = 230.0

# Compute the gradient
gradient = grad_error(estimated_calories, actual_calories)
print("NumPy - Gradient of Error:", gradient)



NumPy - Gradient of Error: -40.0


Just in time compilation

In [18]:
from jax import jit
import jax.numpy as jnp

def calculate_calories(batch):
    return jnp.sum(batch * 1.2)

# Just-In-Time compilation for single-device execution
jit_calculate_calories = jit(calculate_calories)

data = jnp.array([100, 150, 200, 120, 180, 220])
total_calories = jit_calculate_calories(data)
print("JAX - Total Calories:", total_calories)



JAX - Total Calories: 1164.0


vmap (Vectorized Map): This function is used for vectorizing operations. It allows a function that operates on a single data point to be seamlessly applied to each element of an array. Essentially, vmap transforms a function to operate on arrays elementwise

In [20]:
from jax import vmap
import jax.numpy as jnp

def calculate_calories(item):
    # Simulate calorie calculation for a single item
    return item * 1.2

# Vectorize the function
v_calculate_calories = vmap(calculate_calories)

# Example data
data = jnp.array([100, 150, 200, 120, 180, 220])  # All items in one array

# Apply the vectorized function
calories = v_calculate_calories(data)

# Aggregate results
total_calories = jnp.sum(calories)
print("JAX (vmap) - Total Calories:", total_calories)


JAX (vmap) - Total Calories: 1164.0


In this example, vmap is used to apply calculate_calories across each element of data. The function calculate_calories is written for a single item, and vmap automatically vectorizes it over the batch.

NumPy Example for Sequential Calculation
NumPy does not have built-in parallelization like JAX's pmap. Instead, calculations are done sequentially.

In [21]:
import numpy as np

# Function to calculate calories for a batch of food items
def calculate_calories(batch):
    # Simulate a complex calorie calculation
    return np.sum(batch * 1.2)

# Example data
data = np.array([100, 150, 200, 120, 180, 220])  # All items in one array

# Sequential execution
total_calories = calculate_calories(data)
print("NumPy - Total Calories:", total_calories)


NumPy - Total Calories: 1164.0


REAL LIFE SCENARIO

In [22]:
# Comprehensive Python and JAX Tutorial for Advanced Calorie Counting

## Expanded Setup

### Enhanced Library Import

import jax  # Main JAX library for high-performance numerical computing
import jax.numpy as jnp  # JAX version of NumPy with GPU/TPU support
import numpy as np  # Standard NumPy for comparison and data manipulation
import pandas as pd  # Data analysis and manipulation tool

In [26]:
### Creating More Complex Sample Data
# Nutritional information (protein, carbs, fat) for three different food items

import numpy as np
import random

# Setting a seed for reproducibility
random.seed(0)

# Number of days
num_days = 10

# Nutrients: Protein, Carbs, Fat
num_nutrients = 3

# Generating random nutritional information for each day
nutrients = np.array([[random.randint(10, 50) for _ in range(num_nutrients)] for _ in range(num_days)])

# Densities in calories per gram for Protein, Carbs, and Fat
densities = np.array([4, 4, 9])  # 4 calories/gram for Protein and Carbs, 9 for Fat

# Displaying the nutritional information
print("Nutritional Information (Protein, Carbs, Fat) over 10 Days (in grams):")
print(nutrients)

# Calculating total calories for each day
total_calories_per_day = np.sum(nutrients * densities, axis=1)

# Displaying total calories consumed each day
print("\nTotal Calories Consumed Each Day:")
print(total_calories_per_day)

# Preparing data to save to file
data_to_save = np.hstack((nutrients, total_calories_per_day.reshape(-1, 1)))

# Specify the file path (saving in the current working directory)
file_path = 'nutritional_data.csv'

# Save to a CSV file
np.savetxt(file_path, data_to_save, delimiter=',', header='Protein(g),Carbs(g),Fat(g),Total Calories', comments='')

# Print the file path
print(f"\nNutritional data saved to: {file_path}")




Nutritional Information (Protein, Carbs, Fat) over 10 Days (in grams):
[[34 36 12]
 [26 42 41]
 [35 29 40]
 [32 47 23]
 [42 18 28]
 [18 16 49]
 [26 44 48]
 [19 29 16]
 [14 31 40]
 [45 16 32]]

Total Calories Consumed Each Day:
[388 641 616 523 492 577 712 336 540 532]

Nutritional data saved to: nutritional_data.csv


In [30]:
### Sophisticated Calorie Calculation
# Dot product to calculate total calories
import numpy as np

# Assuming random data for demonstration
nutrients = np.random.randint(10, 50, size=(10, 3))  # 10 days, 3 nutrients

# Densities in calories per gram for Protein, Carbs, and Fat
densities = np.array([4, 4, 9])  # 4 calories/gram for Protein and Carbs, 9 for Fat

# Sophisticated Calorie Calculation for each day
# Using np.dot for each row in nutrients array
calories = np.dot(nutrients, densities)

print("NumPy Calories per day:")
print(calories)


NumPy Calories per day:
[667 448 463 613 617 632 494 495 525 529]


nutrients is a 2D array where each row corresponds to a day and each column to a nutrient (protein, carbs, fat).
The np.dot function is used to multiply the nutrients array with the densities array. Since nutrients is a 2D array and densities is a 1D array, np.dot will automatically handle this as a matrix-vector multiplication, resulting in a 1D array where each element is the total calorie count for each day.
The printed calories array will display the total calories for each day.

In [31]:
### Replicating Calculations in JAX

# Converting the data to JAX arrays
import jax.numpy as jnp
import numpy as np  # Used only for random number generation

# Setting a random seed for reproducibility
np.random.seed(0)

# Assuming random data for demonstration
# Note: JAX does not have its own randint, so we use NumPy for random data generation and then convert it to a JAX array
nutrients = jnp.array(np.random.randint(10, 50, size=(10, 3)))  # 10 days, 3 nutrients

# Densities in calories per gram for Protein, Carbs, and Fat
densities = jnp.array([4, 4, 9])  # 4 calories/gram for Protein and Carbs, 9 for Fat

# Sophisticated Calorie Calculation for each day
# Using jnp.dot for each row in nutrients array
calories = jnp.dot(nutrients, densities)

print("JAX Calories per day:")
print(calories)



JAX Calories per day:
[209 533 605 506 564 634 611 335 359 392]


We use numpy to generate the random data because JAX does not have a direct equivalent of numpy.random.randint. We then convert this data to a JAX array using jnp.array.
For the densities and the dot product calculation, we use JAX's jax.numpy module, which functions similarly to NumPy but is compatible with JAX's accelerated computing capabilities.

In [32]:

## Advanced JAX Capabilities

### Deep Dive into Just-in-Time Compilation (`jit`)

import jax.numpy as jnp
from jax import jit
import numpy as np  # Used for random data generation

# Setting a random seed for reproducibility
np.random.seed(0)

# Generating random nutritional data (protein, carbs, fat) for 10 days
nutrients = jnp.array(np.random.randint(10, 50, size=(10, 3)))

# Densities in calories per gram for Protein, Carbs, and Fat
densities = jnp.array([4, 4, 9])  # 4 calories/gram for Protein and Carbs, 9 for Fat

# Function to calculate total calories
def calculate_calories(nutrients, densities):
    return jnp.dot(nutrients, densities)

# JIT compilation to optimize the function
jit_calculate_calories = jit(calculate_calories)

# JIT-compiled function call
calories = jit_calculate_calories(nutrients, densities)
print("JIT Calories per day:")
print(calories)


JIT Calories per day:
[209 533 605 506 564 634 611 335 359 392]


In [34]:
##Similar to JIT

import numpy as np

# Setting a random seed for reproducibility
np.random.seed(0)

# Generating random nutritional data (protein, carbs, fat) for 10 days
nutrients = np.random.randint(10, 50, size=(10, 3))

# Densities in calories per gram for Protein, Carbs, and Fat
densities = np.array([4, 4, 9])  # 4 calories/gram for Protein and Carbs, 9 for Fat

# Function to calculate total calories
def calculate_calories(nutrients, densities):
    return np.dot(nutrients, densities)

# Function call
calories = calculate_calories(nutrients, densities)
print("NumPy Calories per day:")
print(calories)


NumPy Calories per day:
[209 533 605 506 564 634 611 335 359 392]


Let's assume we have a function that not only calculates total calories but also includes a penalty term based on deviations from a target calorie count. This kind of function could be useful in nutritional planning or diet optimization, where we not only calculate calories but also want to minimize the deviation from a dietary goal.

In [36]:
from jax import grad
import jax.numpy as jnp
import numpy as np

# Setting a random seed for reproducibility
np.random.seed(0)

# Function to calculate calories with a penalty for deviation from a target
def calculate_calories_with_penalty(nutrients, densities, target_calories):
    total_calories = jnp.dot(nutrients, densities)
    penalty = jnp.sum((total_calories - target_calories)**2)
    return penalty

# Example data: Nutrients for 10 days, converted to float for JAX compatibility
nutrients = jnp.array(np.random.randint(10, 50, size=(10, 3)), dtype=jnp.float32)

# Densities and target calories
densities = jnp.array([4, 4, 9], dtype=jnp.float32)  # Protein, Carbs, Fat densities
target_calories = 2000.0  # Example target calories, as float

# Function to compute the gradient of calculate_calories_with_penalty
grad_fn = grad(calculate_calories_with_penalty, argnums=0)  # Gradient w.r.t. nutrients

# Compute the gradient
gradient = grad_fn(nutrients, densities, target_calories)
print("Gradient with respect to Nutrients:")
print(gradient)



Gradient with respect to Nutrients:
[[-14328. -14328. -32238.]
 [-11736. -11736. -26406.]
 [-11160. -11160. -25110.]
 [-11952. -11952. -26892.]
 [-11488. -11488. -25848.]
 [-10928. -10928. -24588.]
 [-11112. -11112. -25002.]
 [-13320. -13320. -29970.]
 [-13128. -13128. -29538.]
 [-12864. -12864. -28944.]]


In Numpy we have to manually compute the gradient

In [37]:
import numpy as np

# Setting a random seed for reproducibility
np.random.seed(0)

# Function to calculate calories with a penalty for deviation from a target
def calculate_calories_with_penalty(nutrients, densities, target_calories):
    total_calories = np.dot(nutrients, densities)
    penalty = np.sum((total_calories - target_calories)**2)
    return penalty

# Function to manually compute the gradient of the penalty w.r.t. nutrients
def grad_calories_with_penalty(nutrients, densities, target_calories):
    total_calories = np.dot(nutrients, densities)
    # Gradient calculation
    grad_penalty = 2 * (total_calories - target_calories).reshape(-1, 1) * densities
    return grad_penalty

# Example data: Nutrients for 10 days
nutrients = np.random.randint(10, 50, size=(10, 3))

# Densities and target calories
densities = np.array([4, 4, 9])  # Protein, Carbs, Fat densities
target_calories = 2000  # Example target calories

# Manually compute the gradient
gradient = grad_calories_with_penalty(nutrients, densities, target_calories)
print("Gradient with respect to Nutrients:")
print(gradient)


Gradient with respect to Nutrients:
[[-14328 -14328 -32238]
 [-11736 -11736 -26406]
 [-11160 -11160 -25110]
 [-11952 -11952 -26892]
 [-11488 -11488 -25848]
 [-10928 -10928 -24588]
 [-11112 -11112 -25002]
 [-13320 -13320 -29970]
 [-13128 -13128 -29538]
 [-12864 -12864 -28944]]



The concept of vectorization, particularly as implemented by JAX's vmap (vectorized map), is a powerful tool for parallel processing of batch data. It allows you to apply a function to each element of a batch simultaneously, rather than iterating through the batch in a loop. This is especially beneficial for performance when working with large datasets or simulations. Let's dive into this concept using an example in the context of caloric counting.

Example: Caloric Counting with Vectorization in JAX
Suppose you have a function calculate_calories, which calculates the total calories based on the given nutrients and densities. Normally, this function would work on a single set of nutrients and densities. With vmap, you can easily extend this function to work on batches of data.

First, let's define the calculate_calories function

In [38]:
import jax.numpy as jnp

def calculate_calories(nutrients, densities):
    # Calculates total calories based on nutrients and densities
    return jnp.dot(nutrients, densities)


Now, let's use vmap to vectorize this function for batch processing:

In [39]:
from jax import vmap

# Batched data: multiple sets of nutrients and densities
batched_nutrients = jnp.array([[10, 20, 30], [1, 2, 3]])  # Two sets of nutrient data
batched_densities = jnp.array([[4, 4, 9], [4, 4, 9]])     # Corresponding densities for each set

# Vectorize the calorie calculation function
vectorized_calories = vmap(calculate_calories)

# Apply the vectorized function to the batched data
all_calories = vectorized_calories(batched_nutrients, batched_densities)
print("All Calories:", all_calories)


All Calories: [390  39]


In [None]:
##Automatic differentiation is crucial in optimization and machine learning. JAX's `grad` makes obtaining gradients straightforward and efficient.

### Comprehensive Look at Vectorization (`vmap`)

from jax import vmap

# Batch data for parallel computation
batched_nutrients = jnp.array([[10, 20, 30], [1, 2, 3]])
batched_densities = jnp.array([[4, 4, 9], [4, 4, 9]])

# Vectorize the calorie calculation
vectorized_calories = vmap(calculate_calories)
all_calories = vectorized_calories(batched_nutrients, batched_densities)

print("All Calories:", all_calories)


##`vmap` vectorizes functions, enabling parallel processing of batches, significantly improving performance for large datasets or simulations.


In this code:

batched_nutrients and batched_densities are arrays containing two sets of nutrient data and their corresponding densities.
vmap(calculate_calories) creates a new function, vectorized_calories, which applies calculate_calories to each element (set of nutrients and densities) in the batch.
all_calories is the result of applying vectorized_calories to the batched data, producing a batch of calorie counts.
Advantages of Using vmap:
Performance Improvement: vmap allows for parallel processing of the data, which can significantly speed up computations, especially with large batches of data.

Code Simplicity: It simplifies the code by removing the need for explicit loops over batches. This makes the code more concise and easier to read.

Flexibility: vmap can be used with many functions, allowing for easy vectorization of existing code.

Building and training a simple model using TensorFlow for NumPy data and then comparing it with a JAX-based approach for the same task.


In [41]:
import pandas as pd
import numpy as np
import jax.numpy as jnp
from sklearn.model_selection import train_test_split

# Load and preprocess meal data
# Update the file path to the correct one
df = pd.read_csv("/content/nutritional_data.csv")
X = df[['Protein(g)', 'Carbs(g)', 'Fat(g)']]  # Feature extraction
y = df['Total Calories']  # Target variable (calories)

# Splitting the dataset for training and testing
X_train, X_test, y_train, y_test = train_test_split(X, y)

# Convert to NumPy arrays for TensorFlow model
X_train_np, y_train_np = np.array(X_train), np.array(y_train)

# Convert to JAX arrays for JAX model
X_train_jax, y_train_jax = jnp.array(X_train), jnp.array(y_train)



In [42]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
import time

# Building a neural network model using TensorFlow
model_np = Sequential([
    Dense(32, activation='relu', input_shape=(3,)),
    Dense(1)
])
model_np.compile(optimizer='adam', loss='mse')

# Training the model with NumPy data
start_time = time.time()
model_np.fit(X_train_np, y_train_np, epochs=50)
end_time = time.time()

print("Training time with NumPy data:", end_time - start_time)


Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50
Training time with NumPy data: 1.558666467666626


In [46]:
import pandas as pd
import jax
import jax.numpy as jnp
from flax import linen as nn
from jax import random, grad, jit
from flax.training import train_state
import optax
import time
from sklearn.model_selection import train_test_split

# Load and preprocess meal data
df = pd.read_csv("/content/nutritional_data.csv")
X = df[['Protein(g)', 'Carbs(g)', 'Fat(g)']].values  # Extract features
y = df['Total Calories'].values  # Extract target variable

# Convert to JAX arrays
X_jax = jnp.array(X)
y_jax = jnp.array(y)

# Splitting the dataset for training and testing
X_train_jax, X_test_jax, y_train_jax, y_test_jax = train_test_split(X_jax, y_jax)

# Define a simple neural network model using Flax
class SimpleNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(32)(x)
        x = nn.relu(x)
        x = nn.Dense(1)(x)
        return x

# Initialize the model
rng = jax.random.PRNGKey(0)
input_shape = (1, 3)  # Shape for model initialization
model = SimpleNN()
params = model.init(rng, jnp.ones(input_shape))['params']

# Define the optimizer
optimizer = optax.adam(learning_rate=0.001)
state = train_state.TrainState.create(
    apply_fn=model.apply, params=params, tx=optimizer)

# Define the loss function
def loss_fn(params, inputs, targets):
    predictions = model.apply({'params': params}, inputs)
    return jnp.mean((predictions - targets) ** 2)

# Training step
@jit
def train_step(state, inputs, targets):
    grads = grad(loss_fn)(state.params, inputs, targets)
    return state.apply_gradients(grads=grads)

# Training loop with timing
start_time = time.time()
for epoch in range(50):
    state = train_step(state, X_train_jax, y_train_jax)
end_time = time.time()

print("JAX model trained.")
print("Training time:", end_time - start_time, "seconds")


JAX model trained.
Training time: 1.3899955749511719 seconds


Why the Difference Might Not Be Pronounced in This Case:
Small Dataset and Simple Model: The training task might be too small to highlight JAX's advantages in handling large-scale data and complex computations.

Overhead of Initial Compilation in JAX: The first run in JAX includes JIT compilation time, which might offset some performance gains in short, one-off training runs.

Optimization Differences: Different levels of optimization in TensorFlow and JAX (like default settings in optimizers, initialization, etc.) can lead to variations in training speed.

Hardware Utilization: If the task isn’t large enough to fully utilize GPU/TPU capabilities, the difference in performance might not be as significant.

Conclusion:
While JAX showed a slight improvement in your specific test, its true strengths lie in scenarios that involve larger-scale data, complex mathematical computations, the need for automatic differentiation and advanced parallelization/vectorization capabilities. In simpler tasks or smaller datasets, these advantages might not be as pronounced, and the choice between JAX and NumPy/TensorFlow might be more influenced by factors like familiarity, existing codebases, or specific library features.








Next walkthrough: Employ Numpy and JAX in more complex use case related to caloric counting