In [1]:
import jax
import jax.numpy as jnp

In [14]:
rng = jax.random.PRNGKey(0)
a = jax.random.uniform(rng, (2,3,3))
b = jax.random.uniform(rng, (2,3,4))

In [31]:
'''
einsum examples
'''
import jax.numpy as jnp

# Matrix multiplication
A = jnp.array([[1, 2], [3, 4]])
B = jnp.array([[5, 6], [7, 8]])
C = jnp.einsum('ij,jk->ik', A, B)

# Element-wise multiplication
A = jnp.array([1, 2, 3])
B = jnp.array([4, 5, 6])
C = jnp.einsum('i,i->i', A, B)

# Sum of all elements
A = jnp.array([[1, 2], [3, 4]])
sum_all = jnp.einsum('ij->', A)

# Sum along axis
A = jnp.array([[1, 2], [3, 4]])
sum_axis0 = jnp.einsum('ij->i', A)
sum_axis1 = jnp.einsum('ij->j', A)

# Transpose
A = jnp.array([[1, 2], [3, 4]])
transpose = jnp.einsum('ij->ji', A)

# Trace
A = jnp.array([[1, 2], [3, 4]])
trace = jnp.einsum('ii', A)

# Dot product
A = jnp.array([1, 2, 3])
B = jnp.array([4, 5, 6])
dot_product = jnp.einsum('i,i->', A, B)

# Batch matrix multiplication
A = jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
B = jnp.array([[[9, 10], [11, 12]], [[13, 14], [15, 16]]])
C = jnp.einsum('ijk,ikl->ijl', A, B)

# Outer product
A = jnp.array([1, 2, 3])
B = jnp.array([4, 5, 6])
outer_product = jnp.einsum('i,j->ij', A, B)

# Matrix-vector multiplication
A = jnp.array([[1, 2], [3, 4]])
v = jnp.array([5, 6])
C = jnp.einsum('ij,j->i', A, v)

# Diagonal of a matrix
A = jnp.array([[1, 2], [3, 4]])
diagonal = jnp.einsum('ii->i', A)

# Kronecker product
A = jnp.array([[1, 2], [3, 4]])
B = jnp.array([[5, 6], [7, 8]])
kron_product = jnp.einsum('ij,kl->ikjl', A, B)

# Batch-wise dot product
A = jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
B = jnp.array([[[9, 10], [11, 12]], [[13, 14], [15, 16]]])
batch_dot_product = jnp.einsum('...i,...i', A, B)

# Batch-wise matrix multiplication
A = jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
B = jnp.array([[[9, 10], [11, 12]], [[13, 14], [15, 16]]])
batch_matrix_product = jnp.einsum('...ik,...kj->...ij', A, B)

# Transpose and matrix multiplication
A = jnp.array([[1, 2], [3, 4]])
B = jnp.array([[5, 6], [7, 8]])
C = jnp.einsum('ij,jk->ki', A, B)

# Matrix multiplication and sum along axis
A = jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
B = jnp.array([[9, 10], [11, 12]])
C = jnp.einsum('ijk,jk->i', A, B)

# Trace of batch of matrices
A = jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
trace_batch = jnp.einsum('iik->i', A)

# Matrix multiplication and sum
A = jnp.array([[1, 2], [3, 4]])
B = jnp.array([[5, 6], [7, 8]])
sum_product = jnp.einsum('ij,jk->', A, B)

# Element-wise multiplication and sum
A = jnp.array([1, 2, 3])
B = jnp.array([4, 5, 6])
sum_product_elements = jnp.einsum('i,i->', A, B)

# Outer product and sum
A = jnp.array([1, 2, 3])
B = jnp.array([4, 5, 6])
sum_outer_product = jnp.einsum('i,j->', A, B)

# Concatenation
A = jnp.array([[1, 2], [3, 4]])
B = jnp.array([[5, 6], [7, 8]])
C = jnp.einsum('ij,ij->ij', A, B)

# Matrix-vector multiplication and sum
A = jnp.array([[1, 2], [3, 4]])
v = jnp.array([5, 6])
sum_mv_product = jnp.einsum('ij,j->', A, v)

