In [1]:
import numpy as np
import keras

2024-09-25 16:13:15.784026: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-25 16:13:15.807798: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-25 16:13:15.814966: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-09-25 16:13:15.837796: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
import numpy.random as npr


def init_random_params(layer_sizes, scale=0.1, rng=npr.RandomState(0)):
    return[(scale*rng.randn(m, n), scale*rng.rand(n)) for m, n in zip(layer_sizes[:-1], layer_sizes[1:])]

In [9]:
from jax import jit, grad
from jax import numpy as jnp
from jax.nn import softmax
from jax.scipy.special import logsumexp

def predict(params, inputs):
    activations = inputs
    for w, b in params[:-1]:
        outputs = jnp.dot(activations, w) + b
        activations = softmax(outputs)
    
    final_w, final_b = params[-1]
    logits = jnp.dot(activations, final_w) + final_b
    return logits - logsumexp(logits, axis=1, keepdims=True)

In [10]:
# Loss measures the difference between the guess and the label.
# Here we can use CrossEntropyLoss. We currently have the log pred.

def loss(params, batch):
    inputs, targets = batch
    preds = predict(params, inputs)
    return -jnp.mean(jnp.sum(preds * targets, axis =1))


In [11]:
def accuracy(params, batch):
    inputs, targets = batch
    target_class = jnp.argmax(targets, axis = 1)
    predicted_class = jnp.argmax(predict(params, inputs), axis=1)
    return jnp.mean(predicted_class == target_class)

In [13]:
import time


if __name__ == "__main__":
    layers = [784, 1024, 1024, 10]
    param_scale = 0.3
    learning_rate = 1e-2
    num_epochs = 20
    batch_size = 128

    (train_images, train_labels), (eval_images, eval_labels) = keras.datasets.mnist.load_data()

    train_images = train_images.astype("float32") / 255
    eval_images = eval_images.astype("float32") / 255
    train_images = np.reshape(train_images, (train_images.shape[0], 784))
    eval_images = np.reshape(eval_images, (eval_images.shape[0], 784))
    b = np.zeros((eval_labels.size, eval_labels.max() + 1))
    b[np.arange(eval_labels.size), eval_labels] = 1
    eval_labels = b

    b = np.zeros((train_labels.size, train_labels.max() + 1))
    b[np.arange(train_labels.size), train_labels] = 1
    train_labels = b
    
    dataset_size = train_images.shape[0]
    num_complete_batches, leftover = divmod(dataset_size, batch_size)
    num_batches = num_complete_batches + bool(leftover)

    def data_stream():
        rng = npr.RandomState(0)
        while True:
            perm = rng.permutation(dataset_size) # Shuffle the dataset
            for i in range(num_batches):
                batch_idx = perm[i * num_batches: (i+1)*num_batches]
                yield (train_images[batch_idx], train_labels[batch_idx])

                
    batches = data_stream()

    @jit
    def update(params, batch):
        grads = grad(loss)(params, batch)
        return [(w - learning_rate * dw, b - learning_rate * db) for (w, b), (dw, db) in zip(params, grads)]

    params = init_random_params(layers, scale=param_scale)
    for epoch in range(num_epochs):
        start_time = time.time()
        for batch in batches:
            params = update(params, batch)
            train_acc = accuracy(params, (train_images, train_labels))
            print(f"Training accuracy: {train_acc:0.4f}")

        epoch_time = time.time() - start_time

        train_acc = accuracy(params, (train_images, train_labels))
        eval_acc = accuracy(params, (eval_images, eval_labels))
        print(f"Epoch {epoch} in {epoch_time:0.2f} seconds")
        print(f"Training accuracy: {train_acc:0.4f}")
        print(f"Eval accuracy: {eval_acc:0.4f}")

    

Training accuracy: 0.1041
Training accuracy: 0.1043
Training accuracy: 0.1043
Training accuracy: 0.1043
Training accuracy: 0.1043


KeyboardInterrupt: 

In [None]:
dot_img_file = 'model_graph.png'
keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True)

In [21]:
import jax 
def dlfn(params, batch):
    return  grad(loss)(params, batch)

z=jax.xla_computation(dlfn)(params, batch)
with open("t.txt", "w") as f:
    f.write(z.as_hlo_text())

  z=jax.xla_computation(dlfn)(params, batch)


In [22]:
with open("t.dot", "w") as f:
    f.write(z.as_hlo_dot_graph())

