In [28]:
import jax.numpy as jnp
import jax
import random

Write a script to perform SGD on a polynomial with 3 variables x,y,z i.e $\sum a_{i,j,k}x^iy^jz^k$. The polynomial should have maximum degree $N_x$ for x, $N_y$ for y, and $N_z$ for z.

a. Disregarding t, what is the maximum number of coeffcients needed to fully specify $P(x, y, z)$ for given $N_x, N_y, N_z$?

The number of terms in $P(x,y,z)$ can be calculated as the product of possibilities for each variable that is $(N_x+1)(N_y+1)(N_z+1)$.
In JAX, a data structure for storing these coefficients is a 3D array, where each dimension corresponds to the power of $x,y,z$.

b. c. 

In [29]:
def gen_random(Nx,Ny,Nz,t):
    coefficients = jnp.zeros((Nx + 1, Ny + 1, Nz + 1))
    terms = set()

    while len(terms) < t:
        i = random.randint(0,Nx)
        j = random.randint(0,Ny)
        k = random.randint(0,Nz)
        if (i, j, k) not in terms:
            terms.add((i, j, k))
            
            coefficients = coefficients.at[i, j, k].set(random.uniform(-5, 5))
        return coefficients
    


def eval_poly(coefficients, x,y,z):
    Nx,Ny,Nz = coefficients.shape
    result = 0
    for i in range(Nx):
        for j in range(Ny):
            for k in range(Nz):
                result += coefficients[i, j, k] * (x ** i) * (y ** j) * (z ** k)
    return result

def gen_train_data(coefficients, N, noise_frac=0.1, rnd_seed=42):
    rng = jax.random.PRNGKey(rnd_seed)
    data = []

    for _ in range(N):
        x = random.uniform(-10,10)
        y = random.uniform(-10,10)
        z = random.uniform(-10,10)
        true_value = eval_poly(coefficients,x,y,z)
        noise = noise_frac * true_value * jax.random.normal(rng, ())
        noisy_value = true_value + noise
        data.append((x, y, z, noisy_value))

    return jnp.array(data)

In [30]:
#Loss funstion and its gradient

def loss(params, data):
    errors = [(eval_poly(params, x, y, z) - target) ** 2 for x, y, z, target in data]
    return jnp.log(jnp.sum(jnp.array(errors)))

grad_loss = jax.grad(loss)

def sgd_reconstruct(training_data, Nx, Ny, Nz, t, num_epochs=500, learning_rate=0.001):
    # Initialize random coefficients
    #params = jnp.zeros((Nx + 1, Ny + 1, Nz + 1))
    coefficients = gen_random(Nx, Ny, Nz, t)
    for epoch in range(num_epochs):
        for i in range(len(training_data)):
            batch = training_data[i:i+1]
            grad = grad_loss(coefficients, batch)
            coefficients = coefficients - learning_rate * grad
        if epoch % 2 == 0:
            print(f"Epoch {epoch}: Loss={loss(coefficients, training_data)}")
    return coefficients

In [27]:
#Testing

Nx1, Ny1, Nz1, t1 = 2, 4, 6, 12
Nx2, Ny2, Nz2, t2 = 3, 1, 2, 5
N_data = 10  # Training data size
noise_frac = 0.1

secret_poly1 = gen_random(Nx1, Ny1, Nz1, t1)
training_data1 = gen_train_data(secret_poly1, N_data, noise_frac)

secret_poly2 = gen_random(Nx2, Ny2, Nz2, t2)
training_data2 = gen_train_data(secret_poly2, N_data, noise_frac)

# Run SGD on each polynomial's training data
print("\nReconstructing coefficients for Polynomial 1")
reconstructed_poly1 = sgd_reconstruct(training_data1, Nx1, Ny1, Nz1, t1)

print("\nReconstructing coefficients for Polynomial 2")
reconstructed_poly2 = sgd_reconstruct(training_data2, Nx2, Ny2, Nz2, t2)






Reconstructing coefficients for Polynomial 1
Epoch 0: Loss=55.798526763916016
Epoch 2: Loss=55.797176361083984
Epoch 4: Loss=55.79582595825195
Epoch 6: Loss=55.794471740722656
Epoch 8: Loss=55.793113708496094
Epoch 10: Loss=55.791748046875
Epoch 12: Loss=55.790382385253906
Epoch 14: Loss=55.78900909423828
Epoch 16: Loss=55.78763198852539
Epoch 18: Loss=55.7862434387207
Epoch 20: Loss=55.784828186035156
Epoch 22: Loss=55.78434371948242
Epoch 24: Loss=55.78298568725586
Epoch 26: Loss=55.78162384033203
Epoch 28: Loss=55.7802619934082
Epoch 30: Loss=55.778900146484375
Epoch 32: Loss=55.777530670166016
Epoch 34: Loss=55.77616500854492
Epoch 36: Loss=55.7747917175293
Epoch 38: Loss=55.77342224121094
Epoch 40: Loss=55.77204513549805
Epoch 42: Loss=55.770668029785156
Epoch 44: Loss=55.769290924072266
Epoch 46: Loss=55.76791000366211
Epoch 48: Loss=55.76652526855469
Epoch 50: Loss=55.765140533447266
Epoch 52: Loss=55.763755798339844
Epoch 54: Loss=55.762367248535156
Epoch 56: Loss=55.760974884

TypeError: Cannot interpret 'Array(-1.1003532, dtype=float32)' as a data type