In [1]:
import os
import torch
import torch.utils
import torch.utils.data
from sklearn.datasets import make_sparse_coded_signal
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import OrthogonalMatchingPursuit
from contextlib import contextmanager
from timeit import default_timer
# from test_omp import omp_naive
# from test import *  # FIXME: better name


In [2]:
y_0, X_0, w_0 = make_sparse_coded_signal(
    n_samples=1,
    n_components=256,
    n_features=96,
    n_nonzero_coefs=15,
    random_state=4)

requires_grad = True
y = torch.tensor(y_0, dtype=torch.float32, requires_grad=requires_grad).reshape(y_0.shape[0], -1).T.detach()
X = torch.tensor(X_0, dtype=torch.float32, requires_grad=requires_grad).reshape(1, X_0.shape[0], -1).detach()
w = torch.tensor(w_0, dtype=torch.float32, requires_grad=requires_grad).T.detach()

# y_hat = X @ w
y.shape, X.shape, w.shape, y.is_leaf, X.is_leaf, w.is_leaf
# ((y_hat-y)**2).sum()

  w = torch.tensor(w_0, dtype=torch.float32, requires_grad=requires_grad).T.detach()


(torch.Size([1, 96]),
 torch.Size([1, 96, 256]),
 torch.Size([256]),
 True,
 True,
 True)

In [36]:
'''Orthogonal Matching pursuit algorithm
'''
dict = X[0, ...] # consider cloning the tensor

chunk_length, n_atoms = dict.shape
batch_sz, chunk_length = y.shape
n_nonzero_coefs = 15
tau = 0.1
hard = True

# DTD = dict.T @ dict

residuals = y.clone().detach() # (batch_sz, chunk_length)
residuals.requires_grad = True
# max_score_indices = y.new_zeros((batch_sz, n_nonzero_coefs, n_atoms), dtype=X.dtype, device=X.device)
max_score_indices = []
detached_indices = np.zeros((batch_sz, n_nonzero_coefs), dtype=np.int64) # (batch_sz, n_nonzero_coefs)


tolerance = True
# Control stop interation with norm thresh or sparsity
for i in range(n_nonzero_coefs): 
    # Compute the score of each atoms
    projections = dict.T @ residuals[:, :, None] # (batch_sz, n_atoms, 1)

    detached_indices[:, i] = projections.abs().squeeze(-1).argmax(-1).detach().cpu().numpy()
    soft_score_indices = (projections/tau).abs().squeeze(-1).softmax(-1) # (batch_sz, n_atoms)
    if hard:
        # copied and modified from https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#gumbel_softmax
        # Straight through.
        index = soft_score_indices.max(-1, keepdim=True)[1]
        hard_score_indices = torch.zeros_like(soft_score_indices).scatter_(-1, index, 1.0)
        ret = hard_score_indices - soft_score_indices.detach() + soft_score_indices   
        # max_score_indices[:, i] = ret
        max_score_indices.append(ret[:, None, :])
    else:
        max_score_indices.append(soft_score_indices[:, None, :])
    
    # update selected_D torch.cat(max_score_indices, dim=1)
    _selected_D = X[:, None, ...] * torch.cat(max_score_indices, dim=1)[:, :i+1, None, :] # max_score_indices[:, :i+1, None, :] # (batch_sz, i+1, chunk_length, n_atoms)
    _selected_D = _selected_D.permute(0, 2, 1, 3) # (batch_sz, chunk_length, i+1, n_atoms)
    ind_0 = torch.arange(batch_sz)[:, None, None]
    ind_1 = torch.arange(chunk_length)[None, :, None]
    ind_2 = torch.arange(i+1)[None, None, :]
    selected_D = _selected_D[ind_0, ind_1, ind_2, detached_indices[:, None, :i+1]] # (batch_sz, chunk_length, i+1)
    
    s_D = dict[torch.arange(chunk_length)[:, None], detached_indices[:, :i+1]] # (chunk_length, i+1)

    # calculate selected_DTy
    selected_DTy = selected_D.permute(0, 2, 1) @ y[:, :, None] # (batch_sz, i+1, 1)

    # calculate selected_DTD
    selected_DTD = selected_D.permute(0, 2, 1) @ selected_D # (batch_sz, i+1, i+1)

    # find W = (selected_DTD)^-1 @ selected_DTy
    # selected_DTD.diagonal(dim1=-2, dim2=-1).add_(1.) # TODO: add multipath OMP
    nonzero_W = torch.linalg.solve(selected_DTD, selected_DTy) # (batch_sz, i+1, 1)

    # finally get residuals r=y-Wx
    residuals = y - (selected_D @ nonzero_W).squeeze() # (batch_sz, chunk_length, 1)

W = torch.zeros(batch_sz, n_atoms, dtype=selected_D.dtype, device=selected_D.device)
W[torch.arange(batch_sz)[:, None], detached_indices] = nonzero_W.squeeze()    


In [28]:
# fake_loss = W.sum() + 100
fake_loss = (nonzero_W**5).sum() + 100
fake_loss.backward()

In [30]:
z.grad

tensor(0.0102)

In [29]:
residuals.is_leaf, X.is_leaf, y.is_leaf, W.is_leaf, dict.is_leaf, z.is_leaf

(False, True, True, False, True, True)

In [47]:
x = torch.tensor(2.0, requires_grad = True)
print("x:", x)
y = x**2
y.backward()
x.grad
y.is_leaf, x.is_leaf, x.grad.is_leaf

x: tensor(2., requires_grad=True)


(False, True, True)

In [6]:
w[w!=0]

tensor([ 0.6459,  1.4198, -0.1485, -1.1357,  0.8393,  1.0694, -0.0308, -0.0098,
         0.2885,  0.1948, -0.3049,  1.1214,  1.3752,  0.2172,  0.2122],
       grad_fn=<IndexBackward0>)

