In [5]:
# Import required libraries
import os
import sys
import numpy as np
import jax
import jax.numpy as jnp
from jax import random, jit, grad, vmap, jacfwd, jacrev
import optax
import jaxopt
import matplotlib.pyplot as plt
from matplotlib import cm
import matplotlib as mpl
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score, mean_squared_error

def test_imports():
    print("Testing imports and basic functionality...\n")
    
    # Test numpy
    try:
        np_array = np.array([1, 2, 3])
        assert np_array.sum() == 6
        print("NumPy: OK")
    except Exception as e:
        print(f"NumPy: FAILED - {e}")
    
    # Test JAX
    try:
        jnp_array = jnp.array([1.0, 2.0, 3.0])
        grad_function = grad(lambda x: x**2)(2.0)
        print(f"JAX: OK")
    except Exception as e:
        print(f"JAX: FAILED - {e}")
    
    # Test Optax
    try:
        optimizer = optax.adam(learning_rate=0.01)  # Initialize Adam optimizer
        params = jnp.array([1.0])  # Example parameters
        opt_state = optimizer.init(params)  # Initialize optimizer state
        grads = jnp.array([0.1])  # Example gradient
        updates, opt_state = optimizer.update(grads, opt_state)  # Apply gradient updates
        params = optax.apply_updates(params, updates)  # Update parameters
        print("Optax: OK")
    except Exception as e:
        print(f"Optax: FAILED - {e}")
    
    # Test Matplotlib
    try:
        plt.figure()
        plt.plot([0, 1], [0, 1], label="Test Line")
        plt.legend()
        plt.close()  # Close the figure to avoid showing it during tests
        print("Matplotlib: OK")
    except Exception as e:
        print(f"Matplotlib: FAILED - {e}")
    
    # Test Scikit-learn
    try:
        model = LinearRegression()
        X = np.array([[1], [2], [3]])
        y = np.array([2, 4, 6])
        model.fit(X, y)
        predictions = model.predict(X)
        r2 = r2_score(y, predictions)
        print(f"Scikit-learn: OK")
    except Exception as e:
        print(f"Scikit-learn: FAILED - {e}")
    
    print("\nTesting complete!")

# Run the test
test_imports()


Testing imports and basic functionality...

NumPy: OK
JAX: OK
Optax: OK
Matplotlib: OK
Scikit-learn: OK

Testing complete!
