In [1]:
# Training code for Learned AltGDmin
# Implemented by Silpa Babu
# Date: Feb 23, 2024
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
import numpy as np
import scipy.io as sio
import scipy.sparse.linalg as lina
import time

## ================Preparations====================
#device = torch.device('cuda:0')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
datatype = torch.float64

## ================Parameters======================
r 				=5		# underlying rank
n 				= 200	# size (num. of rows)
q 				= 200		# size (num. of columns)
m         = 60    # size(number of measurements)
step_initial 	= 0.5		# initial value of step size (eta in the paper)
maxIt 			= 100	# num. of layers you want to train
thr_initial = 0.4
## =============Generate Data y_k = A_k U b_k=============
def generate_problem(r,n,q):
    U0_t, _ 		= torch.qr(torch.randn(n,r,dtype = datatype)/math.sqrt(n))
    B0_t 		    = torch.randn(r,q,dtype = datatype)/math.sqrt(q)
    A           = torch.randn(m,n,q,dtype = datatype)/math.sqrt(m)
    noise       = torch.randn(n,q,dtype = datatype)/math.sqrt(n * q)
    Y0_t        = torch.zeros(m, q,dtype = datatype)
    for k in range(q):
      Y0_t[:,k]    = torch.matmul(A[:,:,k], torch.matmul(U0_t,B0_t[:,k]))
    return U0_t, B0_t, Y0_t, A

## ===================LRPCA model===================
class MatNet(nn.Module):
  # Define parameters to be learned
  def __init__(self):
    super(type(self),self).__init__()
    #self.thr 		= [nn.Parameter(Variable(torch.tensor(thr_initial, dtype=datatype), requires_grad=True)) for t in range( maxIt)]
    self.step 		= [nn.Parameter(Variable(torch.tensor(step_initial, dtype=datatype), requires_grad=True)) for t in range( maxIt)]
    self.thr 		= [nn.Parameter(torch.tensor(thr_initial, dtype=datatype), requires_grad=True) ]
    #self.thr        = [nn.Parameter(Variable(torch.tensor(thr_initial, dtype=datatype), requires_grad=True))]



  def lowrank(self, X0, threshold):
    X0 = X0.clone()
    Ut, St, Vt = torch.linalg.svd(X0)
    thres = threshold * St[0]
    thres = thres.unsqueeze(-1)
    St = F.relu(St - thres)
    St = St.float()

    Xinit = torch.matmul(torch.matmul(Ut, torch.diag(St)), torch.conj(Vt))
    Ut, St, Vt = torch.svd_lowrank(Xinit, niter = 4)
    return Ut.double()


  def forward(self, Y0_t, U0_t, B0_t, A, num_l):
    # Initialization
    n , r0 = U0_t.size()
    r0, q = B0_t.size()
    X0        = torch.zeros(n, q)
    for k in range(q):
      X0[:,k]    = torch.matmul(A[:,:,k].t(), Y0_t[:,k])
    U_t = self.lowrank(X0, self.thr[0])
    n , r = U_t.size()
    B_t        = torch.zeros(r, q, dtype = datatype)
    for k in range(q):
      B_t[:,k] = torch.matmul(torch.linalg.pinv(torch.matmul(A[:,:,k], U_t)),Y0_t[:,k]) # Column-wise Least Squares
    ## Main Loop
    for t in range(1,num_l):
      E_t = torch.zeros(n, r, dtype = datatype)
      for k in range(q):
        E_t = E_t + torch.matmul(A[:,:,k].t(),torch.matmul((Y0_t[:,k] - torch.matmul(A[:,:,k], torch.matmul(U_t, B_t[:,k]))).unsqueeze(1),B_t[:,k].unsqueeze(1).t())) #  Compute Gradient
      Unew, _ = torch.qr(U_t + self.step[t] * (E_t)) # Projected Gradient Descent
      Bnew        = torch.zeros(r, q, dtype = datatype)
      for k in range(q):
        Bnew[:,k] = torch.matmul(torch.linalg.pinv(torch.matmul(A[:,:,k], Unew)),Y0_t[:,k]) # Column-wise Least Squares
      U_t = Unew
      B_t = Bnew
    Y_t = torch.zeros(m, q,dtype = datatype)
    for k in range(q):
      Y_t[:,k]    = torch.matmul(A[:,:,k], torch.matmul(U_t,B_t[:,k]))
    loss = (Y_t - Y0_t).norm()

    return loss, U_t, B_t



################################################################################# Above Understood #######################################################################################
  def EnableSingleLayer(self,en_l):
    self.thr[0].requires_grad = True
    for t in range(maxIt):
      self.step[t].requires_grad = False
    self.step[en_l].requires_grad = True
    #self.thr[en_l].requires_grad = True
  def EnableLayers(self, num_l):
    self.thr[0].requires_grad = True
    for t in range(num_l):
      self.step[t].requires_grad = True
      #self.thr[t].requires_grad = True
    for t in range(num_l,maxIt):
      self.step[t].requires_grad = False
      #self.thr[t].requires_grad = False


## =================Training Scripts======================
Nepoches_pre 	= 4
Nepoches_full 	= 4
lr_fac 			= 1.0															# basic learning rate

net = MatNet()
optimizers = []
optimizer = optim.Adam({net.thr[0]},lr = lr_fac * 0.01)
optimizers.append(optimizer)
for i in range(1,maxIt):
  optimizer = optim.Adam({net.step[i]},lr = lr_fac * 0.1)	# optimizer for each layer
  #optimizer.add_param_group({'params': [net.thr[i]], 'lr': lr_fac * 0.01})
  optimizers.append(optimizer)

