In [None]:
"""Verify the gradient derivation of linear masked autoencoder."""
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

### Model initialisation

In [None]:
# Linear auto-encoder model
class LAE(nn.Module):
    def __init__(self, n, p):
        super(LAE, self).__init__()
        self.n = n
        self.p = p
        self.w1 = nn.Linear(n, p, bias=False)
        self.w2 = nn.Linear(p, n, bias=False)

    def forward(self, y):
        y = self.w1(y)
        y = self.w2(y)
        return y

In [None]:
## Masked lienar auto-encoder model
# Define different types of masks
def mask_basic(prob, sample_num, feature_num):
    return torch.zeros(sample_num, feature_num).bernoulli_(prob)

def mask_dropping_probs(prob_list: torch.Tensor, sample_num, feature_num):
    return torch.zeros(sample_num, feature_num).bernoulli_(prob_list)

def mask_patches(prob, patch_size, sample_num, sample_dim):
    patch_size = torch.tensor(patch_size)
    sample_dim = torch.tensor(sample_dim)
    feature_num = sample_dim[0]*sample_dim[1]
    div_check = sample_dim % patch_size == torch.zeros(2)
    if torch.all(div_check):
        pix_num = torch.div(sample_dim, patch_size, rounding_mode='floor')
        mat_patches = torch.zeros(sample_num, *pix_num).bernoulli_(prob)
        mat_patches = torch.repeat_interleave(mat_patches, patch_size[1], dim=2)
        return mat_patches.repeat_interleave(patch_size[0], dim=1).view(sample_num, feature_num)
    else:
        raise NotImplementedError(f"Both height ({sample_dim[0]}) and width ({sample_dim[1]}) should be divisible by patch_size ({patch_size}).")


# Masked autoencoder (linear)
class M_LAE(nn.Module):
    def __init__(self, prob, sample_dim, reduction_dim, type='basic', patch_size=None):
        super(M_LAE, self).__init__()
        self.prob = prob
        self.sample_dim = sample_dim
        self.H, self.W = sample_dim
        self.n = self.H * self.W
        self.p = reduction_dim
        if type not in ['basic', 'probs', 'patches']:
            raise NotImplementedError("Could only implement 'basic', 'probs' and 'patches' type of masking.")
        else:
            self.masking_type = type
        if patch_size is not None:
            self.patch_size = patch_size
        w1 = nn.Linear(self.n, self.p, bias=False)
        w2 = nn.Linear(self.p, self.n, bias=False)
        self.body = nn.Sequential(*[w1, w2])
    
    def forward(self, X, mask=None):
        m = X.shape[0]
        if mask is None:
            if self.masking_type == 'basic':
                mask = mask_basic(self.prob, m, self.n)
            elif self.masking_type == 'probs':
                mask = mask_dropping_probs(self.prob, m, self.n)
            elif self.masking_type == 'patches':
                mask = mask_patches(self.prob, self.patch_size, m, self.sample_dim)
        Y = mask * X
        Y = self.body(Y)
        return Y

