In [119]:
import os
import torch
import torch.utils
import torch.utils.data
from sklearn.datasets import make_sparse_coded_signal
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import OrthogonalMatchingPursuit
from contextlib import contextmanager
from timeit import default_timer
# from test_omp import omp_naive
# from test import *  # FIXME: better name


In [120]:
y_0, X_0, w_0 = make_sparse_coded_signal(
    n_samples=1,
    n_components=256,
    n_features=96,
    n_nonzero_coefs=15,
    random_state=4)




In [131]:
requires_grad = True
y = torch.tensor(y_0, dtype=torch.float32, requires_grad=requires_grad).reshape(1, -1)
X = torch.tensor(X_0, dtype=torch.float32, requires_grad=requires_grad).reshape(1, X_0.shape[0], -1)
w = torch.tensor(w_0, dtype=torch.float32, requires_grad=requires_grad).reshape(1, -1, 1)

# y_hat = X @ w
y.shape, X.shape, w.shape
# ((y_hat-y)**2).sum()

(torch.Size([1, 96]), torch.Size([1, 96, 256]), torch.Size([1, 256, 1]))

In [132]:
'''Orthogonal Matching pursuit algorithm
'''
dict = X[0, ...] # consider cloning the tensor

chunk_length, n_atoms = dict.shape
batch_sz, chunk_length = y.shape
n_nonzero_coefs = 15
tau = 0.01
hard = True

# DTD = dict.T @ dict

residuals = y.clone() # (batch_sz, chunk_length)
max_score_indices = y.new_zeros((batch_sz, n_nonzero_coefs, n_atoms), dtype=torch.long) 
detached_indices = y.new_zeros((batch_sz, n_nonzero_coefs), dtype=torch.long) # (batch_sz, n_nonzero_coefs)

In [133]:
i=0
# Compute the score of each atoms
projections = dict.T @ residuals[:, :, None] # (batch_sz, n_atoms, 1)

detached_indices[:, i] = projections.abs().squeeze(-1).argmax(-1).detach()
soft_score_indices = (projections/tau).squeeze(-1).softmax(-1) # (batch_sz, n_atoms)
if hard:
    # copied and modified from https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#gumbel_softmax
    # Straight through.
    index = soft_score_indices.max(-1, keepdim=True)[1]
    hard_score_indices = torch.zeros_like(soft_score_indices).scatter_(-1, index, 1.0)
    ret = hard_score_indices - soft_score_indices.detach() + soft_score_indices   
    max_score_indices[:, i] = ret
else:
    max_score_indices[:, i] = soft_score_indices

In [134]:
max_score_indices.shape, X.shape, residuals.shape 

(torch.Size([1, 15, 256]), torch.Size([1, 96, 256]), torch.Size([1, 96]))

In [135]:
detached_indices.shape

torch.Size([1, 15])

In [136]:
selected_D = X[:, None, ...] * max_score_indices[:, :i+1, None, :] # (batch_sz, i+1, chunk_length, n_atoms)
selected_D.shape

torch.Size([1, 1, 96, 256])

In [137]:
selected_D = selected_D.permute(0, 2, 1, 3) # (batch_sz, chunk_length, i+1, n_atoms)
ind_0 = torch.arange(batch_sz)[:, None, None]
ind_1 = torch.arange(chunk_length)[None, :, None]
ind_2 = torch.arange(i+1)[None, None, :]
selected_D = selected_D[ind_0, ind_1, ind_2, detached_indices[:, None, :i+1]] # (batch_sz, chunk_length, i+1)
selected_D.shape

torch.Size([1, 96, 1])

In [139]:
# calculate selected_DTy
selected_DTy = selected_D.permute(0, 2, 1) @ residuals[:, :, None] # (batch_sz, i+1, 1)

# calculate selected_DTD
selected_DTD = selected_D.permute(0, 2, 1) @ selected_D # (batch_sz, i+1, i+1)

# find W = (selected_DTD)^-1 @ selected_DTy
selected_DTD.diagonal(dim1=-2, dim2=-1).add_(1e-5) # TODO: add multipath OMP
nonzero_W = torch.linalg.solve(selected_DTD, selected_DTy) # (batch_sz, i+1, 1)

# finally get residuals r=y-Wx
residuals[:, :, None] = y[:, :, None] - selected_D @ nonzero_W # (batch_sz, chunk_length, 1)

