### Setup

In [1]:
"""Verify the gradient derivation of linear masked autoencoder."""
import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
# Initialise matrix and vector
# prob is un-masking prob
sample_num = 3
H = 4
W = 1
sample_dim = [H, W]
feature_num = H * W

m = sample_num
n = feature_num
prob = 0.75
X = torch.rand(m, n)
W1 = torch.rand(n, n, requires_grad=True)
W2 = torch.rand(n, n, requires_grad=True)
#print(X)
#print(W1)
#print(W2)

In [3]:
# Define different types of masks
def mask_basic(prob, sample_num, feature_num):
    return torch.zeros(sample_num, feature_num).bernoulli_(prob)

def mask_dropping_probs(prob_list: torch.Tensor, sample_num, feature_num):
    return torch.zeros(sample_num, feature_num).bernoulli_(prob_list)

def mask_patches(prob, patch_size, sample_num, sample_dim):
    patch_size = torch.tensor(patch_size)
    sample_dim = torch.tensor(sample_dim)
    feature_num = sample_dim[0]*sample_dim[1]
    div_check = sample_dim % patch_size == torch.zeros(2)
    if torch.all(div_check):
        pix_num = torch.div(sample_dim, patch_size, rounding_mode='floor')
        mat_patches = torch.zeros(sample_num, *pix_num).bernoulli_(prob)
        mat_patches = torch.repeat_interleave(mat_patches, patch_size[1], dim=2)
        return mat_patches.repeat_interleave(patch_size[0], dim=1).view(sample_num, feature_num)
    else:
        raise NotImplementedError(f"Both height ({H}) and width ({W}) should be divisible by patch_size ({patch_size}).")

In [4]:
# Masked autoencoder (linear)
class M_LAE(nn.Module):
    def __init__(self, prob, sample_num, sample_dim, type='basic', patch_size=None):
        super(M_LAE, self).__init__()
        self.prob = prob
        self.m = sample_num
        self.sample_dim = sample_dim
        self.H, self.W = sample_dim
        self.n = self.H * self.W
        if type not in ['basic', 'probs', 'patches']:
            raise NotImplementedError("Could only implement 'basic', 'probs' and 'patches' type of masking.")
        else:
            self.masking_type = type
        if patch_size is not None:
            self.patch_size = patch_size
        w1 = nn.Linear(self.n, self.n, bias=False)
        w2 = nn.Linear(self.n, self.n, bias=False)
        self.body = nn.Sequential(*[w1, w2])
    
    def forward(self, X, mask=None):
        if mask is None:
            if self.masking_type == 'basic':
                mask = mask_basic(self.prob, self.m, self.n)
            elif self.masking_type == 'probs':
                mask = mask_dropping_probs(self.prob, self.m, self.n)
            elif self.masking_type == 'patches':
                mask = mask_patches(self.prob, self.patch_size, self.m, self.sample_dim)
        Y = mask * X
        Y = self.body(Y)
        return Y

### Basic Masking

In [5]:
# define loss function in terms of W1 and W2
def loss_func_W1_basic(W1):
    z = (mask_basic(prob, m, n)*X) @ W1.T @ W2.T - X
    return sum(sum(z*z)) / m / n

def loss_func_W2_basic(W2):
    z = (mask_basic(prob, m, n)*X) @ W1.T @ W2.T - X
    return sum(sum(z*z)) / m / n

In [6]:
# find the theoretical and numerical solutions of W1 and W2
# mask_basic
mean_m_basic = torch.ones(m, n) * prob
square_m_basic = torch.ones(n, n) * prob**2
square_m_basic.fill_diagonal_(prob)

grad_w1_theory_basic = W2.T @ (W2@W1@(square_m_basic*(X.T@X)) - X.T@(mean_m_basic*X)) * (2/m/n)
grad_w2_theory_basic = (W2@W1@(square_m_basic*(X.T@X))-X.T@(mean_m_basic*X)) @ W1.T * (2/m/n)

