In [1]:
import numpy as np
from scipy.io import loadmat
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from easydict import EasyDict as edict
from collections import defaultdict as ddict
import torch
import time
from tqdm.autonotebook import tqdm
from scipy import signal
import time
%matplotlib inline

In [2]:
from cUtils import *

In [3]:
class OneStepOpt():
    """
        I concatenate the real and image part into one vector.
    """
    def __init__(self, X, Y, pUinv, fixedParas, lastTheta, **paras):
        """
         Input: 
             Y: A tensor with shape, d x dF x dT
             X: A tensor with shape, d x dF x dT
             pUinv, the first R row of the inverse of the eigen vector matrix, R x d, complex data
             fixedParas: The fixed parameters when optimizing, real data 
                 when update \mu, 2R x dT
                 when update \nu, 2R x dF
             lastTheta: The parameters for optimizing at the last time step, initial parameters, vector of 2R(D-1), real data
             paras:
                 beta: tuning parameter for iteration
                 alp: tuning parameter for iteration
                 rho: a vector of length (D-1)2R, real data
                 lam: the parameter for SCAD
                 a: the parameter for SCAD, > 1+1/beta
                 iterNum:  integer, number of iterations
                 iterC: decimal, stopping rule
                 eps: decimal, stopping rule for conjugate gradient method
        """
        
        parasDefVs = {"a": 2.7,  "beta": 1, "alp": 0.9,  "rho": None,  "lam": 1e2, 
                      "iterNum": 100, "iterC": 1e-4, "eps": 1e-6}
        
        self.paras = edict(parasDefVs)
        for key in paras.keys():
            self.paras[key] = paras[key]
        
            
            
        self.d, self.dF, self.dT = X.shape
        self.R2, _ = fixedParas.shape
        self.D = int(lastTheta.shape[0]/self.R2+1)
        self.nD = self.dF if self.D == self.dT else self.dT
        
        self.pUinv = pUinv
        self.X = X.type_as(self.pUinv)
        self.Y = Y.type_as(self.pUinv) # Make them complex
        
        R = int(self.R2/2)
        if self.paras.rho is None:
            self.paras.rho = torch.ones(self.R2*(self.D-1))
            
            
        self.fixedParas = torch.complex(fixedParas[:R, :], fixedParas[R:, :]).type_as(self.pUinv) # R x nD
        
        self.lastTheta= lastTheta
        
        
        self.newVecGam = None
        self.newVecGamStd = None
        self.halfRho = None
        self.rho = self.paras.rho
        self.lam = self.paras.lam
        self.a = self.paras.a 
        self.iterNum = self.paras.iterNum
        self.iterC = self.paras.iterC
            
        self.leftMat = None
        self.leftMatVec = None
        self.NewXYR2Sum = None
        
    def obtainNewData(self):
        pY = self.Y.permute(1, 2, 0) # dF x dT x d
        pX = self.X.permute(1, 2, 0)
        cNewX = pX.matmul(self.pUinv.T)  # dF x dT x R
        if self.D == self.dF:
            cNewY = pY.matmul(self.pUinv.T) * (1/self.fixedParas.T) # dF x dT x R
        else:
            cNewY = pY.matmul(self.pUinv.T) * (1/self.fixedParas.T).unsqueeze(1) # dF x dT x R
        self.NewXr = cNewX.real
        self.NewYr = cNewY.real
        self.NewXi = cNewX.imag
        self.NewYi = cNewY.imag
        
    def _AmatOpt(self, vec):
        rVec1 = self.leftMatVecP1 * vec
        rVec2 = self.paras.beta * DiffMatTOpt(DiffMatOpt(vec, self.R2), self.R2)
        return rVec1 + rVec2
    
    def _ConjuGrad(self, vec, maxIter=1000):
        """ 
        Ax = vec
        """
        eps = self.paras.eps
        
        xk = torch.zeros_like(vec)
        rk = vec - self._AmatOpt(xk)
        pk = rk
        if torch.norm(rk) <= eps:
            return xk
        
        for k in range(maxIter):
            alpk = torch.sum(rk**2) / torch.sum(pk * self._AmatOpt(pk))
            xk = xk + alpk * pk 
            
            rk_1 = rk
            rk = rk - alpk * self._AmatOpt(pk)
            
            if torch.norm(rk) <= eps:
                break 
                
            betk = torch.sum(rk**2)/torch.sum(rk_1**2)
            pk = rk + betk * pk
            
        return xk
        
    
    def updateVecGam(self):
        """
            I use conjugate gradient to solve it. 
            Update the Gamma matrix, first step 
        """
        self.DiffMatSq = genDiffMatSqfn(self.R2, self.D) # R2D x R2D
        
        optAxis = 1 if self.D == self.dF else 0
        
        if self.leftMat is None:
            NewXSq = self.NewXr**2 + self.NewXi**2
            NewXSqR2 = torch.cat((NewXSq, NewXSq), dim=2) # dF x dT x 2R
            NewXSqR2Sum = NewXSqR2.sum(axis=optAxis) # dF x 2R or dT x 2R
            self.leftMat = torch.diag(NewXSqR2Sum.flatten()).to_sparse()/self.nD +  \
                    self.paras.beta * self.DiffMatSq
        
        if self.NewXYR2Sum is None:
            NewXY1 = self.NewXr * self.NewYr + self.NewXi * self.NewYi
            NewXY2 = -self.NewXi * self.NewYr + self.NewXr * self.NewYi
            NewXYR2 = torch.cat((NewXY1, NewXY2), dim=2) # dF x dT x 2R
            self.NewXYR2Sum = NewXYR2.sum(axis=optAxis) # dF x 2R or dT x 2R
        rightVec = self.NewXYR2Sum.flatten()/self.nD + \
                    DiffMatTOpt(self.rho + self.paras.beta * self.lastTheta, self.R2)
        
        # self.newVecGam, = torch.inverse(self.leftMat).matmul(rightVec)
        # Better way to do so
        self.newVecGam, _  = torch.solve(rightVec.reshape(-1, 1), self.leftMat.to_dense()) 
        self.newVecGam = self.newVecGam.reshape(-1)
        
    def updateVecGamConGra(self):
        """
            Update the Gamma matrix, first step, wth Conjugate Gradient Method
        """
        optAxis = 1 if self.D == self.dF else 0
        
        if self.leftMat is None:
            NewXSq = self.NewXr**2 + self.NewXi**2
            NewXSqR2 = torch.cat((NewXSq, NewXSq), dim=2) # dF x dT x 2R
            NewXSqR2Sum = NewXSqR2.sum(axis=optAxis) # dF x 2R or dT x 2R
            self.leftMatVecP1 = NewXSqR2Sum.flatten()/self.nD
        
        if self.NewXYR2Sum is None:
            NewXY1 = self.NewXr * self.NewYr + self.NewXi * self.NewYi
            NewXY2 = -self.NewXi * self.NewYr + self.NewXr * self.NewYi
            NewXYR2 = torch.cat((NewXY1, NewXY2), dim=2) # dF x dT x 2R
            self.NewXYR2Sum = NewXYR2.sum(axis=optAxis) # dF x 2R or dT x 2R
        rightVec = self.NewXYR2Sum.flatten()/self.nD + \
                    DiffMatTOpt(self.rho + self.paras.beta * self.lastTheta, self.R2)
        
        # self.newVecGam, = torch.inverse(self.leftMat).matmul(rightVec)
        # Better way to do so
        self.newVecGam = self._ConjuGrad(rightVec)
        
    def updateHRho(self):
        """
            Update the vector rho at 1/2 step, second step
        """
        halfRho = self.rho - self.paras.alp * self.paras.beta * (DiffMatOpt(self.newVecGam, self.R2) - self.lastTheta)
        self.halfRho = halfRho
       
    
    def updateTheta(self):
        """
            Update the vector Theta, third step
        """
        halfTheta = DiffMatOpt(self.newVecGam, self.R2) - self.halfRho/self.paras.beta
        tranHTheta = halfTheta.reshape(-1, self.R2) # D-1 x 2R
        hThetaL2Norm = tranHTheta.abs().square().sum(axis=1).sqrt() # D-1
        normCs = torch.zeros_like(hThetaL2Norm) - 1
        
        normC1 = hThetaL2Norm - self.lam/self.paras.beta
        normC1[normC1<0] = 0
        
        normC2 = (self.paras.beta * (self.a - 1) * hThetaL2Norm - self.a * self.lam)/(self.paras.beta * self.a - self.paras.beta -1)
        
        c1 = (1+1/self.paras.beta)* self.lam
        c2 = self.a * self.lam
        
        normCs[hThetaL2Norm<=c1] = normC1[hThetaL2Norm<=c1]
        normCs[hThetaL2Norm>c2] = hThetaL2Norm[hThetaL2Norm>c2]
        normCs[normCs==-1] = normC2[normCs==-1]
        
        normCs[normCs!=0] = normCs[normCs!=0]/hThetaL2Norm[normCs!=0]
        
        self.lastTheta = (tranHTheta*normCs.reshape(-1, 1)).flatten()
        
    
    def updateRho(self):
        """
            Update the vector rho, fourth step
        """
        newRho = self.halfRho - self.paras.alp * self.paras.beta * (DiffMatOpt(self.newVecGam, self.R2) - self.lastTheta)
        self.rho = newRho
        
    
    def __call__(self, is_showProg=False, leave=False):
        self.obtainNewData()
        if self.iterC is None:
            self.iterC = 0
        
        chDiff = torch.tensor(1e10)
        self.updateVecGamConGra()
        self.updateHRho()
        self.updateTheta()
        self.updateRho()
        lastVecGam = self.newVecGam
        if is_showProg:
            with tqdm(total=self.iterNum, leave=leave) as pbar:
                for i in range(self.iterNum):
                    pbar.set_description(f"Inner Loop: The chdiff is {chDiff.item():.3e}.")
                    pbar.update(1)
                    self.updateVecGamConGra()
                    self.updateHRho()
                    self.updateTheta()
                    self.updateRho()
                    chDiff = torch.norm(self.newVecGam-lastVecGam)/torch.norm(lastVecGam)
                    lastVecGam = self.newVecGam
                    if chDiff < self.iterC:
                        pbar.update(self.iterNum)
                        break
        else:
            for i in range(self.iterNum):
                self.updateVecGamConGra()
                self.updateHRho()
                self.updateTheta()
                self.updateRho()
                chDiff = torch.norm(self.newVecGam-lastVecGam)/torch.norm(lastVecGam)
                lastVecGam = self.newVecGam
                if chDiff < self.iterC:
                    break
            
        if self.D == self.dF:
            R = int(self.R2/2)
            newGam = self.newVecGam.reshape(-1, self.R2) # D x 2R
            newGamNorm2 = newGam.square().sum(axis=0) # 2R
            newGamNorm = torch.sqrt(newGamNorm2[:R] + newGamNorm2[R:])
            newGamNorm = torch.cat([newGamNorm, newGamNorm])
            newGam = newGam/newGamNorm
            self.newVecGamStd = newGam.flatten()

