In [1]:
import numpy as np
from scipy.io import loadmat
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from easydict import EasyDict as edict
from collections import defaultdict as ddict
import torch
import time
from tqdm import tqdm
from scipy import signal
import time
%matplotlib inline

In [2]:
if torch.cuda.is_available():
    torch.cuda.set_device(2)

In [107]:
torch.set_default_tensor_type(torch.cuda.DoubleTensor)
class OneStepOpt():
    """
        I concatenate the real and image part into one vector.
    """
    def __init__(self, X, Y, pUinv, fixedParas, lastTheta, **paras):
        """
         Input: 
             Y: A tensor with shape, d x dF x dT
             X: A tensor with shape, d x dF x dT
             pUinv, the first R row of the inverse of the eigen vector matrix, R x d, complex data
             fixedParas: The fixed parameters when optimizing, real data 
                 when update \mu, 2R x dT
                 when update \nu, 2R x dF
             lastTheta: The parameters for optimizing at the last time step, initial parameters, vector of 2R(D-1), real data
             paras:
                 beta: tuning parameter for iteration
                 alp: tuning parameter for iteration
                 rho: a vector of length (D-1)2R, real data
                 lam: the parameter for SCAD
                 a: the parameter for SCAD, > 1+1/beta
                 iterNum:  integer, number of iterations
                 iterC: decimal, stopping rule
        """
        
        
        self.paras = edict()
        for key in paras.keys():
            self.paras[key] = paras[key]
            
        self.d, self.dF, self.dT = X.shape
        self.R2, _ = fixedParas.shape
        self.D = int(lastTheta.shape[0]/self.R2+1)
        self.nD = self.dF if self.D == self.dT else self.dT
        
        self.pUinv = pUinv
        self.X = X.type_as(self.pUinv)
        self.Y = Y.type_as(self.pUinv) # Make them complex
        
        R = int(self.R2/2)
        self.fixedParas = torch.complex(fixedParas[:R, :], fixedParas[R:, :]).type_as(self.pUinv) # R x D
        
        self.lastTheta= lastTheta
        
        self.DiffMatSq = genDiffMatSqfn(self.R2, self.D) # R2D x R2D
        
        self.newVecGam = None
        self.halfRho = None
        self.rho = self.paras.rho
        self.lam = self.paras.lam
        self.a = self.paras.a 
        if "iterNum" not in self.paras.keys():
            self.iterNum = None
        else:
            self.iterNum = self.paras.iterNum
        if "iterC" not in self.paras.keys():
            self.iterC = None
        else:
            self.iterC = self.paras.iterC
            
        self.leftMat = None
        self.leftMatVec = None
        self.NewXYR2Sum = None
        
    def obtainNewData(self):
        pY = self.Y.permute(1, 2, 0) # dF x dT x d
        pX = self.X.permute(1, 2, 0)
        cNewX = pX.matmul(self.pUinv.T)  # dF x dT x R
        if self.D == self.dF:
            cNewY = pY.matmul(self.pUinv.T) * (1/self.fixedParas.T) # dF x dT x R
        else:
            cNewY = pY.matmul(self.pUinv.T) * (1/self.fixedParas.T).unsqueeze(1) # dF x dT x R
        self.NewXr = cNewX.real
        self.NewYr = cNewY.real
        self.NewXi = cNewX.imag
        self.NewYi = cNewY.imag
        
    
    def updateVecGam(self):
        """
            Update the Gamma matrix, first step 
        """
        optAxis = 1 if self.D == self.dF else 0
        
        if self.leftMat is None:
            NewXSq = self.NewXr**2 + self.NewXi**2
            NewXSqR2 = torch.cat((NewXSq, NewXSq), dim=2) # dF x dT x 2R
            NewXSqR2Sum = NewXSqR2.sum(axis=optAxis) # dF x 2R or dT x 2R
            self.leftMat = torch.diag(NewXSqR2Sum.flatten()).to_sparse()/self.nD +  \
                    self.paras.beta * self.DiffMatSq
        
        if self.NewXYR2Sum is None:
            NewXY1 = self.NewXr * self.NewYr + self.NewXi * self.NewYi
            NewXY2 = -self.NewXi * self.NewYr + self.NewXr * self.NewYi
            NewXYR2 = torch.cat((NewXY1, NewXY2), dim=2) # dF x dT x 2R
            self.NewXYR2Sum = NewXYR2.sum(axis=optAxis) # dF x 2R or dT x 2R
        rightVec = self.NewXYR2Sum.flatten()/self.nD + \
                    DiffMatTOpt(self.rho + self.paras.beta * self.lastTheta, self.R2)
        
        # self.newVecGam, = torch.inverse(self.leftMat).matmul(rightVec)
        # Better way to do so
        self.newVecGam, _  = torch.solve(rightVec.reshape(-1, 1), self.leftMat.to_dense()) 
        self.newVecGam = self.newVecGam.reshape(-1)
        
    def updateVecGamApprox(self):
        """
            Not good
            Update the Gamma matrix, first step, approximately
        """
        optAxis = 1 if self.D == self.dF else 0
        
        if self.leftMat is None:
            NewXSq = self.NewXr**2 + self.NewXi**2
            NewXSqR2 = torch.cat((NewXSq, NewXSq), dim=2) # dF x dT x 2R
            NewXSqR2Sum = NewXSqR2.sum(axis=optAxis) # dF x 2R or dT x 2R
            beVec = torch.ones(self.R2*self.D) * 2
            beVec[:self.R2] = 1
            beVec[-self.R2:] = 1
            self.leftMatVec = NewXSqR2Sum.flatten()/self.nD +  self.paras.beta * beVec # this step is approximate as leftMat is not an exact diag mat
        
        if self.NewXYR2Sum is None:
            NewXY1 = self.NewXr * self.NewYr + self.NewXi * self.NewYi
            NewXY2 = -self.NewXi * self.NewYr + self.NewXr * self.NewYi
            NewXYR2 = torch.cat((NewXY1, NewXY2), dim=2) # dF x dT x 2R
            self.NewXYR2Sum = NewXYR2.sum(axis=optAxis) # dF x 2R or dT x 2R
        rightVec = self.NewXYR2Sum.flatten()/self.nD + \
                    DiffMatTOpt(self.rho + self.paras.beta * self.lastTheta, self.R2)
        
        self.newVecGam = rightVec/self.leftMatVec 
        
    def updateHRho(self):
        """
            Update the vector rho at 1/2 step, second step
        """
        halfRho = self.rho - self.paras.alp * self.paras.beta * (DiffMatOpt(self.newVecGam, self.R2) - self.lastTheta)
        self.halfRho = halfRho
       
    
    def updateTheta(self):
        """
            Update the vector Theta, third step
        """
        halfTheta = DiffMatOpt(self.newVecGam, self.R2) - self.halfRho/self.paras.beta
        tranHTheta = halfTheta.reshape(-1, self.R2) # D-1 x 2R
        hThetaL2Norm = tranHTheta.abs().square().sum(axis=1).sqrt() # D-1
        normCs = torch.zeros_like(hThetaL2Norm) - 1
        
        normC1 = hThetaL2Norm - self.lam/self.paras.beta
        normC1[normC1<0] = 0
        
        normC2 = (self.paras.beta * (self.a - 1) * hThetaL2Norm - self.a * self.lam)/(self.paras.beta * self.a - self.paras.beta -1)
        
        c1 = (1+1/self.paras.beta)* self.lam
        c2 = self.a * self.lam
        
        normCs[hThetaL2Norm<=c1] = normC1[hThetaL2Norm<=c1]
        normCs[hThetaL2Norm>c2] = hThetaL2Norm[hThetaL2Norm>c2]
        normCs[normCs==-1] = normC2[normCs==-1]
        
        normCs[normCs!=0] = normCs[normCs!=0]/hThetaL2Norm[normCs!=0]
        
        self.lastTheta = (tranHTheta*normCs.reshape(-1, 1)).flatten()
        
    
    def updateRho(self):
        """
            Update the vector rho, fourth step
        """
        newRho = self.halfRho - self.paras.alp * self.paras.beta * (DiffMatOpt(self.newVecGam, self.R2) - self.lastTheta)
        self.rho = newRho
        
    
    def __call__(self, is_approx=False, is_showProg=False):
        self.obtainNewData()
        
        if self.iterNum is not None:
            if is_showProg:
                for i in tqdm(range(self.iterNum)):
                    if is_approx:
                        self.updateVecGamApprox()
                    else:
                        self.updateVecGam()
                    self.updateHRho()
                    self.updateTheta()
                    self.updateRho()
            else:
                for i in range(self.iterNum):
                    if is_approx:
                        self.updateVecGamApprox()
                    else:
                        self.updateVecGam()
                    self.updateHRho()
                    self.updateTheta()
                    self.updateRho()
                #print(self.lastTheta.reshape(-1, self.R2)[0, :])
        elif self.iterC is not None:
            chDiff = 1e10
            self.updateVecGam()
            self.updateHRho()
            self.updateTheta()
            self.updateRho()
            
            lastVecGam = self.newVecGam
            while (chDiff >= self.iterC):
                if is_approx:
                    self.updateVecGamApprox()
                else:
                    self.updateVecGam()
                self.updateHRho()
                self.updateTheta
                self.updateRho()
                chDiff = torch.norm(self.newVecGam-lastVecGam)
                lastVecGam = self.newVecGam
                
            
        if self.D == self.dF:
            newGam = self.newVecGam.reshape(-1, self.R2) # D x 2R
            newGamNorm = newGam.square().sum(axis=0).sqrt() # 2R
            newGam = newGam/newGamNorm
            self.newVecGam = newGam.flatten()