# Tensor contraction
A = jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
B = jnp.array([[[9, 10], [11, 12]], [[13, 14], [15, 16]]])
contracted_tensor = jnp.einsum('ijk,ikl->ijl', A, B)

# Element-wise multiplication and sum along axis
A = jnp.array([[1, 2], [3, 4]])
B = jnp.array([[5, 6], [7, 8]])
sum_product_elements_axis0 = jnp.einsum('ij,ij->j', A, B)

# Diagonal of matrix product
A = jnp.array([[1, 2], [3, 4]])
B = jnp.array([[5, 6], [7, 8]])
diagonal_product = jnp.einsum('ij,jk->k', A, B)

# Kronecker product and sum
A = jnp.array([[1, 2], [3, 4]])
B = jnp.array([[5, 6], [7, 8]])
sum_kron_product = jnp.einsum('ij,kl->ijkl', A, B)

# Concatenation and sum along axis
A = jnp.array([[1, 2], [3, 4]])
B = jnp.array([[5, 6], [7, 8]])
sum_concatenation_axis0 = jnp.einsum('ij,ij->i', A, B)

# Batch-wise matrix multiplication and sum
A = jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
B = jnp.array([[[9, 10], [11, 12]], [[13, 14], [15, 16]]])
sum_batch_matrix_product = jnp.einsum('...ik,...kj->...ij', A, B)


2.081306


In [32]:
'''
KNN in JAX
'''
import jax
import jax.numpy as jnp
from jax import jit, random

def euclidean_distance(x1, x2):
    """Calculate the Euclidean distance between two points."""
    return jnp.sqrt(jnp.sum((x1 - x2)**2))

def knn(train_X, train_y, test_X, k):
    """K-Nearest Neighbors algorithm."""
    distances = jnp.array([[euclidean_distance(train_X[i], test_X[j]) for j in range(len(test_X))] for i in range(len(train_X))])
    nearest_indices = jnp.argsort(distances, axis=0)[:k]
    nearest_labels = train_y[nearest_indices]
    predicted_labels = jnp.array([jnp.argmax(jnp.bincount(nearest_labels[:, i])) for i in range(nearest_labels.shape[1])])
    return predicted_labels

# Generate synthetic data
rng_key = random.PRNGKey(0)
train_X = random.normal(rng_key, (100, 2))
train_y = random.randint(rng_key, (100,), 0, 2)
test_X = random.normal(rng_key, (10, 2))

# Perform KNN
k = 3
predicted_labels = knn(train_X, train_y, test_X, k)

print("Predicted Labels:", predicted_labels)


'\nKNN in JAX\n'

In [1]:
'''
Naive Bayes in JAX
'''
import jax
import jax.numpy as jnp
from jax import random

def gaussian_pdf(x, mean, std):
    return 1.0 / (jnp.sqrt(2 * jnp.pi) * std) * jnp.exp(-0.5 * ((x - mean) / std) ** 2)

def prior_probabilities(y):
    unique, counts = jnp.unique(y, return_counts=True)
    return counts / len(y)

def fit(X, y):
    n_samples, n_features = X.shape
    classes = jnp.unique(y)
    n_classes = len(classes)

    prior = prior_probabilities(y)

    mean = jnp.zeros((n_classes, n_features))
    std = jnp.zeros((n_classes, n_features))

    for i, c in enumerate(classes):
        X_c = X[y == c]
        mean[i] = jnp.mean(X_c, axis=0)
        std[i] = jnp.std(X_c, axis=0)

    return prior, mean, std

def predict(X, prior, mean, std):
    n_samples, n_features = X.shape
    n_classes = len(prior)
    probabilities = jnp.zeros((n_samples, n_classes))

    for i in range(n_classes):
        probabilities[:, i] = jnp.prod(gaussian_pdf(X, mean[i], std[i]), axis=1) * prior[i]

    return jnp.argmax(probabilities, axis=1)

# Generate synthetic data
rng_key = random.PRNGKey(0)
X = random.normal(rng_key, (100, 2))
y = jnp.concatenate([jnp.zeros(50), jnp.ones(50)]).astype(jnp.int32)

