In [1]:
import torch, numpy as np
import torch.nn.functional as F
import torch.nn as nn

from utills import create_data_set, plot_admm_vs_admm_1d_reconstruction

torch.manual_seed(0)

torch.set_default_dtype(torch.float64)

The following createss a sparse signals databases. each database comprised of N signal matrix with each signal
is 200x1 vector which is k-sparse signal

In [2]:
n, m, k = 150, 200, 4

# Measurement matrix
H = torch.randn(n, m)
H /= torch.norm(H, dim=0)

train_loader = create_data_set(H, n=n, m=m, k=k, N=1000)

test_loader = create_data_set(H, n=n, m=m, k=k, N=1000)

Vanilla ADMM implementation

In [3]:
def vanilla_admm(x, H, lambda_=12.5, mu=0.00005, rho=0.01, max_itr=300, eps=10 ** -5):
    proj = torch.nn.Softshrink(rho / (2 * lambda_))

    # initial estimate
    s = torch.zeros((H.shape[1]))
    u = torch.zeros((H.shape[1]))
    v = torch.zeros((H.shape[1]))

    # left_term = (H^TH+2λI)^-1 2*lambda or rho?
    left_term = torch.linalg.inv(H.T @ H + rho * torch.eye(H.shape[1]))

    recovery_errors = []
    for k in range(max_itr):
        s_prev, v_prev, u_prev = s, v, u

        # Update s_k+1 = ((H^T)H+2λI)^−1(H^T x+2λ(vk−uk)).

        right_term = H.T @ x + rho * (v_prev - u_prev)
        s = left_term @ right_term

        # Update vk+1 = prox_(1/2λϕ)(sk+1 + uk)
        v = proj(s + u_prev)

        # Update uk+1 = uk + μ (sk+1 − vk+1).

        u = u_prev + mu * (s - v)

        # # cease if convergence achieved
        if torch.sum(torch.abs(s - s_prev)) <= eps:
            break

        # save recovery error
        recovery_errors.append(torch.sum((torch.matmul(H, s) - x) ** 2))

    return s

Model Based ADMM implementation

In [4]:
EPSILON = 10 ** -2

class LADMM_Model_1D(nn.Module):
    def __init__(self, n, m, max_iterations=1000, rho=0.01, H=None, lambda_=12.5, mu=0.00005, epsilon=EPSILON):
        super(LADMM_Model_1D, self).__init__()
        self.n, self.m = n, m
        self.H = H

        # admm(x, H, lambda_=12.5, mu=0.00005, rho=0.01, max_itr=300, eps=10 ** -5):

        # Initialization of 1 dimensional parameter
        self.rho = nn.Parameter(torch.ones(1) * rho, requires_grad=True)
        self.lambda_ = nn.Parameter(torch.ones(1) * lambda_, requires_grad=True)
        self.mu = nn.Parameter(torch.ones(1) * mu, requires_grad=True)

        self.max_iteration = max_iterations
        self.epsilon = epsilon

    def _shrink(self, s, beta, rho):
        return beta * F.softshrink(s / beta, lambd=rho)

    def forward(self, x):
        """
        Args:
            x:
        Returns:
        """


        s_prev = torch.zeros(x.shape[0], self.H.shape[1])
        u_prev = torch.zeros((x.shape[0], self.H.shape[1]))
        v_prev = torch.zeros((x.shape[0], self.H.shape[1]))

        #################### Iteration 0 ####################

        left_term = torch.linalg.inv(H.T @ H + self.rho * torch.eye(H.shape[1]))

        right_term = (H.T @ x.T).T + self.rho * (v_prev - u_prev)

        s = (left_term @ right_term.T).T
        v = self._shrink(s + u_prev, self.rho / (2 * self.lambda_), rho=self.rho.item())
        u = u_prev + self.mu * (s - v)

        ######################################################
        iteration = 0

        # Notice the stopping condition isn't fixed K iterations
        while (torch.norm(s_prev.detach() - s.detach()).item() > self.epsilon) and (iteration < self.max_iteration):
            s_prev, v_prev, u_prev = s, v, u

            # left_term = (H^TH+2λI)^-1 2*lambda or rho?
            left_term = torch.linalg.inv(H.T @ H + self.rho * torch.eye(H.shape[1]))

            right_term = (H.T @ x.T).T + self.rho * (v_prev - u_prev)

            # Update s_k+1 = ((H^T)H+2λI)^−1(H^T x+2λ(vk−uk)).
            s = (left_term @ right_term.T).T

            # Update vk+1 = prox_(1/2λϕ)(sk+1 + uk)
            v = self._shrink(s + u_prev, self.rho / (2 * self.lambda_), rho=self.rho.item())

            # Update uk+1 = uk + μ (sk+1 − vk+1).
            u = u_prev + self.mu * (s - v)
            iteration += 1

        print("Batch total iterations: {0}, parameters: mu:{1} lambda:{2} rho:{3}".format(iteration, self.mu.item(),
                                                                                          self.lambda_.item(),
                                                                                          self.rho.item()))
        return s

