<h1><center>
"Short-and-Sparse Deconvolution -- A Geometric Approach"<br></center></h1>
<center>
Yenson Lau\*, Qing Qu\*, Han-Wen Kuo, Pengcheng Zhou, Yuqian Zhang, and John Wright<br>
(* denote equal contribution)<br>

We solve the short-and-sparse convolutional dictionary learning problem<br>

Code written by Qing Qu
</center>

#### <center> This notebook is a rewritten of the matlab code in python.</center>

#### Implementation of Alternating Desecent Method (ADM) in the paper
$y = \sum_{k=1}^K a_{0k} \circledast x_{0k} + b * 1 + n$<br>
with both $a_{0k}$ and $x_{0k}$ unknown, $b$ is a constant bias, $n$ is noise<br>

The algorithms solve the following 1D optimization problem<br>
$min F(A,X) = 0.5 * ||y - \sum_{k=1}^K a_k \circledast x_k||_2^2 + \lambda * ||X||_1$<br>
s.t. $||ak|| = 1$, $k = 1,...,K$<br>
$A = [a_1,a_2,...,a_K]$, $X = [x_1,x_2,...,x_K]$<br>
via alternating gradient descent:<br>

1. Fix $A$, and take a proximal gradient on $X$<br>
2. Fix $X$, and take a Riemannian gradient on $A$<br>

In [20]:
# Import dependencies
import numpy as np
from collections import defaultdict
##### The following functions is written outside in the auxilliary.py file,
##### but also have reference in this notebook at the very bottom
# from auxiliary import cconv
# from auxiliary import Log_map
# from auxiliary import Retract
# from auxiliary import backtracking
# from auxiliary import compute_error
# from auxiliary import compute_gradient
# from auxiliary import compute_y
# from auxiliary import gen_data
# from auxiliary import linesearch
# from auxiliary import reversal
# from auxiliary import shift_correlation

In [2]:
def ADM(y, opts):
    """
    Parameter:
    ---------------
    y: the output we observed, e.g., a blured picture
    opts: a dictionary containing all parameters
    """
    lamb = opts["lambda"] * opts["W"] # Penalty for sparsity
    def Psi_handle_function(v, u, V, lamb): # Loss function
        # Psi = @(v, u, V, Lambda) 0.5 * norm(v - u)^2 +  norm(Lambda .*V,1); % handle function ...
        return 0.5 * np.linalg.norm(v - u) ** 2 + np.linalg.norm(lamb * V, 1)
    
    # Evaluate the function value
    m = np.max(y.shape) # The number of measurements
    n, k = opts["A_init"].shape # n: the length of kernel, K: number of the kernels
    A = opts["A_init"] # Initialization for A
    X = opts["X_init"] # Initialization for X
    b = opts["b_init"] # Initialization of the bias
    t = 1 # Initialization of the stepsize
    
    # Record of the solution path
    Psi_val, psi_val, X_track, A_track, Err_A, Err_X = [], [], [], [], [], []
    
    # Main ADM algorithm
    for iteration in range(1, int(opts["MaxIter"]) + 1):
        # Given A fixed, take a descent step on X via proximal gradient descent
        y_hat = compute_y(A, X) # Compute y_hat = sum_k conv(a_k, x_k, m) ### Function to be import
        y_b = y - np.ones([m, 1]) * b # why not add
        Psi_X = Psi_handle_function(y_b, y_hat, X, opts["lambda"]) # Compute loss
        fx = 0.5 * np.linalg.norm(y_b - y_hat) ** 2 # ???
        grad_fx = compute_gradient(A, X, y_b, y_hat, 0) 
        # Backtracking for update X and update stepsize t
        X_old = X.copy()
        [X, t] = backtracking(y_b, A, X_old, fx, grad_fx, lamb, t, opts) # Function to be imported
        # Given X fixed, take a Riemannian gradient step on A
        y_hat = compute_y(A, X) 
        Psi_A = Psi_handle_function(y_b, y_hat, X, opts["lambda"])
        fa = 0.5 * np.linalg.norm(y_b - y_hat) ** 2 # ???
        grad_fa = compute_gradient(A, X, y_b, y_hat, 1)
        A_old = A.copy()
        [A, tau] = linesearch(y_b, A_old, X, fa, grad_fa) # line search for tau
        y_hat = compute_y(A, X) # Delete
        
        # Given A, X fixed, update the bias b # why are we doing this
        y_hat = compute_y(A, X)
        if opts["isbias"]:
            b = 1 / m * np.sum(y - y_hat)
            
        # Update results and check for stopping criteria
        Psi_val.append(Psi_handle_function(y_b, y_hat, X, opts["lambda"]))
        psi_val.append(0.5 * np.linalg.norm(y_b - y_hat) ** 2)
        X_track.append(X)
        A_track.append(A)
        
        # Calculate the distance between the groudtruth and the iterate
        if opts["err_truth"]:
            [err_A, err_X] = compute_error(A, X, opts)
            Err_A.append(err_A)
            Err_X.append(err_X)
            
        if opts["isprint"]:
            print(f"Running the {iteration}-th simulation, Psi_X = {Psi_X}, Psi_A = {Psi_A}")
            
        # Check for stopping criteria
        if np.linalg.norm(A_old - A, "fro") <= opts["tol"] and np.linalg.norm(X_old - X, "fro") <= opts["tol"]:
            # Why are we checking frobenius norm?
            break
          
    return [A, X, b, Psi_val, psi_val, Err_A, Err_X]

#### Implementation of homotopy acceleration in the paper

$y = \sum_{k=1}^K a_{0k} \circledast x_{0k} + b * 1 + n$ <br>
with both $a_{0k}$ and $x_{0k}$ unknown, $b$ is a constant bias, $n$ is noise<br>

The algorithms solve the following 1D optimization problem<br>
$min F(A,X) = 0.5 * ||y - \sum_{k=1}^K a_k \circledast x_k||_2^2 + \lambda * ||X||_1$ s.t. $||a_k|| = 1, k = 1,...,K$<br>
$A = [a_1,a_2,...,a_K]$, $X = [x_1,x_2,...,x_K]$<br>
homotopy chooses a sparse solution path by shrinking the $\lambda$:<br>
The algorithm starts with a large $\lambda$, and for each iteration it solves <br>
the problem with using a solver (e.g., ADM or iADM).<br>
It shrink the $\lambda$ geoemtrically and repeat until convergence.<br>

