### Setup

In [9]:
"""Verify the gradient derivation of linear masked autoencoder."""
import torch
import torch.nn as nn
import torch.optim as optim

In [15]:
# Initialise matrix and vector
# prob is un-masking prob
sample_num = 3
H = 4
W = 1
sample_dim = [H, W]
feature_num = H * W

m = sample_num
n = feature_num
p = feature_num // 2
prob = 0.75
X = torch.rand(m, n)
W1 = torch.rand(p, n, requires_grad=True)
W2 = torch.rand(n, p, requires_grad=True)
#print(X)
#print(W1)
#print(W2)

In [16]:
# Define different types of masks
def mask_basic(prob, sample_num, feature_num):
    return torch.zeros(sample_num, feature_num).bernoulli_(prob)

def mask_dropping_probs(prob_list: torch.Tensor, sample_num, feature_num):
    return torch.zeros(sample_num, feature_num).bernoulli_(prob_list)

def mask_patches(prob, patch_size, sample_num, sample_dim):
    patch_size = torch.tensor(patch_size)
    sample_dim = torch.tensor(sample_dim)
    feature_num = sample_dim[0]*sample_dim[1]
    div_check = sample_dim % patch_size == torch.zeros(2)
    if torch.all(div_check):
        pix_num = torch.div(sample_dim, patch_size, rounding_mode='floor')
        mat_patches = torch.zeros(sample_num, *pix_num).bernoulli_(prob)
        mat_patches = torch.repeat_interleave(mat_patches, patch_size[1], dim=2)
        return mat_patches.repeat_interleave(patch_size[0], dim=1).view(sample_num, feature_num)
    else:
        raise NotImplementedError(f"Both height ({H}) and width ({W}) should be divisible by patch_size ({patch_size}).")

In [17]:
# Masked autoencoder (linear)
class M_LAE(nn.Module):
    def __init__(self, prob, sample_num, sample_dim, reduction_dim, type='basic', patch_size=None):
        super(M_LAE, self).__init__()
        self.prob = prob
        self.m = sample_num
        self.sample_dim = sample_dim
        self.H, self.W = sample_dim
        self.n = self.H * self.W
        self.p = reduction_dim
        if type not in ['basic', 'probs', 'patches']:
            raise NotImplementedError("Could only implement 'basic', 'probs' and 'patches' type of masking.")
        else:
            self.masking_type = type
        if patch_size is not None:
            self.patch_size = patch_size
        w1 = nn.Linear(self.n, self.p, bias=False)
        w2 = nn.Linear(self.p, self.n, bias=False)
        self.body = nn.Sequential(*[w1, w2])
    
    def forward(self, X, mask=None):
        if mask is None:
            if self.masking_type == 'basic':
                mask = mask_basic(self.prob, self.m, self.n)
            elif self.masking_type == 'probs':
                mask = mask_dropping_probs(self.prob, self.m, self.n)
            elif self.masking_type == 'patches':
                mask = mask_patches(self.prob, self.patch_size, self.m, self.sample_dim)
        Y = mask * X
        Y = self.body(Y)
        return Y

### Basic Masking

In [18]:
# define loss function in terms of W1 and W2
def loss_func_W1_basic(W1):
    z = (mask_basic(prob, m, n)*X) @ W1.T @ W2.T - X
    return sum(sum(z*z)) / m / n

def loss_func_W2_basic(W2):
    z = (mask_basic(prob, m, n)*X) @ W1.T @ W2.T - X
    return sum(sum(z*z)) / m / n

In [19]:
# find the theoretical and numerical solutions of W1 and W2
# mask_basic
mean_m_basic = torch.ones(m, n) * prob
square_m_basic = torch.ones(n, n) * prob**2
square_m_basic.fill_diagonal_(prob)

grad_w1_theory_basic = W2.T @ (W2@W1@(square_m_basic*(X.T@X)) - X.T@(mean_m_basic*X)) * (2/m/n)
grad_w2_theory_basic = (W2@W1@(square_m_basic*(X.T@X))-X.T@(mean_m_basic*X)) @ W1.T * (2/m/n)

# Initialising sampling
N = 10000
grad_w1_numer_basic = 0
grad_w2_numer_basic = 0
# Sampling process
for i in range(N):
    grad_w1_numer_basic += torch.autograd.functional.jacobian(loss_func_W1_basic, W1)
    grad_w2_numer_basic += torch.autograd.functional.jacobian(loss_func_W2_basic, W2)
    #W1.detach()
    #W2.detach()
grad_w1_numer_basic /= N
grad_w2_numer_basic /= N

