In [1]:
import numpy as np

from tqdm import tqdm
from matplotlib import pyplot as plt

import torch, numpy as np
import torch.nn.functional as F
import torch.nn as nn
import torch.utils.data as Data
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore")

### Generate a dataset of signals and binary labels 

In [37]:
class SyntheticSignals():
    def __init__(self, A, B, n, m, s, s_sigma = 0.5, eps_sigma = 0.01, size = 1000, batch_size = 512):
        
        self.n = n          # Number of samples in the original signal
        self.m = m          # Number of samples through the linear transformation

        self.size = size    # Size of the dataset

        self.alpha = torch.zeros(self.size, self.n)          # Underlying sparse vector 
        self.x = torch.zeros(self.size, self.m)              # Observation
        self.y = torch.zeros(self.size)                     # Binary label 

        self.A = A          # Matrix for the linear observation - Label 0
        self.B = B          # Matrix for the linear observation - Label 1
        self.s = s          # Sparsity of the signal

        self.s_sigma = s_sigma
        self.eps_sigma = eps_sigma

        self.batch_size = batch_size

        # Generating the dataset
        self.set_data()


    def set_tuple(self, i):

        # Generating random sparsity in the canonic basis of the original signal
        idxs = np.random.choice(self.m, self.s, replace=False)
        peaks = np.random.normal(scale=self.s_sigma, size = self.s)
        y_ = np.random.choice([-1,1])
        
        # Generating the original signal and its corrupted observations according to a label
        self.alpha[i, idxs] = torch.from_numpy(peaks).to(self.alpha)
        self.x[i,:] = (self.A * (y_ == 1) + self.B * (y_ == 0)) @ self.alpha[i,:] + np.random.normal(scale=self.eps_sigma, size=self.m)
        self.y[i] = y_

    def set_data(self):
        for i in range(self.size):
            self.set_tuple(i)
    
    
    def set_loader(self):

        # We need tuples observation/label
        return Data.DataLoader(dataset = Data.TensorDataset(self.x, self.y),
                               batch_size = self.batch_size,
                               shuffle = True)

In [38]:
# Dimensions for the signal, its sparsity and its observation

m_ = 150
n_ = 200
s_ = 4

# Measurement matrices
A_ = torch.randn(m_,n_)
A_ /= torch.norm(A_, dim=0)

B_ = torch.randn(m_,n_)
B_ /= torch.norm(A_, dim=0)

# Building a training set and a test set 
train_set = SyntheticSignals(A = A_,
                             B = B_,
                             n = n_,
                             m = m_,
                             s = s_,
                             size = 800).set_loader()

test_set = SyntheticSignals(A = A_,
                            B = B_,
                            n = n_,
                            m = m_,
                            s = s_,
                            size = 200).set_loader()

__________________________

___________________

### Model definition

In [59]:
class TDDL(nn.Module):
    def __init__(self, D, K, in_, hidden_, Lambda, T = 100, t_0 = 1, LR = 5e-03):
        super().__init__()

        # Dictionary initialization 
        self.D = D

        # Assumed sparsity
        self.K = K

        # Hyperparameters
        self.Lambda = Lambda
        self.t_0 = t_0
        self.T = T
        self.LR = LR

        # Define the neural architecture
        self.fc1 = nn.Linear(in_, hidden_)
        self.fc2 = nn.Linear(hidden_, 1)

        # Define the optimization utilities 
        self.criterion = nn.BCELoss()  

    def forward(self, alpha):

        alpha = F.relu(self.fc1(alpha))
        alpha = torch.sigmoid(self.fc2(alpha))

        return alpha

    def OMP(self, x):
        S = []
        alpha = torch.zeros(self.D.shape[1])
        iters = 0
        R = x

        while iters < self.K:

            # Retrieve the maximum correlation between atoms and residuals of the previous iteration
            S.append(torch.argmax(torch.abs(torch.matmul(self.D.t(), R))).item())

            dic = self.D[:, S]
            x_S = torch.linalg.pinv(dic.t() @ dic) @ dic.t() @ self.D
            alpha[S] = x_S

            # Update the residuals
            R = x - torch.matmul(self.D, alpha)
            iters += 1

        return torch.tensor(alpha, requires_grad=True)

    def activeSet(alpha):

        return torch.nonzero(alpha).squeeze()
    
    def projD(self):

        # We constrain the atoms to have norm equal to one
        def projCol(d):

            return d / torch.max(torch.tensor(1.0), torch.linalg.norm(d))

        return torch.stack([(projCol(D[:, i])) for i in range(D.shape[1])], dim=1)
    
    def update(self, x, y, t):

        # Compute the sparse approximation and its non-zero entries set
        alpha = self.OMP(x)
        L = self.activeSet(alpha)

        # Forward this sparse feature vector enabling gradient computation with respect to the sparse vector
        y_hat = self.forward(alpha)

        # Loss computation and backpropagation
        loss = self.criterion(y_hat, y)
        loss.backward()

        # Define the support vector for the D-gradient computation 
        beta = torch.zeros_like(alpha)
        dic = self.D[:,L] 
        beta[L] = torch.matmul(torch.linalg.inv(torch.matmul(dic.t(), dic)), alpha.grad[L])

        ##########################
        ### OPTIMIZATION PHASE ###
        ##########################

        # Learning rate heuristic
        LR = torch.min(self.LR, self.LR * self.t_0/t)

        # Gradient descent for the model parameters
        with torch.no_grad():
            for param in self.parameters():
                param -= LR * param.grad

        # Projected gradient descent for the dictionary 
        with torch.no_grad():
            self.D = self.projD(self.D - LR * ( - self.D @ beta @ alpha.t() + (x - self.D @ alpha) @ beta.t()))

    def trainLoop(self, train_set):

        # Set to train mode
        self.train()

        # Main loop
        for t in range(self.T):
            idx = torch.randint(0, len(train_set.dataset), (1,))
            x, y = train_set.dataset.__getitem__(idx)

            self.update(x, y, t)    