In [3]:
def homotopy(y_0, opts):
    # homotopy chooses a sparse solution path by shrinking the lambda:
    # The algorithm starts with a large lambda, and for each iteration it solves 
    # the problem with using a solver (e.g., ADM or iADM).
    # It shrink the lambda geoemtrically and repeat until convergence.
    
    n,k = opts["A_init"].shape
    m = np.max(y_0.shape)
    Psi_Val, psi_Val, Err_A, Err_X = [], [], [], []
    
    homo_opts = opts.copy() # Since we'll change the dict accordingly later, make a copy
    
    # Setting parameters
    case_now = opts["homo_alg"].lower()
    if case_now == "adm":
        eta = 0.8
        delta = 8e-2
        homo_opts["MaxIter"] = 2e2
    elif case_now == "iadm":
        eta = 0.85
        delta = 5e-1
        homo_opts["MaxIter"] = 1e2
    elif case_now == "reweight":
        eta = 0.8
        delta = 0.1
        homo_opts["MaxIter"] = 2e2
    else:
        raise ValueError("Wrong algorithm")
    
    lamb_0 = 1 # Initial lambda
    # lambda_0 = norm(cconv(reversal(y_0),opts.A_init,m),'inf'); % initial lambda
    lamb_0 = np.linalg.norm(cconv(reversal(y_0), opts["A_init"], m).flatten(), float('inf'))
    lamb_tgt = opts["lambda"] # target lambda
    
    homo_opts["lambda"] = lamb_0
    homo_opts["tol"] = delta * lamb_0 # should be the criteria, but why this value
    
    N_stages = int(np.floor(np.log(lamb_0/lamb_tgt) / np.log(1.0/eta))) # What's this
    lamb = lamb_0
    
    # Running the algorithm
    for k in range(1, N_stages + 1):
        case = opts["homo_alg"].lower()
        if case == "adm":
            [A, X, b, Psi_val,psi_val, Err_a, Err_x] = ADM(y_0, homo_opts)
        elif case == "iadm":
            [A, X, b, Psi_val,psi_val, Err_a, Err_x] = iADM(y_0, homo_opts)
            # opts.count = opts.count + length(f_val)
        elif case == "reweight":
            [A, X, b, Psi_val,psi_val,W] = reweighting(y_0, homo_opts)
            homo_opts["W"] = W
        else:
            raise ValueError("Wrong algorithm")
            
        # Record result
        Psi_Val.append(Psi_val)
        psi_Val.append(psi_val)
        Err_A.append(Err_a)
        Err_X.append(Err_x)
        
        # Update the parameters of opts
        homo_opts["A_init"] = A
        homo_opts["X_init"] = X
        homo_opts["b_init"] = b
        
        lamb = lamb * eta
        tol = delta * lamb
        homo_opts["lambda"] = lamb
        homo_opts["tol"] = tol
    
    # Solving the final stage to precision tol
    homo_opts["lambda"] = lamb_tgt
    homo_opts["tol"] = opts["opt"]
    homo_opts["MaxIter"] = opts["MaxIter"]
    
    case_now =  opts["homo_alg"].lower()
    if case_now == 'adm':
        [A, X, b, Psi_val,psi_val,Err_a,Err_x] = ADM(y_0, homo_opts)
    elif case_now == 'iadm':
        [A, X, b, Psi_val,psi_val,Err_a,Err_x] = iADM(y_0, homo_opts)
    elif case_now == 'reweight':
        [A, X, b, Psi_val,psi_val] = reweighting(y_0, homo_opts)
    else:
        raise ValueError("Wrong algorithm")
        
    Psi_Val.append(Psi_val)
    psi_Val.append(psi_val)
    Err_A.append(Err_a)
    Err_X.append(Err_x)
        
    return [A, X, b, Psi_Val, psi_Val, Err_A, Err_X]

#### Implementation of inertial Alternating Desecent Method (iADM) in the paper
via alternating acclerated gradient descent <br>

 1. Fix $A$, and take a proximal gradient on $X$ with momentum acceleration
 2. Fix $X$, and take a Riemannian gradient on $A$ with momentum acceleration

In [6]:
def iADM(y, opts):
    # via alternating acclerated gradient descent
    # 1. Fix A, and take a proximal gradient on X with momentum acceleration
    # 2. Fix X, and take a Riemannian gradient on A with momentum acceleration
    
    lamb = opts["lambda"] * opts["W"] # Penalty for sparsity
    def Psi_handle_function(v, u, V, lamb): # Loss function
        # Psi = @(v, u, V, Lambda) 0.5 * norm(v - u)^2 +  norm(Lambda .*V,1); % handle function ...
        return 0.5 * np.linalg.norm(v - u) ** 2 + np.linalg.norm(lamb * V, 1)
    
    # Evaluate the function value
    m = np.max(y.shape) # The number of measurements
    n, K = opts["A_init"].shape # n: the length of kernel, K: number of the kernels
    A = opts["A_init"] # Initialization for A
    A_old = opts["A_init"].copy() # Initialization for A_hat, auxiliary variable for acceleration
    X = opts["X_init"] # Initialization for X
    X_old = opts["X_init"].copy()
    b = opts["b_init"] # Initialization of the bias
    
    t = 0.5 # Initialization of the stepsize for X
    Psi_val, psi_val, X_track, A_track, Err_A, Err_X = [], [], [], [], [], []
    
    for iteration in range(1, int(opts["MaxIter"]) + 1):
        # Given A fixed, take a descent step on X via proximal gradient descent
        beta = 0.85
        X_hat = X + beta * (X - X_old)
        y_hat = compute_y(A, X_hat) # Compute y_hat = sum_k conv(a_k, x_k, m) ### Function to be import
        y_b = y - np.ones([m, 1]) * b 
        Psi_X = Psi_handle_function(y_b, y_hat, X_hat, opts["lambda"]) # Compute loss
        fx = 0.5 * np.linalg.norm(y_b - y_hat) ** 2 
        grad_fx = compute_gradient(A, X_hat, y_b, y_hat, 0) 
        
        # Backtracking for update X and update stepsize t
        X_old = X.copy()
        
        # Line search
        case = opts["t_linesearch"].lower()
        if case == "fixed":
            t = opts["t_fixed"]
            # Apply proximal mapping of Psi and compute gradient mapping
            X = soft_thres(X_hat - t * grad_fx, lamb * t) # Funciton to import
        elif case == "bt":
            [X,t] = backtracking(y_b, A, X_hat, fx, grad_fx, lamb, t, opts) # Function to import
        else:
            raise ValueError("Line search method not implemented")
        
        # Given X fixed, take a Riemannian gradient step on A
        D = Log_map(A_old, A) # Function to import
        Norm_D = np.zeros([K, 1])
        for k in range(K):
            Norm_D[k] = np.linalg.norm(D[:, k])
        
        A_hat = Retract(A, beta * D, beta * Norm_D) # Function to import
        
        y_hat = compute_y(A_hat, X)
        y_b = y - np.ones([m, 1]) * b
        Psi_A = Psi_handle_function(y_b, y_hat, X, lamb)
        fa = 0.5 * np.linalg.norm(y_b - y_hat) ** 2
        grad_fa = compute_gradient(A_hat, X, y_b, y_hat, 1)
        
        A_old = A.copy()
        [A, tau] = linesearch(y_b, A_hat, X, fa, grad_fa) ### Function to import
        
        # Given A, X fixed, update the bias b # why are we doing this
        y_hat = compute_y(A, X)
        if opts["isbias"]:
            b = 1 / m * np.sum(y - y_hat)
        y_b = y - np.ones([m, 1]) * b
            
        # Update results and check for stopping criteria
        Psi_val.append(Psi_handle_function(y_b, y_hat, X, opts["lambda"]))
        psi_val.append(0.5 * np.linalg.norm(y_b - y_hat) ** 2)
        X_track.append(X)
        A_track.append(A)
        
        # Calculate the distance between the groudtruth and the iterate
        if opts["err_truth"]:
            [err_A, err_X] = compute_error(A, X, opts)
            Err_A.append(err_A)
            Err_X.append(err_X)
            
        if opts["isprint"]:
            print(f"Running the {iteration}-th simulation, Psi_X = {Psi_X}, Psi_A = {Psi_A}")
            
        # Check for stopping criteria
        if np.linalg.norm(A_old - A, "fro") <= opts["tol"] and np.linalg.norm(X_old - X, "fro") <= opts["tol"]:
            # Why are we checking frobenius norm?
            break
          
        
    return [A, X, b, Psi_val, psi_val, Err_A, Err_X]