# absolute value check
abs(grad_w1_theory_basic-grad_w1_numer_basic)

tensor([[7.8443e-04, 7.4720e-04, 1.0729e-06, 6.1157e-04],
        [5.2887e-04, 9.8798e-04, 3.3244e-05, 9.1496e-04]],
       grad_fn=<AbsBackward0>)

In [20]:
# matrix norm check
norm_diff_w1_basic = torch.linalg.matrix_norm(grad_w1_numer_basic-grad_w1_theory_basic)
norm_diff_w2_basic = torch.linalg.matrix_norm(grad_w2_numer_basic-grad_w2_theory_basic)

print("The matrix norm of the difference between theoretical and numerical solutions of W1:", norm_diff_w1_basic.item())
print("The matrix norm of the difference between theoretical and numerical solutions of W2:", norm_diff_w2_basic.item())

The matrix norm of the difference between theoretical and numerical solutions of W1: 0.0019083305960521102
The matrix norm of the difference between theoretical and numerical solutions of W2: 0.0018185757799074054


##### gradient decent check

In [21]:
# Initialising network
learning_rate = 0.01
criterion = nn.MSELoss()

m_net = M_LAE(prob, m, sample_dim, p)
inputs = X
targets = X

optimizer = optim.SGD(m_net.body.parameters(), lr=learning_rate)
params0 = list(m_net.body.parameters())
#print(list(m_net.body.parameters()))
W10 = params0[0].clone().detach()
W20 = params0[1].clone().detach()

# theoretical gradients
grad_w1_theory_basic = W20.T @ (W20@W10@(square_m_basic*(X.T@X)) - X.T@(mean_m_basic*X)) * (2/m/n)
grad_w2_theory_basic = (W20@W10@(square_m_basic*(X.T@X))-X.T@(mean_m_basic*X)) @ W10.T * (2/m/n)

# autograd gradients
grad_w1_numer_basic = torch.autograd.functional.jacobian(loss_func_W1_basic, W10)
grad_w2_numer_basic = torch.autograd.functional.jacobian(loss_func_W2_basic, W20)