# Initialising sampling
N = 10000
grad_w1_numer_basic = 0
grad_w2_numer_basic = 0
# Sampling process
for i in range(N):
    grad_w1_numer_basic += torch.autograd.functional.jacobian(loss_func_W1_basic, W1)
    grad_w2_numer_basic += torch.autograd.functional.jacobian(loss_func_W2_basic, W2)
    #W1.detach()
    #W2.detach()
grad_w1_numer_basic /= N
grad_w2_numer_basic /= N

# absolute value check
abs(grad_w1_theory_basic-grad_w1_numer_basic)

tensor([[0.0040, 0.0010, 0.0056, 0.0029],
        [0.0043, 0.0010, 0.0055, 0.0030],
        [0.0036, 0.0008, 0.0048, 0.0025],
        [0.0047, 0.0011, 0.0071, 0.0036]], grad_fn=<AbsBackward0>)

In [7]:
# matrix norm check
norm_diff_w1_basic = torch.linalg.matrix_norm(grad_w1_numer_basic-grad_w1_theory_basic)
norm_diff_w2_basic = torch.linalg.matrix_norm(grad_w2_numer_basic-grad_w2_theory_basic)

print("The matrix norm of the difference between theoretical and numerical solutions of W1:", norm_diff_w1_basic.item())
print("The matrix norm of the difference between theoretical and numerical solutions of W2:", norm_diff_w2_basic.item())

The matrix norm of the difference between theoretical and numerical solutions of W1: 0.015652021393179893
The matrix norm of the difference between theoretical and numerical solutions of W2: 0.00986055750399828


##### gradient decent check

In [8]:
# Initialising network
learning_rate = 0.01
criterion = nn.MSELoss()

m_net = M_LAE(prob, m, sample_dim)
inputs = X
targets = X

optimizer = optim.SGD(m_net.body.parameters(), lr=learning_rate)
params0 = list(m_net.body.parameters())
#print(list(m_net.body.parameters()))
W10 = params0[0].clone().detach()
W20 = params0[1].clone().detach()

# theoretical gradients
grad_w1_theory_basic = W20.T @ (W20@W10@(square_m_basic*(X.T@X)) - X.T@(mean_m_basic*X)) * (2/m/n)
grad_w2_theory_basic = (W20@W10@(square_m_basic*(X.T@X))-X.T@(mean_m_basic*X)) @ W10.T * (2/m/n)

# autograd gradients
grad_w1_numer_basic = torch.autograd.functional.jacobian(loss_func_W1_basic, W10)
grad_w2_numer_basic = torch.autograd.functional.jacobian(loss_func_W2_basic, W20)