#### Implementation of reweighting method in the paper
Reweighting method starts with an all one weights $W$, and update the weights <br>
$W_{ij} = 1/(|X_{ij}| + \epsilon)$ for each iteration
We repeat the process until convergence

In [5]:
def reweighting(y_0, opts):
    # Reweighting method starts with an all one weights W, and update the weights
    # W_ij = 1/(|X_ij| + eps) for each iteration
    # We repeat the process until convergence
    
    n, k = opts["A_init"].shape
    m = np.max(y_0.shape)
    Psi_Val, psi_Val, Err_A, Err_X = [], [], [], []
    
    for iteration in range(1, opts["MaxIter_reweight"]):
        case = opts["reweight_alg"].lower()
        if case == "adm":
            [A, X, b, Psi_val,psi_val, Err_a, Err_x] = ADM(y_0, homo_opts) #### ? where does homo_opts come from
        elif case == "iadm":
            [A, X, b, Psi_val,psi_val, Err_a, Err_x] = iADM(y_0, opts)
        elif case == "homo":
            [A, X, b, Psi_val,psi_val, Err_a, Err_x] = homotopy(y_0, opts)
        else:
            raise ValueError("Wrong algorithm")
        
        # Record result
        Psi_Val.append(Psi_val)
        psi_Val.append(psi_val)
        Err_A.append(Err_a)
        Err_X.append(Err_x)
        
        if opts["isprint"]:
            print(f"Running the {iteration}-th round of reweighting")
        
        if np.linalg.norm(opts["A_init"] - A, "fro") <= opts["tol"] and \
                np.linalg.norm(opts["X_init"] - X, "fro") <= opts["tol"]:
            break
        
        # Update the initialization
        opts["A_init"] = A
        opts["X_init"] = X
        opts["b_init"] = b
        opts["count"] = opts["count"] + len(psi_val)
        
        # Update the weight matrix
        x = np.sort(np.abs(X.flatten('F')))[::-1] # Sort the flattened array in descending order
        thres = x[int(np.round(n / (4 * np.log(m / n))))] # What's this
        e = np.max(thres, 1e-3)
        
        opts["W"] = 1 / (np.abs(X) + e)
    
    W = opts["W"]
        
    return [A, X, b, Psi_Val, psi_Val, Err_A, Err_X]

#### Comparing the algorithmic performance of the proposed nonconvex optimization methods in the paper
Test the proposed Alternating desecent method (ADM), inertial ADM (iADM), homotopy acceleration and reweighting method.

In [24]:
# platform for simulation of Convolutionoal dictionary learning problem
# optimization parameters
opts = {}
opts["tol"] = 1e-6 # convergence tolerance
opts["isnonnegative"] = False # enforcing nonnegativity on X
opts["isupperbound"] = False # enforce upper bound on X
opts["upperbound"] = 1.5 # upper bound number
opts["hard_thres"] = False # hard-threshold on small entries of X to zero
opts["MaxIter"] = int(1e3) # number of maximum iterations
opts["MaxIter_reweight"] = 10 # reweighting iterations for reweighting algorithm
opts["isbias"] = True # enforce when there is a constant bias in y
opts["t_linesearch"] = 'bt' # linesearch for the stepsize t for X
opts["err_truth"] = True # enforce to compute error w.r.t. the groundtruth for (a0, x0)
opts["isprint"] = True # print the intermediate result

# Generate the measurements
# Set up parameters
n = int(1e2) # length of each kernel a0k
m = 1000 #int(1e4) # length of the measurements y
K = 1 # number of kernels
theta = n**(-3/4) # sparsity parameter for Bernoulli distribution
opts["lambda"] = 1e-2 # penalty parameter lambda

a_type = "randn" # Choose from 'randn', 'ar1', 'ar2', 'gaussian', 'sinc'
x_type = "bernoulli-rademacher" # choose 'bernoulli' or 'bernoulli-rademacher' or 'bernoulli-gaussian'
b_0 = 1 # bias
noise_level = 0 # noise level

# Generate the data
[A_0, X_0, y_0, y] = gen_data(theta, m, n, b_0, noise_level, a_type, x_type)
opts["truth"] = True
opts["A_0"] = A_0
opts["X_0"] = X_0
opts["b_0"] = b_0

# Initiation for A, X, b
# Initialize A
opts["A_init"] = np.zeros([3 * n, K])
for k in range(K):
    ind = np.random.permutation(m)[0]
    y_pad = np.vstack([y_0, y_0])
    a_init = y_pad[ind:ind+n]
    a_init = np.vstack([np.zeros([n, 1]), a_init, np.zeros([n,1])])
    a_init = a_init / np.linalg.norm(a_init, axis = 0)
    opts["A_init"][:, k] = a_init.flatten()

opts["X_init"] = np.zeros([m, K]) # Initialize X
opts["b_init"] = np.mean(y)
opts["W"] = np.ones([m, K]) # Initialize the weight matrix

# Run the optimization algorithms
Alg_num = 4
# Alg_type = ['ADM','iADM','homotopy-ADM','homotopy-iADM','reweighting']
Alg_type = ["ADM",'iADM','homotopy-ADM','homotopy-iADM','reweighting']#'ADM','
#Alg_type = ['homotopy-ADM']
Psi_min, psi_min = float("inf"), float("inf")
Psi = defaultdict(list)
psi = defaultdict(list)#[range(opts["MaxIter"])]
Err_A = defaultdict(list)
Err_X = defaultdict(list)

for k in range(len(Alg_type)):
    case = Alg_type[k].lower()
    if case == "adm":
        [A, X, b, Psi[k], psi[k], Err_A[k], Err_X[k]] = ADM(y_0, opts)
    elif case == "iadm":
        [A, X, b, Psi[k], psi[k], Err_A[k], Err_X[k]] = iADM(y_0, opts)
    elif case == "homotopy-adm":
        opts["homo_alg"] = "adm"
        [A, X, b, Psi[k], psi[k], Err_A[k], Err_X[k]] = homotopy(y_0, opts)
    elif case == "homotopy-iadm":
        opts["homo_alg"] = "iadm"
        [A, X, b, Psi[k], psi[k], Err_A[k], Err_X[k]] = homotopy(y_0, opts)
    elif case == "reweighting":
        opts["reweight_alg"] = "adm"
        [A, X, b, Psi[k], psi[k], Err_A[k], Err_X[k]] = reweighting(y_0, opts)
    
    if Psi[k][-1] <= Psi_min:
        Psi_min = Psi[k][-1]
    if psi[k][-1] <= psi_min:
        psi_min = psi[k][-1]
        