In [None]:
def train_loop(data_dict, model, criterion, optimizer, epochs=10, sample_average=20):
    test_loss = []
    val_loss = []

    test_inputs = data_dict['test_inputs']
    test_targets = data_dict['test_targets']

    for epoch in range(epochs+1):
        loss_total = 0
        optimizer.zero_grad()
        for i in range(sample_average):
            test_outputs = model(test_inputs)
            loss = criterion(test_outputs, test_targets)
            loss_total += loss
        loss_total /= sample_average
        loss_total.backward()
        optimizer.step()
        test_loss.append(loss_total.item())
        if epoch % (epochs//10) == 0:
            v_loss = test_loop(data_dict, model, criterion)
            val_loss.append(v_loss)
            print('epoch: ', epoch, ', test loss: ', loss.item(), ', val loss', v_loss)
    return {'test_loss': test_loss, 'val_loss': val_loss}

def test_loop(data_dict, model, criterion):
    val_inputs = data_dict['val_inputs']
    val_targets = data_dict['val_targets']

    with torch.no_grad():
        val_outputs = model(val_inputs)
        loss = criterion(val_outputs, val_targets)
    return loss.item()

In [None]:
# feature extraction
class FE_Net(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(FE_Net, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.theta = nn.Linear(in_dim, out_dim, bias=False)

    def forward(self, W):
        return self.theta(W)

### Dataset

In [None]:
test_num = 80
val_num = 20
H = 8
W = 8
sample_dim = torch.tensor([H, W])
feature_num = H * W
reduction_dim = 3 * feature_num // 4
target_dim = feature_num // 2

prob = 0.75
prob_list = torch.rand(feature_num)*0.2 + 0.65
patch_size = torch.div(sample_dim, 4, rounding_mode='floor')

test_inputs = torch.rand(test_num, feature_num) * 2
test_targets = torch.rand(test_num, target_dim)
val_inputs = torch.rand(val_num, feature_num) * 2
val_targets = torch.rand(val_num, target_dim)
data_dict = {'test_inputs': test_inputs, 'test_targets': test_inputs,
             'val_inputs': val_inputs, 'val_targets': val_inputs}

### Get features from autoencoder

In [None]:
learning_rate = 0.0001
fe_learning_rate = 0.00003
epochs = 8000
fe_epochs = 5000

##### Linear autoencoder

In [None]:
net_LAE = LAE(feature_num, reduction_dim)

params = list(net_LAE.parameters())
criterion = nn.MSELoss()
optimizer = optim.Adam(params, lr=learning_rate)

### TRAINING ###
loss_LAE = train_loop(data_dict, net_LAE, criterion, optimizer, epochs=epochs)

In [None]:
plt.plot(loss_LAE['test_loss'])
plt.plot(range(0, epochs+1, epochs//10), loss_LAE['val_loss'])
plt.legend(loss_LAE.keys())

del loss_LAE

In [None]:
# feature extraction
params_LAE = list(net_LAE.parameters())
W1_LAE = params_LAE[0].clone().detach()
fe_LAE_test_inputs = test_inputs @ W1_LAE.T
fe_LAE_val_inputs = val_inputs @ W1_LAE.T

fe_dict = {'test_inputs': fe_LAE_test_inputs, 'test_targets': test_targets,
          'val_inputs': fe_LAE_val_inputs, 'val_targets': val_targets}

net_LAE_fe = FE_Net(reduction_dim, target_dim)

params = list(net_LAE_fe.parameters())
criterion = nn.MSELoss()
optimizer = optim.Adam(params, lr=fe_learning_rate)

### TRAINING ###
loss_LAE_fe = train_loop(fe_dict, net_LAE_fe, criterion, optimizer, epochs=fe_epochs)

In [None]:
plt.plot(loss_LAE_fe['test_loss'])
plt.plot(range(0, fe_epochs+1, fe_epochs//10), loss_LAE_fe['val_loss'])
plt.legend(loss_LAE_fe.keys())

net_LAE.cpu()
del net_LAE

net_LAE_fe.cpu()
del net_LAE_fe

##### Masked linear autoencoder (basic)

In [None]:
net_MLAE_basic = M_LAE(prob, sample_dim, reduction_dim)

params = list(net_MLAE_basic.body.parameters())
criterion = nn.MSELoss()
optimizer = optim.Adam(params, lr=learning_rate)

### TRAINING ###
loss_MLAE_basic = train_loop(data_dict, net_MLAE_basic, criterion, optimizer, epochs=epochs)

In [None]:
plt.plot(loss_MLAE_basic['test_loss'])
plt.plot(range(0, epochs+1, epochs//10), loss_MLAE_basic['val_loss'])
plt.legend(loss_MLAE_basic.keys())

del loss_MLAE_basic

In [None]:
# feature extraction
params_MLAE_basic = list(net_MLAE_basic.body.parameters())
W1_MLAE_basic = params_MLAE_basic[0].clone().detach()
fe_MLAE_basic_test_inputs = test_inputs @ W1_MLAE_basic.T
fe_MLAE_basic_val_inputs = val_inputs @ W1_MLAE_basic.T

fe_dict = {'test_inputs': fe_MLAE_basic_test_inputs, 'test_targets': test_targets,
          'val_inputs': fe_MLAE_basic_val_inputs, 'val_targets': val_targets}

net_MLAE_basic_fe = FE_Net(reduction_dim, target_dim)

params = list(net_MLAE_basic_fe.parameters())
criterion = nn.MSELoss()
optimizer = optim.Adam(params, lr=fe_learning_rate)

### TRAINING ###
loss_MLAE_basic_fe = train_loop(fe_dict, net_MLAE_basic_fe, criterion, optimizer, epochs=fe_epochs)

In [None]:
plt.plot(loss_MLAE_basic_fe['test_loss'])
plt.plot(range(0, fe_epochs+1, fe_epochs//10), loss_MLAE_basic_fe['val_loss'])
plt.legend(loss_MLAE_basic_fe.keys())

net_MLAE_basic.cpu()
del net_MLAE_basic

net_MLAE_basic_fe.cpu()
del net_MLAE_basic_fe

##### Masked linear autoencoder (probs)

In [None]:
net_MLAE_probs = M_LAE(prob_list, sample_dim, reduction_dim, type='probs')

params = list(net_MLAE_probs.body.parameters())
criterion = nn.MSELoss()
optimizer = optim.Adam(params, lr=learning_rate)

### TRAINING ###
loss_MLAE_probs = train_loop(data_dict, net_MLAE_probs, criterion, optimizer, epochs=epochs)

In [None]:
plt.plot(loss_MLAE_probs['test_loss'])
plt.plot(range(0, epochs+1, epochs//10), loss_MLAE_probs['val_loss'])
plt.legend(loss_MLAE_probs.keys())

del loss_MLAE_probs

In [None]:
# feature extraction
params_MLAE_probs = list(net_MLAE_probs.body.parameters())
W1_MLAE_probs = params_MLAE_probs[0].clone().detach()
fe_MLAE_probs_test_inputs = test_inputs @ W1_MLAE_probs.T
fe_MLAE_probs_val_inputs = val_inputs @ W1_MLAE_probs.T

fe_dict = {'test_inputs': fe_MLAE_probs_test_inputs, 'test_targets': test_targets,
          'val_inputs': fe_MLAE_probs_val_inputs, 'val_targets': val_targets}

net_MLAE_probs_fe = FE_Net(reduction_dim, target_dim)

params = list(net_MLAE_probs_fe.parameters())
criterion = nn.MSELoss()
optimizer = optim.Adam(params, lr=fe_learning_rate)

### TRAINING ###
loss_MLAE_probs_fe = train_loop(fe_dict, net_MLAE_probs_fe, criterion, optimizer, epochs=fe_epochs)

In [None]:
plt.plot(loss_MLAE_probs_fe['test_loss'])
plt.plot(range(0, fe_epochs+1, fe_epochs//10), loss_MLAE_probs_fe['val_loss'])
plt.legend(loss_MLAE_probs_fe.keys())

net_MLAE_probs.cpu()
del net_MLAE_probs

net_MLAE_probs_fe.cpu()
del net_MLAE_probs_fe

##### Masked linear autoencoder (patches)

In [None]:
net_MLAE_patches = M_LAE(prob, sample_dim, reduction_dim, type='patches', patch_size=patch_size)

params = list(net_MLAE_patches.body.parameters())
criterion = nn.MSELoss()
optimizer = optim.Adam(params, lr=learning_rate)

### TRAINING ###
loss_MLAE_patches = train_loop(data_dict, net_MLAE_patches, criterion, optimizer, epochs=epochs)

In [None]:
plt.plot(loss_MLAE_patches['test_loss'])
plt.plot(range(0, epochs+1, epochs//10), loss_MLAE_patches['val_loss'])
plt.legend(loss_MLAE_patches.keys())

del loss_MLAE_patches

In [None]:
# feature extraction
params_MLAE_patches = list(net_MLAE_patches.body.parameters())
W1_MLAE_patches = params_MLAE_patches[0].clone().detach()
fe_MLAE_patches_test_inputs = test_inputs @ W1_MLAE_patches.T
fe_MLAE_patches_val_inputs = val_inputs @ W1_MLAE_patches.T

fe_dict = {'test_inputs': fe_MLAE_patches_test_inputs, 'test_targets': test_targets,
          'val_inputs': fe_MLAE_patches_val_inputs, 'val_targets': val_targets}

net_MLAE_patches_fe = FE_Net(reduction_dim, target_dim)

params = list(net_MLAE_patches_fe.parameters())
criterion = nn.MSELoss()
optimizer = optim.Adam(params, lr=fe_learning_rate)

### TRAINING ###
loss_MLAE_patches_fe = train_loop(fe_dict, net_MLAE_patches_fe, criterion, optimizer, epochs=fe_epochs)

In [None]:
plt.plot(loss_MLAE_patches_fe['test_loss'])
plt.plot(range(0, fe_epochs+1, fe_epochs//10), loss_MLAE_patches_fe['val_loss'])
plt.legend(loss_MLAE_patches_fe.keys())

net_MLAE_patches.cpu()
del net_MLAE_patches

net_MLAE_patches_fe.cpu()
del net_MLAE_patches_fe