In [None]:
pip install jax



In [None]:
import jax
import jax.numpy as jnp
import time
from functools import partial
import jax.lax as lax
# ====== BLOQUES COMO FUNCIONES ======

def matmul_qkv(WQ, WK, WV, EI):
    Q = WQ @ EI
    K = WK @ EI
    V = WV @ EI
    return Q, K, V

def attention_per_head(K, Q, V, l, ddh, b, h):
    lb = l*b
    E2 = jnp.zeros((K.shape[0], lb), dtype=K.dtype)

    def process_single_head(E2, idx):
        j, i = idx // h, idx % h
        indJ = j*l
        indI2 = i*ddh
        # E1 = (K^T) @ Q

        K_slice = lax.dynamic_slice(K, (indI2, indJ), (ddh, l))  # (rows, cols)
        Q_slice = lax.dynamic_slice(Q, (indI2, indJ), (ddh, l))
        E1 = K_slice.transpose() @ Q_slice
        # E2 = V @ E1
        V_slice = lax.dynamic_slice(V, (indI2, indJ), (ddh, l))
        result = V_slice @ E1
        E2 = lax.dynamic_update_slice(E2, result, (indI2, indJ))

        return E2, None

    E2_init = jnp.zeros((d, lb), dtype=jnp.float32)
    idxs = jnp.arange(b * h)
    E2, _ = lax.scan(process_single_head, E2_init, idxs)

    return E2

def mha_output(WO, E2, EI):
    return EI + WO @ E2

def ffn_forward(W1, W2, AO):
    E3 = W1 @ AO
    EO = AO + W2 @ E3
    return EO

# ====== FUNCION PRINCIPAL ======
@partial(jax.jit, static_argnums=(0,1,2,3,4))
def transformer_block_jax(b, l, d, h, f, EI, WQ, WK, WV, WO, W1, W2):
    lb = l*b
    ddh = d // h

    Q, K, V = matmul_qkv(WQ, WK, WV, EI)
    E2 = attention_per_head(K, Q, V, l, ddh, b, h)
    AO = mha_output(WO, E2, EI)
    EO = ffn_forward(W1, W2, AO)

    return EO

def generate_transformer_block_jax(b, l, d, h, f, key):
    lb  = l*b
    k1, k2, k3, k4, k5, k6, k7 = jax.random.split(key, 7)

    EI = jax.random.uniform(k1, (d, lb), minval=-0.5, maxval=0.5)
    WQ = jax.random.uniform(k2, (d, d), minval=-0.5, maxval=0.5)
    WK = jax.random.uniform(k3, (d, d), minval=-0.5, maxval=0.5)
    WV = jax.random.uniform(k4, (d, d), minval=-0.5, maxval=0.5)
    WO = jax.random.uniform(k5, (d, d), minval=-0.5, maxval=0.5)
    W1 = jax.random.uniform(k6, (f, d), minval=-0.5, maxval=0.5)
    W2 = jax.random.uniform(k7, (d, f), minval=-0.5, maxval=0.5)

    return EI, WQ, WK, WV, WO, W1, W2

# ====== USO FINAL ======

b = 32
l = 128
d = 1024
h = 16
f = 4 * d
key = jax.random.PRNGKey(0)

start_compilado = time.time()
EI, WQ, WK, WV, WO, W1, W2 = generate_transformer_block_jax(b, l, d, h, f, key)
end_compilado = time.time()

print(f"Tiempo de compilación: {end_compilado-start_compilado:.6f} s")
start = time.time()
EO = transformer_block_jax(b, l, d, h, f, EI, WQ, WK, WV, WO, W1, W2)
end = time.time()

print(f"Tiempo de ejecución (compilado): {end-start:.6f} s")


Tiempo de compilación: 2.256962 s
Tiempo de ejecución (compilado): 0.659483 s
