In [1]:
"""Verify the gradient derivation of linear masked autoencoder."""
import torch
import torch.nn as nn

In [2]:
# Initialise matrix and vector
m = 50
n = 10
prob = 0.6
X = torch.rand(m, n)
W1 = torch.rand(n, n, requires_grad=True)
W2 = torch.rand(n, n, requires_grad=True)

In [25]:
# Define different types of masks
def mask_basic(prob, m, n):
    return torch.zeros(m, n).bernoulli_(prob)

def mask_dropping_probs(prob_list: torch.Tensor, m, n):
    return torch.zeros(m, n).bernoulli_(prob_list)

def mask_patches(prob, patch_size, m, n):
    if not n % patch_size:
        pix_num = n // patch_size
        mat_patches = torch.zeros(m, pix_num).bernoulli_(prob)
        return mat_patches.repeat_interleave(patch_size, dim=1)
    else:
        NotImplementedError

In [13]:
# define loss function in terms of W1 and W2
def loss_func_W1_basic(W1):
    z = (mask_basic(prob, m, n)*X) @ W1.T @ W2.T - X
    return sum(sum(z*z)) / m / n

def loss_func_W2_basic(W2):
    z = (mask_basic(prob, m, n)*X) @ W1.T @ W2.T - X
    return sum(sum(z*z)) / m / n

In [14]:
# find the theoretical and numerical solutions of W1 and W2
# mask_basic
mean_m_basic = torch.ones(m, n) * prob
square_m_basic = torch.ones(n, n) * prob**2
square_m_basic.fill_diagonal_(prob)

grad_w1_theory_basic = W2.T @ (W2@W1@(square_m_basic*(X.T@X)) - X.T@(mean_m_basic*X)) * (2/m/n)
grad_w2_theory_basic = (W2@W1@(square_m_basic*(X.T@X))-X.T@(mean_m_basic*X)) @ W1.T * (2/m/n)

# Initialising sampling
N = 1000
grad_w1_numer_basic = 0
grad_w2_numer_basic = 0
# Sampling process
for i in range(N):
    grad_w1_numer_basic += torch.autograd.functional.jacobian(loss_func_W1_basic, W1)
    grad_w2_numer_basic += torch.autograd.functional.jacobian(loss_func_W2_basic, W2)
    W1.detach()
    W2.detach()
grad_w1_numer_basic / N
grad_w2_numer_basic / N

grad_w1_numer_basic
#norm_diff_w1 = torch.linalg.matrix_norm(grad_w1-grad_w1_true)
#norm_diff_w2 = torch.linalg.matrix_norm(grad_w2-grad_w2_true)

#print("The matrix norm of the difference between theoretical and numerical solutions of W1:", norm_diff_w1.item())
#print("The matrix norm of the difference between theoretical and numerical solutions of W2:", norm_diff_w2.item())