# Split the data into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

# Train the model
prior, mean, std = fit(X_train, y_train)

# Make predictions
y_pred = predict(X_test, prior, mean, std)

# Calculate accuracy
accuracy = jnp.mean(y_pred == y_test)
print(f"Accuracy: {accuracy}")

ModuleNotFoundError: No module named 'jax'

In [44]:
'''
PCA in JAX
'''
import jax
import jax.numpy as jnp
from jax import jit

# Generate synthetic data
data = jnp.array([[1, 2], [2, 3], [8, 9], [9, 10], [5, 6]])

# Initialize cluster centroids
k = 2
centroids = data[:k]

# Define the distance function
def distance(x, centroids):
    return jnp.linalg.norm(x - centroids, axis=1)

# Assign data points to the nearest centroids
def assign_clusters(data, centroids):
    return jnp.argmin(jnp.linalg.norm(data[:, None] - centroids, axis=2), axis=1)

# Update centroids to the mean of assigned data points
def update_centroids(data, cluster_assignments):
    new_centroids = jnp.array([jnp.mean(data[cluster_assignments == i], axis=0) for i in range(k)])
    return new_centroids

# Perform K-Means clustering
for _ in range(100):
    cluster_assignments = assign_clusters(data, centroids)
    new_centroids = update_centroids(data, cluster_assignments)
    
    # Check for convergence
    if jnp.all(centroids == new_centroids):
        break
    
    centroids = new_centroids

# Final cluster assignments
print("Cluster Assignments:", cluster_assignments)
print("Cluster Centroids:", centroids)


(5, 2) (2, 2)
(5, 2, 2)


In [37]:
'''
Linear regression in JAX
'''
import jax
import jax.numpy as jnp
from jax import grad

# Generate synthetic data
X = jnp.array([[1, 1], [2, 2], [3, 3], [4, 4]])
y = jnp.array([2, 4, 6, 8])

# Initialize model parameters
params = jnp.array([0.0, 0.0])

# Define the linear regression model
def linear_regression(params, x):
    return jnp.dot(params, x)

# Define the mean squared error loss
def mse(params, x, y):
    predictions = linear_regression(params, x)
    return jnp.mean((predictions - y) ** 2)

# Compute gradients and update parameters using gradient descent
learning_rate = 0.01
for _ in range(100):
    gradients = grad(mse)(params, X.T, y)
    params -= learning_rate * gradients

# Make predictions
predictions = linear_regression(params, X.T)
print(predictions)


[1.9999999 3.9999998 5.9999995 7.9999995]


In [38]:
'''
Logistic regression in JAX
'''
import jax
import jax.numpy as jnp
from jax import grad

# Generate synthetic data
X = jnp.array([[1, 2], [2, 3], [3, 4], [5, 6]])
y = jnp.array([0, 0, 1, 1], dtype=jnp.float32)

# Initialize model parameters
params = jnp.array([0.0, 0.0])

# Define the logistic regression model
def logistic_regression(params, x):
    z = jnp.dot(params, x)
    return 1.0 / (1.0 + jnp.exp(-z))

# Define the cross-entropy loss
def cross_entropy(params, x, y):
    predictions = logistic_regression(params, x)
    return -jnp.mean(y * jnp.log(predictions) + (1.0 - y) * jnp.log(1.0 - predictions))

# Compute gradients and update parameters using gradient descent
learning_rate = 0.01
for _ in range(100):
    gradients = grad(cross_entropy)(params, X.T, y)
    params -= learning_rate * gradients

# Make predictions
predictions = logistic_regression(params, X.T)
print(predictions)


[0.5671718  0.6211104  0.67221195 0.7624377 ]


In [None]:
import jax
import jax.numpy as jnp
from jax import grad, jit, random

def sigmoid(x):
    return 1.0 / (1.0 + jnp.exp(-x))

def neural_network(params, x):
    w1, b1, w2, b2 = params
    hidden = jnp.tanh(jnp.dot(x, w1) + b1)
    output = jnp.dot(hidden, w2) + b2
    return output

def loss(params, x, y):
    output = neural_network(params, x)
    return jnp.mean((output - y) ** 2)