# Plotting results



Running the 1-th simulation, Psi_X = 233.72363618675516, Psi_A = 220.33101655872096
Running the 2-th simulation, Psi_X = 216.42236158012554, Psi_A = 135.91357896325857
Running the 3-th simulation, Psi_X = 125.26621325071349, Psi_A = 114.42595973818324
Running the 4-th simulation, Psi_X = 108.85090058271739, Psi_A = 99.8280041842448
Running the 5-th simulation, Psi_X = 93.28358648759837, Psi_A = 79.59667582654471
Running the 6-th simulation, Psi_X = 73.929403634276, Psi_A = 57.448175590237895
Running the 7-th simulation, Psi_X = 52.12674519075298, Psi_A = 41.97207567302136
Running the 8-th simulation, Psi_X = 39.55094264327281, Psi_A = 32.7098669149857
Running the 9-th simulation, Psi_X = 31.0687245747311, Psi_A = 23.636614194254534
Running the 10-th simulation, Psi_X = 21.628536460161612, Psi_A = 17.6072990967223
Running the 11-th simulation, Psi_X = 16.079875265976412, Psi_A = 13.47484939494776
Running the 12-th simulation, Psi_X = 12.234649888153779, Psi_A = 10.385418303202805
Runnin

Running the 112-th simulation, Psi_X = 4.376183042367956, Psi_A = 4.375586684596126
Running the 113-th simulation, Psi_X = 4.374522467511592, Psi_A = 4.3736492039842245
Running the 114-th simulation, Psi_X = 4.372626934515717, Psi_A = 4.37194232499431
Running the 115-th simulation, Psi_X = 4.370153781993356, Psi_A = 4.369711915758023
Running the 116-th simulation, Psi_X = 4.368346643569856, Psi_A = 4.367843217681303
Running the 117-th simulation, Psi_X = 4.366710586993172, Psi_A = 4.366119208855112
Running the 118-th simulation, Psi_X = 4.365085660929327, Psi_A = 4.364217820339558
Running the 119-th simulation, Psi_X = 4.363219682556327, Psi_A = 4.362540536050918
Running the 120-th simulation, Psi_X = 4.3608002988099495, Psi_A = 4.360366151831698
Running the 121-th simulation, Psi_X = 4.359042711991387, Psi_A = 4.358551842537077
Running the 122-th simulation, Psi_X = 4.357457793280275, Psi_A = 4.356882300331949
Running the 123-th simulation, Psi_X = 4.355885386180776, Psi_A = 4.3550522

Running the 214-th simulation, Psi_X = 4.21445407611745, Psi_A = 4.214159643365658
Running the 215-th simulation, Psi_X = 4.213342252945498, Psi_A = 4.212988398962958
Running the 216-th simulation, Psi_X = 4.212262776921023, Psi_A = 4.211867588796526
Running the 217-th simulation, Psi_X = 4.211181390871254, Psi_A = 4.210766325706185
Running the 218-th simulation, Psi_X = 4.2101033548524915, Psi_A = 4.209534180939046
Running the 219-th simulation, Psi_X = 4.208866942699987, Psi_A = 4.2086256110764975
Running the 220-th simulation, Psi_X = 4.207677125989104, Psi_A = 4.207391803502411
Running the 221-th simulation, Psi_X = 4.206595850393478, Psi_A = 4.206252304776685
Running the 222-th simulation, Psi_X = 4.20554604615055, Psi_A = 4.20516098374809
Running the 223-th simulation, Psi_X = 4.204491805245769, Psi_A = 4.204086275343757
Running the 224-th simulation, Psi_X = 4.203438102607743, Psi_A = 4.2028804835998335
Running the 225-th simulation, Psi_X = 4.202224336288912, Psi_A = 4.20198750

Running the 322-th simulation, Psi_X = 4.1147360920817375, Psi_A = 4.114524326240481
Running the 323-th simulation, Psi_X = 4.114083313599358, Psi_A = 4.113847252070913
Running the 324-th simulation, Psi_X = 4.113428742735155, Psi_A = 4.113182756701532
Running the 325-th simulation, Psi_X = 4.112779147286051, Psi_A = 4.112528671568051
Running the 326-th simulation, Psi_X = 4.112136339557871, Psi_A = 4.11179951240093
Running the 327-th simulation, Psi_X = 4.111406142220407, Psi_A = 4.111274209966224
Running the 328-th simulation, Psi_X = 4.1107038235409155, Psi_A = 4.1105433575398385
Running the 329-th simulation, Psi_X = 4.110075860021733, Psi_A = 4.109878132559702
Running the 330-th simulation, Psi_X = 4.109463467984422, Psi_A = 4.109242476689879
Running the 331-th simulation, Psi_X = 4.1088506106905855, Psi_A = 4.108618295765844
Running the 332-th simulation, Psi_X = 4.108239501162267, Psi_A = 4.107920566498356
Running the 333-th simulation, Psi_X = 4.107542868127608, Psi_A = 4.10741

Running the 423-th simulation, Psi_X = 4.056139815901001, Psi_A = 4.055981733088589
Running the 424-th simulation, Psi_X = 4.055653445675834, Psi_A = 4.055477380961172
Running the 425-th simulation, Psi_X = 4.0551655451725725, Psi_A = 4.054980676145208
Running the 426-th simulation, Psi_X = 4.05467840771706, Psi_A = 4.054488577766535
Running the 427-th simulation, Psi_X = 4.054192037728516, Psi_A = 4.053999368415152
Running the 428-th simulation, Psi_X = 4.053706657682921, Psi_A = 4.05351231051315
Running the 429-th simulation, Psi_X = 4.053222226096086, Psi_A = 4.052961279420933
Running the 430-th simulation, Psi_X = 4.052659678362451, Psi_A = 4.05254922766825
Running the 431-th simulation, Psi_X = 4.052129044169664, Psi_A = 4.051999864403729
Running the 432-th simulation, Psi_X = 4.051639761465392, Psi_A = 4.051485998569783
Running the 433-th simulation, Psi_X = 4.0511653763449385, Psi_A = 4.050993526589392
Running the 434-th simulation, Psi_X = 4.050688014865397, Psi_A = 4.050507278

Running the 534-th simulation, Psi_X = 4.008545731290173, Psi_A = 4.008401782213728
Running the 535-th simulation, Psi_X = 4.008190260240161, Psi_A = 4.00804616013225
Running the 536-th simulation, Psi_X = 4.0078354113155035, Psi_A = 4.007691270252868
Running the 537-th simulation, Psi_X = 4.007481147798248, Psi_A = 4.007337032977013
Running the 538-th simulation, Psi_X = 4.007127427247298, Psi_A = 4.006983802670717
Running the 539-th simulation, Psi_X = 4.006775423652153, Psi_A = 4.0066321201134345
Running the 540-th simulation, Psi_X = 4.006424623208967, Psi_A = 4.0062815618049585
Running the 541-th simulation, Psi_X = 4.006074728441171, Psi_A = 4.005931918889187
Running the 542-th simulation, Psi_X = 4.005725637744418, Psi_A = 4.005583131613095
Running the 543-th simulation, Psi_X = 4.005377481145014, Psi_A = 4.005236415744466
Running the 544-th simulation, Psi_X = 4.0050328750991255, Psi_A = 4.004892252007699
Running the 545-th simulation, Psi_X = 4.004690230423125, Psi_A = 4.00450

