In [1]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import pandas as pd
import numpy as np
from collections import Counter
from sklearn.model_selection import train_test_split

In [25]:
from functools import partial

In [2]:
# load the CSV file into a Pandas DataFrame
df = pd.read_csv("data.csv")

In [3]:
# extract features
elements = set([el for row in df['element_symbols'].str.split(',') for el in row])
data = []
for i, row in df.iterrows():
    element_counts = Counter(row['element_symbols'])
    feature = [element_counts[el] / row['nsites'] for el in elements]
    feature.extend([1 if el in row['element_symbols'] else 0 for el in elements])
    data.append(feature)

In [59]:
x = np.array(data)
y = df['spacegroup_num'].values

In [60]:
# split the data into training and testing sets
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

In [61]:
x.shape

(10628, 472)

In [62]:
x_train.shape

(8502, 472)

In [41]:
y_train.shape

(8502,)

In [63]:
def make_network(layer_sizes):
    
    def init(key, scale=1e-2):
        params = []
        for n_in, n_out in zip(layer_sizes[:-1], layer_sizes[1:]):
            weight_key, bias_key = jax.random.split(key)
            weight = scale * jax.random.normal(weight_key, (n_in, n_out))
            bias = scale * jax.random.normal(bias_key, (n_out,))
            params.append((weight, bias))
        return params

    def relu(x):
        return jnp.maximum(0, x)

    def apply(params, x):
        for w, b in params[:-1]:
            x = relu(jnp.dot(x, w) + b)
        final_w, final_b = params[-1]
        return jnp.dot(x, final_w) + final_b

    return init, apply

In [74]:
layer_sizes = [472, 220, 128, 64, 1] 
init_fn, apply_fn = make_network(layer_sizes)
key = jax.random.PRNGKey(42)
params = init_fn(key)

In [75]:
from jax.flatten_util import ravel_pytree
ravel_pytree(params)[0].size

140669

In [76]:
def cross_entropy(params, x, y):
    logits = apply_fn(params, x)
    return jnp.sum(y * jax.nn.log_softmax(logits))

def cross_entropy_loss(params, x, y):
    return -jnp.mean(jax.vmap(cross_entropy, (None, 0, 0),0)(params, x, y))

In [96]:
def loss(params, x, y):
    """ Mean squared error loss function """
    preds = apply_fn(params, x)
    return jnp.mean((preds - y)**2)

In [92]:
# Define the update function
@jax.jit
def update(params, x, y, learning_rate):
    grads = jax.grad(loss)(params, x, y)
    return jax.tree_map(lambda p, g: p - learning_rate * g, params, grads) 

In [93]:
# Define the accuracy function
def accuracy(params, x, y):
    predictions = apply_fn(params, x)
    return jnp.mean(predictions == y)

In [95]:
learning_rate = 0.1
num_epochs = 10
batch_size = 128

train_size = X_train.shape[0]
num_complete_batches, leftover = divmod(train_size, batch_size)
num_batches = num_complete_batches + bool(leftover)

for epoch in range(num_epochs):
    # Shuffle the training data
    key, subkey = jax.random.split(key)
    permutation = jax.random.permutation(subkey, train_size)
    x_train = x_train[permutation]
    y_train = y_train[permutation]

    for i in range(num_batches):
        # Get batch data
        batch_start = i * batch_size
        batch_end = (i + 1) * batch_size
        x_batch = x_train[batch_start:batch_end]
        y_batch = y_train[batch_start:batch_end]

        # Update parameters
        params = update(params, x_batch, y_batch, learning_rate)

    # Compute accuracy on training and test sets
    train_accuracy = accuracy(params, x_train, y_train)
    test_accuracy = accuracy(params, x_test, y_test)
    lossf = loss(params, x_train, y_train)
    print(f"Epoch {epoch}: train accuracy = {train_accuracy:.10f}, lossf = {loss:.10f},test accuracy = {test_accuracy:.10f}")

TypeError: 'ArrayImpl' object is not callable

In [84]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [85]:
# Define the MLP model
def init_params(layer_sizes, key):
    """ Initialize the parameters of the MLP """
    params = []
    for i in range(1, len(layer_sizes)):
        key, subkey = random.split(key)
        # Use He initialization for the weights
        input_size, output_size = layer_sizes[i-1], layer_sizes[i]
        W = random.normal(subkey, (input_size, output_size)) * jnp.sqrt(2 / input_size)
        b = jnp.zeros((output_size,))
        params.append((W, b))
    return params

def relu(x):
    """ ReLU activation function """
    return jnp.maximum(0, x)

def predict(params, inputs):
    """ MLP prediction function """
    for W, b in params:
        outputs = jnp.dot(inputs, W) + b
        inputs = relu(outputs)
    return outputs

def loss(params, inputs, targets):
    """ Mean squared error loss function """
    preds = predict(params, inputs)
    return jnp.mean((preds - targets)**2)

In [89]:
@jit
def update(params, x, y, learning_rate):
    """ Update step for the MLP parameters """
    grads = grad(loss)(params, x, y)
    return [(W - learning_rate*dW, b - learning_rate*db)
            for (W, b), (dW, db) in zip(params, grads)]

@jit
def train_mlp(train_X, train_Y, test_X, test_Y, layer_sizes, learning_rate, num_epochs):
    """ Train an MLP and evaluate its performance on train and test data """
    key = random.PRNGKey(42)
    params = init_params(layer_sizes, key)

    for epoch in range(num_epochs):
        for x, y in zip(train_X, train_Y):
            params = update(params, x, y, learning_rate)

        train_loss = loss(params, train_X, train_Y)
        test_loss = loss(params, test_X, test_Y)

        print("Epoch: {}, Train Loss: {}, Test Loss: {}".format(epoch+1, train_loss, test_loss))

    # Compute accuracy on the test set
    preds = predict(params, test_X)
    accuracy = jnp.mean(jnp.round(preds) == test_Y)
    print("Test Accuracy: ", accuracy)

In [90]:
# Train the MLP model
layer_sizes = [472, 220, 128, 64, 1] # Input size, hidden layer sizes, output size
learning_rate = 0.001
num_epochs = 100
train_mlp(x_train, y_train, x_test, y_test, layer_sizes, learning_rate, num_epochs)

TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function train_mlp at /tmp/ipykernel_10423/3470114896.py:8 for jit. This concrete value was not available in Python because it depends on the value of the argument layer_sizes[0].
The error occurred while tracing the function train_mlp at /tmp/ipykernel_10423/3470114896.py:8 for jit. This concrete value was not available in Python because it depends on the value of the argument layer_sizes[1].