In [1]:
import sys
if '..' not in sys.path:
    sys.path.append('..')

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import USPS

import matplotlib.pyplot as plt
import numpy as np

# from autograd import grad, jacobian
# import autograd.numpy as np

from source.models import AE_ReLU, AE_Sigm, AE_ReLU_Small
from source.data import get_train_test_dataloaders
from source.eval import reconstruction_loss

from functools import partial
import copy
import math
import gc
from tqdm import tqdm
# from multiprocessing import Pool
import torch.multiprocessing as mp
from collections import OrderedDict

# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = 'cpu'
# torch.set_default_device('cuda')

  from .autonotebook import tqdm as notebook_tqdm


ImportError: cannot import name 'eval_loss' from 'source.eval' (/beegfs/home/daria.cherniuk/deep-optimization/examples/../source/eval.py)

In [None]:
torch.__version__

In [None]:
torch.cuda.is_available()

In [None]:
def plot_progress(x, z):
    rows = 2
    cols = 16
    x, z = x[:cols], z[:cols]
    fig, ax = plt.subplots(rows, cols, sharex=True, sharey=True, figsize=(cols, rows))
    for i in range(len(x)):
        # ax[i//cols, i%cols].imshow(x[i].cpu().reshape([16, 16]), 'gray')
        # ax[i//cols, i%cols].set_axis_off()
        ax[0, i].imshow(x[i].cpu().reshape([16, 16]), 'gray')
        ax[0, i].set_axis_off()
    for i in range(len(z)):
        # ax[i//cols+2, i%cols].imshow(z[i].detach().cpu().numpy().reshape([16, 16]), 'gray')
        # ax[i//cols+2, i%cols].set_axis_off()
        ax[1, i].imshow(z[i].detach().cpu().numpy().reshape([16, 16]), 'gray')
        ax[1, i].set_axis_off()
    plt.show()

In [None]:
train_loader, test_loader = get_train_test_dataloaders('..', 'USPS', 
                                                       batch_size=7291, 
                                                       drop_last=False,
                                                       num_workers=4)
len(train_loader), len(test_loader)

In [None]:
device = 'cuda'

In [None]:
from scipy.optimize import line_search
from functools import partial

def w_jacob(w_h, layer_idx, n_samples, x, CACHE_RELUS):
    J = torch.zeros((n_samples, len(w_h)), device=device)

    if layer_idx == 0:
        for n in range(n_samples):
            # if CACHE_RELUS[layer_idx][n][h] > 0:
            if x[n] @ w_h > 0:
                J[n, :] = - x[n]
    else:
        for n in range(n_samples):
            if CACHE_RELUS[layer_idx-1][n] @ w_h > 0:
                J[n, :] = - CACHE_RELUS[layer_idx-1][n]
                
    return J

def w_residuals(w_h, layer_idx, h, x, CACHE_RELUS):
    
    if layer_idx == 0:
        residuals = CACHE_RELUS[layer_idx][:,h] - torch.relu(torch.einsum('bq,q->b', x, w_h))
    elif layer_idx == len(CACHE_RELUS):
        residuals = x[:,h] - torch.relu(torch.einsum('bq,q->b', CACHE_RELUS[layer_idx-1], w_h))
    else:
        residuals = CACHE_RELUS[layer_idx][:,h] - torch.relu(torch.einsum('bq,q->b', CACHE_RELUS[layer_idx-1], w_h))
        
    return residuals

def w_f(w_h, layer_idx, h, x, CACHE_RELUS):
    
    squares = torch.square(w_residuals(w_h, layer_idx, h, x, CACHE_RELUS))
    return squares.sum() / 2