Running the 647-th simulation, Psi_X = 3.9774588898387258, Psi_A = 3.977371209928836
Running the 648-th simulation, Psi_X = 3.9772448252444645, Psi_A = 3.9771572830872843
Running the 649-th simulation, Psi_X = 3.977031089729839, Psi_A = 3.9769436776312026
Running the 650-th simulation, Psi_X = 3.9768176548171903, Psi_A = 3.976730363038461
Running the 651-th simulation, Psi_X = 3.9766044954475497, Psi_A = 3.9765173142577463
Running the 652-th simulation, Psi_X = 3.976391587457278, Psi_A = 3.9763045078387314
Running the 653-th simulation, Psi_X = 3.9761789085097585, Psi_A = 3.976091921971427
Running the 654-th simulation, Psi_X = 3.975966437702987, Psi_A = 3.9758795362157944
Running the 655-th simulation, Psi_X = 3.975754155399409, Psi_A = 3.9756673313394786
Running the 656-th simulation, Psi_X = 3.97554204306835, Psi_A = 3.975455289176884
Running the 657-th simulation, Psi_X = 3.9753300831619773, Psi_A = 3.9752433925126063
Running the 658-th simulation, Psi_X = 3.975118259011884, Psi_A 

Running the 759-th simulation, Psi_X = 3.9567283765239973, Psi_A = 3.9566682932287702
Running the 760-th simulation, Psi_X = 3.9565799909965156, Psi_A = 3.95652001486515
Running the 761-th simulation, Psi_X = 3.956431940974307, Psi_A = 3.9563721271568157
Running the 762-th simulation, Psi_X = 3.9562842632788358, Psi_A = 3.9562246764444287
Running the 763-th simulation, Psi_X = 3.956137039062573, Psi_A = 3.956077605306133
Running the 764-th simulation, Psi_X = 3.9559901643828224, Psi_A = 3.9559308678077936
Running the 765-th simulation, Psi_X = 3.955843607269416, Psi_A = 3.9557844326674974
Running the 766-th simulation, Psi_X = 3.9556973392266075, Psi_A = 3.9556382765801277
Running the 767-th simulation, Psi_X = 3.9555513413433943, Psi_A = 3.9554925392290636
Running the 768-th simulation, Psi_X = 3.955405885150515, Psi_A = 3.9553472391494213
Running the 769-th simulation, Psi_X = 3.9552608300592613, Psi_A = 3.9552023063398978
Running the 770-th simulation, Psi_X = 3.955116108684648, Psi

Running the 883-th simulation, Psi_X = 3.940039598772584, Psi_A = 3.939992176727638
Running the 884-th simulation, Psi_X = 3.9399216633497933, Psi_A = 3.9398743156449014
Running the 885-th simulation, Psi_X = 3.9398039479834366, Psi_A = 3.9397566984752292
Running the 886-th simulation, Psi_X = 3.939686528873585, Psi_A = 3.939639496237337
Running the 887-th simulation, Psi_X = 3.939569681300787, Psi_A = 3.9395227138707263
Running the 888-th simulation, Psi_X = 3.939453156476752, Psi_A = 3.939406267236797
Running the 889-th simulation, Psi_X = 3.9393369146412676, Psi_A = 3.9392901100964486
Running the 890-th simulation, Psi_X = 3.939220933252436, Psi_A = 3.939174211818233
Running the 891-th simulation, Psi_X = 3.9391051916920086, Psi_A = 3.9390585515612315
Running the 892-th simulation, Psi_X = 3.938989694100962, Psi_A = 3.9389431283093272
Running the 893-th simulation, Psi_X = 3.9388744144399688, Psi_A = 3.93882791596873
Running the 894-th simulation, Psi_X = 3.9387593304002464, Psi_A =

Running the 981-th simulation, Psi_X = 3.9294600846608327, Psi_A = 3.929421006562197
Running the 982-th simulation, Psi_X = 3.929358936388008, Psi_A = 3.9293188361911446
Running the 983-th simulation, Psi_X = 3.9292576520891114, Psi_A = 3.929217104051259
Running the 984-th simulation, Psi_X = 3.929156825865149, Psi_A = 3.9291159424233117
Running the 985-th simulation, Psi_X = 3.9290562503605253, Psi_A = 3.9290151243729436
Running the 986-th simulation, Psi_X = 3.9289558232211443, Psi_A = 3.9289008632194666
Running the 987-th simulation, Psi_X = 3.9288381443476283, Psi_A = 3.928814618796373
Running the 988-th simulation, Psi_X = 3.9287270446802682, Psi_A = 3.9286995864615215
Running the 989-th simulation, Psi_X = 3.928627170348132, Psi_A = 3.9285940549967995
Running the 990-th simulation, Psi_X = 3.9285287825451225, Psi_A = 3.9284923022094156
Running the 991-th simulation, Psi_X = 3.9284298753622733, Psi_A = 3.928391703102605
Running the 992-th simulation, Psi_X = 3.928330838392233, Psi

Running the 105-th simulation, Psi_X = 4.394303761588309, Psi_A = 4.392110583832471
Running the 106-th simulation, Psi_X = 4.389984476452863, Psi_A = 4.388014167255743
Running the 107-th simulation, Psi_X = 4.385927345333433, Psi_A = 4.383753124722335
Running the 108-th simulation, Psi_X = 4.381632187197393, Psi_A = 4.379598412571034
Running the 109-th simulation, Psi_X = 4.377518741849796, Psi_A = 4.3755959675436795
Running the 110-th simulation, Psi_X = 4.373514088355042, Psi_A = 4.370974571210532
Running the 111-th simulation, Psi_X = 4.369416481401962, Psi_A = 4.367035563893653
Running the 112-th simulation, Psi_X = 4.365097954678332, Psi_A = 4.362851767700577
Running the 113-th simulation, Psi_X = 4.360975222386531, Psi_A = 4.358809092267019
Running the 114-th simulation, Psi_X = 4.356926389284761, Psi_A = 4.354763219890917
Running the 115-th simulation, Psi_X = 4.352897549956941, Psi_A = 4.350899421689974
Running the 116-th simulation, Psi_X = 4.349057333968241, Psi_A = 4.3472171