The function trains a machine learning model

In [5]:
def train(model, train_loader, valid_loader, num_epochs=60):
    """Train a network.
    Returns:
        loss_test {numpy} -- loss function values on test set
    """
    # Initialization
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=5e-05,
        momentum=0.9,
        weight_decay=0,
    )

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)
    loss_train = np.zeros((num_epochs,))
    loss_test = np.zeros((num_epochs,))

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for step, (b_x, b_H, b_s) in enumerate(train_loader):

            s_hat = model(b_x)
            loss = F.mse_loss(s_hat, b_s, reduction="sum")
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            model.zero_grad()
            train_loss += loss.data.item()

        # Aggregate loss
        loss_train[epoch] = train_loss / len(train_loader.dataset)
        scheduler.step()

        # Validation
        model.eval()
        test_loss = 0
        for step, (b_x, b_H, b_s) in enumerate(valid_loader):
            s_hat = model(b_x)
            test_loss += F.mse_loss(s_hat, b_s, reduction="sum").data.item()
        loss_test[epoch] = test_loss / len(valid_loader.dataset)

        if epoch % 10 == 0:
            print("Epoch %d, Train loss %.8f, Validation loss %.8f" % (epoch, loss_train[epoch], loss_test[epoch]))

    return loss_test, b_x, b_s

The apply function perform L-ADMM or ADMM upon the whole sparse signal database and aggregate the loss.

In [6]:
def admm_1d_apply(train_loader, test_loader, max_iterations, H):
    n = H.shape[1]
    m = H.shape[1]

    ladmm = LADMM_Model_1D(n=n, m=m, max_iterations=max_iterations, H=H)

    loss_test, b_x, b_s = train(ladmm, train_loader, test_loader)
    error = loss_test[-1]

    return error, ladmm, b_x, b_s


def admm_apply(test_loader, T, H):

    loss = 0

    for step, (x, _, s) in enumerate(test_loader.dataset):
        s_hat = vanilla_admm(x=x, H=H, max_itr=T)
        loss += F.mse_loss(s_hat, s, reduction="sum").data.item()

    return loss / len(test_loader.dataset)

In [7]:
max_iter, epochs = 1000, 60


# b_x, b_s = a batch from the validation set
# Train and apply L-ADMM One Parameter with T iterations / layers
ladmm_mse, admm1d_model, b_x, b_s = admm_1d_apply(train_loader, test_loader, max_iter, H)

admm_mse = admm_apply(test_loader, max_iter, H)

######################### Visuallization #########################
b_x, s_gt = b_x[0], b_s[0]
s_hat_ladmm = admm1d_model(b_x)
s_hat_ladmm = s_hat_ladmm.detach().numpy()[0]

s_hat_admm = vanilla_admm(x=b_x, H=H, max_itr=max_iter)

plot_admm_vs_admm_1d_reconstruction(s_hat_admm=s_hat_admm,
                                  s_hat_ladmm=s_hat_ladmm, max_iter=max_iter,s_gt = s_gt)



Batch total iterations: 4, parameters: mu:5e-05 lambda:12.5 rho:0.01
Batch total iterations: 4, parameters: mu:5.298389142165073e-05 lambda:12.499999761252216 rho:0.010417189157524355
Batch total iterations: 5, parameters: mu:5.884584478711163e-05 lambda:12.499999292220467 rho:0.01123078291634062
Batch total iterations: 4, parameters: mu:5.884584478711163e-05 lambda:12.499999292220467 rho:0.01123078291634062
Epoch 0, Train loss 0.30019529, Validation loss 0.29377592
Batch total iterations: 5, parameters: mu:5.884584478711163e-05 lambda:12.499999292220467 rho:0.01123078291634062
Batch total iterations: 5, parameters: mu:6.88628106020087e-05 lambda:12.499998490702433 rho:0.012437596668718683
Batch total iterations: 5, parameters: mu:8.347796573597211e-05 lambda:12.499997321237489 rho:0.014040887751553384
Batch total iterations: 5, parameters: mu:8.347796573597211e-05 lambda:12.499997321237489 rho:0.014040887751553384
Batch total iterations: 5, parameters: mu:8.347796573597211e-05 lambda:

TypeError: plot_admm_vs_admm_1d_reconstruction() missing 1 required positional argument: 'epochs'

In [None]:
plot_admm_vs_admm_1d_reconstruction(s_hat_admm=s_hat_admm,
                                    s_hat_ladmm=s_hat_ladmm,
                                    max_iter=max_iter, s_gt = s_gt, epochs=epochs)