def w_step(w_h, layer_idx, h, n_samples, x, CACHE_RELUS):
    
    residuals = w_residuals(w_h, layer_idx, h, x, CACHE_RELUS)

    # Building Jacobian (for each layer and output dim)
    J = w_jacob(w_h, layer_idx, n_samples, x, CACHE_RELUS)
    
    # why not inversion here?
    p = torch.linalg.lstsq(J.T@J, - J.T @ residuals, rcond=None)[0]
    # p.T @ grad (.T doesn't matter, numpy performs inner product)
    descent_inner_prod = (J.T @ residuals) @ p
    # try:
    #     assert descent_inner_prod <= 0
    # except:
    #     # pass
    #     print(f'layer {layer_idx+1} dim {h} descent(should be <= 0)', descent_inner_prod)

    # Line Search with Backtracking
    # w_grad = grad(partial(w_f, layer_idx=layer_idx, h=h, 
    #                       x=x, CACHE_RELUS=CACHE_RELUS))
    # def w_grad(w):
    #     return J.T @ w_residuals(w, layer_idx, h, x, CACHE_RELUS)
    # alpha = line_search(
    #     partial(w_f,
    #             x=x,
    #             layer_idx=layer_idx,
    #             h=h,
    #             CACHE_RELUS=CACHE_RELUS),
    #     w_grad,
    #     w_h,
    #     p
    # )[0]
    
    
    # how to tweak it? 
    alpha = 10.0
    f = w_f(w_h, layer_idx, h, x, CACHE_RELUS)
    # First Update
    f_new = w_f(w_h + alpha*p, layer_idx, h, x, CACHE_RELUS)
    rhs = alpha * line_search_c * descent_inner_prod

    max_iter, counter = 40, 0
    # Armijo Condition
    while f_new - f > rhs and counter < max_iter:
        alpha *= line_search_tau
        # Update
        f_new = w_f(w_h + alpha*p, layer_idx, h, x, CACHE_RELUS)
        rhs = alpha * line_search_c * descent_inner_prod

        # print("Line search armijo: obj func old: %f new: %f diff: %.16f rhs: %.16f, alpha: %f" % (f, f_new, f_new - f, rhs, alpha))
        counter += 1
        
    # if counter == max_iter:
    #     print(f'Layer {layer_idx}: reached maximum number of iterations for line search')
    
    if alpha is None: alpha = 0.
    # print('alpha:', alpha)
    
    assert not torch.linalg.norm(p).isnan()
    assert not torch.linalg.norm(w_h + alpha * p).isnan()
        
    return w_h + alpha * p

In [None]:
def z_residuals(z_n, h_k, x, WEIGHTS):
    
    residuals = torch.zeros(h_k[-1], device=device)
    start = 0
    for layer_idx in range(len(h_k)):
        if layer_idx == 0:
            residuals[start:h_k[layer_idx]] = z_n[start:h_k[layer_idx]] - torch.relu(torch.einsum('q,hq->h', x, WEIGHTS[layer_idx]))
        elif layer_idx == 1:
            residuals[start:h_k[layer_idx]] = z_n[start:h_k[layer_idx]] \
                                              - torch.relu(torch.einsum('q,hq->h', z_n[:h_k[layer_idx-1]], 
                                                                             WEIGHTS[layer_idx]), 0)
        elif layer_idx == len(h_k):
            residuals[start:h_k[layer_idx]] = x - torch.relu(torch.einsum('q,hq->h', z_n[h_k[layer_idx-2]:h_k[layer_idx-1]], WEIGHTS[layer_idx]))
        else:
            residuals[start:h_k[layer_idx]] = z_n[start:h_k[layer_idx]] \
                                              - torch.relu(torch.einsum('q,hq->h', z_n[h_k[layer_idx-2]:h_k[layer_idx-1]], 
                                                                             WEIGHTS[layer_idx]))
        start = h_k[layer_idx]
    return residuals

def z_jacob(z_n, mu, h_k, WEIGHTS):
    J = torch.zeros((h_k[-1], h_k[-1]), device=device)
    for h in range(h_k[-1]):
        J[h, h] = math.sqrt(mu)
        
        
        for layer_idx in range(len(h_k)):
            # find layer that this coordinate belongs to
            if h < h_k[layer_idx]: break
        # if this is a coordinate from the first layer
        # then the residual only depends on z_1
        if layer_idx == 0: continue

        # coordinate relatively to the current layer
        n_feature = h-h_k[layer_idx-1]
        assert n_feature >= 0 and n_feature < WEIGHTS[layer_idx].shape[0]

        if layer_idx == 1:
            start, finish = 0, h_k[layer_idx-1]
        elif layer_idx > 1:
            start, finish = h_k[layer_idx-2], h_k[layer_idx-1]
            
        # of if WEIGHTS[layer_idx][n_feature] @ z_k > 0?
        if z_n[h] > 0:
            J[h, start:finish] = - WEIGHTS[layer_idx][n_feature]
            # Quadratic Penalty
            if layer_idx != len(h_k)-1:
                J[h, start:finish] *= mu
                
    return J