Running the 204-th simulation, Psi_X = 4.138497604840472, Psi_A = 4.1376162912607555
Running the 205-th simulation, Psi_X = 4.1368570938444025, Psi_A = 4.136072629309971
Running the 206-th simulation, Psi_X = 4.135331611479003, Psi_A = 4.1342909543616715
Running the 207-th simulation, Psi_X = 4.133759535914984, Psi_A = 4.132779463978723
Running the 208-th simulation, Psi_X = 4.132084325277354, Psi_A = 4.131095545959232
Running the 209-th simulation, Psi_X = 4.130452583117678, Psi_A = 4.1295070265232825
Running the 210-th simulation, Psi_X = 4.12883656784905, Psi_A = 4.127940763776244
Running the 211-th simulation, Psi_X = 4.127220457909345, Psi_A = 4.126410772467423
Running the 212-th simulation, Psi_X = 4.125715319800177, Psi_A = 4.124891913996522
Running the 213-th simulation, Psi_X = 4.124187619600031, Psi_A = 4.12346555399032
Running the 214-th simulation, Psi_X = 4.122759549436274, Psi_A = 4.121814876931467
Running the 215-th simulation, Psi_X = 4.121289969522983, Psi_A = 4.120399

Running the 306-th simulation, Psi_X = 4.018570063422802, Psi_A = 4.018153648948994
Running the 307-th simulation, Psi_X = 4.017782751632205, Psi_A = 4.017288341083823
Running the 308-th simulation, Psi_X = 4.017000502270443, Psi_A = 4.0165365474468375
Running the 309-th simulation, Psi_X = 4.016174307692678, Psi_A = 4.015746126914945
Running the 310-th simulation, Psi_X = 4.015392940152226, Psi_A = 4.014966438790517
Running the 311-th simulation, Psi_X = 4.014605251653055, Psi_A = 4.014183932967718
Running the 312-th simulation, Psi_X = 4.013826076087577, Psi_A = 4.013405606390231
Running the 313-th simulation, Psi_X = 4.013102715040018, Psi_A = 4.012640232102604
Running the 314-th simulation, Psi_X = 4.012300666283034, Psi_A = 4.01188433645903
Running the 315-th simulation, Psi_X = 4.011536626901895, Psi_A = 4.011133779984536
Running the 316-th simulation, Psi_X = 4.010792088558618, Psi_A = 4.010390480470226
Running the 317-th simulation, Psi_X = 4.010036565937977, Psi_A = 4.00965048

Running the 411-th simulation, Psi_X = 3.9519618063442703, Psi_A = 3.951627648707924
Running the 412-th simulation, Psi_X = 3.951376843015467, Psi_A = 3.9510850118867293
Running the 413-th simulation, Psi_X = 3.950830404914399, Psi_A = 3.950549989317298
Running the 414-th simulation, Psi_X = 3.9502934257066773, Psi_A = 3.9500228581271597
Running the 415-th simulation, Psi_X = 3.9497629123659634, Psi_A = 3.949493318246436
Running the 416-th simulation, Psi_X = 3.9492336058818096, Psi_A = 3.9489175548403592
Running the 417-th simulation, Psi_X = 3.948715270941541, Psi_A = 3.948415519224496
Running the 418-th simulation, Psi_X = 3.9481712039572407, Psi_A = 3.9478875830803535
Running the 419-th simulation, Psi_X = 3.947655650870578, Psi_A = 3.9473665630575847
Running the 420-th simulation, Psi_X = 3.947112217094208, Psi_A = 3.9468419333803815
Running the 421-th simulation, Psi_X = 3.9465848352820854, Psi_A = 3.9462676151884715
Running the 422-th simulation, Psi_X = 3.9460735681274484, Psi_

Running the 509-th simulation, Psi_X = 3.914155515794543, Psi_A = 3.913995925388537
Running the 510-th simulation, Psi_X = 3.9138791800258144, Psi_A = 3.9137421356598527
Running the 511-th simulation, Psi_X = 3.913623446303345, Psi_A = 3.913491545104178
Running the 512-th simulation, Psi_X = 3.9133729915699016, Psi_A = 3.9132132977734346
Running the 513-th simulation, Psi_X = 3.91309786875371, Psi_A = 3.9129526472550604
Running the 514-th simulation, Psi_X = 3.9128333724493167, Psi_A = 3.9127099287247904
Running the 515-th simulation, Psi_X = 3.912588122769831, Psi_A = 3.912459456127288
Running the 516-th simulation, Psi_X = 3.912337236818528, Psi_A = 3.9122114560702927
Running the 517-th simulation, Psi_X = 3.912105993233979, Psi_A = 3.9119637898524457
Running the 518-th simulation, Psi_X = 3.9118615487652244, Psi_A = 3.91172140768114
Running the 519-th simulation, Psi_X = 3.9116099029290066, Psi_A = 3.911479938377029
Running the 520-th simulation, Psi_X = 3.9113748928133196, Psi_A = 

Running the 618-th simulation, Psi_X = 3.892321872755426, Psi_A = 3.8922527447879207
Running the 619-th simulation, Psi_X = 3.8921927849431763, Psi_A = 3.892126392192371
Running the 620-th simulation, Psi_X = 3.892064764219633, Psi_A = 3.8919991592226673
Running the 621-th simulation, Psi_X = 3.8919372005120314, Psi_A = 3.89186138966136
Running the 622-th simulation, Psi_X = 3.891812769250016, Psi_A = 3.8917407040449548
Running the 623-th simulation, Psi_X = 3.8916823303402492, Psi_A = 3.8916144981392895
Running the 624-th simulation, Psi_X = 3.891555685542446, Psi_A = 3.8914903405905736
Running the 625-th simulation, Psi_X = 3.8914297592511176, Psi_A = 3.8913650117052376
Running the 626-th simulation, Psi_X = 3.891305426960954, Psi_A = 3.891240586939275
Running the 627-th simulation, Psi_X = 3.8911804243489323, Psi_A = 3.8911153373439062
Running the 628-th simulation, Psi_X = 3.8910568972923034, Psi_A = 3.890991543550969
Running the 629-th simulation, Psi_X = 3.890931254181669, Psi_A 

Running the 715-th simulation, Psi_X = 3.880593400004305, Psi_A = 3.880544108501649
Running the 716-th simulation, Psi_X = 3.8805093461407223, Psi_A = 3.8804535268259
Running the 717-th simulation, Psi_X = 3.8804170329728582, Psi_A = 3.8803679515074956
Running the 718-th simulation, Psi_X = 3.88033307630791, Psi_A = 3.880286126815882
Running the 719-th simulation, Psi_X = 3.880242533864921, Psi_A = 3.880191337955505
Running the 720-th simulation, Psi_X = 3.880145379774434, Psi_A = 3.880085808350136
Running the 721-th simulation, Psi_X = 3.8800484169655105, Psi_A = 3.879992666771014
Running the 722-th simulation, Psi_X = 3.879955970388901, Psi_A = 3.8798998195398524
Running the 723-th simulation, Psi_X = 3.8798659961954445, Psi_A = 3.879815149647081
Running the 724-th simulation, Psi_X = 3.8797765306966414, Psi_A = 3.8797192313988025
Running the 725-th simulation, Psi_X = 3.879677375190755, Psi_A = 3.8796222850333555
Running the 726-th simulation, Psi_X = 3.879574891827969, Psi_A = 3.87

