<a href="https://colab.research.google.com/github/Zhuoyue-Huang/urop_2022_ml/blob/main/Try_and_prove.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
"""try and prove different results"""
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from tqdm import tqdm
import statistics
import numpy as np
import pandas as pd 
import scipy.stats as stats
from copy import deepcopy

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# Linear auto-encoder model
class LAE(nn.Module):
    def __init__(self, n, p):
        super(LAE, self).__init__()
        self.n = n
        self.p = p
        self.w1 = nn.Linear(n, p, bias=False)
        self.w2 = nn.Linear(p, n, bias=False)

    def forward(self, y):
        y = self.w1(y)
        y = self.w2(y)
        return y

## Masked lienar auto-encoder model
# Define different types of masks
def mask_basic(prob, sample_num, x_dim, cst=0, device='cuda'):
    return torch.zeros(sample_num, x_dim, device=device).bernoulli_(prob)*(1-cst) + cst

def mask_probs(prob_list: torch.Tensor, sample_num, x_dim, cst=0, device='cuda'):
    return torch.zeros(sample_num, x_dim, device=device).bernoulli_(prob_list)*(1-cst) + cst

# def mask_patches_block(prob, patch_size: torch.Tensor, sample_num, sample_dim: torch.Tensor, device='cuda'):
#     x_dim = sample_dim[0]*sample_dim[1]
#     div_check = sample_dim % patch_size == torch.zeros(2)
#     if torch.all(div_check):
#         pix_num = torch.div(sample_dim, patch_size, rounding_mode='floor')
#         mat_patches = torch.zeros(sample_num, *pix_num).bernoulli_(prob).to(device)
#         mat_patches = torch.repeat_interleave(mat_patches, patch_size[1], dim=2)
#         return mat_patches.repeat_interleave(patch_size[0], dim=1).view(sample_num, x_dim)
#     else:
#         raise NotImplementedError(f"Both height ({sample_dim[0]}) and width ({sample_dim[1]}) should be divisible by patch_size ({patch_size}).")

def mask_patches_plain(prob, patch_size, sample_num, x_dim, cst=0, device='cuda'):
    if not x_dim % patch_size:
        pix_num = x_dim // patch_size
        mat_patches = torch.zeros(sample_num, pix_num).bernoulli_(prob)
        return mat_patches.repeat_interleave(patch_size, dim=1).to(device)*(1-cst) + cst
    else:
        raise NotImplementedError

# Masked autoencoder (linear)
class M_LAE(nn.Module):
    def __init__(self, prob, x_dim, f_dim, type='basic', patch_size=None, sample_dim=None, cst=0):
        super(M_LAE, self).__init__()
        self.prob = prob
        if sample_dim is not None:
            self.sample_dim = sample_dim
            self.H, self.W = sample_dim
            self.n = self.H * self.W
        self.n = x_dim
        self.p = f_dim
        self.cst = cst
        if type not in ['basic', 'probs', 'patches']:
            raise NotImplementedError("Could only implement 'basic', 'probs' and 'patches' type of masking.")
        else:
            self.masking_type = type
        if patch_size is not None:
            self.patch_size = patch_size
        w1 = nn.Linear(self.n, self.p, bias=False)
        w2 = nn.Linear(self.p, self.n, bias=False)
        self.body = nn.Sequential(*[w1, w2])
    
    def forward(self, X, mask=None):
        m = X.shape[0]
        if mask is None:
            if self.masking_type == 'basic':
                mask = mask_basic(self.prob, m, self.n, cst=self.cst)
            elif self.masking_type == 'probs':
                mask = mask_probs(self.prob, m, self.n, cst=self.cst)
            elif self.masking_type == 'patches':
                mask = mask_patches_plain(self.prob, self.patch_size, m, self.n, cst=self.cst)
        Y = mask * X
        Y = self.body(Y)
        return Y

class FE_Net(nn.Module):
    def __init__(self, f_dim, y_dim):
        super(FE_Net, self).__init__()
        self.theta = nn.Linear(f_dim, y_dim, bias=False)

    def forward(self, W):
        return self.theta(W)