def z_f(z_n, mu, h_k, x, WEIGHTS):
    squares = torch.square(z_residuals(z_n, h_k, x, WEIGHTS))
    return squares[:h_k[-2]+1].sum() * mu / 2 + squares[h_k[-2]+1:].sum() / 2

def z_step(z_n, n_sample, mu, h_k, x, WEIGHTS):
    
    # Calculating Residuals
    residuals = z_residuals(z_n, h_k, x, WEIGHTS)
    assert len(residuals) == h_k[-1]
    
    # is it needed? 
    mu_mult = torch.ones_like(residuals)
    mu_mult[:h_k[-2]] = math.sqrt(mu)
    residuals = mu_mult * residuals

    # Building Jacobian for each sample in dataset
    J = z_jacob(z_n, mu, h_k, WEIGHTS)

    # Descent Direction
    p = torch.linalg.lstsq(J.T@J, - J.T @ residuals, rcond=None)[0]
    # p.T @ grad (.T doesn't matter, numpy performs inner product)
    descent_inner_prod = (J.T @ residuals) @ p
    try:
        assert descent_inner_prod <= 0
    except:
        print(f'sample {n_sample} descent direction should be <= 0', descent_inner_prod)

    # Line Search with Backtracking
    # p_ = torch.zeros(h_k[-1])
    # p_[:h_k[-2]] = p
    alpha = 1.0
    f = z_f(z_n, mu, h_k, x, WEIGHTS)
    f_new = z_f(z_n + alpha * p, mu, h_k, x, WEIGHTS)
    rhs = alpha * line_search_c * descent_inner_prod

    max_iter, counter = 40, 0
    # Armijo Condition
    while f_new - f > rhs and counter < max_iter:
        # step update
        alpha *= line_search_tau

        f_new = z_f(z_n + alpha * p, mu, h_k, x, WEIGHTS)
        rhs = alpha * line_search_c * descent_inner_prod

        # print("Line search armijo: obj func old: %f new: %f diff: %.16f rhs: %.16f, alpha: %f" % (f, f_new, f_new - f, rhs, alpha))
        counter += 1
        
    # if counter == max_iter:
    #     print(f'Layer {layer_idx}: reached maximum number of iterations for line search')
     
    return J, z_n + alpha * p

In [None]:
# %%time

batch_size = 100

train_loader, test_loader = get_train_test_dataloaders('..', 'USPS', 
                                                       batch_size=batch_size, 
                                                       drop_last=False,
                                                       num_workers=4)
for x, _ in train_loader:
    x = x.to(device)
    break

line_search_c = pow(10,-4)
# backtracking multiplier
line_search_tau = 0.5

# Quadratic Penalty multiplier 
mu = 1.0

loss_function = torch.nn.MSELoss()
loss_hist = []

# creating model
model = AE_ReLU(bias=False).to(device)
loss = eval_loss(model, test_loader, loss_function, device=device)
loss_hist.append(loss)
print(f'Random-initialized model loss: {loss}')

WEIGHTS = [
    model.encoder[0].weight.detach(),
    model.encoder[2].weight.detach(),
    model.encoder[4].weight.detach(),
    model.decoder[0].weight.detach(),
    model.decoder[2].weight.detach(),
    model.decoder[4].weight.detach(),
]

# all output shapes
h_k = [0]
for W in WEIGHTS:
    h_k.append(W.shape[0] + h_k[-1])
h_k = h_k[1:-1]

# # registering hooks to cache activations
# CACHE_RELUS = {}

# def cache_relu_hook(idx, module, input, output):
#     CACHE_RELUS[idx] = output.detach().cpu()

# i = 0 
# for m in model.modules():
#     if isinstance(m, nn.ReLU):
#         handle = m.register_forward_hook(partial(cache_relu_hook, i))
#         i += 1
        
# # caching activations
# with torch.no_grad():
#     for x, _ in train_loader:
#         x = x.to(device)
#         _ = model(x)
#         x = x.detach().cpu()
        
# # Removing Hooks
# for m in model.modules():
#     if isinstance(m, nn.ReLU):
#         m._forward_hooks = OrderedDict()

n_samples = train_loader.batch_size
# init z_n randomly
CACHE_RELUS = {}
for i, w in enumerate(WEIGHTS[:-1]):
    # CACHE_RELUS[i] = np.random.uniform(-0.5, 0.5, size=(n_samples, w.shape[0]))
    CACHE_RELUS[i] = torch.rand((n_samples, w.shape[0]), device=device) - 0.5 
        