def init_params(rng_key, input_dim, hidden_dim, output_dim):
    rng_key, w1_key, b1_key, w2_key, b2_key = random.split(rng_key, 5)
    w1 = random.normal(w1_key, (input_dim, hidden_dim))
    b1 = random.normal(b1_key, (hidden_dim,))
    w2 = random.normal(w2_key, (hidden_dim, output_dim))
    b2 = random.normal(b2_key, (output_dim,))
    return (w1, b1, w2, b2)

def update(params, x, y, learning_rate):
    grad_params = grad(loss)(params, x, y)
    return [(param - learning_rate * grad_param)
            for param, grad_param in zip(params, grad_params)]

# Generate synthetic data
rng_key = random.PRNGKey(0)
X = random.normal(rng_key, (100, 1))
y = 3*X + 1 + random.normal(rng_key, (100, 1)) * 0.1  # y = 3x + 1 + noise

# Initialize parameters
input_dim = X.shape[1]
hidden_dim = 10
output_dim = y.shape[1]
params = init_params(rng_key, input_dim, hidden_dim, output_dim)

# Train the model using gradient descent
num_epochs = 1000
learning_rate = 0.01
for epoch in range(num_epochs):
    params = update(params, X, y, learning_rate)

# Test the model
X_test = jnp.array([[1.0]])
y_pred = neural_network(params, X_test)
print(f"Predicted y for x=1.0: {y_pred[0][0]}")



In [2]:
'''
2D convolution in JAX
'''
import jax
import jax.numpy as jnp

def convolution2d(image, kernel):
    """
    Perform 2D convolution on an image.

    Args:
    - image: 2D array representing the input image
    - kernel: 2D array representing the convolution kernel

    Returns:
    - result: 2D array representing the convolved image
    """
    image_height, image_width = image.shape
    kernel_height, kernel_width = kernel.shape

    result_height = image_height - kernel_height + 1
    result_width = image_width - kernel_width + 1

    result = jnp.zeros((result_height, result_width))

    for i in range(result_height):
        for j in range(result_width):
            result[i, j] = jnp.sum(image[i:i+kernel_height, j:j+kernel_width] * kernel)

    return result

# Sample 2D image (5x5)
image = jnp.array([[1, 2, 3, 4, 5],
                    [6, 7, 8, 9, 10],
                    [11, 12, 13, 14, 15],
                    [16, 17, 18, 19, 20],
                    [21, 22, 23, 24, 25]])

# Sample kernel (3x3)
kernel = jnp.array([[1, 0, -1],
                     [1, 0, -1],
                     [1, 0, -1]])

# Perform convolution
result = convolution2d(image, kernel)

print("Input Image:")
print(image)
print("\nConvolution Kernel:")
print(kernel)
print("\nConvolved Image:")
print(result)


ModuleNotFoundError: No module named 'jax'

In [3]:
'''
Adam Optimizer
'''
import jax.numpy as jnp
from jax import grad, jit, random

def adam(grad, init_params, step_size=0.001, b1=0.9, b2=0.999, eps=1e-8, num_iters=1000):
    """
    Adam optimizer implementation.

    Args:
    - grad: Function to compute the gradient of the loss with respect to the parameters.
    - init_params: Initial parameters of the model.
    - step_size: Step size (learning rate).
    - b1: Exponential decay rate for the first moment estimates.
    - b2: Exponential decay rate for the second moment estimates.
    - eps: Small constant to prevent division by zero.
    - num_iters: Number of iterations for optimization.

    Returns:
    - params: Optimized parameters.
    """
    m = jnp.zeros_like(init_params)
    v = jnp.zeros_like(init_params)
    params = init_params

    for i in range(num_iters):
        g = grad(params)
        m = b1 * m + (1 - b1) * g
        v = b2 * v + (1 - b2) * (g ** 2)
        m_hat = m / (1 - b1 ** (i + 1))
        v_hat = v / (1 - b2 ** (i + 1))
        params = params - step_size * m_hat / (jnp.sqrt(v_hat) + eps)

    return params

# Test the Adam optimizer on a simple function
def f(x):
    return x**2