## =================Layerwise Training======================
start = time.time()

for stage in range(1,maxIt):														# in k-th stage, we train the k-th layer

	## Pre-training: only train the k-th layer
	print('Layer ',stage,', Pre-training ======================')
	if(stage > 6):
		Nepoches_full = 2
	for epoch in range(Nepoches_pre):
		for i in range(maxIt):
			optimizers[i].zero_grad()


		U0_t,B0_t,Y0_t, A = generate_problem(r,n,q)
		net.EnableSingleLayer(stage)
		loss, U_t, B_t = net(Y0_t,  U0_t, B0_t,A, stage+1)
		loss.backward()

		optimizers[stage].step()




		if epoch % 2 == 0:
			print("epoch: " + str(epoch), "\t loss: " + str(loss.item()))

	# Full-training: train 0~k th layers
	print('Layer ',stage,', Full-training =====================')
	if stage == 0:
		continue

	for epoch in range(Nepoches_full):
		for i in range(maxIt):
			optimizers[i].zero_grad()

		U0_t,B0_t,Y0_t,A = generate_problem(r,n,q)
		net.EnableLayers(stage+1)
		loss, U_t, B_t = net(Y0_t,  U0_t, B0_t, A, stage+1)
		loss.backward()

		for i in range(stage+1):
			optimizers[i].step()

		if epoch % 2 == 0:
			print("epoch: " + str(epoch), "\t loss: " + str(loss.item()))


end = time.time()
print("Training end. Time: " + str(end - start))

## =====================Save model to .mat file ========================
result_stp = np.zeros((maxIt,))
result_thr = np.zeros((maxIt,))
result_thr 	= net.thr[0].data.cpu().numpy()
for i in range(maxIt):
  result_stp[i] 	= net.step[i].data.cpu().numpy()


spath = 'LRPCA_alpha'+'.mat'
sio.savemat(spath, {'ths':result_thr, 'step':result_stp})














The boolean parameter 'some' has been replaced with a string parameter 'mode'.
Q, R = torch.qr(A, some)
should be replaced with
Q, R = torch.linalg.qr(A, 'reduced' if some else 'complete') (Triggered internally at ../aten/src/ATen/native/BatchLinearAlgebra.cpp:2426.)
  U0_t, _ 		= torch.qr(torch.randn(n,r,dtype = datatype)/math.sqrt(n))


epoch: 0 	 loss: 0.46746613581645446
epoch: 2 	 loss: 0.374043391828708
epoch: 0 	 loss: 0.3365417842991157
epoch: 2 	 loss: 0.2546720950988543
epoch: 0 	 loss: 0.15843484035930547
epoch: 2 	 loss: 0.1342793161305323
epoch: 0 	 loss: 0.1281863899488382
epoch: 2 	 loss: 0.15272693914903157
epoch: 0 	 loss: 0.06620876698806762
epoch: 2 	 loss: 0.06748732717566594
epoch: 0 	 loss: 0.05734098622509836
epoch: 2 	 loss: 0.0773443560418312
epoch: 0 	 loss: 0.03921488220831694
epoch: 2 	 loss: 0.03300717673873461
epoch: 0 	 loss: 0.031142103180006653
epoch: 2 	 loss: 0.03685174565870248
epoch: 0 	 loss: 0.022121869839748592
epoch: 2 	 loss: 0.017249657282639882
epoch: 0 	 loss: 0.018910425988002445
epoch: 2 	 loss: 0.02395444846759116
epoch: 0 	 loss: 0.010282935276431182
epoch: 2 	 loss: 0.010886491809757915
epoch: 0 	 loss: 0.007586936830590376
epoch: 2 	 loss: 0.008053236827548122
epoch: 0 	 loss: 0.005997746262848515
epoch: 2 	 loss: 0.0064507351359293
epoch: 0 	 loss: 0.005762983499605485

KeyboardInterrupt: 

In [None]:
result_thr

array(0.51669041)

In [None]:
result_stp


array([0.5       , 1.93514147, 0.80864945, 0.60448286, 0.88496346,
       1.43840961, 0.59636465, 2.1705129 , 1.76584915, 0.40185209,
       0.39694755, 0.46756304, 2.16777527, 2.23668645, 2.07449664,
       1.83732582, 0.37635988, 1.4958818 , 0.6621005 , 0.01929649,
       0.32267303, 0.25271242, 1.19656133, 1.94216878, 1.82703888,
       1.07353684, 0.98512214, 0.53016262, 0.45006827, 0.7629641 ,
       0.71776654, 0.60472756, 0.84449897, 0.80453575, 0.82170544,
       0.72900139, 0.77497536, 0.68617869, 0.76580305, 0.61025142,
       0.59471921, 0.56268884, 0.55175302, 0.55308136, 0.75027799,
       0.52572224, 0.52085304, 0.54441848, 0.52629019, 0.50931355,
       0.51974804, 0.50695936, 0.50554592, 0.50342293, 0.50486421,
       0.50266811, 0.50171058, 0.5043831 , 0.50129951, 0.50122675,
       0.50130459, 0.50115624, 0.50119468, 0.50092603, 0.50119845,
       0.50062525, 0.5004605 , 0.50035924, 0.5003906 , 0.50030337,
       0.50033208, 0.500334  , 0.50019408, 0.50018676, 0.50017