def _get_n_sample_relus(n_sample, cache):
    n_sample_relus = torch.zeros(h_k[-1])
    start = 0
    for layer_idx in range(len(h_k)):
        n_sample_relus[start:h_k[layer_idx]] = cache[layer_idx][n_sample]
        start = h_k[layer_idx]
    return n_sample_relus

        
with torch.no_grad():
    for epoch in range(10):
            
        # W-Step
        WEIGHTS_COPY = copy.deepcopy(WEIGHTS)

        for layer_idx in tqdm(range(len(WEIGHTS))):

            for h in range(WEIGHTS[layer_idx].shape[0]):
                for it in range(3):
                    w_h = WEIGHTS[layer_idx][h]
                    new_w_h = w_step(w_h, layer_idx, h, n_samples, x, CACHE_RELUS)
                    # rel_diff = torch.linalg.norm(new_w_h-w_h) / torch.linalg.norm(w_h)
                    # print(f'layer_idx {layer_idx}, h: {h}, it: {it}, REL_DIFF: {rel_diff}')
                    WEIGHTS[layer_idx][h] = new_w_h
                    
                    assert not torch.linalg.norm(new_w_h).isnan()
                    assert not torch.linalg.norm(w_h).isnan()
                    assert not torch.linalg.norm(new_w_h - w_h).isnan()
                    
        # Checking Norm Difference
        for layer_idx in range(len(WEIGHTS)):
            print(f'Layer {layer_idx}, '
                  f'Weights Diff Norm (%): {torch.linalg.norm(WEIGHTS[layer_idx]-WEIGHTS_COPY[layer_idx]) / torch.linalg.norm(WEIGHTS_COPY[layer_idx])}')
        
        # Loading weights back to model to compute eval loss
        model.encoder[0].weight.data = WEIGHTS[0]
        model.encoder[2].weight.data = WEIGHTS[1]
        model.encoder[4].weight.data = WEIGHTS[2]
        model.decoder[0].weight.data = WEIGHTS[3]
        model.decoder[2].weight.data = WEIGHTS[4]
        model.decoder[4].weight.data = WEIGHTS[5]
        model = model.to(device)
        
        # Loss on eval dataset
        loss = eval_loss(model, test_loader, loss_function, device=device)
        loss_hist.append(loss)
        print(f'Epoch {epoch+1}, Eval loss: {loss}')
        
        # # recalculating activations after weights update
        # for x, _ in train_loader:
        #     x = x.to(device)
        #     _ = model(x)
        #     x = x.cpu()
    
        # Z-Step
        CACHE_RELUS_COPY = copy.deepcopy(CACHE_RELUS)
        
        # process for each sample
        for n_sample in tqdm(range(n_samples)):
            for it in range(1):
                z_n = _get_n_sample_relus(n_sample, CACHE_RELUS)
                J, new_z_n = z_step(z_n, n_sample, mu, h_k, x[n_sample], WEIGHTS)
                start = 0
                for layer_idx in range(len(h_k)):
                    CACHE_RELUS[layer_idx][n_sample] = new_z_n[start:h_k[layer_idx]]
                    start = h_k[layer_idx]
                    
                # rel_diff = torch.linalg.norm(new_z_n - z_n) / torch.linalg.norm(z_n)

        # Computing Relus Norm Difference
        for layer_idx in range(len(CACHE_RELUS)):
            print(f'Layer {layer_idx} Activations Diff Norm (%): '
                  f'{torch.linalg.norm(CACHE_RELUS[layer_idx]-CACHE_RELUS_COPY[layer_idx]) / torch.linalg.norm(CACHE_RELUS_COPY[layer_idx])}')
            
        # Updating Quandratic Penalty Multiplier
        if it > 2 and abs(loss_hist[-2] - loss_hist[-1]) < 1e-3:
            mu *= 10
            print(f'New MU: {mu}')
        
        if epoch > 0 and epoch % 2 == 0: 
            with torch.no_grad():
                y = model(x[:16].cuda())
            plot_progress(x[:16], y)
           
        # print()

In [None]:
plt.plot(loss)

In [None]:
plt.spy(J)

In [None]:
x.device, w_h.device, CA