grad_f = grad(f)
init_params = jnp.array(10.0)
optimized_params = adam(grad_f, init_params)

print(f"Optimized parameter value: {optimized_params}")


ModuleNotFoundError: No module named 'jax'

In [4]:
'''
LayerNorm in JAX
'''
import jax.numpy as jnp

def layer_norm(x, scale, bias, epsilon=1e-5):
    """
    Layer normalization implementation.

    Args:
    - x: Input array of shape (batch_size, features).
    - scale: Scale parameter of shape (features,) for each feature.
    - bias: Bias parameter of shape (features,) for each feature.
    - epsilon: Small constant to prevent division by zero.

    Returns:
    - Normalized output array of the same shape as x.
    """
    mean = jnp.mean(x, axis=-1, keepdims=True)
    var = jnp.var(x, axis=-1, keepdims=True)
    x_norm = (x - mean) / jnp.sqrt(var + epsilon)
    return scale * x_norm + bias

# Test the Layer Normalization implementation
batch_size = 4
features = 3
epsilon = 1e-5
scale = jnp.array([1.0, 1.0, 1.0])
bias = jnp.array([0.0, 0.0, 0.0])
x = jnp.array([[1.0, 2.0, 3.0],
                [4.0, 5.0, 6.0],
                [7.0, 8.0, 9.0],
                [10.0, 11.0, 12.0]])

normalized_x = layer_norm(x, scale, bias, epsilon)
print("Input:")
print(x)
print("\nLayer Normalized Output:")
print(normalized_x)


ModuleNotFoundError: No module named 'jax'

In [5]:
'''
BatchNorm in JAX
'''
import jax.numpy as jnp
from jax import jit, grad, random

def batch_norm(x, scale, bias, epsilon=1e-5):
    """
    Batch normalization implementation.

    Args:
    - x: Input array of shape (batch_size, features).
    - scale: Scale parameter of shape (features,) for each feature.
    - bias: Bias parameter of shape (features,) for each feature.
    - epsilon: Small constant to prevent division by zero.

    Returns:
    - Normalized output array of the same shape as x.
    """
    mean = jnp.mean(x, axis=0, keepdims=True)
    var = jnp.var(x, axis=0, keepdims=True)
    x_norm = (x - mean) / jnp.sqrt(var + epsilon)
    return scale * x_norm + bias

# Test the Batch Normalization implementation
batch_size = 4
features = 3
epsilon = 1e-5
scale = jnp.array([1.0, 1.0, 1.0])
bias = jnp.array([0.0, 0.0, 0.0])
x = jnp.array([[1.0, 2.0, 3.0],
                [4.0, 5.0, 6.0],
                [7.0, 8.0, 9.0],
                [10.0, 11.0, 12.0]])

normalized_x = batch_norm(x, scale, bias, epsilon)
print("Input:")
print(x)
print("\nBatch Normalized Output:")
print(normalized_x)


ModuleNotFoundError: No module named 'jax'

In [6]:
'''
Super resolution code in JAX
'''
import jax
import jax.numpy as jnp
from jax import grad, jit, random
from jax.experimental import optimizers

def init_params(rng, layer_sizes):
    """Initialize model parameters."""
    scale = 1.0 / jnp.sqrt(layer_sizes[0])
    return [(random.normal(rng, (m, n)) * scale, jnp.zeros(n))
            for m, n in zip(layer_sizes[:-1], layer_sizes[1:])]

def relu(x):
    """ReLU activation function."""
    return jnp.maximum(0, x)

def upsampling_layer(params, x):
    """Custom upsampling layer using transposed convolution."""
    w, b = params
    output_shape = (x.shape[0], x.shape[1]*2, x.shape[2]*2, w.shape[1])
    return jnp.conv_transpose(x, w, (2, 2), 'VALID') + b

def super_resolution_model(params, x):
    """Super-resolution model architecture."""
    h1 = relu(upsampling_layer(params[0], x))
    return h1

# Generate synthetic low-resolution image
rng_key = random.PRNGKey(0)
low_res_image = random.normal(rng_key, (1, 16, 16, 3))

