In [1]:
import jax.numpy as jnp
import jax
from jax import lax
from jax import Array

In [2]:
# Linear Regression Using the Normal Equation
def linear_regression_normal_equation(X: Array, y: Array) -> Array:
    return jnp.linalg.inv(X.T @ X) @ X.T @ y
# Test the function
X = jnp.array([[1, 1], [1, 2], [1, 3]])
y = jnp.array([1, 2, 3])
linear_regression_normal_equation(X, y)

Array([9.536743e-07, 1.000000e+00], dtype=float32)

In [9]:
# Linear Regression Using Gradient Descent
def linear_regression_gradient_descent(X: Array, y: Array, alpha: jnp.float32 = 0.1, n_iterations: jnp.int32 = 10000, seed: jnp.int32 = 0) -> Array:
    # Initialize theta at random
    key = jax.random.key(seed)
    theta = jax.random.normal(key, (X.shape[1], 1))
    for _ in range(n_iterations):
        y_pred = X @ theta
        error = y_pred - y.reshape(-1, 1)
        gradient = X.T @ error / X.shape[0]
        theta = theta - alpha * gradient
    return theta
# Test the function
X = jnp.array([[1, 1], [1, 2], [1, 3]])
y = jnp.array([1, 2, 3])
linear_regression_gradient_descent(X, y)

Array([[-1.4583919e-06],
       [ 1.0000007e+00]], dtype=float32)

In [15]:
# Feature Scaling Implementation 
def feature_scaling(X: Array) -> [Array, Array]:
    # Standardization
    mean = X.mean(axis=0)
    std = X.std(axis=0)
    standardized_data = (X - mean) / std
    
    # Min-Max Normalization
    min_value = X.min(axis=0)
    print(min_value)
    max_value = X.max(axis=0)
    normalized_data = (X - min_value) / (max_value - min_value)
    
    return standardized_data, normalized_data
    
# Test the function
X = jnp.array([[1, 2], [3, 4], [5, 6]])
feature_scaling(X)

[1 2]


(Array([[-1.2247448, -1.2247448],
        [ 0.       ,  0.       ],
        [ 1.2247448,  1.2247448]], dtype=float32),
 Array([[0. , 0. ],
        [0.5, 0.5],
        [1. , 1. ]], dtype=float32))

In [18]:
# Random Shuffle of Dataset
def shuffle_data(X: Array, y: Array, seed: jnp.int32 = 42) -> [Array, Array]:
    key = jax.random.PRNGKey(seed)
    perm = jax.random.permutation(key, X.shape[0])
    return X[perm], y[perm]
# Test the function
X = jnp.array([[1, 2], [3, 4], [5, 6]])
y = jnp.array([1, 2, 3])
shuffle_data(X, y)

(Array([[5, 6],
        [1, 2],
        [3, 4]], dtype=int32),
 Array([3, 1, 2], dtype=int32))

In [21]:
# Batch Iterator for Dataset
def batch_iterator(X: Array, y: Array, batch_size: jnp.int32 = 32) -> [Array, Array]:
    n_samples = X.shape[0]
    for i in range(0, n_samples, batch_size):
        yield X[i:i+batch_size], y[i:i+batch_size]
        
# Test the function
X = jnp.array([[1, 2],
               [3, 4],
               [5, 6],
               [7, 8],
               [9, 10]])
y = jnp.array([1, 2, 3, 4, 5])
batch_gen = batch_iterator(X, y, 2)

for batch_X, batch_y in batch_gen:
    print("Batch X:\n", batch_X)
    print("Batch y:\n", batch_y)

Batch X:
 [[1 2]
 [3 4]]
Batch y:
 [1 2]
Batch X:
 [[5 6]
 [7 8]]
Batch y:
 [3 4]
Batch X:
 [[ 9 10]]
Batch y:
 [5]


In [25]:
# One-Hot Encoding of Nominal Values
def to_categorical(X: Array, n_classes: jnp.int32) -> Array:
    one_hot = jnp.zeros((X.size, n_classes))
    for i in range(X.size):
        one_hot = one_hot.at[i, X[i]].set(1)
    return one_hot
# Test the function
X = jnp.array([0, 1, 2, 1, 0])
to_categorical(X, 3)

Array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.],
       [0., 1., 0.],
       [1., 0., 0.]], dtype=float32)

In [27]:
# Calculate Accuracy Score
def accuracy_score(y_true: Array, y_pred: Array) -> jnp.float32:
    return jnp.mean(y_true == y_pred)
# Test the function
y_true = jnp.array([1, 0, 1, 1, 0, 1])
y_pred = jnp.array([1, 0, 0, 1, 0, 1])
accuracy_score(y_true, y_pred)

Array(0.8333334, dtype=float32)

In [29]:
# Implement Ridge Regression Loss Function
def ridge_loss(X: Array, w: Array, y_true: Array, alpha: jnp.float32) -> jnp.float32:
   return jnp.mean((X @ w - y_true) ** 2) + alpha * jnp.sum(w ** 2)
# Test the function
X = jnp.array([[1, 1], [2, 1], [3, 1], [4, 1]])
w = jnp.array([0.2, 2])
y_true = jnp.array([2, 3, 4, 5])
alpha = 0.1
ridge_loss(X, w, y_true, alpha)

Array(2.204, dtype=float32)

In [30]:
# Linear Kernel Function
def kernel_function(X1: Array, X2: Array) -> jnp.float32:
    return jnp.inner(X1, X2)
