In [1]:
import numpy as np
from scipy.linalg import norm
import matplotlib.pyplot as plt

# Define OMP function
def omp(A, y, k):
    n, p = A.shape
    x = np.zeros(p)
    r = y.copy()
    idx = np.zeros(k, dtype=int)
    for i in range(k):
        idx[i] = np.argmax(np.abs(np.dot(A.T, r)))
        x_new = np.zeros(p)
        x_new[idx[:i+1]] = np.dot(np.linalg.pinv(A[:, idx[:i+1]]), y)
        r = y - np.dot(A, x_new)
        normr = norm(r)
        if normr < 1e-6:
            break
    return x_new

# Define GPSR function
def gpsr(A, y, lambda_, tol=1e-5, max_iter=1000):
    x = np.zeros(A.shape[1])
    z = np.zeros(A.shape[1])
    t = 1.0
    iter = 0
    residual = np.inf
    while residual > tol and iter < max_iter:
        x_old = x.copy()
        z_old = z.copy()
        grad = np.dot(A.T, np.dot(A, x) - y) + lambda_ * np.sign(z)
        z -= t * grad
        x = z[:A.shape[1]]
        z[A.shape[1]:] = np.maximum(np.abs(z[A.shape[1]:]) - lambda_ * t, 0) * np.sign(z[A.shape[1]:])
        t *= 1.01
        iter += 1
        residual = norm(x - x_old) / norm(x_old)
    return x

# Generate random sparse signal
n = 100  # signal length
k = 10  # sparsity level
x = np.zeros(n)
x[:k] = np.random.randn(k)

# Generate measurement matrix
m = 50  # number of measurements
A = np.random.randn(m, n)

# Generate noisy measurements
sigma = 0.1  # noise level
e = sigma * np.random.randn(m)
y = np.dot(A, x) + e

# Solve OMP problem
x_omp = omp(A, y, k)

# Solve GPSR problem
lambda_ = sigma * np.sqrt(2 * np.log(n))  # regularization parameter
x_gpsr = gpsr(A, y, lambda_)

# Solve BP problem
from cvxpy import Variable, Problem, Minimize, norm1
x_bp = Variable(n)
prob = Problem(Minimize(norm1(x_bp)), [A @ x_bp == y])
prob.solve()

# Print results
print("Original signal: ", x)
print("Recovered signal (OMP):", x_omp)
print("Error (OMP):", norm(x - x_omp))
print("Recovered signal (GPSR):", x_gpsr)
print("Error (GPSR):", norm(x - x_gpsr))
print("Recovered signal (BP):", x_bp.value)
print("Error (BP):", norm(x - x_bp.value))

# Plot results
fig, ax = plt.subplots()
ax.plot(x, label='Original signal')
ax.plot(x_omp, label='OMP')
ax.plot(x_gpsr, label='GPSR')
ax.plot(x_bp.value, label='BP')
ax.legend()


ZeroDivisionError: ignored