In [142]:
fake_loss = nonzero_W.sum()
fake_loss.backward()

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

In [2]:

tolerance = True
# Control stop interation with norm thresh or sparsity
for i in range(n_nonzero_coefs): 
    # Compute the score of each atoms
    projections = dict.T @ residuals[:, :, None] # (batch_sz, n_atoms, 1)
    
    soft_score_indices = (projections/tau).squeeze(-1).softmax(-1) # (batch_sz, n_atoms)
    if hard :
        # copied and modified from https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#gumbel_softmax
        # Straight through.
        index = soft_score_indices.max(-1, keepdim=True)[1]
        hard_score_indices = torch.zeros_like(soft_score_indices).scatter_(-1, index, 1.0)
        ret = hard_score_indices - soft_score_indices.detach() + soft_score_indices   
        max_score_indices[:, i] = ret
    else:
        max_score_indices[:, i] = soft_score_indices

    
    # update selected_D
    selected_D = X * max_score_indices[:, :, None] # (batch_sz, chunk_length, n_atoms)
    ind_0 = torch.arange(batch_sz, dtype=max_score_indices.dtype, device=max_score_indices.device)[:, None, None],
    ind_1 = torch.arange(chunk_length, dtype=max_score_indices.dtype, device=max_score_indices.device)[None, :, None],
    selected_D = X[ind_0, ind_1, max_score_indices[:, None, :i+1]] # (batch_sz, chunk_length, i+1)
    
    # calculate selected_DTy
    selected_DTy = selected_D.permute(0, 2, 1) @ residuals[:, :, None] # (batch_sz, i+1, 1)
    
    # calculate selected_DTD
    selected_DTD = selected_D.permute(0, 2, 1) @ selected_D # (batch_sz, i+1, i+1)
    
    # find W = (selected_DTD)^-1 @ selected_DTy
    # selected_DTD.diagonal(dim1=-2, dim2=-1).add_(self.reg_coeff) # TODO: add multipath OMP
    nonzero_W = torch.linalg.solve(selected_DTD, selected_DTy) # (batch_sz, i+1, 1)

    # finally get residuals r=y-Wx
    residuals[:, :, None] = y[:, :, None] - selected_D @ nonzero_W # (batch_sz, chunk_length, 1)

W = torch.zeros(batch_sz, n_atoms, dtype=selected_D.dtype, device=selected_D.device)
W[torch.arange(batch_sz, dtype=max_score_indices.dtype, device=max_score_indices.device)[:, None], max_score_indices] = nonzero_W.squeeze()    


In [3]:
X = torch.randn(96, 256, device='cuda', requires_grad=True)
y = torch.randn(32, 96, device='cuda')

In [4]:
X.requires_grad = True
y.requires_grad = False
sets = run_omp(X, y, n_nonzero_coefs=15)

In [5]:
y, X, w = make_sparse_coded_signal(
    n_samples=32,
    n_components=256,
    n_features=96,
    n_nonzero_coefs=15,
    random_state=4)

y = torch.from_numpy(y).cuda()
X = torch.from_numpy(X).cuda()
y.shape, X.shape, w.shape
w[:, 0][w[:, 0]!=0]



array([ 0.64592679,  1.41983684, -0.14846366, -1.13573509,  0.83928862,
        1.06942006, -0.03079246, -0.00980617,  0.28853245,  0.19478957,
       -0.30494662,  1.12142266,  1.37522674,  0.21715859,  0.2122062 ])

In [6]:
sets = run_omp(X, y.T, n_nonzero_coefs=15)
sets[0,:][sets[0,:]!=0]

tensor([ 0.6459,  1.4198, -0.1485, -1.1357,  0.8393,  1.0694, -0.0308, -0.0098,
         0.2885,  0.1948, -0.3049,  1.1214,  1.3752,  0.2172,  0.2122],
       device='cuda:0', dtype=torch.float64)

In [7]:
W, sets2, solutions2 = omp(X, y.T, n_nonzero_coefs=15)
W[0,:][W[0,:]!=0]

tensor([ 0.6459,  1.4198, -0.1485, -1.1357,  0.8393,  1.0694, -0.0308, -0.0098,
         0.2885,  0.1948, -0.3049,  1.1214,  1.3752,  0.2172,  0.2122],
       device='cuda:0', dtype=torch.float64)