tensor([[2928.7549, 2423.7131, 3062.7380, 2925.0562, 2736.2166, 2715.0981,
         3384.3022, 2748.9661, 2966.3354, 3231.7920],
        [2167.0989, 1778.6750, 2264.7625, 2161.1702, 2015.5416, 1990.7881,
         2496.6406, 2036.9464, 2189.6941, 2373.5215],
        [2415.8381, 1988.2748, 2549.7305, 2422.0857, 2263.9260, 2229.2139,
         2812.9221, 2285.7661, 2450.0088, 2674.1602],
        [3371.6895, 2769.1604, 3539.1252, 3373.8645, 3136.0845, 3125.5830,
         3912.2446, 3169.8306, 3409.8096, 3734.4739],
        [3059.6621, 2527.8501, 3202.0374, 3057.0481, 2868.6990, 2810.4558,
         3517.7339, 2880.7788, 3097.1609, 3357.7683],
        [2402.6104, 1984.2102, 2527.0803, 2401.4229, 2239.8945, 2238.6538,
         2805.3030, 2263.0488, 2441.6416, 2667.9233],
        [2034.5262, 1676.0745, 2123.3013, 2029.9707, 1896.3358, 1872.0820,
         2334.2024, 1908.7225, 2055.7942, 2236.1763],
        [2345.4126, 1942.9313, 2443.9529, 2342.9968, 2190.1743, 2162.2864,
         2695.7705, 21

In [15]:
# define loss function in terms of W1 and W2
def loss_func_W1_probs(W1):
    z = (mask_dropping_probs(prob, m, n)*X) @ W1.T @ W2.T - X
    return sum(sum(z*z)) / m / n

def loss_func_W2_probs(W2):
    z = (mask_dropping_probs(prob, m, n)*X) @ W1.T @ W2.T - X
    return sum(sum(z*z)) / m / n

In [16]:
# mask_dropping_probs
prob_list = torch.rand(n)
mean_m_probs = prob_list.repeat(m, 1)
square_m_probs = prob_list.view(n, 1) @ prob_list.view(1, n)
square_m_probs = square_m_probs.fill_diagonal_(0) + torch.diag(prob_list)

grad_w1_theory_probs = W2.T @ (W2@W1@(square_m_probs*(X.T@X)) - X.T@(mean_m_probs*X)) * (2/m/n)
grad_w2_theory_probs = (W2@W1@(square_m_probs*(X.T@X))-X.T@(mean_m_probs*X)) @ W1.T * (2/m/n)

# Initialising sampling
N = 1000
grad_w1_numer_probs = 0
grad_w2_numer_probs = 0
# Sampling process
for i in range(N):
    grad_w1_numer_probs += torch.autograd.functional.jacobian(loss_func_W1_probs, W1)
    grad_w2_numer_probs += torch.autograd.functional.jacobian(loss_func_W2_probs, W2)
grad_w1_numer_probs / N
grad_w2_numer_probs / N

abs(grad_w1_theory_probs-grad_w1_numer_probs)

tensor([[2940.8464, 2433.2715, 3076.4978, 2943.4871, 2774.6506, 2761.4392,
         3396.5461, 2729.7273, 2978.4912, 3248.4111],
        [2175.8208, 1785.6799, 2274.4841, 2174.8479, 2043.7139, 2024.8901,
         2505.4998, 2022.4923, 2198.7373, 2385.6819],
        [2425.6426, 1995.9534, 2560.5955, 2437.2378, 2295.8110, 2267.5999,
         2822.8616, 2269.2683, 2460.1506, 2687.5793],
        [3385.2117, 2780.2659, 3554.6824, 3394.9675, 3180.3481, 3179.1077,
         3926.0044, 3147.3542, 3423.8721, 3753.2332],
        [3072.4275, 2537.5488, 3215.7354, 3076.3008, 2908.6423, 2858.6912,
         3530.5283, 2860.4102, 3109.9451, 3375.1941],
        [2412.2893, 1992.0916, 2538.6189, 2416.5273, 2271.6750, 2276.9592,
         2815.2249, 2247.0300, 2451.6768, 2681.3105],
        [2042.8871, 1682.6747, 2132.5168, 2042.7648, 1922.7485, 1904.0958,
         2342.5352, 1895.3468, 2064.2466, 2247.7844],
        [2355.2891, 1950.6976, 2454.7852, 2357.8203, 2220.8254, 2199.2422,
         2705.7156, 21

In [23]:
# define loss function in terms of W1 and W2
def loss_func_W1_patches(W1):
    z = (mask_patches(prob, patch_size, m, n)*X) @ W1.T @ W2.T - X
    return sum(sum(z*z)) / m / n

def loss_func_W2_patches(W2):
    z = (mask_patches(prob, patch_size, m, n)*X) @ W1.T @ W2.T - X
    return sum(sum(z*z)) / m / n

In [26]:
# mask_patches
patch_size = 2
mean_m_patches = torch.ones(m, n) * prob
square_m_patches = torch.ones(n, n) * prob**2
for i in range(n):
    place_val = i // patch_size
    square_m_patches[place_val:place_val+patch_size, place_val:place_val+patch_size] = torch.ones(patch_size, patch_size) * prob

grad_w1_theory_patches = W2.T @ (W2@W1@(square_m_patches*(X.T@X)) - X.T@(mean_m_patches*X)) * (2/m/n)
grad_w2_theory_patches = (W2@W1@(square_m_patches*(X.T@X))-X.T@(mean_m_patches*X)) @ W1.T * (2/m/n)

# Initialising sampling
N = 1000
grad_w1_numer_patches = 0
grad_w2_numer_patches = 0
# Sampling process
for i in range(N):
    grad_w1_numer_patches += torch.autograd.functional.jacobian(loss_func_W1_patches, W1)
    grad_w2_numer_patches += torch.autograd.functional.jacobian(loss_func_W2_patches, W2)
grad_w1_numer_patches / N
grad_w2_numer_patches / N

abs(grad_w1_theory_patches-grad_w1_numer_patches)

tensor([[3055.3391, 2584.1592, 3207.7812, 3145.1729, 2969.8660, 2936.4583,
         3567.0972, 2904.6987, 3131.7996, 3385.6895],
        [2255.5510, 1899.5037, 2372.1047, 2327.7117, 2185.0222, 2153.2278,
         2634.0022, 2153.7358, 2307.0090, 2483.9072],
        [2511.2561, 2119.6418, 2673.9436, 2611.7866, 2448.3037, 2410.2981,
         2967.5339, 2419.1169, 2591.7263, 2800.4080],
        [3506.1997, 2953.7954, 3708.0137, 3634.6118, 3403.2258, 3372.7830,
         4120.5054, 3355.3027, 3606.8003, 3905.8286],
        [3190.8740, 2699.5879, 3356.3948, 3290.2515, 3104.3398, 3047.8022,
         3716.5864, 3037.8496, 3265.4751, 3517.6233],
        [2500.7712, 2112.6472, 2646.3867, 2586.7690, 2433.3599, 2413.7000,
         2952.4780, 2400.4111, 2582.7148, 2795.5845],
        [2121.7808, 1790.3220, 2224.9912, 2182.6672, 2055.8218, 2026.7932,
         2464.1260, 2012.4594, 2168.9951, 2341.5400],
        [2449.1279, 2074.3503, 2560.9006, 2518.9949, 2375.2085, 2343.0017,
         2842.0945, 23