In [1]:
import os
import torch
import torch.utils
import torch.utils.data
from sklearn.datasets import make_sparse_coded_signal
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import OrthogonalMatchingPursuit
from contextlib import contextmanager
from timeit import default_timer
# from test_omp import omp_naive
# from test import *  # FIXME: better name


In [2]:
y_0, X_0, w_0 = make_sparse_coded_signal(
    n_samples=1,
    n_components=256,
    n_features=96,
    n_nonzero_coefs=15,
    random_state=4)

requires_grad = True
y = torch.tensor(y_0, dtype=torch.float32, requires_grad=requires_grad).reshape(y_0.shape[0], -1).T.detach()
X = torch.tensor(X_0, dtype=torch.float32, requires_grad=requires_grad).reshape(1, X_0.shape[0], -1).detach()
w = torch.tensor(w_0, dtype=torch.float32, requires_grad=requires_grad).T.detach()

# y_hat = X @ w
y.shape, X.shape, w.shape, y.is_leaf, X.is_leaf, w.is_leaf
# ((y_hat-y)**2).sum()

  w = torch.tensor(w_0, dtype=torch.float32, requires_grad=requires_grad).T.detach()


(torch.Size([1, 96]),
 torch.Size([1, 96, 256]),
 torch.Size([256]),
 True,
 True,
 True)

In [36]:
'''Orthogonal Matching pursuit algorithm
'''
dict = X[0, ...] # consider cloning the tensor

chunk_length, n_atoms = dict.shape
batch_sz, chunk_length = y.shape
n_nonzero_coefs = 15
tau = 0.1
hard = True

# DTD = dict.T @ dict

residuals = y.clone().detach() # (batch_sz, chunk_length)
residuals.requires_grad = True
# max_score_indices = y.new_zeros((batch_sz, n_nonzero_coefs, n_atoms), dtype=X.dtype, device=X.device)
max_score_indices = []
detached_indices = np.zeros((batch_sz, n_nonzero_coefs), dtype=np.int64) # (batch_sz, n_nonzero_coefs)


tolerance = True
# Control stop interation with norm thresh or sparsity
for i in range(n_nonzero_coefs): 
    # Compute the score of each atoms
    projections = dict.T @ residuals[:, :, None] # (batch_sz, n_atoms, 1)

    detached_indices[:, i] = projections.abs().squeeze(-1).argmax(-1).detach().cpu().numpy()
    soft_score_indices = (projections/tau).abs().squeeze(-1).softmax(-1) # (batch_sz, n_atoms)
    if hard:
        # copied and modified from https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#gumbel_softmax
        # Straight through.
        index = soft_score_indices.max(-1, keepdim=True)[1]
        hard_score_indices = torch.zeros_like(soft_score_indices).scatter_(-1, index, 1.0)
        ret = hard_score_indices - soft_score_indices.detach() + soft_score_indices   
        # max_score_indices[:, i] = ret
        max_score_indices.append(ret[:, None, :])
    else:
        max_score_indices.append(soft_score_indices[:, None, :])
    
    # update selected_D torch.cat(max_score_indices, dim=1)
    _selected_D = X[:, None, ...] * torch.cat(max_score_indices, dim=1)[:, :i+1, None, :] # max_score_indices[:, :i+1, None, :] # (batch_sz, i+1, chunk_length, n_atoms)
    _selected_D = _selected_D.permute(0, 2, 1, 3) # (batch_sz, chunk_length, i+1, n_atoms)
    ind_0 = torch.arange(batch_sz)[:, None, None]
    ind_1 = torch.arange(chunk_length)[None, :, None]
    ind_2 = torch.arange(i+1)[None, None, :]
    selected_D = _selected_D[ind_0, ind_1, ind_2, detached_indices[:, None, :i+1]] # (batch_sz, chunk_length, i+1)
    
    s_D = dict[torch.arange(chunk_length)[:, None], detached_indices[:, :i+1]] # (chunk_length, i+1)

    # calculate selected_DTy
    selected_DTy = selected_D.permute(0, 2, 1) @ y[:, :, None] # (batch_sz, i+1, 1)

    # calculate selected_DTD
    selected_DTD = selected_D.permute(0, 2, 1) @ selected_D # (batch_sz, i+1, i+1)

    # find W = (selected_DTD)^-1 @ selected_DTy
    # selected_DTD.diagonal(dim1=-2, dim2=-1).add_(1.) # TODO: add multipath OMP
    nonzero_W = torch.linalg.solve(selected_DTD, selected_DTy) # (batch_sz, i+1, 1)

    # finally get residuals r=y-Wx
    residuals = y - (selected_D @ nonzero_W).squeeze() # (batch_sz, chunk_length, 1)

W = torch.zeros(batch_sz, n_atoms, dtype=selected_D.dtype, device=selected_D.device)
W[torch.arange(batch_sz)[:, None], detached_indices] = nonzero_W.squeeze()    


In [28]:
# fake_loss = W.sum() + 100
fake_loss = (nonzero_W**5).sum() + 100
fake_loss.backward()

In [30]:
z.grad

tensor(0.0102)

In [29]:
residuals.is_leaf, X.is_leaf, y.is_leaf, W.is_leaf, dict.is_leaf, z.is_leaf

(False, True, True, False, True, True)

In [47]:
x = torch.tensor(2.0, requires_grad = True)
print("x:", x)
y = x**2
y.backward()
x.grad
y.is_leaf, x.is_leaf, x.grad.is_leaf

x: tensor(2., requires_grad=True)


(False, True, True)

In [6]:
w[w!=0]

tensor([ 0.6459,  1.4198, -0.1485, -1.1357,  0.8393,  1.0694, -0.0308, -0.0098,
         0.2885,  0.1948, -0.3049,  1.1214,  1.3752,  0.2172,  0.2122],
       grad_fn=<IndexBackward0>)

In [7]:
torch.cat(max_score_indices, dim=1)[0, :][torch.cat(max_score_indices, dim=1)[0, :]!=0]

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       grad_fn=<IndexBackward0>)

In [13]:
W[W!=0]

tensor([ 0.6459,  1.4198, -0.1485, -1.1357,  0.8393,  1.0694, -0.0308, -0.0098,
         0.2885,  0.1948, -0.3049,  1.1214,  1.3752,  0.2172,  0.2122],
       grad_fn=<IndexBackward0>)