<a href="https://colab.research.google.com/github/ark-saini/COW/blob/main/w_solver_functions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np

# -------------------------------
# Residual Computation
# -------------------------------
def compute_residual(Y, W_all, X_all, G_all, j_exclude):
    """
    Computes R^(j) = Y - sum_{k != j} G_k @ W_k @ x_k

    Parameters:
    - Y        : (N, T) Sensor observations
    - W_all    : list of (M_k, 1) Weight vectors for all ROIs
    - X_all    : (P, T) Source activity time series
    - G_all    : list of (N, M_k) Lead field matrices
    - j_exclude: Index of ROI to exclude

    Returns:
    - residual : (N, T) Matrix with ROI j's contribution removed
    """
    residual = Y.copy()
    for k, (G_k, W_k) in enumerate(zip(G_all, W_all)):
        if k == j_exclude:
            continue
        pattern = G_k @ W_k               # (N, 1)
        x_k = X_all[k, :].reshape(1, -1)  # (1, T)
        residual -= pattern @ x_k        # (N, T)
    return residual

# -------------------------------
# Compute A^(j) Matrix
# -------------------------------
def compute_A(G_j, x_j, Sigma_inv):
    """
    A^(j) = G_j.T @ Sigma_inv @ G_j * (x_j @ x_j.T)

    Parameters:
    - G_j       : (N, M_j) Lead field for ROI j
    - x_j       : (T,) Time series for ROI j
    - Sigma_inv : (N, N) Inverse noise covariance

    Returns:
    - A         : (M_j, M_j)
    """
    scale = np.dot(x_j, x_j)  # scalar
    A = scale * (G_j.T @ Sigma_inv @ G_j)
    return A

# -------------------------------
# Compute b^(j) Vector
# -------------------------------
def compute_b(G_j, x_j, R_j, Sigma_inv):
    """
    b^(j) = G_j.T @ Sigma_inv @ R_j @ x_j.T

    Parameters:
    - G_j       : (N, M_j)
    - x_j       : (T,)
    - R_j       : (N, T)
    - Sigma_inv : (N, N)

    Returns:
    - b         : (M_j, 1)
    """
    x_j = x_j.reshape(-1, 1)           # (T, 1)
    temp = R_j @ x_j                   # (N, 1)
    b = G_j.T @ Sigma_inv @ temp      # (M_j, 1)
    return b

# -------------------------------
# Update W^(j)
# -------------------------------
def update_W_j(G_j, x_j, Y, W_all, X_all, G_all, Sigma, Sigma_j, j):
    """
    Updates W^{(j)} by solving the regularized linear system.

    Parameters:
    - G_j       : (N, M_j) Lead field matrix for ROI j
    - x_j       : (T,) Time series of ROI j
    - Y         : (N, T) Sensor observations
    - W_all     : list of (M_k, 1) Weight vectors for all ROIs
    - X_all     : (P, T) Time series of all ROIs
    - G_all     : list of (N, M_k) Lead field matrices
    - Sigma     : (N, N) Measurement noise covariance
    - Sigma_j   : (M_j, M_j) Prior covariance for W^{(j)}
    - j         : ROI index

    Returns:
    - W_j       : (M_j, 1) Updated weight vector
    """
    Sigma_inv = np.linalg.pinv(Sigma)
    Sigma_j_inv = np.linalg.pinv(Sigma_j)
    R_j = compute_residual(Y, W_all, X_all, G_all, j)  # (N, T)
    A = compute_A(G_j, x_j, Sigma_inv)
    b = compute_b(G_j, x_j, R_j, Sigma_inv)
    return np.linalg.solve(A + Sigma_j_inv, b)

# -------------------------------
# Recursive W Solver
# -------------------------------
def recursive_W_solver(Y, X_all, G_all, Sigma, Sigma_j_all, num_iters=100, tol=1e-6):
    """
    Recursively updates all W^{(j)} using coordinate descent.

    Parameters:
    - Y             : (N, T) Sensor observations
    - X_all         : (P, T) Latent source activities
    - G_all         : list of (N, M_j) Lead field matrices
    - Sigma         : (N, N) Noise covariance
    - Sigma_j_all   : list of (M_j, M_j) Prior covariances for each ROI
    - num_iters     : Number of max iterations
    - tol           : Convergence threshold

    Returns:
    - W_all         : list of (M_j, 1) Estimated weight matrices
    """
    P = len(G_all)
    W_all = [np.zeros((G.shape[1], 1)) for G in G_all]
    for iteration in range(num_iters):
        max_change = 0
        for j in range(P):
            W_old = W_all[j].copy()
            W_new = update_W_j(G_all[j], X_all[j], Y, W_all, X_all, G_all, Sigma, Sigma_j_all[j], j)
            W_all[j] = W_new
            max_change = max(max_change, np.linalg.norm(W_new - W_old))
        print(f"Iteration {iteration+1}, Max Change: {max_change:.2e}")
        if max_change < tol:
            print(f"Converged at iteration {iteration+1} with max change {max_change:.2e}")
            break
    return W_all