# Initialize model parameters
layer_sizes = [48, 3*2*2*3]  # Custom upsampling layer with 2x2 kernel
params = init_params(rng_key, layer_sizes)

# Upsample the low-resolution image
upsampled_image = super_resolution_model(params, low_res_image)

print("Low-resolution image shape:", low_res_image.shape)
print("Upsampled image shape:", upsampled_image.shape)


ModuleNotFoundError: No module named 'jax'

In [7]:
'''
LSTM in JAX
'''

import jax
import jax.numpy as jnp
from jax import random

def sigmoid(x):
    return 1.0 / (1.0 + jnp.exp(-x))

def tanh(x):
    return jnp.tanh(x)

def lstm_cell(prev_c, prev_h, x, params):
    """LSTM cell implementation."""
    W_i, b_i, W_f, b_f, W_o, b_o, W_c, b_c = params

    # Input gate
    i = sigmoid(jnp.dot(x, W_i) + jnp.dot(prev_h, W_i) + b_i)

    # Forget gate
    f = sigmoid(jnp.dot(x, W_f) + jnp.dot(prev_h, W_f) + b_f)

    # Output gate
    o = sigmoid(jnp.dot(x, W_o) + jnp.dot(prev_h, W_o) + b_o)

    # Candidate memory cell update
    c_tilde = tanh(jnp.dot(x, W_c) + jnp.dot(prev_h, W_c) + b_c)

    # Cell state update
    c = f * prev_c + i * c_tilde

    # Hidden state computation
    h = o * tanh(c)

    return c, h

# Initialize LSTM cell parameters
def init_lstm_params(rng_key, input_size, hidden_size):
    """Initialize LSTM cell parameters."""
    rng_key, W_i_key, b_i_key, W_f_key, b_f_key, W_o_key, b_o_key, W_c_key, b_c_key = random.split(rng_key, 9)
    W_i = random.normal(W_i_key, (input_size, hidden_size))
    b_i = jnp.zeros(hidden_size)
    W_f = random.normal(W_f_key, (input_size, hidden_size))
    b_f = jnp.zeros(hidden_size)
    W_o = random.normal(W_o_key, (input_size, hidden_size))
    b_o = jnp.zeros(hidden_size)
    W_c = random.normal(W_c_key, (input_size, hidden_size))
    b_c = jnp.zeros(hidden_size)
    return (W_i, b_i, W_f, b_f, W_o, b_o, W_c, b_c)

# Test the LSTM cell
input_size = 10
hidden_size = 5
rng_key = random.PRNGKey(0)
x = random.normal(rng_key, (1, input_size))
prev_c = jnp.zeros((1, hidden_size))
prev_h = jnp.zeros((1, hidden_size))
params = init_lstm_params(rng_key, input_size, hidden_size)
next_c, next_h = lstm_cell(prev_c, prev_h, x, params)

print("Input size:", input_size)
print("Hidden size:", hidden_size)
print("Input x shape:", x.shape)
print("Output c shape:", next_c.shape)
print("Output h shape:", next_h.shape)


ModuleNotFoundError: No module named 'jax'

In [ ]:
'''
Various distances
'''
import jax.numpy as jnp

def euclidean_distance(x1, x2):
    return jnp.sqrt(jnp.sum((x1 - x2) ** 2))

def squared_euclidean_distance(x1, x2):
    return jnp.sum((x1 - x2) ** 2)

def manhattan_distance(x1, x2):
    return jnp.sum(jnp.abs(x1 - x2))

def chebyshev_distance(x1, x2):
    return jnp.max(jnp.abs(x1 - x2))

def cosine_similarity(x1, x2):
    dot_product = jnp.dot(x1, x2)
    norm_x1 = jnp.sqrt(jnp.sum(x1 ** 2))
    norm_x2 = jnp.sqrt(jnp.sum(x2 ** 2))
    return dot_product / (norm_x1 * norm_x2 + 1e-8)

def minkowski_distance(x1, x2, p):
    return jnp.sum(jnp.abs(x1 - x2) ** p) ** (1.0 / p)

def hamming_distance(x1, x2):
    return jnp.sum(x1 != x2)