# representation learning
def repr_learning(data_dict, model_parameters, criterion, type, epochs=200, device='cuda'):
    X_train, y_train, _, _, X_test, y_test = data_dict.values()

    params = list(model_parameters)
    W1 = params[0].clone().detach()
    f_train = X_train @ W1.T
    f_test = X_test @ W1.T
    f_dim = f_train.shape[1]
    y_dim = y_train.shape[1]

    if type=='ls':
        theta = (torch.inverse(f_train.T @ f_train) @ f_train.T @ y_train).T
        loss = criterion(y_test, f_test @ theta.T)
        return loss.item()
    elif type=='gd':
        data_dict_fe = {'X_train': f_train, 'y_train': y_train,
                        'X_test': f_test, 'y_test': y_test}
        net_fe = FE_Net(f_dim, y_dim).to(device)
        theta = list(net_fe.parameters())
        optimizer = optim.Adam(theta, lr=0.0001)
        ### TRAINING ###
        train_loop(data_dict_fe, net_fe, criterion, optimizer, epochs=epochs, record=False, type='fe')
        loss_fe = test_loop(data_dict_fe, net_fe, criterion, type='fe')
        return loss_fe

### Examine the effect of detach

In [None]:
# Linear auto-encoder model
class LAE(nn.Module):
    def __init__(self, n, p):
        super(LAE, self).__init__()
        self.n = n
        self.p = p
        self.w1 = nn.Linear(n, p, bias=False)
        self.w2 = nn.Linear(p, n, bias=False)

    def forward(self, y):
        y = self.w1(y)
        y = self.w2(y)
        return y

