In [2]:
import numpy as np
import dace as dc
from dace.autodiff import add_backward_pass
from dace.transformation.auto.auto_optimize import auto_optimize
from dace.dtypes import DeviceType
import jax
import jax.numpy as jnp

# Define matrix dimensions

In [3]:

NI, NJ, NK, NL = 32, 36, 38, 42


# Define the DaCe program for the k2mm computation

In [4]:
@dc.program
def k2mm(alpha: dc.float64, beta: dc.float64, A: dc.float64[NI, NK], B: dc.float64[NK, NJ], C: dc.float64[NJ, NL],
         D: dc.float64[NI, NL], S: dc.float64[1]):
    """
    Computes D = alpha * A @ B @ C + beta * D
    Computes S = sum(D)
    """
    D[:] = alpha * A @ B @ C + beta * D
    S[0] = np.sum(D)


# Initialize scalar parameters and matrices

In [5]:
alpha, beta = 0.2, 1.2
A = np.ones((NI, NK))
B = np.ones((NK, NJ))
C = np.ones((NJ, NL))
D = np.ones((NI, NL))
gradient_A = np.zeros((NI, NK))
gradient_S = np.ones((1))
S = np.zeros((1))


# Convert DaCe program to an SDFG and save it

In [6]:
sdfg = k2mm.to_sdfg(alpha=alpha, beta=beta, A=A, B=B, C=C, D=D)
sdfg.save("log_sdfgs/k2mm_forward.sdfg")


'bd9e4008e341d48fc0fba4437169472f243f6ee1cc442100e35835f01ab6e0c4'

# Add backward pass and optimize

In [7]:
add_backward_pass(sdfg=sdfg, state=sdfg.states()[0], inputs=["A"], outputs=["S"])
sdfg.simplify()
sdfg_bwd_ao = auto_optimize(sdfg, device=DeviceType.CPU)
sdfg.save("log_sdfgs/k2mm_backward.sdfg")


'68262a02837e7c9d01092c506345a7e1ab48f5f06a262e988fb170952f563f85'

# Execute the SDFG

In [8]:
sdfg(alpha, beta, A, B, C, D, S, gradient_A=gradient_A, gradient_S=gradient_S)


# Define equivalent JAX function for comparison

In [9]:
def k2mm_jax(alpha, beta, A, B, C, D):
    """
    JAX implementation of the k2mm operation
    """
    return jnp.sum(alpha * A @ B @ C + beta * D)


# Compute gradients using JAX and compare

In [10]:
target_grad = jax.grad(k2mm_jax, argnums=[2])  # Compute gradient w.r.t A
A = jnp.ones((NI, NK))
B = jnp.ones((NK, NJ))
C = jnp.ones((NJ, NL))
D = jnp.ones((NI, NL))
gradient_A_jax = target_grad(alpha, beta, A, B, C, D)
print(gradient_A_jax)
assert np.allclose(gradient_A_jax, gradient_A)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
  setattr(self, word, getattr(machar, word).flat[0])
  return self._float_to_str(self.smallest_subnormal)


(Array([[302.4, 302.4, 302.4, ..., 302.4, 302.4, 302.4],
       [302.4, 302.4, 302.4, ..., 302.4, 302.4, 302.4],
       [302.4, 302.4, 302.4, ..., 302.4, 302.4, 302.4],
       ...,
       [302.4, 302.4, 302.4, ..., 302.4, 302.4, 302.4],
       [302.4, 302.4, 302.4, ..., 302.4, 302.4, 302.4],
       [302.4, 302.4, 302.4, ..., 302.4, 302.4, 302.4]], dtype=float32),)