def jaccard_distance(x1, x2):
    intersection = jnp.sum(x1 & x2)
    union = jnp.sum(x1 | x2)
    return 1.0 - intersection / (union + 1e-8)

def dice_similarity(x1, x2):
    intersection = jnp.sum(x1 & x2)
    dice_coefficient = 2.0 * intersection / (jnp.sum(x1) + jnp.sum(x2))
    return 1.0 - dice_coefficient

def kullback_leibler_divergence(p, q):
    return jnp.sum(p * jnp.log(p / q + 1e-8))

# Example usage
x1 = jnp.array([1, 2, 3])
x2 = jnp.array([4, 5, 6])

print("Euclidean Distance:", euclidean_distance(x1, x2))
print("Squared Euclidean Distance:", squared_euclidean_distance(x1, x2))
print("Manhattan Distance:", manhattan_distance(x1, x2))
print("Chebyshev Distance:", chebyshev_distance(x1, x2))
print("Cosine Similarity:", cosine_similarity(x1, x2))
print("Minkowski Distance (p=3):", minkowski_distance(x1, x2, 3))
print("Hamming Distance:", hamming_distance(x1, x2))
print("Jaccard Distance:", jaccard_distance(x1 > 2, x2 > 5))
print("Dice Similarity:", dice_similarity(x1 > 2, x2 > 5))
print("Kullback-Leibler Divergence:", kullback_leibler_divergence(x1, x2))


In [8]:
'''
Various evaluation metrics in JAX
'''
import jax.numpy as jnp

# Classification Metrics
def accuracy(y_true, y_pred):
    return jnp.mean(y_true == y_pred)

def precision(y_true, y_pred):
    true_positives = jnp.sum((y_true == 1) & (y_pred == 1))
    predicted_positives = jnp.sum(y_pred == 1)
    return true_positives / (predicted_positives + 1e-8)

def recall(y_true, y_pred):
    true_positives = jnp.sum((y_true == 1) & (y_pred == 1))
    actual_positives = jnp.sum(y_true == 1)
    return true_positives / (actual_positives + 1e-8)

def f1_score(y_true, y_pred):
    prec = precision(y_true, y_pred)
    rec = recall(y_true, y_pred)
    return 2 * (prec * rec) / (prec + rec + 1e-8)

def confusion_matrix(y_true, y_pred):
    return jnp.array([[jnp.sum((y_true == 1) & (y_pred == 1)), jnp.sum((y_true == 0) & (y_pred == 1))],
                       [jnp.sum((y_true == 1) & (y_pred == 0)), jnp.sum((y_true == 0) & (y_pred == 0))]])

# Regression Metrics
def mean_squared_error(y_true, y_pred):
    return jnp.mean((y_true - y_pred) ** 2)

def mean_absolute_error(y_true, y_pred):
    return jnp.mean(jnp.abs(y_true - y_pred))

def r_squared(y_true, y_pred):
    ss_res = jnp.sum((y_true - y_pred) ** 2)
    ss_tot = jnp.sum((y_true - jnp.mean(y_true)) ** 2)
    return 1 - (ss_res / (ss_tot + 1e-8))

# NLP Metrics
def bleu_score(reference_corpus, candidate_corpus):
    raise NotImplementedError("BLEU score calculation is not implemented in JAX.")

def rouge_score(reference_corpus, candidate_corpus):
    raise NotImplementedError("ROUGE score calculation is not implemented in JAX.")

# Other Metrics
def mean_iou(y_true, y_pred):
    intersection = jnp.sum(y_true & y_pred)
    union = jnp.sum(y_true | y_pred)
    return intersection / (union + 1e-8)

def dice_coefficient(y_true, y_pred):
    intersection = jnp.sum(y_true & y_pred)
    return 2 * intersection / (jnp.sum(y_true) + jnp.sum(y_pred) + 1e-8)

# Example usage
y_true = jnp.array([1, 0, 1, 1, 0])
y_pred = jnp.array([1, 1, 0, 1, 0])

