In [69]:
import numpy as np
import torch

In [209]:
def quantization(T,z,b):
  scale = (torch.max(T)-torch.min(T))/(2**b-1)
  f = lambda x: np.clip(np.round(x/scale)+z,-2**(b-1),2**(b-1)-1)
  T.apply_(f)
  return T

def ADMM(B, U, K, G, rank, epsilon):

    ro = torch.trace(G) / rank

    LL = G + ro * torch.eye(*G.size())
    L = torch.linalg.cholesky(LL)
    LL_inv = torch.cholesky_inverse(L)

    r = torch.inf
    s = torch.inf

    while (r > epsilon) or (s > epsilon):
        B_ = LL_inv @ torch.transpose(K + ro * (B + U), 0, 1)
        B0 = torch.clone(B)

        B = quantization(torch.transpose(B_,0,1) - U,0,8)

        U = U + B - torch.transpose(B_,0,1)

        r = torch.norm(B - torch.transpose(B_,0,1), p='fro')**2 / torch.norm(B, p='fro')**2
        s = torch.norm(B - B0, p='fro')**2 / torch.norm(U, p='fro')**2

    return B, U

In [210]:
X = torch.rand(6, 7)
rank = 4
epsilon = 0.01
A, B = torch.linalg.qr(X)
B = torch.transpose(B, 0, 1)

A_old = torch.zeros_like(A)+1000
B_old = torch.zeros_like(B)+1000

eps = 1
n=0

while torch.max(abs(A_old-A))> eps or torch.max(abs(B_old-B))> eps:

    K = torch.transpose(X, 0, 1) @ A
    G = torch.transpose(A, 0, 1) @ A

    U = torch.rand(*B.size())
    B_old = torch.clone(B)
    B, U = ADMM(B, U, K, G, rank, epsilon)

    K = X @ B
    G = torch.transpose(B, 0, 1) @ B

    U = torch.rand(*A.size())
    A_old = torch.clone(A)
    A, U = ADMM(A, U, K, G, rank, epsilon)

In [214]:
A@B.T

tensor([[ 109.,   22.,   73.,   36.,   28.,  -70.,  218.],
        [ 121.,   34.,  -25.,   28.,   98.,  166., -110.],
        [ 176.,  -97.,  -32.,  -77.,  187.,   72.,   -6.],
        [ -20.,  -71.,  -67.,  -61.,  407.,  196.,   49.],
        [ 110.,  297.,  235.,  271., -131.,   82.,   98.],
        [  91.,  302.,   51.,  260., -135.,   58.,  158.]])

In [213]:
X

tensor([[0.8162, 0.0863, 0.0128, 0.1539, 0.7266, 0.3999, 0.8753],
        [0.1650, 0.7647, 0.0542, 0.2135, 0.5394, 0.7972, 0.2413],
        [0.1260, 0.2097, 0.0328, 0.0626, 0.7054, 0.2129, 0.8634],
        [0.0756, 0.3062, 0.0955, 0.0852, 0.4647, 0.3032, 0.2865],
        [0.1262, 0.3610, 0.0813, 0.0422, 0.1737, 0.1830, 0.2195],
        [0.0890, 0.3659, 0.9238, 0.3759, 0.5191, 0.1736, 0.8454]])