In [1]:
import jax
import jax.numpy as jnp

In [14]:
rng = jax.random.PRNGKey(0)
a = jax.random.uniform(rng, (2,3,3))
b = jax.random.uniform(rng, (2,3,4))

In [15]:
# Matrix multiplication
c = jnp.einsum('bij,bjk->bik', a, b)
print(a.shape, b.shape, c.shape)

(2, 3, 3) (2, 3, 4) (2, 3, 4)


In [31]:
# Dot product
a = jax.random.uniform(rng, (4,4))
b = jnp.einsum('ii', a)
print(b)

2.081306


In [32]:
'''
KNN in JAX
'''

'\nKNN in JAX\n'

In [33]:
'''
PCA in JAX
'''

'\nPCA in JAX\n'

In [44]:
import jax
import jax.numpy as jnp
from jax import jit

# Generate synthetic data
data = jnp.array([[1, 2], [2, 3], [8, 9], [9, 10], [5, 6]])

# Initialize cluster centroids
k = 2
centroids = data[:k]

# Define the distance function
def distance(x, centroids):
    return jnp.linalg.norm(x - centroids, axis=1)

# Assign data points to the nearest centroids
def assign_clusters(data, centroids):
    return jnp.argmin(jnp.linalg.norm(data[:, None] - centroids, axis=2), axis=1)

# Update centroids to the mean of assigned data points
def update_centroids(data, cluster_assignments):
    new_centroids = jnp.array([jnp.mean(data[cluster_assignments == i], axis=0) for i in range(k)])
    return new_centroids

# Perform K-Means clustering
for _ in range(100):
    cluster_assignments = assign_clusters(data, centroids)
    new_centroids = update_centroids(data, cluster_assignments)
    
    # Check for convergence
    if jnp.all(centroids == new_centroids):
        break
    
    centroids = new_centroids

# Final cluster assignments
print("Cluster Assignments:", cluster_assignments)
print("Cluster Centroids:", centroids)


(5, 2) (2, 2)
(5, 2, 2)


In [37]:
'''
Linear regression in JAX
'''
import jax
import jax.numpy as jnp
from jax import grad

# Generate synthetic data
X = jnp.array([[1, 1], [2, 2], [3, 3], [4, 4]])
y = jnp.array([2, 4, 6, 8])

# Initialize model parameters
params = jnp.array([0.0, 0.0])

# Define the linear regression model
def linear_regression(params, x):
    return jnp.dot(params, x)

# Define the mean squared error loss
def mse(params, x, y):
    predictions = linear_regression(params, x)
    return jnp.mean((predictions - y) ** 2)

# Compute gradients and update parameters using gradient descent
learning_rate = 0.01
for _ in range(100):
    gradients = grad(mse)(params, X.T, y)
    params -= learning_rate * gradients

# Make predictions
predictions = linear_regression(params, X.T)
print(predictions)


[1.9999999 3.9999998 5.9999995 7.9999995]


In [38]:
'''
Logistic regression in JAX
'''
import jax
import jax.numpy as jnp
from jax import grad

# Generate synthetic data
X = jnp.array([[1, 2], [2, 3], [3, 4], [5, 6]])
y = jnp.array([0, 0, 1, 1], dtype=jnp.float32)

# Initialize model parameters
params = jnp.array([0.0, 0.0])

# Define the logistic regression model
def logistic_regression(params, x):
    z = jnp.dot(params, x)
    return 1.0 / (1.0 + jnp.exp(-z))

# Define the cross-entropy loss
def cross_entropy(params, x, y):
    predictions = logistic_regression(params, x)
    return -jnp.mean(y * jnp.log(predictions) + (1.0 - y) * jnp.log(1.0 - predictions))

# Compute gradients and update parameters using gradient descent
learning_rate = 0.01
for _ in range(100):
    gradients = grad(cross_entropy)(params, X.T, y)
    params -= learning_rate * gradients

# Make predictions
predictions = logistic_regression(params, X.T)
print(predictions)


[0.5671718  0.6211104  0.67221195 0.7624377 ]