print("Accuracy:", accuracy(y_true, y_pred))
print("Precision:", precision(y_true, y_pred))
print("Recall:", recall(y_true, y_pred))
print("F1 Score:", f1_score(y_true, y_pred))
print("Confusion Matrix:\n", confusion_matrix(y_true, y_pred))
print("Mean Squared Error:", mean_squared_error(y_true, y_pred))
print("Mean Absolute Error:", mean_absolute_error(y_true, y_pred))
print("R Squared:", r_squared(y_true, y_pred))
print("Mean IoU:", mean_iou(y_true, y_pred))
print("Dice Coefficient:", dice_coefficient(y_true, y_pred))


ModuleNotFoundError: No module named 'jax'

In [ ]:
'''
mAP in JAX
'''
import jax.numpy as jnp
from jax import jit

def calculate_iou(box1, box2):
    """Calculate IoU (Intersection over Union) between two bounding boxes."""
    x1_tl, y1_tl, x1_br, y1_br = box1
    x2_tl, y2_tl, x2_br, y2_br = box2

    intersection_width = jnp.maximum(0, jnp.minimum(x1_br, x2_br) - jnp.maximum(x1_tl, x2_tl))
    intersection_height = jnp.maximum(0, jnp.minimum(y1_br, y2_br) - jnp.maximum(y1_tl, y2_tl))
    intersection_area = intersection_width * intersection_height

    box1_area = (x1_br - x1_tl) * (y1_br - y1_tl)
    box2_area = (x2_br - x2_tl) * (y2_br - y2_tl)

    union_area = box1_area + box2_area - intersection_area

    return intersection_area / (union_area + 1e-8)

def calculate_precision_recall(gt_boxes, pred_boxes, iou_threshold=0.5):
    """Calculate precision and recall for a given IoU threshold."""
    num_pred_boxes = len(pred_boxes)
    num_gt_boxes = len(gt_boxes)
    tp = jnp.zeros(num_pred_boxes)
    fp = jnp.zeros(num_pred_boxes)
    fn = jnp.zeros(num_gt_boxes)

    for i, pred_box in enumerate(pred_boxes):
        ious = [calculate_iou(pred_box, gt_box) for gt_box in gt_boxes]
        max_iou_idx = jnp.argmax(ious)
        if ious[max_iou_idx] >= iou_threshold:
            if not fn[max_iou_idx]:
                tp[i] = 1
            else:
                fp[i] = 1
        else:
            fp[i] = 1

    tp_cumsum = jnp.cumsum(tp)
    fp_cumsum = jnp.cumsum(fp)
    fn_cumsum = jnp.cumsum(fn)

    precision = tp_cumsum / (tp_cumsum + fp_cumsum + 1e-8)
    recall = tp_cumsum / (tp_cumsum + fn_cumsum + 1e-8)

    return precision, recall

def calculate_average_precision(gt_boxes, pred_boxes, iou_threshold=0.5):
    """Calculate average precision (AP) for a given IoU threshold."""
    precision, recall = calculate_precision_recall(gt_boxes, pred_boxes, iou_threshold)

    # Compute the precision-recall curve
    sorted_indices = jnp.argsort(recall)
    precision = precision[sorted_indices]
    recall = recall[sorted_indices]

    # Calculate the area under the precision-recall curve
    ap = jnp.trapz(precision, recall)

    return ap

def calculate_mAP(gt_boxes_list, pred_boxes_list, iou_threshold=0.5):
    """Calculate mean Average Precision (mAP) across all classes."""
    num_classes = len(gt_boxes_list)
    aps = []

    for i in range(num_classes):
        ap = calculate_average_precision(gt_boxes_list[i], pred_boxes_list[i], iou_threshold)
        aps.append(ap)

    mAP = jnp.mean(aps)
    return mAP

# Example usage
gt_boxes_list = [[(0, 0, 1, 1), (0.5, 0.5, 1.5, 1.5)], [(1, 1, 2, 2)], [(0, 0, 1, 1)]]
pred_boxes_list = [[(0, 0, 1, 1), (0.5, 0.5, 1.5, 1.5)], [(1, 1, 2, 2)], [(0, 0, 1, 1)]]

mAP = calculate_mAP(gt_boxes_list, pred_boxes_list, iou_threshold=0.5)
print("mAP:", mAP)
