In [38]:
import numpy as np
 
def forward_sqrtn_checkpoint(W):
    n_leaves = len(W) + 1
    checkpoint_interval = int(np.sqrt(n_leaves))
    E = np.zeros((n_leaves, n_leaves), dtype=W.dtype)
    E_new = np.zeros((n_leaves, n_leaves), dtype=W.dtype)
    cache = {}
    cache[0] = np.zeros((1, 1), dtype=W.dtype)
 
    for i in range(1, n_leaves):
        E_before_sub = E[:i, :i]
        E_new.fill(0)
        w_row_sub = W[i - 1, :i]
 
        for k in range(i):
            for j in range(k):
                E_new[k, j] = E_before_sub[k, j] * (1 - 0.5 * (w_row_sub[j] + w_row_sub[k]))
            E_new[i, k] = 0.5 * np.sum(E_before_sub[:i, k] * w_row_sub) + 0.25 * w_row_sub[k]
 
        E = E_new + E_new.T

        if i % checkpoint_interval == 0 or i == n_leaves - 1:
            cache[i] = E[:i, :i].copy()
 
    return E, cache

In [39]:
import numpy as np

def recompute_E_until(W, i, cache):
    # Find nearest checkpoint <= i
    nearest_cp = max(k for k in cache if k <= i)
    E = cache[nearest_cp].copy()

    # Recompute E from checkpoint up to i
    for step in range(nearest_cp + 1, i + 1):
        E_prev = np.zeros((step, step), dtype=W.dtype)
        E_prev[:E.shape[0], :E.shape[1]] = E
        w_row_sub = W[step - 1, :step]  # shape = (step,)
        E_new = np.zeros((step + 1, step + 1), dtype=W.dtype)

        for k in range(step):
            for j in range(k):
                E_new[k, j] = E_prev[k, j] * (1 - 0.5 * (w_row_sub[j] + w_row_sub[k]))

            # Debug shapes if needed
            if E_prev[:step, k].shape != w_row_sub[:step].shape:
                print(f"Shape mismatch at step={step}, k={k}")
                print("E_prev[:step, k].shape:", E_prev[:step, k].shape)
                print("w_row_sub[:step].shape:", w_row_sub[:step].shape)

            E_new[step, k] = (
                0.5 * np.sum(E_prev[:step, k] * w_row_sub[:step])
                + 0.25 * w_row_sub[k]
            )

        E = E_new + E_new.T 

    return E[:i, :i]

In [40]:
def backward_sqrtn_checkpoint(W, D, cache):
    n_leaves = len(W) + 1
    dW = np.zeros_like(W)
    dE = D.copy()
 
    for j in range(n_leaves - 2, -1, -1):
        i = j + 1
        i_idx = i - 1
        w_row_sub = W[i_idx, :i]
 
        if i in cache:
            E_before_sub = cache[i]
        else:
            E_before_sub = recompute_E_until(W, i, cache)
 
        dE_new = dE + dE.T
        dE_before_accum_sub = np.zeros_like(E_before_sub, dtype=W.dtype)
 
        for k_ in range(i):
            for j_ in range(k_):
                dval = dE_new[k_, j_]
                dW[i_idx, j_] += dval * (-0.5 * E_before_sub[k_, j_])
                dW[i_idx, k_] += dval * (-0.5 * E_before_sub[k_, j_])
                dE_before_accum_sub[k_, j_] += dval * (1 - 0.5 * (w_row_sub[j_] + w_row_sub[k_]))
 
        for k_ in range(i):
            dval = dE_new[i, k_]
            for m_ in range(i):
                dW[i_idx, m_] += dval * (0.5 * E_before_sub[m_, k_])
                dE_before_accum_sub[m_, k_] += dval * (0.5 * w_row_sub[m_])
            dW[i_idx, k_] += dval * 0.25
 
        dE[:i, :i] = dE_before_accum_sub
        dE[i, :i] = 0
        dE[:i, i] = 0
 
    return dW

In [41]:
W = np.random.rand(7, 7).astype(np.float32)  
E, cache = forward_sqrtn_checkpoint(W)
 
D = np.random.rand(8, 8).astype(np.float32) 
dW = backward_sqrtn_checkpoint(W, D, cache)