# one-step gradient decent
optimizer.zero_grad()
outputs = m_net(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()

params1 = list(m_net.body.parameters())
#print(list(m_net.body.parameters()))
W11 = params1[0].clone().detach()
W21 = params1[1].clone().detach()

gradient_W1 = (W10 - W11) / learning_rate
gradient_W2 = (W20 - W21) / learning_rate
print('The difference between autograd and gradient decent for w1\n', gradient_W1-grad_w1_numer_basic)
print('The difference between autograd and gradient decent for w2\n', gradient_W2-grad_w2_numer_basic)
print('difference in terms of autograd for w1\n', grad_w1_numer_basic-grad_w1_theory_basic)
print('difference in terms of gradient decent for w1\n', gradient_W1-grad_w1_theory_basic)
print('difference in terms of autograd for w2\n', grad_w1_numer_basic-grad_w1_theory_basic)
print('difference in terms of gradient decent for w2\n', gradient_W2-grad_w2_theory_basic)

The difference between autograd and gradient decent for w1
 tensor([[0.7996, 0.1777, 0.3846, 0.4695],
        [0.6469, 0.1431, 0.4264, 0.4093],
        [0.5162, 0.1140, 0.1997, 0.2960],
        [0.9573, 0.1967, 0.5212, 0.5788]])
The difference between autograd and gradient decent for w2
 tensor([[0.7163, 0.7473, 1.0426, 0.5126],
        [0.6173, 0.6946, 0.8368, 0.3907],
        [0.4021, 0.4329, 0.6343, 0.2957],
        [0.1292, 0.0326, 0.3160, 0.1945]])
difference in terms of autograd for w1
 tensor([[-0.8455, -0.2098, -0.3956, -0.4692],
        [-0.7161, -0.2059, -0.3667, -0.4150],
        [-0.4824, -0.0917, -0.1752, -0.2933],
        [-0.9960, -0.2233, -0.4796, -0.5571]])
difference in terms of gradient decent for w1
 tensor([[-0.0459, -0.0321, -0.0110,  0.0003],
        [-0.0692, -0.0628,  0.0597, -0.0058],
        [ 0.0338,  0.0223,  0.0245,  0.0028],
        [-0.0388, -0.0266,  0.0416,  0.0217]])
difference in terms of autograd for w2
 tensor([[-0.8455, -0.2098, -0.3956, -0.4692],

### Masking with different rate

In [9]:
# define loss function in terms of W1 and W2
def loss_func_W1_probs(W1):
    z = (mask_dropping_probs(prob_list, m, n)*X) @ W1.T @ W2.T - X
    return sum(sum(z*z)) / m / n

def loss_func_W2_probs(W2):
    z = (mask_dropping_probs(prob_list, m, n)*X) @ W1.T @ W2.T - X
    return sum(sum(z*z)) / m / n

In [10]:
# mask_dropping_probs
#prob_list = torch.arange(1/(n+1), 1-0.01, 1/(n+1))
#random.shuffle(prob_list)
prob_list = torch.rand(n)*0.1 + 0.2
mean_m_probs = prob_list.repeat(m, 1)
square_m_probs = prob_list.view(n, 1) @ prob_list.view(1, n)
square_m_probs = square_m_probs.fill_diagonal_(0) + torch.diag(prob_list)

grad_w1_theory_probs = W2.T @ (W2@W1@(square_m_probs*(X.T@X)) - X.T@(mean_m_probs*X)) * (2/m/n)
grad_w2_theory_probs = (W2@W1@(square_m_probs*(X.T@X))-X.T@(mean_m_probs*X)) @ W1.T * (2/m/n)

# Initialising sampling
N = 10000
grad_w1_numer_probs = 0
grad_w2_numer_probs = 0
# Sampling process
for i in range(N):
    grad_w1_numer_probs += torch.autograd.functional.jacobian(loss_func_W1_probs, W1)
    grad_w2_numer_probs += torch.autograd.functional.jacobian(loss_func_W2_probs, W2)
grad_w1_numer_probs /= N
grad_w2_numer_probs /= N

abs(grad_w1_theory_probs-grad_w1_numer_probs)

tensor([[0.0028, 0.0002, 0.0015, 0.0004],
        [0.0028, 0.0003, 0.0014, 0.0002],
        [0.0024, 0.0001, 0.0012, 0.0004],
        [0.0038, 0.0003, 0.0019, 0.0006]], grad_fn=<AbsBackward0>)

In [11]:
norm_diff_w1_probs = torch.linalg.matrix_norm(grad_w1_numer_probs-grad_w1_theory_probs)
norm_diff_w2_probs = torch.linalg.matrix_norm(grad_w2_numer_probs-grad_w2_theory_probs)

print("The matrix norm of the difference between theoretical and numerical solutions of W1:", norm_diff_w1_probs.item())
print("The matrix norm of the difference between theoretical and numerical solutions of W2:", norm_diff_w2_probs.item())

The matrix norm of the difference between theoretical and numerical solutions of W1: 0.006764405872672796
The matrix norm of the difference between theoretical and numerical solutions of W2: 0.001338590169325471


##### Small mask test

In [12]:
"""
m = 1
n = 2
X = torch.tensor([[2, 3]]).float()
W1 = torch.tensor([[1, 0], [1, 1]]).float()
W2 = torch.tensor([[1, -1], [0, 1]]).float()

prob_list = torch.tensor([0.5, 0.8])
mean_m_probs = prob_list.repeat(1, 1)
square_m_probs = prob_list.view(2, 1) @ prob_list.view(1, 2)
square_m_probs = square_m_probs.fill_diagonal_(0) + torch.diag(prob_list)

grad_w1_theory_probs = W2.T @ (W2@W1@(square_m_probs*(X.T@X)) - X.T@(mean_m_probs*X))
grad_w2_theory_probs = (W2@W1@(square_m_probs*(X.T@X))-X.T@(mean_m_probs*X)) @ W1.T

print(grad_w1_theory_probs)
print(torch.autograd.functional.jacobian(loss_func_W1_probs, W1))
"""

'\nm = 1\nn = 2\nX = torch.tensor([[2, 3]]).float()\nW1 = torch.tensor([[1, 0], [1, 1]]).float()\nW2 = torch.tensor([[1, -1], [0, 1]]).float()\n\nprob_list = torch.tensor([0.5, 0.8])\nmean_m_probs = prob_list.repeat(1, 1)\nsquare_m_probs = prob_list.view(2, 1) @ prob_list.view(1, 2)\nsquare_m_probs = square_m_probs.fill_diagonal_(0) + torch.diag(prob_list)\n\ngrad_w1_theory_probs = W2.T @ (W2@W1@(square_m_probs*(X.T@X)) - X.T@(mean_m_probs*X))\ngrad_w2_theory_probs = (W2@W1@(square_m_probs*(X.T@X))-X.T@(mean_m_probs*X)) @ W1.T\n\nprint(grad_w1_theory_probs)\nprint(torch.autograd.functional.jacobian(loss_func_W1_probs, W1))\n'

##### gradient decent

In [13]:
# Initialising network
learning_rate = 0.01
criterion = nn.MSELoss()

m_net = M_LAE(prob_list, m, sample_dim, type='probs')
inputs = X
targets = X

optimizer = optim.SGD(m_net.body.parameters(), lr=learning_rate)
params0 = list(m_net.body.parameters())
#print(list(m_net.body.parameters()))
W10 = params0[0].clone().detach()
W20 = params0[1].clone().detach()

# theoretical gradients
grad_w1_theory_probs = W2.T @ (W2@W1@(square_m_probs*(X.T@X)) - X.T@(mean_m_probs*X)) * (2/m/n)
grad_w2_theory_probs = (W2@W1@(square_m_probs*(X.T@X))-X.T@(mean_m_probs*X)) @ W1.T * (2/m/n)
# autograd gradients
grad_w1_numer_probs = torch.autograd.functional.jacobian(loss_func_W1_probs, W1)
grad_w2_numer_probs = torch.autograd.functional.jacobian(loss_func_W2_probs, W2)

# one-step gradient decent
optimizer.zero_grad()
outputs = m_net(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()

params1 = list(m_net.body.parameters())
#print(list(m_net.body.parameters()))
W11 = params1[0].clone().detach()
W21 = params1[1].clone().detach()

gradient_W1 = (W10 - W11) / learning_rate
gradient_W2 = (W20 - W21) / learning_rate
print('The difference between autograd and gradient decent for w1\n', gradient_W1-grad_w1_numer_basic)
print('The difference between autograd and gradient decent for w2\n', gradient_W2-grad_w2_numer_basic)
print('difference in terms of autograd for w1\n', grad_w1_numer_basic-grad_w1_theory_basic)
print('difference in terms of gradient decent for w1\n', gradient_W1-grad_w1_theory_basic)
print('difference in terms of autograd for w2\n', grad_w1_numer_basic-grad_w1_theory_basic)
print('difference in terms of gradient decent for w2\n', gradient_W2-grad_w2_theory_basic)

The difference between autograd and gradient decent for w1
 tensor([[0.7237, 0.1594, 0.3120, 0.4054],
        [0.4711, 0.1159, 0.2449, 0.2504],
        [0.4813, 0.1316, 0.2401, 0.3043],
        [0.9083, 0.2050, 0.4371, 0.4898]])
The difference between autograd and gradient decent for w2
 tensor([[0.7400, 0.8676, 0.9506, 0.4633],
        [0.6329, 0.7573, 0.8233, 0.4025],
        [0.4337, 0.5606, 0.5946, 0.2979],
        [0.1425, 0.1463, 0.1684, 0.0825]])
difference in terms of autograd for w1
 tensor([[-0.8455, -0.2098, -0.3956, -0.4692],
        [-0.7161, -0.2059, -0.3667, -0.4150],
        [-0.4824, -0.0917, -0.1752, -0.2933],
        [-0.9960, -0.2233, -0.4796, -0.5571]])
difference in terms of gradient decent for w1
 tensor([[-0.1218, -0.0504, -0.0836, -0.0638],
        [-0.2449, -0.0900, -0.1218, -0.1646],
        [-0.0011,  0.0399,  0.0649,  0.0111],
        [-0.0877, -0.0183, -0.0425, -0.0673]])
difference in terms of autograd for w2
 tensor([[-0.8455, -0.2098, -0.3956, -0.4692],

### Masking in terms of patches

In [14]:
# define loss function in terms of W1 and W2
def loss_func_W1_patches(W1):
    z = (mask_patches(prob, patch_size, m, sample_dim)*X) @ W1.T @ W2.T - X
    return sum(sum(z*z)) / m / n

def loss_func_W2_patches(W2):
    z = (mask_patches(prob, patch_size, m, sample_dim)*X) @ W1.T @ W2.T - X
    return sum(sum(z*z)) / m / n

In [15]:
# mask_patches
mean_m_patches = torch.ones(m, n) * prob
patch_size = [2, 1]
pix_num = torch.div(torch.tensor(sample_dim), torch.tensor(patch_size), rounding_mode='floor')
mat_patches = torch.arange(pix_num[0]*pix_num[1]).view(*pix_num)
mat_patches = torch.repeat_interleave(mat_patches, patch_size[1], dim=1)
mat_patches = torch.repeat_interleave(mat_patches, patch_size[0], dim=0).view(n)
square_m_patches = torch.zeros(n, n)
for i in range(n):
    for j in range(n):
        if mat_patches[i] == mat_patches[j]:
            square_m_patches[i, j] = prob
        else:
            square_m_patches[i, j] = prob**2

grad_w1_theory_patches = W2.T @ (W2@W1@(square_m_patches*(X.T@X)) - X.T@(mean_m_patches*X)) * (2/m/n)
grad_w2_theory_patches = (W2@W1@(square_m_patches*(X.T@X))-X.T@(mean_m_patches*X)) @ W1.T * (2/m/n)

# Initialising sampling
N = 10000
grad_w1_numer_patches = 0
grad_w2_numer_patches = 0
# Sampling process
for i in range(N):
    grad_w1_numer_patches += torch.autograd.functional.jacobian(loss_func_W1_patches, W1)
    grad_w2_numer_patches += torch.autograd.functional.jacobian(loss_func_W2_patches, W2)
grad_w1_numer_patches /= N
grad_w2_numer_patches /= N

abs(grad_w1_theory_patches-grad_w1_numer_patches)

tensor([[0.0029, 0.0012, 0.0004, 0.0020],
        [0.0029, 0.0011, 0.0004, 0.0020],
        [0.0026, 0.0011, 0.0004, 0.0016],
        [0.0039, 0.0016, 0.0004, 0.0027]], grad_fn=<AbsBackward0>)

In [16]:
norm_diff_w1_patches = torch.linalg.matrix_norm(grad_w1_numer_patches-grad_w1_theory_patches)
norm_diff_w2_patches = torch.linalg.matrix_norm(grad_w2_numer_patches-grad_w2_theory_patches)

print("The matrix norm of the difference between theoretical and numerical solutions of W1:", norm_diff_w1_patches.item())
print("The matrix norm of the difference between theoretical and numerical solutions of W2:", norm_diff_w2_patches.item())

The matrix norm of the difference between theoretical and numerical solutions of W1: 0.007966650649905205
The matrix norm of the difference between theoretical and numerical solutions of W2: 0.0035006017424166203


##### gradient decent

In [18]:
# Initialising network
learning_rate = 0.01
criterion = nn.MSELoss()

m_net = M_LAE(prob, m, sample_dim, type='patches', patch_size=[1, 1])
inputs = X
targets = X

optimizer = optim.SGD(m_net.body.parameters(), lr=learning_rate)
params0 = list(m_net.body.parameters())
#print(list(m_net.body.parameters()))
W10 = params0[0].clone().detach()
W20 = params0[1].clone().detach()

# theoretical gradients
grad_w1_theory_patches = W2.T @ (W2@W1@(square_m_patches*(X.T@X)) - X.T@(mean_m_patches*X)) * (2/m/n)
grad_w2_theory_patches = (W2@W1@(square_m_patches*(X.T@X))-X.T@(mean_m_patches*X)) @ W1.T * (2/m/n)
# autograd gradients
grad_w1_numer_patches = torch.autograd.functional.jacobian(loss_func_W1_patches, W1)
grad_w2_numer_patches = torch.autograd.functional.jacobian(loss_func_W2_patches, W2)

# one-step gradient decent
optimizer.zero_grad()
outputs = m_net(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()

params1 = list(m_net.body.parameters())
#print(list(m_net.body.parameters()))
W11 = params1[0].clone().detach()
W21 = params1[1].clone().detach()

gradient_W1 = (W10 - W11) / learning_rate
gradient_W2 = (W20 - W21) / learning_rate
print('The difference between autograd and gradient decent for w1\n', gradient_W1-grad_w1_numer_basic)
print('The difference between autograd and gradient decent for w2\n', gradient_W2-grad_w2_numer_basic)
print('difference in terms of autograd for w1\n', grad_w1_numer_basic-grad_w1_theory_basic)
print('difference in terms of gradient decent for w1\n', gradient_W1-grad_w1_theory_basic)
print('difference in terms of autograd for w2\n', grad_w1_numer_basic-grad_w1_theory_basic)
print('difference in terms of gradient decent for w2\n', gradient_W2-grad_w2_theory_basic)

The difference between autograd and gradient decent for w1
 tensor([[0.6982, 0.1474, 0.3014, 0.3718],
        [0.5052, 0.0978, 0.2262, 0.2680],
        [0.7942, 0.2380, 0.3881, 0.4633],
        [0.7188, 0.0699, 0.2477, 0.3994]])
The difference between autograd and gradient decent for w2
 tensor([[0.6646, 0.9037, 0.8835, 0.5673],
        [0.5958, 0.7679, 0.7840, 0.4498],
        [0.3888, 0.5327, 0.5511, 0.3683],
        [0.1078, 0.1927, 0.1291, 0.1451]])
difference in terms of autograd for w1
 tensor([[-0.8455, -0.2098, -0.3956, -0.4692],
        [-0.7161, -0.2059, -0.3667, -0.4150],
        [-0.4824, -0.0917, -0.1752, -0.2933],
        [-0.9960, -0.2233, -0.4796, -0.5571]])
difference in terms of gradient decent for w1
 tensor([[-0.1473, -0.0624, -0.0942, -0.0974],
        [-0.2109, -0.1080, -0.1405, -0.1470],
        [ 0.3118,  0.1463,  0.2129,  0.1701],
        [-0.2772, -0.1533, -0.2319, -0.1577]])
difference in terms of autograd for w2
 tensor([[-0.8455, -0.2098, -0.3956, -0.4692],