In [None]:
!pip install torch jax jaxlib



In [None]:
"""
Matrix multiplication forward and backward pass implementation.

And test correctness with PyTorch and JAX frameworks
"""

import numpy as np
import torch
from typing import Tuple

try:
    import jax
    import jax.numpy as jnp
    from jax import config as jax_config
    jax_config.update("jax_enable_x64", True)
    JAX_AVAILABLE = True
except Exception:
    JAX_AVAILABLE = False


In [None]:
def matmul(A: np.ndarray, B: np.ndarray) -> np.ndarray:
    """
    Args:
        A: Left matrix of shape (m, k)
        B: Right matrix of shape (k, n)

    Returns:
        Result matrix of shape (m, n)
    """
    return np.dot(A, B)


In [None]:
def forward(X: np.ndarray, W: np.ndarray) -> np.ndarray:
    """
    Computes Y = X @ W

    Args:
        X: Input matrix of shape (batch_size, input_dim)
        W: Weight matrix of shape (input_dim, output_dim)

    Returns:
        Output matrix of shape (batch_size, output_dim)
    """
    return matmul(X, W)


def backward(X: np.ndarray, W: np.ndarray, dY: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """
    Backward pass for matrix multiplication.

    Given the gradient of the loss with respect to the output (dY),
    computes the gradients with respect to the input (dX) and weights (dW).

    For Y = X @ W:
    - dW = # FINISH THIS
    - dX = # FINISH THIS

    Args:
        X: Input matrix of shape (batch_size, input_dim)
        W: Weight matrix of shape (input_dim, output_dim)
        dY: Gradient of loss w.r.t. output, shape (batch_size, output_dim)

    Returns:
        Tuple of (dW, dX) where:
        - dW: Gradient w.r.t. weights, shape (input_dim, output_dim)
        - dX: Gradient w.r.t. input, shape (batch_size, input_dim)
    """
    dW = np.ones_like(W) # TODO
    dX = np.ones_like(X) # TODO

    return dW, dX


In [None]:
def check() -> None:
    """
    Main function to test the forward and backward implementations.

    Creates random matrices, performs forward and backward passes manually,
    then validates the results against PyTorch's autograd.
    """
    # Set random seed for reproducibility
    np.random.seed(42)
    torch.manual_seed(42)

    # Create random input and weight matrices
    batch_size, input_dim, output_dim = 500, 1000, 100
    X = np.random.randn(batch_size, input_dim)
    W = np.random.randn(input_dim, output_dim)

    # Manual forward pass
    Y = forward(X, W)

    # Create random gradient for backward pass
    dY = np.random.randn(batch_size, output_dim)

    # Manual backward pass
    dW_manual, dX_manual = backward(X, W, dY)

    # PyTorch validation
    X_torch = torch.tensor(X, dtype=torch.float64, requires_grad=True)
    W_torch = torch.tensor(W, dtype=torch.float64, requires_grad=True)
    dY_torch = torch.tensor(dY, dtype=torch.float64)

    # PyTorch forward pass
    Y_torch = torch.matmul(X_torch, W_torch)

    # PyTorch backward pass
    Y_torch.backward(dY_torch)

    # Extract gradients
    assert W_torch.grad is not None, "W_torch.grad should not be None after backward()"
    assert X_torch.grad is not None, "X_torch.grad should not be None after backward()"
    dW_torch = W_torch.grad.detach().numpy()
    dX_torch = X_torch.grad.detach().numpy()

    if JAX_AVAILABLE:
        X_jax = jnp.asarray(X)
        W_jax = jnp.asarray(W)
        dY_jax = jnp.asarray(dY)

        # Forward
        Y_jax = jnp.matmul(X_jax, W_jax)

        def f(X_, W_):
            return jnp.matmul(X_, W_)

        Y_val, vjp_fun = jax.vjp(f, X_jax, W_jax)
        dX_jax, dW_jax = vjp_fun(dY_jax)


        Y_jax_np = np.asarray(Y_jax)
        dW_jax_np = np.asarray(dW_jax)
        dX_jax_np = np.asarray(dX_jax)

    # Compare results
    print("Gradient comparisons:")
    print(f"dW matches PyTorch: {np.allclose(dW_manual, dW_torch, rtol=1e-10)}")
    print(f"dX matches PyTorch: {np.allclose(dX_manual, dX_torch, rtol=1e-10)}")

    # Print maximum absolute differences
    print(f"\nMaximum absolute differences:")
    print(f"dW max diff: {np.max(np.abs(dW_manual - dW_torch))}")
    print(f"dX max diff: {np.max(np.abs(dX_manual - dX_torch))}")

    if JAX_AVAILABLE:
        print("\nJAX forward comparison:")
        print(f"Y (NumPy vs JAX) matches: {np.allclose(Y, Y_jax_np, rtol=1e-10, atol=0)}")

        print("\nGradient comparisons vs JAX:")
        print(f"dW (manual vs JAX) matches: {np.allclose(dW_manual, dW_jax_np, rtol=1e-10, atol=0)}")
        print(f"dX (manual vs JAX) matches: {np.allclose(dX_manual, dX_jax_np, rtol=1e-10, atol=0)}")

        print("\nMaximum absolute differences (manual vs JAX):")
        print(f"dW max diff: {np.max(np.abs(dW_manual - dW_jax_np))}")
        print(f"dX max diff: {np.max(np.abs(dX_manual - dX_jax_np))}")
    else:
        print("\n[JAX not available] Skipping JAX validation. Install jax & jaxlib to enable this check.")

In [None]:
check()

Gradient comparisons:
dW matches PyTorch: False
dX matches PyTorch: False

Maximum absolute differences:
dW max diff: 93.59434845453889
dX max diff: 47.27757285580921

JAX forward comparison:
Y (NumPy vs JAX) matches: True

Gradient comparisons vs JAX:
dW (manual vs JAX) matches: False
dX (manual vs JAX) matches: False

Maximum absolute differences (manual vs JAX):
dW max diff: 93.59434845453896
dX max diff: 47.27757285580921
