In [1]:
import numpy as np 
from tensorflow.keras.datasets import mnist
from tqdm import tqdm

# Dictionary learning step 

In [29]:
def soft_shrink(
        x:np.array, 
        lambda_:float
):
    
    return np.sign(x) * np.maximum(0, np.abs(x) - lambda_)

def sparse_coding_IPM(
        A:np.array, 
        D:np.array, 
        Lambda1:float, 
        Lambda2:float, 
        eps = 1e-4, 
        max_iter = 10
):

    X = np.random.randn(D.shape[1], A.shape[1])

    loss = np.linalg.norm(A - D @ X)**2 + Lambda1 * np.linalg.norm(X, ord = 1)
    iters = 0

    # Useful precomputations
    B = D.T @ A
    C = D.T @ D

    # Optimal stepsize
    mu = np.real(1 /np.max(np.real(np.linalg.eigvals(C))))
    
    while loss > eps and iters < max_iter:

        m = np.mean(X, axis = 0)
        M = np.tile(m.reshape(-1, 1), X.shape[0]).transpose(1,0)

        X_hat = X + mu*((B - C @ X) + Lambda2 * (X - M))

        X = soft_shrink(X_hat, Lambda1 * mu)

        # Loss tracking
        loss = np.linalg.norm(A - D @ X)**2 + Lambda1 * np.linalg.norm(X, ord = 1) + Lambda2 * np.linalg.norm(X - M)
        iters += 1
        
    return X

In [30]:
def get_Q(
        D:np.array, 
        W:np.array, 
        C:int,
        n:int
):

    S, N = D.shape 
    Q = np.zeros((S, S))  
    
    for m in np.arange(N)[np.arange(N) != n]:

        inner_sum = 0
        
        for j in range(C):

            for l in np.arange(C)[np.arange(C) != j]:

                inner_sum += W[m,j] * W[n,l]

        Q += np.outer(D[:,m], D[:,m]) * inner_sum

    return Q

def dictionary_update(
        A:dict, 
        X:dict, 
        W:np.array, 
        D:np.array, 
        C:int, 
        Lambda3:float
):

    Y = {c:None for c in range(C)}
    
    for c in range(C):

        Y[c] = np.diag(W[:,c]) @ X[c] 

    Y = np.concatenate(list(Y.values()), axis = 1)
    A = np.concatenate(list(A.values()), axis = 1)
    
    G = Y @ Y.T
    L = A @ Y.T

    N = D.shape[1]
    D_hat = np.zeros_like(D)

    for n in range(N):

        Q = get_Q(D, W, C, n)
        u = np.linalg.inv(L[n,n] * np.eye(Q.shape[0]) + Lambda3 * np.diag(np.diag(Q))) @ (L[:,n] - D @ G[:,n] - Lambda3 * Q @ D[:,n])
        D_hat[:,n] = u / np.linalg.norm(u)

    return D_hat

In [31]:
def DictionaryLearning(
        A:dict, 
        W:np.array,
        D:np.array,
        C:int, 
        Lambda1:float, 
        Lambda2:float, 
        Lambda3:float, 
        eps = 1e-2, 
        MAXITER=10
):

    for _ in range(MAXITER):
        
        # Sparse coding step

        X = {c: None for c in range(C)}

        for c in range(C):

            X[c] = sparse_coding_IPM(A[c], D, Lambda1, Lambda2)

        # Dictionary update step

        D = dictionary_update(A, X, W, D, C, Lambda3)

    return X, D

# Latent matrix update step 

In [32]:
def latent_vector_update(
        X:dict,
        k:int, 
        N:int,
        C:int,
        delta:float,
        sigma:float,
        Lambda3:float,
        A:np.array,
        D:np.array,
        W:np.array,
        MAXITER = 10,
        ):
    
    b = np.zeros(N)
    w = np.copy(W[:,k])

    for n in range(N):

        b_ = 0

        for m in [i for i in range(N) if i != n]:
            
            d_mn = np.dot(D[:,m], D[:,n])**2
            inner_sum = 0

            for j in [i for i in range(C) if i != k]:

                inner_sum += W[j,m]
            
            b_ += d_mn * inner_sum
        
        b[n] = b_
    
    a = A.flatten()
    B = []
    for n in range(N):
        B.append(np.outer(D[:,n],X[n,:])) 
    
    R = np.concatenate([B_.flatten() for B_ in B])

    iter = 0

    while iter < MAXITER:

        tau_0 = w - R.T @ (R @ (w - a) + Lambda3 * b) / sigma
        tau_1 = tau_0 - (np.sum(tau_0 - delta))/N
        tau_2 = np.maximum(0, tau_1)
        w = tau_2 /(np.sum(tau_2) * delta)

        iter += 1
        
    return w

In [33]:
def LatentMatrixLearning(
        X:dict,
        N:int,
        C:int,
        delta:float,
        sigma:float,
        Lambda3:float,
        A:np.array,
        D:np.array,
        W:np.array,
        MAXITER = 100,
):

    W_hat = np.zeros_like(W)

    for k in range(C):
        W_hat[:,k] = latent_vector_update(X, k, N, C, delta, sigma, Lambda3, A, D, W, MAXITER)

    return W_hat

# Full pipeline

In [34]:
def LDL(        
        N:int,
        C:int,
        delta:float,
        sigma:float,
        Lambda1:float,
        Lambda2:float,
        Lambda3:float,
        A:np.array,
        MAXITER = 100
        ):
    
    # Dictionary initialization
    D = np.random.randn(A[0].shape[0], N)
    D = np.apply_along_axis(lambda x: x/np.linalg.norm(x), axis = 0, arr = D)
    
    # Initialize the latent matrix
    W = np.random.randn(N, C)
    W = np.maximum(0, W)

    for _ in tqdm(range(MAXITER)):

        # Dictionary update
        X, D = DictionaryLearning(A, W, D, C, Lambda1, Lambda2, Lambda3)

        # Latent matrix update
        W = LatentMatrixLearning(X, N, C, delta, sigma, Lambda3, A, D, W)

    return D, W     

In [35]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_data = np.concatenate([x_train, x_test])
y_data = np.concatenate([y_train, y_test])

# Flattening images to vectors
x_data = x_data.reshape(x_data.shape[0], -1)  

# Initialize an empty dictionary to store the data
mnist_dict = {class_label: None for class_label in range(10)}

# Loop through each class (0 to 9) and collect the corresponding flattened images
for class_label in range(10):
    mnist_dict[class_label] = x_data[np.where(y_data == class_label)[0][:100]].transpose(1,0)

In [None]:
N = 1000
C = 10

D, W = LDL(N, C, 1, 0.01, 0.15, 0.25, 0.35, mnist_dict)