In [116]:
class TVDNextOpt():
    """
        The class to implement the full procedure of TVDNext method
    """
    def __init__(self, rawDat, fs, T, R, hs, **paras):
        """
         Input: 
             rawDat: The raw dataset, tensor of d x dT+1
             fs: The sampling freq of the raw dataset
             T: Time course of the data
             R: The rank of A mat, R << d to reduce the computational burden 
             hs: the bandwidths for the kernel regression whe estimating A matrix
             paras:
               For Preprocess:
                 is_detrend: Whether detrend the raw data or not
                 bandsCuts: the cirtical freqs to use
                 Nord: The order of the filter
                 q: The decimate rate
                 
               For A matrix:
                 downrates: The downrate factors for freq and time, determine how many A(s_i, t_i) matrix to be summed
                 
               For one-step Opt:
                 betas: list of two tuning parameter for iteration
                 alps: list of two tuning parameter for iteration
                 rhos: list of two vectors of length (dF-1)2R and (dT-1)2R, real data
                 lams: list of two parameters for SCAD, for mu and nu
                 As: list of two parameters for SCAD for mu and nu, > 1+1/beta
                 iterNums:  integer or list of two integers, number of iterations for one-step-opt
                 iterCs: decimal or list of two decimate, stopping rule for one-step-opt
               
               For the outer optimization procedure:
                 paraMuInit: The initial value of mu parameters, along the freq axis
                 paraNuInit: The initial value of nu parameters, along the time axis
                 maxIter: Integer, the maximal times of iteration for the outer loop
                 outIterC:  decimal, stopping rule for the outer loop
        """
        parasDefVs = {
                      "is_detrend": True, "bandsCuts": [[2, 3.5], [4, 7], [8, 12], [13, 30], [30, 80]], 
                      "Nord": None, "q": 10, 
                      "downrates": [1, 10],  "betas":[1, 1], "alps": [1, 1],  "rhos": None,  "lams": None, 
                      "As": [2.7, 2.7],  "iterNums": [1, 10], "iterCs": None, "paraMuInit": None,
                      "paraNuInit": None, "maxIter": 100, "outIterC": None
                    }
        self.paras = edict(parasDefVs)
        for key in paras.keys():
            self.paras[key] = paras[key]
            
        if self.paras.iterCs is None:
            self.paras.iterCs = [None, None]
        
            
        self.rawDat = rawDat
        self.fs, self.T = fs, T
        self.R, self.R2 = R, R*2
        self.hs = hs 
        
        # Some none definitions
        self.X = self.Y = self.pUinv = None
        self.dF = self.dT = self.D = self.nD = None
        self.lastThetaMu = self.lastTheteNu = None # Vector of 2R(dF-1)/2R(dT-1)
        self.paraMu = self.paraNu = None # matrix of 2R x dF/dT
        
    def _PreProcess(self):
        """
        To preprocess the raw dataset, including 
            1. Detrend, 
            2. Filter under bands
            3. Decimate
        """
        dat = signal.detrend(self.rawDat)
        cDat = mat2Tensor(dat, fs=self.fs, q=self.paras.q)
        # Avoid stride problem when convert numpy to tensor
        self.X = torch.tensor(cDat.X.copy())
        self.Y = torch.tensor(cDat.Y.copy())
        
    def _estAmat(self):
        _, self.dF, self.dT = self.Y.shape
        times = np.linspace(0, self.T, dT)
        freqs = np.array([np.mean(bandCut) for bandCut in self.paras.bandsCuts])
        self.Amat = GetAmatTorch(self.Y, self.X, times, freqs, self.paras.downrates, self.hs)
        
        res = np.linalg.eig(self.Amat)
        Uinv = np.linalg.inv(res[1])
        pUinv = Uinv[:self.R, :]
        self.pUinv = torch.tensor(pUinv)
        
    def __call__(self, show_prog=True):
        if self.X is None:
            self._PreProcess()
        if self.pUinv is None:
            self._estAmat()
        
        _, self.dF, self.dT = self.X.shape
        
        self.D = self.dF
        self.nD = int(self.dF*self.dT/self.D)
        if self.paras.paraMuInit is None:
            self.paras.paraMuInit = torch.rand(self.R2, self.dF)
        if self.paras.paraNuInit is None:
            self.paras.paraNuInit = torch.rand(self.R2, self.dT)
        if self.paras.rhos is None:
            rho1 = torch.ones(self.R2*(self.dF-1))
            rho2 = torch.ones(self.R2*(self.dT-1))
            self.paras.rhos = [rho1, rho2]
            
        
        
        chDiffBoth = 1e10 # Stopping rule
        
        lastMuTheta = DiffMatOpt(colStackFn(self.paras.paraMuInit), self.R2)
        fixedNuMat = self.paras.paraNuInit
        
        stopLastMuMat = self.paras.paraMuInit
        stopLastNuMat = self.paras.paraNuInit
        
        for i in range(self.paras.maxIter):
            optMu = OneStepOpt(X=self.X, Y=self.Y, pUinv=self.pUinv, fixedParas=fixedNuMat, lastTheta=lastMuTheta, 
                               alp=self.paras.alps[0], beta=self.paras.betas[0], lam=self.paras.lams[0], 
                               a=self.paras.As[0], iterNum=self.paras.iterNums[0], rho=self.paras.rhos[0], iterC=self.paras.iterCs[0])
            optMu()
            
            fixedMuMat = colStackFn(optMu.newVecGam, self.R2)
            lastNuTheta = DiffMatOpt(colStackFn(fixedNuMat), self.R2)
            
            optNu = OneStepOpt(X=self.X, Y=self.Y, pUinv=self.pUinv, fixedParas=fixedMuMat, lastTheta=lastNuTheta, 
                               alp=self.paras.alps[1], beta=self.paras.betas[1], lam=self.paras.lams[1], 
                               a=self.paras.As[1], iterNum=self.paras.iterNums[1], rho=self.paras.rhos[1], iterC=self.paras.iterCs[1])
            optNu()
            
            fixedNuMat = colStackFn(optNu.newVecGam, self.R2)
            lastMuTheta = DiffMatOpt(colStackFn(fixedMuMat), self.R2)
            
            chDiffMu = torch.norm(stopLastMuMat-fixedMuMat)/torch.norm(stopLastMuMat)
            chDiffNu = torch.norm(stopLastNuMat-fixedNuMat)/torch.norm(stopLastNuMat)
            chDiffBoth = chDiffMu + chDiffNu
            
            stopLastMuMat = fixedMuMat
            stopLastNuMat = fixedNuMat
            if show_prog:
                print(f"Current iteration is {i+1}/{self.paras.maxIter}, the change of diff is {chDiffBoth}")
            if chDiffBoth <= self.paras.outIterC:
                break
            
        self.paraMu = stopLastMuMat
        self.paraNu = stopLastNuMat
        

In [117]:
dataPath = Path("../data")
datF = list(dataPath.glob("*.mat"))[0]

Time course is $60$s, so freq is $600$ Hz.

In [118]:
rawDat = loadmat(datF)
dat = rawDat["DK_timecourse"]

In [119]:
fs = 600
R = 4
outIterC = 1e-10
lams = [1e2, 1e2]
q = 100
downrates = [1, 1]
hs = [1, 0.1]
T = 6

In [120]:
fOpt = TVDNextOpt(rawDat=dat, fs=fs, T=T, hs=hs, R=R, lams=lams, downrates=downrates, q=q, outIterC=outIterC, maxIter=10)

In [121]:
fOpt()

Current iteration is 1/10, the change of diff is 2.962829556484455
Current iteration is 2/10, the change of diff is 0.03441779624658302
Current iteration is 3/10, the change of diff is 0.00048043776138255673
Current iteration is 4/10, the change of diff is 4.165814612636461e-07
Current iteration is 5/10, the change of diff is 4.895290021482425e-10
Current iteration is 6/10, the change of diff is 5.046226601092302e-13
