In [2]:
%pip install jax jaxlib
import jax
from jax import numpy as jnp

Collecting jax
  Downloading jax-0.4.35-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib
  Downloading jaxlib-0.4.35-cp312-cp312-macosx_11_0_arm64.whl.metadata (983 bytes)
Collecting ml-dtypes>=0.4.0 (from jax)
  Using cached ml_dtypes-0.5.0-cp312-cp312-macosx_10_9_universal2.whl.metadata (21 kB)
Collecting numpy>=1.24 (from jax)
  Downloading numpy-2.1.2-cp312-cp312-macosx_14_0_arm64.whl.metadata (60 kB)
Collecting opt-einsum (from jax)
  Using cached opt_einsum-3.4.0-py3-none-any.whl.metadata (6.3 kB)
Collecting scipy>=1.10 (from jax)
  Using cached scipy-1.14.1-cp312-cp312-macosx_14_0_arm64.whl.metadata (60 kB)
Downloading jax-0.4.35-py3-none-any.whl (2.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m23.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jaxlib-0.4.35-cp312-cp312-macosx_11_0_arm64.whl (68.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m68.1/68.1 MB[0m [31m20.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01

In [49]:

class MemristorCrossbar(object):
    def __init__(self, shape, min_conductance=0, max_conductance=1, key=jax.random.PRNGKey(0)):
        rows, cols = shape
        _key, self.key = jax.random.split(key, 2)
        self.min_conductance = min_conductance
        self.max_conductance = max_conductance
        
        conductance_range = max_conductance - min_conductance
        self.coefficients = jax.random.uniform(_key, (rows, cols), dtype=jnp.float32)*conductance_range + min_conductance

    def forward(self, x):
        return jnp.dot(self.coefficients, x)

    def backward(self, y):
        return jnp.dot(y.T, self.coefficients.T)
    
    def outer_product_update(self, inp, outp, learning_rate=0.01, noise_std=0.001):
        # compute the output for the given input
        pred_outp = self.forward(inp)

        # compute the ternarized error
        E = jnp.sign(outp - pred_outp)
        print(f"got: {pred_outp}, expected: {outp}, error: {E}")

        # compute the update
        _key, self.key = jax.random.split(self.key, 2)
        dW = learning_rate * jnp.outer(inp, E) + jax.random.normal(_key, self.coefficients.shape) * noise_std

        self.coefficients = jnp.clip(self.coefficients + dW, self.min_conductance, self.max_conductance)
        return self.coefficients, jnp.linalg.norm(E)

In [50]:
key0, key1, key2 = jax.random.split(jax.random.PRNGKey(0), 3)
num_rows, num_cols = 10, 10
num_samples = 100

thresholds = jnp.array([-1, 1])
cb = MemristorCrossbar((10, 10), key=key0, min_conductance=-10, max_conductance=10)

# target matrix
mat = jax.random.normal(key1, (num_rows, num_cols))*10

# random input vectors (ternary)
inputs = jnp.round(jax.random.uniform(key2, (10,num_samples), minval=-1.5, maxval=1.5))

# target output vectors
outputs = jnp.digitize(jnp.dot(mat, inputs), thresholds) - 1

# update the crossbar
for i in range(num_samples):
    _, E = cb.outer_product_update(inputs[:,i], outputs[:,i], learning_rate=0.1, noise_std=0.01)
    print(f"Error: {E:.2f}")



got: [ -0.81951    10.024758   -5.388677    6.927127  -18.699364  -31.740871
   4.0758185 -18.690172   16.867758    7.5898595], expected: [-1 -1  1 -1  1 -1 -1 -1  1  1], error: [-1. -1.  1. -1.  1.  1. -1.  1. -1. -1.]
Error: 3.16
got: [ 4.42084    -0.04894328 -6.7270713  -3.2126017   3.391058   -3.0326514
  9.086334   20.47334    -3.8854926   2.0964413 ], expected: [-1 -1 -1 -1 -1 -1  1  1 -1 -1], error: [-1. -1.  1.  1. -1.  1. -1. -1.  1. -1.]
Error: 3.16
got: [  4.9479647 -12.681738   22.485924    3.2268515   6.532053   16.758652
 -10.199177   23.683868  -12.560626   -6.729405 ], expected: [ 1 -1 -1  1  1  1  1  1  1 -1], error: [-1.  1. -1. -1. -1. -1.  1. -1.  1.  1.]
Error: 3.16
got: [  9.915655  -10.851304    8.216419   12.175535  -14.637548  -10.950789
  -7.3319697  11.300847   -9.851976  -19.550179 ], expected: [ 1  1  1 -1  1  1 -1 -1 -1 -1], error: [-1.  1. -1. -1.  1.  1.  1. -1.  1.  1.]
Error: 3.16
got: [ -7.4106607 -19.169764    6.721541  -30.194351  -20.131084    9.69

In [39]:
cb.coefficients

Array([[  0.27303347,   3.4133468 ,   2.698047  ,  -5.321483  ,
         -2.4610946 ,  -4.6736965 ,  -0.9453821 ,   9.233441  ,
          6.608251  ,  -0.44214743],
       [  2.890135  ,  -1.7186968 ,  -8.531701  ,   7.1469827 ,
          6.261838  ,  -5.9661026 ,  -1.0238335 ,  -2.7529078 ,
         -0.8153183 ,  -2.1493304 ],
       [  4.25529   ,  10.        ,  -4.8442416 ,  -8.413846  ,
         -0.5819839 ,   8.3721075 ,  -9.9814005 , -10.        ,
          2.6175122 ,   0.02008253],
       [  8.908697  ,  -6.3992767 ,   3.8701587 ,   1.5764375 ,
         10.        ,   0.38979635,  -8.327979  ,   1.8886642 ,
          4.336371  ,   8.95297   ],
       [  9.711518  ,  -8.779551  ,   1.7742808 , -10.        ,
          5.1011457 ,  -7.2443075 ,  -5.4573045 ,  -8.197834  ,
          9.641652  ,  -3.6498408 ],
       [  2.819837  ,  -1.1989489 ,  -3.4030905 ,  -9.5967865 ,
          2.4149792 ,   5.590087  ,   8.7535715 ,  -5.883714  ,
         -9.921344  ,  -8.790252  ],
       [ -