# one-step gradient decent
optimizer.zero_grad()
outputs = m_net(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()

params1 = list(m_net.body.parameters())
#print(list(m_net.body.parameters()))
W11 = params1[0].clone().detach()
W21 = params1[1].clone().detach()

gradient_W1 = (W10 - W11) / learning_rate
gradient_W2 = (W20 - W21) / learning_rate
print('The difference between autograd and gradient decent for w1\n', gradient_W1-grad_w1_numer_basic)
print('The difference between autograd and gradient decent for w2\n', gradient_W2-grad_w2_numer_basic)
print('difference in terms of autograd for w1\n', grad_w1_numer_basic-grad_w1_theory_basic)
print('difference in terms of gradient decent for w1\n', gradient_W1-grad_w1_theory_basic)
print('difference in terms of autograd for w2\n', grad_w1_numer_basic-grad_w1_theory_basic)
print('difference in terms of gradient decent for w2\n', gradient_W2-grad_w2_theory_basic)

The difference between autograd and gradient decent for w1
 tensor([[-0.0247,  0.0949,  0.0767,  0.0656],
        [ 0.2116,  0.1435,  0.1816,  0.4203]])
The difference between autograd and gradient decent for w2
 tensor([[ 0.2152,  0.1965],
        [-0.0246, -0.1369],
        [ 0.1903,  0.1562],
        [ 0.1057, -0.0219]])
difference in terms of autograd for w1
 tensor([[-0.0354, -0.0849, -0.1052, -0.1579],
        [-0.1689, -0.2375, -0.2009, -0.3544]])
difference in terms of gradient decent for w1
 tensor([[-0.0601,  0.0100, -0.0284, -0.0923],
        [ 0.0426, -0.0940, -0.0192,  0.0660]])
difference in terms of autograd for w2
 tensor([[-0.0354, -0.0849, -0.1052, -0.1579],
        [-0.1689, -0.2375, -0.2009, -0.3544]])
difference in terms of gradient decent for w2
 tensor([[ 0.0660, -0.0160],
        [ 0.0389, -0.0427],
        [ 0.0507, -0.0314],
        [ 0.0580, -0.0558]])


### Masking with different rate

In [22]:
# define loss function in terms of W1 and W2
def loss_func_W1_probs(W1):
    z = (mask_dropping_probs(prob_list, m, n)*X) @ W1.T @ W2.T - X
    return sum(sum(z*z)) / m / n

def loss_func_W2_probs(W2):
    z = (mask_dropping_probs(prob_list, m, n)*X) @ W1.T @ W2.T - X
    return sum(sum(z*z)) / m / n

In [23]:
# mask_dropping_probs
#prob_list = torch.arange(1/(n+1), 1-0.01, 1/(n+1))
#random.shuffle(prob_list)
prob_list = torch.rand(n)*0.1 + 0.2
mean_m_probs = prob_list.repeat(m, 1)
square_m_probs = prob_list.view(n, 1) @ prob_list.view(1, n)
square_m_probs = square_m_probs.fill_diagonal_(0) + torch.diag(prob_list)

grad_w1_theory_probs = W2.T @ (W2@W1@(square_m_probs*(X.T@X)) - X.T@(mean_m_probs*X)) * (2/m/n)
grad_w2_theory_probs = (W2@W1@(square_m_probs*(X.T@X))-X.T@(mean_m_probs*X)) @ W1.T * (2/m/n)

# Initialising sampling
N = 10000
grad_w1_numer_probs = 0
grad_w2_numer_probs = 0
# Sampling process
for i in range(N):
    grad_w1_numer_probs += torch.autograd.functional.jacobian(loss_func_W1_probs, W1)
    grad_w2_numer_probs += torch.autograd.functional.jacobian(loss_func_W2_probs, W2)
grad_w1_numer_probs /= N
grad_w2_numer_probs /= N

abs(grad_w1_theory_probs-grad_w1_numer_probs)

tensor([[9.4836e-04, 1.3860e-03, 4.9757e-05, 1.9649e-03],
        [7.2881e-04, 1.0648e-03, 2.2769e-05, 2.2316e-03]],
       grad_fn=<AbsBackward0>)

In [24]:
norm_diff_w1_probs = torch.linalg.matrix_norm(grad_w1_numer_probs-grad_w1_theory_probs)
norm_diff_w2_probs = torch.linalg.matrix_norm(grad_w2_numer_probs-grad_w2_theory_probs)

print("The matrix norm of the difference between theoretical and numerical solutions of W1:", norm_diff_w1_probs.item())
print("The matrix norm of the difference between theoretical and numerical solutions of W2:", norm_diff_w2_probs.item())

The matrix norm of the difference between theoretical and numerical solutions of W1: 0.0036509279161691666
The matrix norm of the difference between theoretical and numerical solutions of W2: 0.0023121535778045654


##### Small mask test

In [25]:
"""
m = 1
n = 2
X = torch.tensor([[2, 3]]).float()
W1 = torch.tensor([[1, 0], [1, 1]]).float()
W2 = torch.tensor([[1, -1], [0, 1]]).float()

prob_list = torch.tensor([0.5, 0.8])
mean_m_probs = prob_list.repeat(1, 1)
square_m_probs = prob_list.view(2, 1) @ prob_list.view(1, 2)
square_m_probs = square_m_probs.fill_diagonal_(0) + torch.diag(prob_list)

grad_w1_theory_probs = W2.T @ (W2@W1@(square_m_probs*(X.T@X)) - X.T@(mean_m_probs*X))
grad_w2_theory_probs = (W2@W1@(square_m_probs*(X.T@X))-X.T@(mean_m_probs*X)) @ W1.T

print(grad_w1_theory_probs)
print(torch.autograd.functional.jacobian(loss_func_W1_probs, W1))
"""

'\nm = 1\nn = 2\nX = torch.tensor([[2, 3]]).float()\nW1 = torch.tensor([[1, 0], [1, 1]]).float()\nW2 = torch.tensor([[1, -1], [0, 1]]).float()\n\nprob_list = torch.tensor([0.5, 0.8])\nmean_m_probs = prob_list.repeat(1, 1)\nsquare_m_probs = prob_list.view(2, 1) @ prob_list.view(1, 2)\nsquare_m_probs = square_m_probs.fill_diagonal_(0) + torch.diag(prob_list)\n\ngrad_w1_theory_probs = W2.T @ (W2@W1@(square_m_probs*(X.T@X)) - X.T@(mean_m_probs*X))\ngrad_w2_theory_probs = (W2@W1@(square_m_probs*(X.T@X))-X.T@(mean_m_probs*X)) @ W1.T\n\nprint(grad_w1_theory_probs)\nprint(torch.autograd.functional.jacobian(loss_func_W1_probs, W1))\n'

##### gradient decent

In [26]:
# Initialising network
learning_rate = 0.01
criterion = nn.MSELoss()

m_net = M_LAE(prob_list, m, sample_dim, p, type='probs')
inputs = X
targets = X

optimizer = optim.SGD(m_net.body.parameters(), lr=learning_rate)
params0 = list(m_net.body.parameters())
#print(list(m_net.body.parameters()))
W10 = params0[0].clone().detach()
W20 = params0[1].clone().detach()

# theoretical gradients
grad_w1_theory_probs = W2.T @ (W2@W1@(square_m_probs*(X.T@X)) - X.T@(mean_m_probs*X)) * (2/m/n)
grad_w2_theory_probs = (W2@W1@(square_m_probs*(X.T@X))-X.T@(mean_m_probs*X)) @ W1.T * (2/m/n)
# autograd gradients
grad_w1_numer_probs = torch.autograd.functional.jacobian(loss_func_W1_probs, W1)
grad_w2_numer_probs = torch.autograd.functional.jacobian(loss_func_W2_probs, W2)

# one-step gradient decent
optimizer.zero_grad()
outputs = m_net(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()

params1 = list(m_net.body.parameters())
#print(list(m_net.body.parameters()))
W11 = params1[0].clone().detach()
W21 = params1[1].clone().detach()

gradient_W1 = (W10 - W11) / learning_rate
gradient_W2 = (W20 - W21) / learning_rate
print('The difference between autograd and gradient decent for w1\n', gradient_W1-grad_w1_numer_basic)
print('The difference between autograd and gradient decent for w2\n', gradient_W2-grad_w2_numer_basic)
print('difference in terms of autograd for w1\n', grad_w1_numer_basic-grad_w1_theory_basic)
print('difference in terms of gradient decent for w1\n', gradient_W1-grad_w1_theory_basic)
print('difference in terms of autograd for w2\n', grad_w1_numer_basic-grad_w1_theory_basic)
print('difference in terms of gradient decent for w2\n', gradient_W2-grad_w2_theory_basic)

The difference between autograd and gradient decent for w1
 tensor([[0.0901, 0.1308, 0.1312, 0.2385],
        [0.0398, 0.0421, 0.0918, 0.1560]])
The difference between autograd and gradient decent for w2
 tensor([[ 0.1920,  0.3329],
        [-0.1256,  0.0346],
        [ 0.1294,  0.3111],
        [ 0.0284,  0.1786]])
difference in terms of autograd for w1
 tensor([[-0.0354, -0.0849, -0.1052, -0.1579],
        [-0.1689, -0.2375, -0.2009, -0.3544]])
difference in terms of gradient decent for w1
 tensor([[ 0.0547,  0.0459,  0.0260,  0.0806],
        [-0.1292, -0.1954, -0.1091, -0.1984]])
difference in terms of autograd for w2
 tensor([[-0.0354, -0.0849, -0.1052, -0.1579],
        [-0.1689, -0.2375, -0.2009, -0.3544]])
difference in terms of gradient decent for w2
 tensor([[ 0.0428,  0.1204],
        [-0.0621,  0.1289],
        [-0.0102,  0.1235],
        [-0.0193,  0.1447]])


### Masking in terms of patches

In [27]:
# define loss function in terms of W1 and W2
def loss_func_W1_patches(W1):
    z = (mask_patches(prob, patch_size, m, sample_dim)*X) @ W1.T @ W2.T - X
    return sum(sum(z*z)) / m / n

def loss_func_W2_patches(W2):
    z = (mask_patches(prob, patch_size, m, sample_dim)*X) @ W1.T @ W2.T - X
    return sum(sum(z*z)) / m / n

In [28]:
# mask_patches
mean_m_patches = torch.ones(m, n) * prob
patch_size = [2, 1]
pix_num = torch.div(torch.tensor(sample_dim), torch.tensor(patch_size), rounding_mode='floor')
mat_patches = torch.arange(pix_num[0]*pix_num[1]).view(*pix_num)
mat_patches = torch.repeat_interleave(mat_patches, patch_size[1], dim=1)
mat_patches = torch.repeat_interleave(mat_patches, patch_size[0], dim=0).view(n)
square_m_patches = torch.zeros(n, n)
for i in range(n):
    for j in range(n):
        if mat_patches[i] == mat_patches[j]:
            square_m_patches[i, j] = prob
        else:
            square_m_patches[i, j] = prob**2

grad_w1_theory_patches = W2.T @ (W2@W1@(square_m_patches*(X.T@X)) - X.T@(mean_m_patches*X)) * (2/m/n)
grad_w2_theory_patches = (W2@W1@(square_m_patches*(X.T@X))-X.T@(mean_m_patches*X)) @ W1.T * (2/m/n)

# Initialising sampling
N = 10000
grad_w1_numer_patches = 0
grad_w2_numer_patches = 0
# Sampling process
for i in range(N):
    grad_w1_numer_patches += torch.autograd.functional.jacobian(loss_func_W1_patches, W1)
    grad_w2_numer_patches += torch.autograd.functional.jacobian(loss_func_W2_patches, W2)
grad_w1_numer_patches /= N
grad_w2_numer_patches /= N

abs(grad_w1_theory_patches-grad_w1_numer_patches)

tensor([[0.0003, 0.0006, 0.0003, 0.0005],
        [0.0003, 0.0007, 0.0004, 0.0007]], grad_fn=<AbsBackward0>)

In [29]:
norm_diff_w1_patches = torch.linalg.matrix_norm(grad_w1_numer_patches-grad_w1_theory_patches)
norm_diff_w2_patches = torch.linalg.matrix_norm(grad_w2_numer_patches-grad_w2_theory_patches)

print("The matrix norm of the difference between theoretical and numerical solutions of W1:", norm_diff_w1_patches.item())
print("The matrix norm of the difference between theoretical and numerical solutions of W2:", norm_diff_w2_patches.item())

The matrix norm of the difference between theoretical and numerical solutions of W1: 0.0014025537529960275
The matrix norm of the difference between theoretical and numerical solutions of W2: 0.005967330187559128


##### gradient decent

In [30]:
# Initialising network
learning_rate = 0.01
criterion = nn.MSELoss()

m_net = M_LAE(prob, m, sample_dim, p, type='patches', patch_size=[1, 1])
inputs = X
targets = X

optimizer = optim.SGD(m_net.body.parameters(), lr=learning_rate)
params0 = list(m_net.body.parameters())
#print(list(m_net.body.parameters()))
W10 = params0[0].clone().detach()
W20 = params0[1].clone().detach()

# theoretical gradients
grad_w1_theory_patches = W2.T @ (W2@W1@(square_m_patches*(X.T@X)) - X.T@(mean_m_patches*X)) * (2/m/n)
grad_w2_theory_patches = (W2@W1@(square_m_patches*(X.T@X))-X.T@(mean_m_patches*X)) @ W1.T * (2/m/n)
# autograd gradients
grad_w1_numer_patches = torch.autograd.functional.jacobian(loss_func_W1_patches, W1)
grad_w2_numer_patches = torch.autograd.functional.jacobian(loss_func_W2_patches, W2)

# one-step gradient decent
optimizer.zero_grad()
outputs = m_net(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()

params1 = list(m_net.body.parameters())
#print(list(m_net.body.parameters()))
W11 = params1[0].clone().detach()
W21 = params1[1].clone().detach()

gradient_W1 = (W10 - W11) / learning_rate
gradient_W2 = (W20 - W21) / learning_rate
print('The difference between autograd and gradient decent for w1\n', gradient_W1-grad_w1_numer_basic)
print('The difference between autograd and gradient decent for w2\n', gradient_W2-grad_w2_numer_basic)
print('difference in terms of autograd for w1\n', grad_w1_numer_basic-grad_w1_theory_basic)
print('difference in terms of gradient decent for w1\n', gradient_W1-grad_w1_theory_basic)
print('difference in terms of autograd for w2\n', grad_w1_numer_basic-grad_w1_theory_basic)
print('difference in terms of gradient decent for w2\n', gradient_W2-grad_w2_theory_basic)

The difference between autograd and gradient decent for w1
 tensor([[ 0.0901, -0.0779,  0.0164,  0.0970],
        [ 0.0398,  0.0037,  0.0617,  0.1137]])
The difference between autograd and gradient decent for w2
 tensor([[ 0.1942,  0.3982],
        [-0.0475,  0.0602],
        [ 0.1653,  0.3566],
        [ 0.0752,  0.2145]])
difference in terms of autograd for w1
 tensor([[-0.0354, -0.0849, -0.1052, -0.1579],
        [-0.1689, -0.2375, -0.2009, -0.3544]])
difference in terms of gradient decent for w1
 tensor([[ 0.0547, -0.1628, -0.0888, -0.0608],
        [-0.1292, -0.2338, -0.1392, -0.2407]])
difference in terms of autograd for w2
 tensor([[-0.0354, -0.0849, -0.1052, -0.1579],
        [-0.1689, -0.2375, -0.2009, -0.3544]])
difference in terms of gradient decent for w2
 tensor([[0.0450, 0.1857],
        [0.0160, 0.1544],
        [0.0257, 0.1690],
        [0.0275, 0.1806]])