Running the 820-th simulation, Psi_X = 3.8717478133843244, Psi_A = 3.8717065961302533
Running the 821-th simulation, Psi_X = 3.871668267460071, Psi_A = 3.87162716680002
Running the 822-th simulation, Psi_X = 3.871596780956965, Psi_A = 3.8715556617590585
Running the 823-th simulation, Psi_X = 3.871524129581552, Psi_A = 3.871488193054963
Running the 824-th simulation, Psi_X = 3.8714576826798495, Psi_A = 3.871420687708499
Running the 825-th simulation, Psi_X = 3.871392792057843, Psi_A = 3.8713561031369723
Running the 826-th simulation, Psi_X = 3.871333713573832, Psi_A = 3.8712912769330847
Running the 827-th simulation, Psi_X = 3.871264433604632, Psi_A = 3.871228758601264
Running the 828-th simulation, Psi_X = 3.8712020665077223, Psi_A = 3.871168600743311
Running the 829-th simulation, Psi_X = 3.8711360166106727, Psi_A = 3.8710979046581655
Running the 830-th simulation, Psi_X = 3.8710632871705486, Psi_A = 3.8710191572282513
Running the 831-th simulation, Psi_X = 3.870990501356448, Psi_A = 

Running the 927-th simulation, Psi_X = 3.8646301641814906, Psi_A = 3.864591442461005
Running the 928-th simulation, Psi_X = 3.8645612373072558, Psi_A = 3.864523963080972
Running the 929-th simulation, Psi_X = 3.8644911867774105, Psi_A = 3.8644496672934348
Running the 930-th simulation, Psi_X = 3.864425265996103, Psi_A = 3.864384773131137
Running the 931-th simulation, Psi_X = 3.864359743531138, Psi_A = 3.864323439962568
Running the 932-th simulation, Psi_X = 3.864292638781164, Psi_A = 3.864257009952485
Running the 933-th simulation, Psi_X = 3.864225067346405, Psi_A = 3.864189808230459
Running the 934-th simulation, Psi_X = 3.8641570018155202, Psi_A = 3.8641225749238592
Running the 935-th simulation, Psi_X = 3.8640892957294306, Psi_A = 3.864054163269765
Running the 936-th simulation, Psi_X = 3.8640209217780743, Psi_A = 3.8639865597894474
Running the 937-th simulation, Psi_X = 3.863953243184742, Psi_A = 3.863917510691195
Running the 938-th simulation, Psi_X = 3.863888988130241, Psi_A = 3

  return ufunc.reduce(obj, axis, dtype, out, **passkwargs)


TypeError: unorderable types for comparison

In [10]:
opts["W"].shape

(100, 1)

In [77]:

%% plotting results
% figure;
% plot(A_0);

color = {'r','g','b','k'};

figure(1);
hold on;
for k = 1:length(Alg_type)
    plot(log( Psi{k} - Psi_min ), color{k}, 'LineWidth', 2);
end
leg1 = legend(Alg_type);
set(leg1,'FontSize',16); set(leg1,'Interpreter','latex');
xlabel('Iteration','Interpreter','latex','FontSize',16);
ylabel('$\log ( \Psi(${\boldmath$a$},{\boldmath$x$}$) - \Psi_{\min} )$',...
    'Interpreter','latex','FontSize',16);
xlim([0,opts.MaxIter]);
set(gca, 'FontName', 'Times New Roman','FontSize',14);
title('(a) function value convergence','Interpreter','latex','FontSize',20);
grid on;

figure(2);
hold on;
for k = 1:length(Alg_type)
    plot(log(Err_A{k}), color{k}, 'LineWidth', 2);
end
leg2 = legend(Alg_type);
set(leg2,'FontSize',16); set(leg2,'Interpreter','latex');
xlabel('Iteration','Interpreter','latex');
ylabel('$\log ( \min \{||${\boldmath$a$}$_\star-${\boldmath$a$}$_0 ||\;,||${\boldmath$a$}$_\star + ${\boldmath$a$}$_0  || \} )$',...
    'Interpreter','latex','FontSize',16);
xlim([0,opts.MaxIter]);
set(gca, 'FontName', 'Times New Roman','FontSize',14);
title('(b) iterate convergence','Interpreter','latex','FontSize',20);
grid on;

save('incoherent.mat','Psi_min','Alg_type','opts','Err_A','Psi');

In [8]:
import numpy as np

def Log_map(Z,D):
    proj_a = lambda w, z: z - np.inner(w, z) * w
    n, K = Z.shape
    T = np.zeros([n, K])
    for k in range(K):
        alpha = np.arccos(np.inner(Z[:, k], D[:, k]))
        proj_temp = proj_a(Z[:, k], D[:, k])
        T[:, k] = proj_temp * alpha / np.sin(alpha)
    
    return T

# -------------------------------
def Retract(Z, D, t):
    n, K = Z.shape
    T = np.zeros([n, K])

    for k in range(K):
        T[:, k] = Z[:, k] * np.cos(t[k]) + (D[:, k] / t[k]) * np.sin(t[k])
    T = T / np.linalg.norm(T, axis = 0) # Normalize T by column

    return T

# -------------------------------
def backtracking(y, A, X, fx, grad_fx, lamb, t, opts):
    # update X via backtracking linesearch
    m = np.max(y.shape)
    Q = lambda Z, tau: fx + np.linalg.norm(lamb * Z, 1) + innerprod(grad_fx, Z - X) \
        + 0.5 / tau * np.linalg.norm(Z - X, "fro") ** 2
    t = 8 * t
    X1 = soft_thres(X - t * grad_fx, lamb * t) # Proximal mapping
    if opts["isnonnegative"]:
        X1 = np.max(X1, 0)

    if opts["hard_thres"]:
        ind = (X1 <= opts["hard_threshold"])
        X1[ind] = 0
    while Psi_val(y, A, X1, lamb) > Q(X1,t):
        t = 1 / 2 * t
        X1 = soft_thres(X - t * grad_fx, lamb * t)
        if opts["isnonnegative"]:
            X1 = np.max(X1, 0)
        if opts["isupperbound"]:
            X1 = np.min(X1, opts["upperbound"])
        if opts["hard_thres"]:
            ind = (X1 <= opts["hard_threshold"])
            X1[ind] = 0
    return [X1, t]

def innerprod(U, V):
    # This function is a dependency for the backtracking function
    return np.sum(U * V)

def Psi_val(y, A, Z, lamb = None):
    # This function is a dependency for the backtracking function
    # Also a dependency for the linesearch function
    m = np.max(y.shape)
    n, K = A.shape
    y_hat = np.zeros(y.shape)

    for k in range(K):
        y_hat = y_hat + cconv(A[:, k], Z[:, k], m)

    if lamb is None:
        f = 0.5 * np.sum((y - y_hat) ** 2)
    else:
        f = 0.5 * np.linalg.norm(y - y_hat) ** 2 + np.linalg.norm(lamb * Z, 1)

    return f

