<a href="https://colab.research.google.com/github/marekpiotradamczyk/ml_uwr_23/blob/main/Assignments/Assignment8.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Learning a sphere

Below you can find a sphere given by a function $f(\phi, \theta) = (x,y,z) = (\sin \theta \cdot \cos \phi, \sin \theta \cdot \sin \phi, \cos \theta)$. Let's add some more jazz to it, and let add a random normal noise to each point so that the relation between $X$ and $Y$ will be given by $Y = f(\phi, \theta) + \varepsilon$, where $\varepsilon \sim \mathcal{N}\left(0^3, \sigma\cdot \begin{pmatrix}
1 & 0 & 0 \\
0 & 1 & 0 \\
0 & 0 & 1
\end{pmatrix}\right)$ is a 3-dimensional Normal Distribution.

In [1]:
!pip install plotly



In [6]:
import plotly.graph_objects as go
import numpy as np

# Number of points
n_points = 1000

# Generate random points on a sphere
phi = np.random.uniform(0, 2 * np.pi, n_points)
theta = np.arccos(np.random.uniform(-1, 1, n_points))

# Convert spherical coordinates to Cartesian coordinates
x = np.sin(theta) * np.cos(phi) 
y = np.sin(theta) * np.sin(phi)
z = np.cos(theta)

#Make some noise
epsilon = 0.1
x = x + np.random.randn(n_points) * epsilon
y = y + np.random.randn(n_points) * epsilon
z = z + np.random.randn(n_points) * epsilon

# Create a Plotly figure
fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z, mode='markers',
                                   marker=dict(size=2, color='blue'))])

# Update layout for a better view
fig.update_layout(title="3D Scatter Plot on a Unit Sphere",
                  scene=dict(
                      xaxis_title='X Axis',
                      yaxis_title='Y Axis',
                      zaxis_title='Z Axis'
                  ),
                  margin=dict(l=0, r=0, b=0, t=0))

# Show the plot
fig.show()


# simple neural network to learn a sphere
Below you can find a simple neural network which will learn function $f(\phi, \theta) =  (\sin \theta \cdot \cos \phi, \sin \theta \cdot \sin \phi, \cos \theta) = (x,y,z)$. 

In [11]:
!pip install jax
!pip install jaxlib

Collecting jaxlib
  Downloading jaxlib-0.4.25-cp310-cp310-manylinux2014_x86_64.whl.metadata (2.1 kB)
Downloading jaxlib-0.4.25-cp310-cp310-manylinux2014_x86_64.whl (79.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.2/79.2 MB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: jaxlib
Successfully installed jaxlib-0.4.25


In [12]:
#CUT{
import jax
import jax.numpy as jnp
from jax import random


# Set a random seed for reproducibility
key = random.PRNGKey(0)

# Define the neural network model
def neural_network(params, x):
    w0, b0, w1, b1, w2, b2 = params
    h1 = jax.nn.tanh(jnp.dot(x, w0) + b0)
    h2 = jax.nn.tanh(jnp.dot(h1, w1) + b1)
    return jnp.dot(h2, w2) + b2

# Define the mean squared error loss function
def mean_squared_error(params, x, y_true):
    y_pred = neural_network(params, x)
    return jnp.mean((y_pred - y_true)**2)

# Initialize the neural network parameters
key, subkey = random.split(key)

in_dim = 2
h1_dim = 24
h2_dim = 24
ot_dim = 3
params = [
    random.normal(subkey, (in_dim, h1_dim)), random.normal(subkey, (h1_dim,)),
    random.normal(subkey, (h1_dim, h2_dim)), random.normal(subkey, (h2_dim,)),
    random.normal(subkey, (h2_dim, ot_dim)), random.normal(subkey, (ot_dim,))
]
mems = [ jnp.zeros_like(param) for param in params]

grad_loss = jax.grad(mean_squared_error)

#CUT}

AttributeError: partially initialized module 'jax' has no attribute 'version' (most likely due to a circular import)

In [9]:
#CUT{
phi_theta = np.column_stack((phi, theta))

xyz = np.column_stack((x,y,z))



# Training loop
learning_rate = 0.01
num_epochs = 10000

import math
prev_loss = math.inf
for epoch in range(num_epochs):
    # Compute gradients and update parameters
    random_indices = np.random.choice(n_points, 44, replace=False)
    grads = grad_loss(params, phi_theta[random_indices], xyz[random_indices])
    
    mems = [ 0.99*mem + grad**2 for mem, grad in zip(mems, grads)]
    params = [param - learning_rate * grad_param / jnp.sqrt( mem + 1e-8) for param, grad_param, mem in zip(params, grads, mems)]
    
    # Print the loss every 100 epochs
    if epoch % 1000 == 0:
        loss_value = mean_squared_error(params, phi_theta, xyz)
        print(f"Epoch {epoch}, Loss: {loss_value}")

        if loss_value > prev_loss + 0.1:
            print(f"gradient descent ends")
            break
        prev_loss = loss_value

#CUT}

NameError: name 'grad_loss' is not defined

In [10]:
import plotly.graph_objects as go
import numpy as np

# After training, use the learned model to predict the transformed points
predicted_points = neural_network(params, phi_theta)

# Create a Plotly figure
fig = go.Figure(data=[go.Scatter3d(x=predicted_points[:,0],
                                   y=predicted_points[:,1],
                                   z=predicted_points[:,2],
                                   mode='markers',
                                   marker=dict(size=2, color='blue'))])

# Update layout for a better view
fig.update_layout(title="3D Scatter Plot on a Unit Sphere",
                  scene=dict(
                      xaxis_title='X Axis',
                      yaxis_title='Y Axis',
                      zaxis_title='Z Axis'
                  ),
                  margin=dict(l=0, r=0, b=0, t=0))

# Show the plot
fig.show()


NameError: name 'neural_network' is not defined