In [4]:
if torch.cuda.is_available():
    torch.cuda.set_device(2)

In [5]:
dataPath = Path("../data")
datF = list(dataPath.glob("*.mat"))[0]

In [6]:
rawDat = loadmat(datF)
dat = rawDat["DK_timecourse"]

In [21]:
fs = 600
Rn = 5
outIterC = 1e-5
lams = [5e2, 1e2]
q = 1
downrates = [1, 1]
hs = [0.1, 0.1]
T = 6
iterNums = [100, 100]
betas = [10, 10]
iterCs = [1e-4, 1e-3]

In [22]:
fOpt = TVDNextOpt(rawDat=dat, fs=fs, T=T, hs=hs, Rn=Rn, lams=lams, downrates=downrates, q=q, 
                  iterNums=iterNums, iterCs=iterCs, 
                  outIterC=outIterC, maxIter=1000, betas=betas)

In [23]:
fOpt._PreProcess()
fOpt._estAmat()

In [24]:
_, fOpt.dF, fOpt.dT = fOpt.X.shape
fOpt.paras.paraMuInit = torch.rand(fOpt.R2, fOpt.dF)
fOpt.paras.paraNuInit = torch.rand(fOpt.R2, fOpt.dT)
lastNuTheta = DiffMatOpt(colStackFn(fOpt.paras.paraNuInit), fOpt.R2)
fixedMuMat = fOpt.paras.paraMuInit