class FE_Net(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(FE_Net, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.theta = nn.Linear(in_dim, out_dim, bias=False)

    def forward(self, W):
        return self.theta(W)

In [None]:
def train_loop(data_dict, model, criterion, optimizer, type, epochs=10, sample_average=10, record=True):
    train_loss = []
    val_loss = []
    if type=='encoder':
        train_inputs = data_dict['train_inputs']
        train_targets = train_inputs
    elif type=='fe':
        train_inputs = data_dict['train_inputs']
        train_targets = data_dict['train_targets']
        sample_average = 1

    for epoch in range(epochs+1):
        loss_total = 0
        optimizer.zero_grad()
        for i in range(sample_average):
            train_outputs = model(train_inputs)
            loss = criterion(train_outputs, train_targets)
            loss_total += loss
        loss_total /= sample_average
        loss_total.backward()
        optimizer.step()
        if record:
          train_loss.append(loss_total.item())
          if epochs>=5 and epoch%(epochs//5)==0:
              v_loss = test_loop(data_dict, model, criterion, type)
              val_loss.append(v_loss)
              print('epoch: ', epoch, ', train loss: ', loss.item(), ', val loss', v_loss)
    if record:
        return {'train_loss': train_loss, 'val_loss': val_loss}
    else:
        return loss_total.item()

def test_loop(data_dict, model, criterion, type):
    if type=='encoder':
        val_inputs = data_dict['val_inputs']
        val_targets = val_inputs
    elif type=='fe':
        val_inputs = data_dict['val_inputs']
        val_targets = data_dict['val_targets']

    with torch.no_grad():
        val_outputs = model(val_inputs)
        loss = criterion(val_outputs, val_targets)
    return loss.item()

In [None]:
# feature extraction
def feature_extraction_test(data_dict, model_parameters, criterion, type, epochs=200, device='cuda'):
    train_inputs, train_targets, val_inputs, val_targets = data_dict.values()

    params = list(model_parameters)
    W1 = params[0]
    W10 = deepcopy(W1)
    W10 = W10.clone().detach()
    print(W10)
    train_inputs_fe0 = train_inputs @ W10.T
    val_inputs_fe0 = val_inputs @ W10.T

    W11 = deepcopy(W1)
    W11 = W11.clone()
    print(W11)
    train_inputs_fe1 = (train_inputs @ W11.T).detach()
    val_inputs_fe1 = (val_inputs @ W11.T).detach()

    reduction_dim = train_inputs_fe1.shape[1]
    target_dim = train_targets.shape[1]

    if type=='ls':
        param_fe0 = (torch.inverse(train_inputs_fe0.T@train_inputs_fe0) @ train_inputs_fe0.T @ train_targets).T
        loss0 = criterion(val_targets, val_inputs_fe0 @ param_fe0.T)
        print('test0', loss0.item())

        param_fe1 = (torch.inverse(train_inputs_fe1.T@train_inputs_fe1) @ train_inputs_fe1.T @ train_targets).T
        loss1 = criterion(val_targets, val_inputs_fe1 @ param_fe1.T)
        print('test1', loss1.item())
        #return loss.item()
    elif type=='gd':
        net_fe = FE_Net(reduction_dim, target_dim).to(device)
        net_fe0 = deepcopy(net_fe)
        net_fe1 = deepcopy(net_fe)

        data_dict_fe0 = {'train_inputs': train_inputs_fe0, 'train_targets': train_targets,
                        'val_inputs': val_inputs_fe0, 'val_targets': val_targets}
        param_fe0 = list(net_fe0.parameters())
        optimizer0 = optim.Adam(param_fe0, lr=0.0001)
        ### TRAINING ###
        loss_fe0 = train_loop(data_dict_fe0, net_fe0, criterion, optimizer0, epochs=epochs, record=False, type='fe')
        print('test0', loss_fe0)
        data_dict_fe1 = {'train_inputs': train_inputs_fe1, 'train_targets': train_targets,
                        'val_inputs': val_inputs_fe1, 'val_targets': val_targets}
        param_fe1 = list(net_fe1.parameters())
        optimizer1 = optim.Adam(param_fe1, lr=0.0001)
        ### TRAINING ###
        loss_fe1 = train_loop(data_dict_fe1, net_fe1, criterion, optimizer1, epochs=epochs, record=False, type='fe')
        print('test1', loss_fe1)
        #return loss_fe

In [None]:
train_num = 50
val_num = 20
H = 8
W = 8
sample_dim = torch.tensor([H, W])
feature_num = H * W
reduction_dim = feature_num
target_dim = feature_num

prob = 0.75
prob_list = torch.rand(feature_num)*0.2 + 0.65
patch_size = torch.div(sample_dim, 2, rounding_mode='floor')

train_inputs = torch.rand(train_num, feature_num) * 2
train_inputs = train_inputs.to(device)
train_targets = torch.rand(train_num, target_dim)
train_targets = train_targets.to(device)
val_inputs = torch.rand(val_num, feature_num) * 2
val_inputs = val_inputs.to(device)
val_targets = torch.rand(val_num, target_dim)
val_targets = val_targets.to(device)
data_dict = {'train_inputs': train_inputs, 'train_targets': train_targets,
             'val_inputs': val_inputs, 'val_targets': val_targets}

In [None]:
net_LAE = LAE(feature_num, reduction_dim).to(device)

params = list(net_LAE.parameters())
criterion = nn.MSELoss()
optimizer = optim.Adam(params, lr=0.001)

### TRAINING ###
train_loop(data_dict, net_LAE, criterion, optimizer, epochs=500, record=False, type='encoder')

# feature extraction
loss_ls = feature_extraction_test(data_dict, net_LAE.parameters(), criterion, type='ls', epochs=500)
loss_gd = feature_extraction_test(data_dict, net_LAE.parameters(), criterion, type='gd', epochs=500)

tensor([[-0.1336,  0.0542,  0.0287,  ..., -0.1842,  0.0315,  0.1009],
        [-0.0978, -0.0365, -0.1108,  ..., -0.0541, -0.1404,  0.0120],
        [ 0.1117,  0.0700,  0.0037,  ...,  0.0655,  0.0070, -0.1101],
        ...,
        [ 0.1055,  0.0057, -0.0793,  ..., -0.0329,  0.1920,  0.0874],
        [-0.1235, -0.0956,  0.1405,  ..., -0.0077,  0.0426, -0.1306],
        [ 0.1496, -0.1518, -0.1922,  ..., -0.0546, -0.0099, -0.0903]],
       device='cuda:0')
tensor([[-0.1336,  0.0542,  0.0287,  ..., -0.1842,  0.0315,  0.1009],
        [-0.0978, -0.0365, -0.1108,  ..., -0.0541, -0.1404,  0.0120],
        [ 0.1117,  0.0700,  0.0037,  ...,  0.0655,  0.0070, -0.1101],
        ...,
        [ 0.1055,  0.0057, -0.0793,  ..., -0.0329,  0.1920,  0.0874],
        [-0.1235, -0.0956,  0.1405,  ..., -0.0077,  0.0426, -0.1306],
        [ 0.1496, -0.1518, -0.1922,  ..., -0.0546, -0.0099, -0.0903]],
       device='cuda:0', grad_fn=<CloneBackward0>)
test0 606.8356323242188
test1 606.8356323242188
tensor([[-

### Least square vs gradient decent

##### check whether least square estimator has zero gradient

In [None]:
def train_loop(data_dict, model, criterion, optimizer, type, epochs=10, sample_average=10, record=True):
    train_loss = []
    val_loss = []
    if type=='encoder':
        train_x = data_dict['train_x']
        train_y = train_x
    elif type=='fe':
        train_x = data_dict['train_x']
        train_y = data_dict['train_y']
        sample_average = 1

    for epoch in range(epochs+1):
        loss_total = 0
        optimizer.zero_grad()
        for i in range(sample_average):
            outputs = model(train_x)
            loss = criterion(outputs, train_y)
            loss_total += loss
        loss_total /= sample_average
        loss_total.backward()
        optimizer.step()
        if record:
          train_loss.append(loss_total.item())
          if epochs>=5 and epoch%(epochs//5)==0:
              v_loss = test_loop(data_dict, model, criterion, type)
              val_loss.append(v_loss)
              print('epoch: ', epoch, ', train loss: ', loss.item(), ', val loss', v_loss)
    if record:
        return {'train_loss': train_loss, 'val_loss': val_loss}
    else:
        return loss_total.item()

def test_loop(data_dict, model, criterion, type):
    if type=='encoder':
        val_x = data_dict['val_x']
        val_y = val_x
    elif type=='fe':
        val_x = data_dict['val_x']
        val_y = data_dict['val_y']

    with torch.no_grad():
        outputs = model(val_x)
        loss = criterion(outputs, val_y)
    return loss.item()

In [None]:
class FE_Net(nn.Module):
    def __init__(self, z_dim, y_dim):
        super(FE_Net, self).__init__()
        self.theta = nn.Linear(z_dim, y_dim, bias=False)

    def forward(self, W):
        return self.theta(W)

def feature_extraction(data_dict, model_parameters, criterion, type, epochs=200, device='cuda'):
    train_x, train_y, val_x, val_y, _, _ = data_dict.values()

    params = list(model_parameters)
    W1 = params[0].clone().detach()
    train_z = train_x @ W1.T
    val_z = val_x @ W1.T
    z_dim = train_z.shape[1]
    y_dim = train_y.shape[1]

    if type=='ls':
        print(torch.linalg.matrix_rank(train_z))
        theta = (torch.inverse(train_z.T @ train_z) @ train_z.T @ train_y).T
        print(-train_z.T @ train_y + train_z.T @ train_z @ theta.T)
        loss = criterion(val_y, val_z @ theta.T)
        return loss.item()
    elif type=='gd':
        data_dict_fe = {'train_x': train_z, 'train_y': train_y,
                        'val_x': val_z, 'val_y': val_y}
        net_fe = FE_Net(z_dim, y_dim).to(device)
        theta = list(net_fe.parameters())
        optimizer = optim.Adam(theta, lr=0.00001)
        ### TRAINING ###
        loss_fe = train_loop(data_dict_fe, net_fe, criterion, optimizer, epochs=epochs, record=False, type='fe')
        return loss_fe

In [None]:
train_num, val_num, test_num = (60, 20, 20)
sample_num_split = (train_num, train_num+val_num)
total = train_num+val_num+test_num

# need to consider x_dim < and > z_dim
z_dim = 10 # dimension of z
H = 4
W = 4
sample_dim = torch.tensor([H, W])
x_dim = H * W
y_dim = z_dim // 2

prob = 0.75
prob_list = torch.rand(x_dim)*0.2 + 0.65
patch_size = torch.tensor([2, 2])

In [None]:
z = torch.normal(mean=0, std=1, size=(total,z_dim)) # here distribution is high dimensional guassian
z = z.to(device)

U = torch.rand(x_dim, z_dim)
U = U.to(device)
V = torch.rand(y_dim, z_dim)
V = V.to(device)

x = z @ U.T
y = z @ V.T
train_x, val_x, test_x = torch.tensor_split(x, sample_num_split)
train_y, val_y, test_y = torch.tensor_split(y, sample_num_split)
data_dict = {'train_x': train_x, 'train_y': train_y, 'val_x': val_x, 'val_y': val_y, 'test_x': test_x, 'test_y': test_y}
fe_loss_dict = {'LAE': [], 'MLAE_basic': [], 'MLAE_probs': [], 'MLAE_patches': []}

In [None]:
learning_rate = 0.01
epochs = 500
epochs_fe = 250

In [None]:
net_LAE = LAE(x_dim, z_dim).to(device)

params = list(net_LAE.parameters())
criterion = nn.MSELoss()
optimizer = optim.Adam(params, lr=learning_rate)

### TRAINING ###
train_loop(data_dict, net_LAE, criterion, optimizer, epochs=epochs, record=False, type='encoder')

# feature extraction
loss_ls = feature_extraction(data_dict, net_LAE.parameters(), criterion, type='ls')
loss_gd = feature_extraction(data_dict, net_LAE.parameters(), criterion, type='gd', epochs=50000)
print('least square loss', loss_ls)
print('gradient decent loss', loss_gd)

tensor(10, device='cuda:0')
tensor([[ 5.3024e-04,  6.4087e-04,  6.7139e-04,  8.3160e-04,  6.7902e-04],
        [ 3.3569e-04,  3.6621e-04,  2.4414e-04,  5.1880e-04,  4.5776e-04],
        [ 8.0109e-05,  1.1253e-04,  1.2589e-04,  1.4019e-04,  1.2922e-04],
        [ 1.0681e-04,  1.6022e-04,  1.9836e-04,  1.8311e-04,  1.3733e-04],
        [-4.5776e-05, -3.4332e-05, -3.0518e-05, -6.8665e-05, -3.8147e-05],
        [-2.4796e-04, -2.7466e-04, -2.4414e-04, -3.9673e-04, -3.3569e-04],
        [ 1.1253e-04,  1.6022e-04,  1.5831e-04,  2.0981e-04,  1.7548e-04],
        [ 4.5776e-05,  3.0518e-05, -7.6294e-05,  9.1553e-05,  7.6294e-05],
        [ 1.9073e-05,  2.2888e-05,  3.0518e-05,  2.4319e-05,  2.2888e-05],
        [-2.0599e-04, -2.7466e-04, -1.9836e-04, -3.8147e-04, -3.2043e-04]],
       device='cuda:0')
least square loss 5.977891098796206e-11
gradient decent loss 0.09737139940261841


##### set gradient of network to the theoretical sol

In [None]:
n = 10
p = 5
m = 100
X = torch.rand(m, n)
y = torch.rand(m, p)
theta = (torch.inverse(X.T @ X) @ X.T @ y).T

net = FE_Net(n, p)
net.theta.weight.data = theta
#print(theta)
#print(list(net.parameters()))

In [None]:
criterion = nn.MSELoss()
params = list(net.parameters())
optimiser = optim.SGD(params, lr=1)
optimiser.zero_grad()
output = net(X)
loss = criterion(output, y)
loss.backward()
optimiser.step()

In [None]:
print(list(net.parameters())[0].clone().detach()-theta)

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])


### check the analytic solution of autoencoder
Linear autoencoder

In [4]:
train_num, val_num, test_num = (4000, 500, 2000) # larger number of dataset
sample_num_split = (train_num, train_num + val_num)
total = train_num+val_num+test_num

# need to consider f_dim <, =, > z_dim
z_dim = 10 # dimension of z
H = 4
W = 4
sample_dim = torch.tensor([H, W])
x_dim = H * W
f_dim = 8
y_dim = z_dim // 2

z = torch.normal(mean=0, std=1, size=(total,z_dim), generator=torch.random.manual_seed(1911)) # here distribution is high dimensional guassian
z = z.to(device)

U = torch.rand(x_dim, z_dim, generator=torch.random.manual_seed(1911))
U = U.to(device)
V = torch.rand(y_dim, z_dim, generator=torch.random.manual_seed(1911))
V = V.to(device)

x = z @ U.T
y = z @ V.T
X_train, X_val, X_test = torch.tensor_split(x, sample_num_split)
y_train, y_val, y_test = torch.tensor_split(y, sample_num_split)
data_dict = {'X_train': X_train, 'y_train': y_train, 'X_val': X_val, 'y_val': y_val, 'X_test': X_test, 'y_test': y_test}

In [61]:
Sigma = X_train.T @ X_train
U, S, Vh = torch.linalg.svd(Sigma)
U = U[:, :f_dim]
W1 = U.T
W2 = U

In [6]:
net = LAE(x_dim, f_dim)
net.w1.weight.data = W1
net.w2.weight.data = W2

In [7]:
criterion = nn.MSELoss()
params = list(net.parameters())
optimiser = optim.SGD(params, lr=0.3)

optimiser.zero_grad()
output = net(X_train)
loss = criterion(output, X_train)
loss.backward()
optimiser.step()

In [8]:
print(list(net.parameters())[0].clone().detach()-W1)

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
       device='cuda:0')


In [9]:
print(list(net.parameters())[1].clone().detach()-W2)

tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]], device='cuda:0')


Masked linear autoencoder

In [112]:
cst = 5
prob = 0
prob_list = torch.rand(x_dim)
patch_size = 4
model_type = 'MLAE_patches'
mask_dict = {}
if model_type == 'MLAE_basic':
    first_moment = prob + cst * (1-prob)
    second_moment = prob + cst**2 * (1-prob)
    mask_dict['mean'] = (torch.ones(train_num, x_dim, device=device) * first_moment)
    mask_dict['square'] = torch.ones(x_dim, x_dim, device=device) * first_moment**2
    mask_dict['square'].fill_diagonal_(second_moment)
elif model_type == 'MLAE_probs':
    first_moment = prob_list + cst * (1-prob_list)
    second_moment = prob_list + cst**2 * (1-prob_list)
    mask_dict['mean'] = first_moment.repeat(train_num, 1).to(device)
    mask_dict['square'] = first_moment.view(x_dim, 1) @ first_moment.view(1, x_dim)
    mask_dict['square'] = (mask_dict['square'].fill_diagonal_(0) + torch.diag(second_moment)).to(device)
elif model_type == 'MLAE_patches':
    first_moment = prob + cst * (1-prob)
    second_moment = prob + cst**2 * (1-prob)
    mask_dict['mean'] = torch.ones(train_num, x_dim, device=device) * first_moment
    # plain patches
    patch_block_mat = torch.ones(patch_size, patch_size, device=device)
    mask_dict['square'] = torch.block_diag(*[patch_block_mat]*(x_dim//patch_size))*(second_moment-first_moment**2) + first_moment**2
# block patches
# patch_size = [2, 2]
# pix_num = torch.div(torch.tensor(sample_dim), torch.tensor(patch_size), rounding_mode='floor')
# mat_patches = torch.arange(pix_num[0]*pix_num[1]).view(*pix_num)
# mat_patches = torch.repeat_interleave(mat_patches, patch_size[1], dim=1)
# mat_patches = torch.repeat_interleave(mat_patches, patch_size[0], dim=0).view(n)
# mask_dict['MLAE_patches']['square'] = torch.zeros(x_dim, x_dim)
# for i in range(x_dim):
#     for j in range(x_dim):
#         if mat_patches[i] == mat_patches[j]:
#             mask_dict['MLAE_patches']['square'][i, j] = prob
#         else:
#             mask_dict['MLAE_patches']['square'][i, j] = prob**2

In [113]:
X_train, _, X_val, _, _, _ = data_dict.values()
X = X_train
try:
    torch.inverse(mask_dict['square'] * (X.T @ X))
except torch._C._LinAlgError:
    # Sigma_XX = X.T@X
    # U_XX, S_XX, Vh_XX = torch.linalg.svd(Sigma_XX)
    # print(S_XX)
    # inv = Vh_XX.T*torch.inverse(torch.diag(S_XX+1e-4))*U_XX.T

    perturbation = torch.diag(torch.ones(x_dim, device=device)*1e-5)
    inv = torch.inverse(mask_dict['square']*(X.T@X) + perturbation)
else:
    inv = torch.inverse(mask_dict['square'] * (X.T @ X))
Sigma = X.T@(mask_dict['mean']*X) @ inv @ ((mask_dict['mean']*X).T@X)
U, _, _ = torch.linalg.svd(Sigma)
U = U[:, :f_dim]
W1 = U.T @ (X.T@(mask_dict['mean']*X)) @ inv
W2 = U
# if model_type == 'MLAE_basic':
#     loss = criterion((mask_basic(prob, val_num, x_dim, cst=cst)*X_val)@W1.T@W2.T, X_val)
# elif model_type == 'MLAE_probs':
#     loss = criterion((mask_probs(prob_list, val_num, x_dim, cst=cst)*X_val)@W1.T@W2.T, X_val)
# elif model_type == 'MLAE_patches':
#     loss = criterion((mask_patches_plain(prob, patch_size, val_num, x_dim, cst=cst)*X_val)@W1.T@W2.T, X_val)
W1

tensor([[-0.0625,  0.1250, -0.0312, -0.3125,  0.1250, -0.0312, -0.5000,  0.2500,
          0.3750,  0.0000,  0.6250,  0.1875,  0.1094, -0.1875,  0.0781,  0.1250],
        [-0.0703, -0.1172, -0.0635, -0.0078, -0.0391,  0.0000, -0.3438, -0.1016,
         -0.0078,  0.2188,  0.1250,  0.0391, -0.0859, -0.2500,  0.1328,  0.1562],
        [ 0.1641, -0.1406,  0.0625,  0.0938,  0.0000,  0.1211, -0.0625, -0.1328,
         -0.1406,  0.0938, -0.1250,  0.0430,  0.0078,  0.4062, -0.0703, -0.1406],
        [ 0.0312, -0.0938,  0.1875,  0.0312, -0.0391,  0.0703,  0.0000,  0.0000,
         -0.1250,  0.1875,  0.0000, -0.0312,  0.0391,  0.2344, -0.0391,  0.0078],
        [ 0.0195, -0.0508, -0.1621,  0.0938,  0.1133,  0.0137, -0.0508, -0.0234,
          0.0352, -0.0781,  0.1641, -0.0400, -0.0693, -0.1484,  0.0000,  0.1406],
        [ 0.0625,  0.0938,  0.1406, -0.1953, -0.1055,  0.0117,  0.1875, -0.0156,
          0.0000, -0.3125, -0.1562,  0.0625,  0.0742,  0.0938, -0.0430, -0.0469],
        [-0.0234, -0.1

In [114]:
rand_loop = 100
total_W1 = 0
total_W2 = 0
for i in range(rand_loop):
    net = M_LAE(prob, x_dim, f_dim, type='patches', patch_size=patch_size, cst=cst).to(device)
    net.body[0].weight.data = W1.clone().detach()
    net.body[1].weight.data = W2.clone().detach()
    criterion = nn.MSELoss()
    optimiser = optim.SGD(net.body.parameters(), lr=0.003)

    optimiser.zero_grad()
    output = net(X)
    loss = criterion(output, X)
    loss.backward()
    optimiser.step()

    total_W1 += net.body[0].weight.data.clone().detach().cpu()-W1.cpu()
    total_W2 += net.body[1].weight.data.clone().detach().cpu()-W2.cpu()

total_W1 /= rand_loop
total_W2 /= rand_loop

In [115]:
print(total_W1)

tensor([[-0.0033, -0.0014, -0.0014, -0.0016, -0.0028, -0.0032, -0.0019, -0.0039,
         -0.0024, -0.0027, -0.0031, -0.0031, -0.0011, -0.0037, -0.0030, -0.0044],
        [ 0.0036,  0.0043,  0.0038,  0.0052,  0.0030,  0.0034,  0.0046,  0.0043,
          0.0036,  0.0052,  0.0043,  0.0048,  0.0049,  0.0028,  0.0047,  0.0044],
        [ 0.0011,  0.0013,  0.0011,  0.0016,  0.0008,  0.0009,  0.0012,  0.0011,
          0.0012,  0.0014,  0.0010,  0.0017,  0.0012,  0.0009,  0.0016,  0.0015],
        [-0.0043, -0.0060, -0.0045, -0.0068, -0.0039, -0.0040, -0.0059, -0.0054,
         -0.0043, -0.0067, -0.0055, -0.0062, -0.0063, -0.0035, -0.0058, -0.0049],
        [ 0.0007,  0.0008,  0.0007,  0.0008,  0.0006,  0.0006,  0.0005,  0.0008,
          0.0006,  0.0008,  0.0005,  0.0009,  0.0006,  0.0005,  0.0007,  0.0006],
        [ 0.0021,  0.0024,  0.0021,  0.0028,  0.0018,  0.0018,  0.0024,  0.0024,
          0.0024,  0.0031,  0.0019,  0.0028,  0.0026,  0.0017,  0.0029,  0.0024],
        [-0.0016, -0.0

In [116]:
print(total_W2)

tensor([[-7.4002e-04,  2.4545e-04, -4.6641e-05, -3.4681e-04, -9.7010e-05,
          1.5509e-04, -8.5223e-05,  1.4651e-04],
        [ 9.3579e-06, -1.0353e-04, -1.7381e-04,  4.6581e-05, -9.5457e-05,
         -1.1124e-04,  9.6813e-05,  1.6436e-05],
        [-3.4104e-03,  1.1090e-03, -2.7287e-04, -1.6612e-03,  6.4969e-05,
          1.0751e-03, -4.6735e-04,  4.7761e-04],
        [-1.5506e-04, -4.1834e-04, -3.0786e-04,  1.7340e-04, -2.6630e-04,
         -1.3183e-04,  2.5612e-04,  2.2026e-05],
        [-9.6220e-04,  4.2802e-04, -3.5077e-05, -5.2085e-04,  4.2595e-05,
          2.4971e-04, -1.3104e-04,  1.7133e-04],
        [-1.2338e-03,  3.7853e-04, -1.5220e-04, -5.5110e-04, -1.6898e-05,
          3.7715e-04, -1.1849e-04,  1.9169e-04],
        [ 9.7531e-04, -3.8153e-04, -1.9763e-04,  5.5108e-04,  6.8458e-05,
         -3.2629e-04,  2.3210e-04, -1.6497e-04],
        [-5.1880e-03,  1.5033e-03, -2.6818e-04, -2.5878e-03,  1.1517e-04,
          1.7832e-03, -7.9195e-04,  6.9547e-04],
        [-1.5095