In [30]:
"""Examine the effect of detach."""
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from tqdm import tqdm
import statistics
import numpy as np
import pandas as pd 
import scipy.stats as stats
from copy import deepcopy

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# Linear auto-encoder model
class LAE(nn.Module):
    def __init__(self, n, p):
        super(LAE, self).__init__()
        self.n = n
        self.p = p
        self.w1 = nn.Linear(n, p, bias=False)
        self.w2 = nn.Linear(p, n, bias=False)

    def forward(self, y):
        y = self.w1(y)
        y = self.w2(y)
        return y

class FE_Net(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(FE_Net, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.theta = nn.Linear(in_dim, out_dim, bias=False)

    def forward(self, W):
        return self.theta(W)

In [4]:
def train_loop(data_dict, model, criterion, optimizer, type, epochs=10, sample_average=10, record=True):
    train_loss = []
    val_loss = []
    if type=='encoder':
        train_inputs = data_dict['train_inputs']
        train_targets = train_inputs
    elif type=='fe':
        train_inputs = data_dict['train_inputs']
        train_targets = data_dict['train_targets']
        sample_average = 1

    for epoch in range(epochs+1):
        loss_total = 0
        optimizer.zero_grad()
        for i in range(sample_average):
            train_outputs = model(train_inputs)
            loss = criterion(train_outputs, train_targets)
            loss_total += loss
        loss_total /= sample_average
        loss_total.backward()
        optimizer.step()
        if record:
          train_loss.append(loss_total.item())
          if epochs>=5 and epoch%(epochs//5)==0:
              v_loss = test_loop(data_dict, model, criterion, type)
              val_loss.append(v_loss)
              print('epoch: ', epoch, ', train loss: ', loss.item(), ', val loss', v_loss)
    if record:
        return {'train_loss': train_loss, 'val_loss': val_loss}
    else:
        return loss_total.item()

def test_loop(data_dict, model, criterion, type):
    if type=='encoder':
        val_inputs = data_dict['val_inputs']
        val_targets = val_inputs
    elif type=='fe':
        val_inputs = data_dict['val_inputs']
        val_targets = data_dict['val_targets']

    with torch.no_grad():
        val_outputs = model(val_inputs)
        loss = criterion(val_outputs, val_targets)
    return loss.item()

In [43]:
# feature extraction
def feature_extraction(data_dict, model_parameters, criterion, type, epochs=200, device='cuda'):
    train_inputs, train_targets, val_inputs, val_targets = data_dict.values()

    params = list(model_parameters)
    W1 = params[0]
    W10 = deepcopy(W1)
    W10 = W10.clone().detach()
    print(W10)
    train_inputs_fe0 = train_inputs @ W10.T
    val_inputs_fe0 = val_inputs @ W10.T

    W11 = deepcopy(W1)
    W11 = W11.clone()
    print(W11)
    train_inputs_fe1 = (train_inputs @ W11.T).detach()
    val_inputs_fe1 = (val_inputs @ W11.T).detach()

    reduction_dim = train_inputs_fe1.shape[1]
    target_dim = train_targets.shape[1]

    if type=='ls':
        param_fe0 = (torch.inverse(train_inputs_fe0.T@train_inputs_fe0) @ train_inputs_fe0.T @ train_targets).T
        loss0 = criterion(val_targets, val_inputs_fe0 @ param_fe0.T)
        print('test0', loss0.item())

        param_fe1 = (torch.inverse(train_inputs_fe1.T@train_inputs_fe1) @ train_inputs_fe1.T @ train_targets).T
        loss1 = criterion(val_targets, val_inputs_fe1 @ param_fe1.T)
        print('test1', loss1.item())
        #return loss.item()
    elif type=='gd':
        net_fe = FE_Net(reduction_dim, target_dim).to(device)
        net_fe0 = deepcopy(net_fe)
        net_fe1 = deepcopy(net_fe)

        data_dict_fe0 = {'train_inputs': train_inputs_fe0, 'train_targets': train_targets,
                        'val_inputs': val_inputs_fe0, 'val_targets': val_targets}
        param_fe0 = list(net_fe0.parameters())
        optimizer0 = optim.Adam(param_fe0, lr=0.0001)
        ### TRAINING ###
        loss_fe0 = train_loop(data_dict_fe0, net_fe0, criterion, optimizer0, epochs=epochs, record=False, type='fe')
        print('test0', loss_fe0)
        data_dict_fe1 = {'train_inputs': train_inputs_fe1, 'train_targets': train_targets,
                        'val_inputs': val_inputs_fe1, 'val_targets': val_targets}
        param_fe1 = list(net_fe1.parameters())
        optimizer1 = optim.Adam(param_fe1, lr=0.0001)
        ### TRAINING ###
        loss_fe1 = train_loop(data_dict_fe1, net_fe1, criterion, optimizer1, epochs=epochs, record=False, type='fe')
        print('test1', loss_fe1)
        #return loss_fe

In [44]:
train_num = 50
val_num = 20
H = 8
W = 8
sample_dim = torch.tensor([H, W])
feature_num = H * W
reduction_dim = feature_num
target_dim = feature_num

prob = 0.75
prob_list = torch.rand(feature_num)*0.2 + 0.65
patch_size = torch.div(sample_dim, 2, rounding_mode='floor')

train_inputs = torch.rand(train_num, feature_num) * 2
train_inputs = train_inputs.to(device)
train_targets = torch.rand(train_num, target_dim)
train_targets = train_targets.to(device)
val_inputs = torch.rand(val_num, feature_num) * 2
val_inputs = val_inputs.to(device)
val_targets = torch.rand(val_num, target_dim)
val_targets = val_targets.to(device)
data_dict = {'train_inputs': train_inputs, 'train_targets': train_targets,
             'val_inputs': val_inputs, 'val_targets': val_targets}

In [45]:
net_LAE = LAE(feature_num, reduction_dim).to(device)

params = list(net_LAE.parameters())
criterion = nn.MSELoss()
optimizer = optim.Adam(params, lr=0.001)

### TRAINING ###
train_loop(data_dict, net_LAE, criterion, optimizer, epochs=500, record=False, type='encoder')

# feature extraction
loss_ls = feature_extraction(data_dict, net_LAE.parameters(), criterion, type='ls', epochs=500)
loss_gd = feature_extraction(data_dict, net_LAE.parameters(), criterion, type='gd', epochs=500)

tensor([[-0.1336,  0.0542,  0.0287,  ..., -0.1842,  0.0315,  0.1009],
        [-0.0978, -0.0365, -0.1108,  ..., -0.0541, -0.1404,  0.0120],
        [ 0.1117,  0.0700,  0.0037,  ...,  0.0655,  0.0070, -0.1101],
        ...,
        [ 0.1055,  0.0057, -0.0793,  ..., -0.0329,  0.1920,  0.0874],
        [-0.1235, -0.0956,  0.1405,  ..., -0.0077,  0.0426, -0.1306],
        [ 0.1496, -0.1518, -0.1922,  ..., -0.0546, -0.0099, -0.0903]],
       device='cuda:0')
tensor([[-0.1336,  0.0542,  0.0287,  ..., -0.1842,  0.0315,  0.1009],
        [-0.0978, -0.0365, -0.1108,  ..., -0.0541, -0.1404,  0.0120],
        [ 0.1117,  0.0700,  0.0037,  ...,  0.0655,  0.0070, -0.1101],
        ...,
        [ 0.1055,  0.0057, -0.0793,  ..., -0.0329,  0.1920,  0.0874],
        [-0.1235, -0.0956,  0.1405,  ..., -0.0077,  0.0426, -0.1306],
        [ 0.1496, -0.1518, -0.1922,  ..., -0.0546, -0.0099, -0.0903]],
       device='cuda:0', grad_fn=<CloneBackward0>)
test0 606.8356323242188
test1 606.8356323242188
tensor([[-