In [7]:
torch.cat(max_score_indices, dim=1)[0, :][torch.cat(max_score_indices, dim=1)[0, :]!=0]

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       grad_fn=<IndexBackward0>)

In [13]:
W[W!=0]

tensor([ 0.6459,  1.4198, -0.1485, -1.1357,  0.8393,  1.0694, -0.0308, -0.0098,
         0.2885,  0.1948, -0.3049,  1.1214,  1.3752,  0.2172,  0.2122],
       grad_fn=<IndexBackward0>)

In [6]:
import torch
import torch.nn.functional as F
src = torch.arange(25, dtype=torch.float).reshape(1, 1, 5, 5).requires_grad_()  # 1 x 1 x 6 x 5 with 0 ... 30
indices = torch.tensor([[-1, 0],[1, 1]], dtype=torch.float).reshape(1, 1, -1, 2)  # 1 x 1 x 2 x 2
output = F.grid_sample(src, indices)
print(output, output.shape)  # tensor([[[[  0.,  12.]]]])

tensor([[[[5., 6.]]]], grad_fn=<GridSampler2DBackward0>) torch.Size([1, 1, 1, 2])


In [77]:
src

tensor([[[[ 0.,  1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.,  9.],
          [10., 11., 12., 13., 14.],
          [15., 16., 17., 18., 19.],
          [20., 21., 22., 23., 24.]]]], requires_grad=True)

In [75]:
indices

tensor([[[[0.0000, 0.5000]]]])

In [45]:
import torch
from torch.autograd import Variable

torch.manual_seed(0)
x = Variable(torch.arange(9, dtype=torch.float).reshape(-1,3), requires_grad=True)
idx = Variable(torch.FloatTensor([0,1]), requires_grad=True)


i0 = idx.floor().detach()
i1 = i0 + 1
i_1 = i0 - 1

y0 = x[i0.long(), :]
y1 = x[i1.long(), :]
y_1 = x[i_1.long(), :]

Wa = (i1 - idx).unsqueeze(1).expand_as(y0)
Wb = (idx - i0).unsqueeze(1).expand_as(y1)
Wa2 = (idx - i_1).unsqueeze(1).expand_as(y_1)
Wb2 = (i0 - idx).unsqueeze(1).expand_as(y_1)

out = (Wa * y0 + Wb * y1 + Wa2 * y0 + Wb2 * y_1)/2

print(out)
out.sum().backward()
print(idx.grad)

tensor([[0., 1., 2.],
        [3., 4., 5.]], grad_fn=<DivBackward0>)
tensor([-4.5000,  9.0000])


In [46]:
y0, y1
Wa2, Wb2

(tensor([[1., 1., 1.],
         [1., 1., 1.]], grad_fn=<ExpandBackward0>),
 tensor([[0., 0., 0.],
         [0., 0., 0.]], grad_fn=<ExpandBackward0>))

In [48]:
y_1

tensor([[6., 7., 8.],
        [0., 1., 2.]], grad_fn=<IndexBackward0>)

In [49]:
x

tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]], requires_grad=True)

In [78]:
xx = x.reshape(1,1,3,3)
xx

tensor([[[[0., 1., 2.],
          [3., 4., 5.],
          [6., 7., 8.]]]], grad_fn=<ReshapeAliasBackward0>)

In [81]:

grid = torch.tensor([0, -1.]).reshape(1,1,1,2).float()
F.grid_sample(xx, grid)



tensor([[[[0.5000]]]], grad_fn=<GridSampler2DBackward0>)

In [124]:
input = torch.arange(5*5).view(1, 1, 5, 5).float()
print(input)

# Create grid to upsample input
d = torch.linspace(-1, 1, 10)
meshx, meshy = torch.meshgrid((d, d))
grid = torch.stack((meshy, meshx), 2)
grid = grid.unsqueeze(0) # add batch dim

output = torch.nn.functional.grid_sample(input, grid, align_corners=True)
print(output)

tensor([[[[ 0.,  1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.,  9.],
          [10., 11., 12., 13., 14.],
          [15., 16., 17., 18., 19.],
          [20., 21., 22., 23., 24.]]]])
tensor([[[[ 0.0000,  0.4444,  0.8889,  1.3333,  1.7778,  2.2222,  2.6667,
            3.1111,  3.5556,  4.0000],
          [ 2.2222,  2.6667,  3.1111,  3.5556,  4.0000,  4.4444,  4.8889,
            5.3333,  5.7778,  6.2222],
          [ 4.4444,  4.8889,  5.3333,  5.7778,  6.2222,  6.6667,  7.1111,
            7.5556,  8.0000,  8.4444],
          [ 6.6667,  7.1111,  7.5556,  8.0000,  8.4444,  8.8889,  9.3333,
            9.7778, 10.2222, 10.6667],
          [ 8.8889,  9.3333,  9.7778, 10.2222, 10.6667, 11.1111, 11.5556,
           12.0000, 12.4444, 12.8889],
          [11.1111, 11.5556, 12.0000, 12.4444, 12.8889, 13.3333, 13.7778,
           14.2222, 14.6667, 15.1111],
          [13.3333, 13.7778, 14.2222, 14.6667, 15.1111, 15.5556, 16.0000,
           16.4444, 16.8889, 17.3333],
          [15.5556, 1

In [130]:
input.reshape(1,5,1,5)[0,0,...]

tensor([[0., 1., 2., 3., 4.]])

In [136]:
grid = torch.tensor([-1, 0]).reshape(1,1,1,2).float()
F.grid_sample(input.reshape(1,5,1,5), grid, align_corners=True).shape

torch.Size([1, 5, 1, 1])