In [25]:
rho = torch.ones(fOpt.R2*(fOpt.dT-1))
lam = 1
iterNum = 1000
optNu1 = OneStepOpt(X=fOpt.X, Y=fOpt.Y, pUinv=fOpt.pUinv, fixedParas=fixedMuMat, lastTheta=lastNuTheta, 
                               alp=0.9, beta=1, lam=lam, 
                               a=2.7, iterNum=iterNum, rho=rho, iterC=1e-5, eps=1e-6)

In [26]:
optNu1(True, leave=True)

  0%|          | 0/1000 [00:00<?, ?it/s]

In [29]:
colStackFn(optNu1.newVecGam, optNu1.R2)

tensor([[-2.7399e-02, -1.9078e-01, -4.9737e-01,  ...,  1.0160e+00,
          1.3673e+00,  1.4837e+00],
        [-1.7804e-02, -2.2225e-01, -7.2426e-01,  ...,  6.9845e-01,
          6.7883e-01,  7.0824e-01],
        [-4.1351e-05,  7.8151e-06,  1.5950e-06,  ...,  1.4594e+00,
          1.3090e+00,  1.1016e+00],
        ...,
        [ 5.3989e-03,  6.7355e-02,  2.1943e-01,  ..., -9.1689e-01,
         -9.5823e-01, -9.9974e-01],
        [ 8.8812e-05, -1.6764e-05, -3.4021e-06,  ..., -3.7566e+00,
         -3.3763e+00, -2.8216e+00],
        [ 2.6780e-01,  2.8097e+00,  6.2698e+00,  ..., -8.2912e-01,
         -8.2856e-01, -6.5083e-01]])

In [30]:
optNu1.lastTheta

tensor([-1.6340e-01, -2.0449e-01,  5.2005e-05,  ..., -4.1522e-02,
         5.5475e-01,  1.7776e-01])