# Test the function
X1 = jnp.array([1, 2, 3])
X2 = jnp.array([4, 5, 6])
kernel_function(X1, X2)

Array(32, dtype=int32)

In [31]:
# Implement Precision Metric
def precision(y_true: Array, y_pred: Array) -> jnp.float32:
    true_positives = jnp.sum((y_true == 1) & (y_pred == 1))
    false_positives = jnp.sum((y_true == 0) & (y_pred == 1))
    return true_positives / (true_positives + false_positives)
# Test the function
y_true = jnp.array([1, 0, 1, 1, 0, 1])
y_pred = jnp.array([1, 0, 1, 0, 0, 1])
precision(y_true, y_pred)

Array(1., dtype=float32)

In [32]:
# Implement Recall Metric in Binary Classification
def recall(y_true: Array, y_pred: Array) -> jnp.float32:
    true_positives = jnp.sum((y_true == 1) & (y_pred == 1))
    false_negatives = jnp.sum((y_true == 1) & (y_pred == 0))
    return true_positives / (true_positives + false_negatives)
# Test the function
y_true = jnp.array([1, 0, 1, 1, 0, 1])
y_pred = jnp.array([1, 0, 1, 0, 0, 1])
recall(y_true, y_pred)

Array(0.75, dtype=float32)

In [33]:
# Implement F-Score Calculation for Binary Classification
def f_score(y_true: Array, y_pred: Array, beta: jnp.float32 = 1) -> jnp.float32:
    p = precision(y_true, y_pred)
    r = recall(y_true, y_pred)
    return (1 + beta ** 2) * (p * r) / ((beta ** 2) * p + r)
# Test the function
y_true = jnp.array([1, 0, 1, 1, 0, 1])
y_pred = jnp.array([1, 0, 1, 0, 0, 1])
f_score(y_true, y_pred)

Array(0.85714287, dtype=float32)

In [34]:
# Implement Gini Impurity Calculation for a Set of Classes
def gini_impurity(y: Array) -> jnp.float32:
    return 1 - jnp.sum((jnp.bincount(y) / y.size) ** 2)
# Test the function
y = jnp.array([0, 1, 1, 1, 0])
gini_impurity(y)

Array(0.47999996, dtype=float32)

In [35]:
# Calculate R-squared for Regression Analysis

def r_squared(y_true: Array, y_pred: Array) -> jnp.float32:
    ss_res = jnp.sum((y_true - y_pred) ** 2)
    ss_tot = jnp.sum((y_true - y_true.mean()) ** 2)
    return 1 - ss_res / ss_tot
# Test the function
y_true = jnp.array([1, 2, 3, 4, 5])
y_pred = jnp.array([1.1, 2.1, 2.9, 4.2, 4.8])
r_squared(y_true, y_pred)

Array(0.989, dtype=float32)

In [36]:
# Calculate Root Mean Square Error (RMSE)
def rmse(y_true: Array, y_pred: Array) -> jnp.float32:
    return jnp.sqrt(jnp.mean((y_true - y_pred) ** 2))
# Test the function
y_true = jnp.array([3, -0.5, 2, 7])
y_pred = jnp.array([2.5, 0.0, 2, 8])
rmse(y_true, y_pred)

Array(0.61237246, dtype=float32)

In [41]:
# Calculate Jaccard Index for Binary Classification
def jaccard_index(y_true: Array, y_pred: Array) -> jnp.float32:
    intersection = jnp.sum((y_true ==1) & (y_pred == 1))
    union = jnp.sum((y_true ==1 )| (y_pred == 1))
    result = intersection / union
    if jnp.isnan(result):
        return  jnp.float32(0.0)
    return intersection / union
y_true = jnp.array([1, 0, 1, 1, 0, 1])
y_pred = jnp.array([1, 0, 1, 0, 0, 1])
jaccard_index(y_true, y_pred)

Array(0.75, dtype=float32)

In [43]:
# Calculate Dice Score for Classification
def dice_score(y_true: Array, y_pred: Array) -> jnp.float32:
    intersection = jnp.sum((y_true == 1) & (y_pred == 1))
    if jnp.sum(y_true == 1) ==0 or jnp.sum(y_pred == 1) == 0:
        return jnp.float32(0.0)
    return 2 * intersection / (jnp.sum(y_true == 1) + jnp.sum(y_pred == 1))
# Test the function
y_true = jnp.array([1, 1, 0, 1, 0, 1])
y_pred = jnp.array([1, 1, 0, 0, 0, 1])
dice_score(y_true, y_pred)

Array(0.85714287, dtype=float32)

In [44]:
# Generate a Confusion Matrix for Binary Classification
def confusion_matrix(data: Array) -> Array:
    tp = jnp.sum((data[:, 0] == 1) & (data[:, 1] == 1))
    fp = jnp.sum((data[:, 0] == 1) & (data[:, 1] == 0))
    fn = jnp.sum((data[:, 0] == 0) & (data[:, 1] == 1))
    tn = jnp.sum((data[:, 0] == 0) & (data[:, 1] == 0))
    conf_matrix = jnp.array([[tp, fp], [fn, tn]])
    return conf_matrix
# Test the function
data = jnp.array([[1, 1], [1, 0], [0, 1], [0, 0], [0, 1]])
confusion_matrix(data)

Array([[1, 1],
       [2, 1]], dtype=int32)