# -------------------------------
def compute_error(A, X, opts):
    n, K = A.shape
    m, m_hat = X.shape

    A_0 = np.vstack([np.zeros([int(n / 3), K]), opts["A_0"],
                     np.zeros([int(n / 3), K])]) # Why
    X_0 = opts["X_0"]
    err_A = 0
    err_X = 0
    for i in range(K):
        a = A[:, i]
        x = X[:, i]
        cor = np.zeros([K, 1])
        ind = np.zeros([K, 1])
        for j in range(K):
            #### Circular convolution again
            Corr = cconv(reversal(A_0[:, j]), a, m)
            cor[j] = np.max(np.abs(Corr), axis = 0)
            ind[j] = np.argmax(np.abs(Corr), axis = 0)
        Ind = np.argmax(cor)
        # Use np.roll to mimic circshift
        a_max = np.roll(A_0[:, Ind], int(ind[Ind] - 1), axis = 0)
        x_max = np.roll(X_0[:, Ind], int(-(ind[Ind] - 1)), axis = 0)
        err_A = err_A + np.min([np.linalg.norm(a_max - a), np.linalg.norm(a_max + a)])
        err_X = err_X + np.min([np.linalg.norm(x_max - x), np.linalg.norm(x_max + x)])
        
    return [err_A, err_X]

# -------------------------------
def compute_gradient(A, X, y_b, y_hat, gradient_case):
    # Compute (Riemannian) gradient
    proj_a = lambda w, z: z - np.inner(w, z) * w
    m, K = X.shape
    n, n_hat = A.shape

    if gradient_case == 0:
        Grad = np.zeros([m, K])
    elif gradient_case == 1:
        Grad = np.zeros([n, K])

    for k in range(K):
        if gradient_case == 0:
            Grad[:, k] = cconv(reversal(A[:, k], m), y_hat - y_b, m).flatten()
        elif gradient_case == 1:
            G = cconv(reversal(X[:, k], m), y_hat - y_b, m).flatten()
            Grad[:, k] = proj_a(A[:, k], G[:n])
    
    return Grad

# -------------------------------
def compute_y(A, X):
    # compute y = sum_k conv(a_k,x_k)
    m, K = X.shape
    y_hat = np.zeros([m, 1])
    for k in range(K):
        ### Circular convolution alert
        y_hat = y_hat + cconv(A[:, k], X[:, k], m)
    
    return y_hat

# -------------------------------
def gen_data(theta, m, n, b, noise_level, a_type,x_type):
    # generate the groudtruth data
    # y = sum_{k=1}^K a0k conv x0k + b*1 + n
    # s = rng(seed)
    
    # generate the kernel a_0
    gamma = [1.7, -0.712] # Parameter for AR2 model
    t = np.linspace(0, 1, n).reshape([n, 1]) # [0:1/(n-1):1]'
    case = a_type.lower()
    if case == "randn": # Random Gaussian
        a_0 = np.random.normal(size = [n, 1])
    elif case == "ar2": # AR2 kernel
        tau = 0.01 * ar2exp(gamma) # Function defined below
        a_0 = np.exp(-t / tau[0]) - np.exp(-t/tau[1])
    elif case == "ar1": # AR1 model
        tau = 0.25
        a_0 = np.exp(-t / tau)
    elif case == "gaussian":
        t = np.linspace(-2, 2, n).reshape([n, 1])
        a_0 = np.exp(-t**2)
    elif case == "sinc":
        sigma = 0.05
        a_0 = np.sinc((t-0.5)/sigma)
    else:
        raise ValueError("Wrong type")

    a_0 = a_0 / np.linalg.norm(a_0, axis = 0)  # Normalize kernel by column

    # Generate the spike train x_0
    case_x = x_type.lower()
    if case_x == "bernoulli":
        x_0 = (np.random.uniform(size = [m, 1]) <= theta) # Bernoulli spike train
    elif case_x == 'bernoulli-rademacher':
        x_0 = (np.random.uniform(size = [m, 1]) <= theta) * ((np.random.uniform(
            size = [m, 1]) < 0.5) - 0.5) * 2
    elif case_X == 'bernoulli-gaussian':
        # Gaussian-Bernoulli spike train
        x_0 = np.random.normal([m, 1]) * (np.random.uniform(m, 1) <= theta)
    else:
        raise ValueError("Wrong type")

    # generate the data y = a_0 conv b_0 + bias + noise
    ##### Circular convolution alert
    y_0 = cconv(a_0, x_0, m) + b * np.ones([m,1])
    y = y_0 + np.random.normal(size = [m, 1]) * noise_level
        
    return [a_0, x_0, y_0, y]

# -------------------------------
def ar2exp(g):
    # get parameters of the convolution kernel for AR2 process
    # Dependency of gen_data
    if len(g) == 1:
        g.append(0)
    temp = np.roots([1, -g[0], -g[1]]) # Polynomial roots
    d = np.max(temp)
    r = np.min(temp)
    tau_d = -1 / np.log(d)
    tau_r = -1 / np.log(r)

    tau_dr = [tau_d, tau_r]
    return tau_dr

# -------------------------------
def linesearch(y, A, X, fa, grad_a):
    # update A via Riemannian linsearch
    m = np.max(y.shape)
    K_hat, K = A.shape
    eta = 0.8
    tau = 1

    norm_grad = np.linalg.norm(grad_a, "fro")
    Norm_G = np.zeros([K, 1])
    for k in range(K):
        Norm_G[k] = np.linalg.norm(grad_a[:, k])

    A1 = Retract(A, -tau * grad_a, tau * Norm_G)

    count = 1
    while Psi_val(y, A1, X) > fa - eta * tau * norm_grad ** 2:
        tau = 0.5 * tau
        A1 = Retract(A, -tau*grad_a, tau*Norm_G)

        if count >= 100:
            break
        count += 1
    
    return [A1, tau]

# -------------------------------
def reversal(X, m = None):
    if len(X.shape) == 1:
        X = X.reshape([X.shape[0], 1])
    if m != None:
        X = np.vstack([X[:np.min([X.shape[0], m]), :], np.zeros(
            [np.max([m - X.shape[0], 0]), X.shape[1]])])
        
    revX = np.vstack([X[0, :], np.flipud(X[1:, :])])
    
    return [revX]

# -------------------------------
def shift_correlation(a, x, opts):
    a_0 = opts["A_0"]
    x_0 = opts["x_0"]

    n_0 = np.max(a_0.shape)
    n = np.max(a.shape)
    m = np.max(x.shape)

    if opts["ground_truth"]:
        ###### Circular convolution alert
        Corr = cconv(reversal(a_0), a, m)
        ind_hat, ind = np.max(np.abs(Corr))
        Corr_max = Corr[ind]

        if Corr_max > 0:
            ## not for sure since circshift shift the first dim that's not 1
            a_shift = np.roll(a, ind - 1)
            x_shift = np.roll(x, -(ind - 1))
        else:
            a_shift = -np.roll(a, ind - 1)
            x_shift = np.roll(x, -(ind - 1))
            # a_max = np.roll(A_0[:, Ind], ind[Ind] - 1, axis = 0)
    
    return [a_shift, x_shift]

# -------------------------------
def soft_thres(z, lamb):
    z = np.sign(z) * np.maximum(np.abs(z) - lamb, 0)
    
    return z

# -------------------------------
def cconv(vec1, vec2, length):
    # Since there's a lot of functions use circular function
    # and python doesn't have a function for that
    
    vec1 = np.array(vec1)
    vec2 = np.array(vec2)
    return (np.fft.ifft(np.fft.fft(vec1.flatten(), length) \
                        * np.fft.fft(vec2.flatten(), length))).reshape([length,1])
   
