This File is to pipeline on running the TVDN algorithm on AD and Control data and save the results

Of course, you should tune some parameters for your data.

I think the most important parameters are

- `paras.lamb` in section 1.2

- `lowCpts` in section 2. 

For the meaning of some parmaters in TVDN algorithm, plz refer to our github repo. 

https://github.com/JINhuaqing/TVDN

In [1]:
import sys
sys.path.append("/home/huaqingj/MyResearch/TVDN-AD")

In [2]:
from pyTVDN import TVDNDetect
from pathlib import Path
from scipy.io import loadmat, savemat
import numpy as np
from easydict import EasyDict as edict
import matplotlib
import matplotlib.pyplot as plt
import os
from scipy import signal
from tqdm import tqdm
import pickle
%matplotlib inline

In [3]:
os.chdir("/home/huaqingj/MyResearch/TVDN-AD")
resDir = Path("./results")
dataDir = Path("./data")

In [4]:
filAD = "87ADs_before_filter1000.mat" # 0-86
filCtrl = "70Ctrls_before_filter1000.mat" #0-69

## Run data with TVDN

### Load the datasets


In [5]:
ADdatasets = loadmat(dataDir/filAD)["dk10"]
Ctrldatasets = loadmat(dataDir/filCtrl)["dk10"]

### Run TVDN 

#### Parameters for TVDN

**For below parameters, the most important one is `paras.lamb`.** 

It is the smooth parameter for B-spline, I suggest you tune it between range [1e-1, 1e-8].


In [66]:
paras = edict()
paras.Lmin = 200
paras.wh = 10
paras.MaxM = 20
paras.fct = 0.5
paras.r = None
paras.lamb = 1e-4 # tuning this. 
paras.T = 2
paras.is_detrend = True
paras.decimateRate = 5
paras.kappa = 3.005 # this kappa is not important, you will tune it later. 
paras.downRate = 20

# your data sampling freq after decimate. 
if paras.decimateRate is None:
    paras.freq = 600
else:
    paras.freq = int(600/paras.decimateRate)

In [67]:
kps = np.linspace(1, 4, 1000) # It is the range of kappa. 

#### Print the data after Bspline (optional).


I suggest you to print the data after Bspline to find whether `paras.lamb` is suitable or not.


In [None]:
idx = 0
curDat = ADdatasets[idx, :, :]
#curDat = Ctrldatasets[idx, :, :]
detRes = TVDNDetect(Ymat=curDat, dataType="MEG", 
                         saveDir=None, 
                         showProgress=False, 
                         fName="demoMEG",
                         MaxM=paras.MaxM,
                         Lmin=paras.Lmin, 
                         lamb=paras.lamb, 
                         kappa=paras.kappa, 
                         freq=paras.freq,
                         r=paras.r,
                         T=paras.T,
                         is_detrend=paras.is_detrend,
                         decimateRate=paras.decimateRate, 
                         downRate=paras.downRate)
detRes.Screening(wh=paras.wh)
    


In [None]:
# here is the data after Bspline
# If it is weird, you can choose another one. 
for ix in range(detRes.Xmat.shape[0]):
    plt.plot(detRes.Xmat[ix, :])

#### Run Kappa tuning algorithm 

This is to tune the kappa. 

Below, I just run and save the results, the tuning step will do later

In [10]:
for idx in tqdm(range(ADdatasets.shape[0])):
    ADdataset = ADdatasets[idx, :, :]
    detADa = []
    for i in range(paras.decimateRate):
        detADt = TVDNDetect(Ymat=ADdataset[:, i::paras.decimateRate], dataType="MEG", 
                             saveDir=None, 
                             showProgress=False, 
                             fName="demoMEG",
                             MaxM=paras.MaxM,
                             Lmin=paras.Lmin, 
                             lamb=paras.lamb, 
                             kappa=paras.kappa, 
                             freq=paras.freq,
                             r=paras.r,
                             T=paras.T,
                             is_detrend=paras.is_detrend,
                             decimateRate=None, 
                             downRate=paras.downRate)
        detADt.Screening(wh=paras.wh)
        detADt()
        detADt.TuningKappa(kps)
        detADa.append(detADt)
        
    saveFil = f"AD_data_det_{idx}_lamb{paras.lamb:.1E}_decimate{paras.decimateRate:.0f}_tuning.pkl"
    with open(resDir/saveFil, "wb") as f:
         pickle.dump(detADa, f)

100%|██████████| 87/87 [3:22:12<00:00, 139.45s/it]  


In [58]:
for idx in tqdm(range(Ctrldatasets.shape[0])):
    Ctrldataset = Ctrldatasets[idx, :, :]
    detCa = []
    for i in range(paras.decimateRate):
        detCt = TVDNDetect(Ymat=Ctrldataset[:, i::paras.decimateRate], dataType="MEG", 
                             saveDir=None, 
                             showProgress=False, 
                             fName="demoMEG",
                             MaxM=paras.MaxM,
                             Lmin=paras.Lmin, 
                             lamb=paras.lamb, 
                             kappa=paras.kappa, 
                             freq=paras.freq,
                             r=paras.r,
                             T=paras.T,
                             is_detrend=paras.is_detrend,
                             decimateRate=None, 
                             downRate=paras.downRate)
        detCt.Screening(wh=paras.wh)
        detCt()
        detCt.TuningKappa(kps)
        detCa.append(detCt)
        
    saveFil = f"Ctrl_data_det_{idx}_lamb{paras.lamb:.1E}_decimate{paras.decimateRate:.0f}_tuning.pkl"
    with open(resDir/saveFil, "wb") as f:
        pickle.dump(detCa, f)

100%|██████████| 70/70 [2:25:15<00:00, 124.50s/it]  


#### Run Main TVDN without tuning kappa

Here we run the main results.
We can update kappa later, so here I randomly use  a kappa. 

In [12]:
for idx in tqdm(range(ADdatasets.shape[0])):
    ADdataset = ADdatasets[idx, :, :]
    detAD = TVDNDetect(Ymat=ADdataset, dataType="MEG", 
                         saveDir=None, 
                         showProgress=False, 
                         fName="demoMEG",
                         MaxM=paras.MaxM,
                         Lmin=paras.Lmin, 
                         lamb=paras.lamb, 
                         kappa=paras.kappa, 
                         freq=paras.freq,
                         r=paras.r,
                         T=paras.T,
                         is_detrend=paras.is_detrend,
                         decimateRate=paras.decimateRate, 
                         downRate=paras.downRate)
    detAD.Screening(wh=paras.wh)
    detAD()
    detAD.TuningKappa(kps)
    
    saveFil = f"AD_data_det_{idx}_lamb{paras.lamb:.1E}_decimate{paras.decimateRate:.0f}.pkl"
    with open(resDir/saveFil, "wb") as f:
        pickle.dump(detAD, f)

  0%|          | 0/87 [00:00<?, ?it/s]R[write to console]: 
Attaching package: ‘signal’


R[write to console]: The following objects are masked from ‘package:stats’:

    filter, poly


100%|██████████| 87/87 [43:39<00:00, 30.10s/it]  


In [60]:
for idx in tqdm(range(Ctrldatasets.shape[0])):
    Ctrldataset = Ctrldatasets[idx, :, :]
    detC = TVDNDetect(Ymat=Ctrldataset, dataType="MEG", 
                         saveDir=None, 
                         showProgress=False, 
                         fName="demoMEG",
                         MaxM=paras.MaxM,
                         Lmin=paras.Lmin, 
                         lamb=paras.lamb, 
                         kappa=paras.kappa, 
                         freq=paras.freq,
                         r=paras.r,
                         T=paras.T,
                         is_detrend=paras.is_detrend,
                         decimateRate=paras.decimateRate, 
                         downRate=paras.downRate)
    detC.Screening(wh=paras.wh)
    detC()
    detC.TuningKappa(kps)
    
    saveFil = f"Ctrl_data_det_{idx}_lamb{paras.lamb:.1E}_decimate{paras.decimateRate:.0f}.pkl"
    with open(resDir/saveFil, "wb") as f:
        pickle.dump(detC, f)

  0%|          | 0/70 [00:00<?, ?it/s]R[write to console]: 
Attaching package: ‘signal’


R[write to console]: The following objects are masked from ‘package:stats’:

    filter, poly


100%|██████████| 70/70 [30:17<00:00, 25.96s/it]


### Save results

**Order of the dataset matters**

I save the results one by one, if you read then with `resDir.glob`, the order of the results is random. 

So I sort it with `sorted` function. 

But I think 

`sortedPs = sorted(ps, key=lambda p:int(p.stem.split("_")[3]))` is not always compatible if you change the name of the files. 

I suggest you to print it out to check whether it works as you like or not. 

In [None]:
detObjsNum = []
ps = list(resDir.glob("AD_*04_decimate5_tuning.pkl"))
sortedPs = sorted(ps, key=lambda p:int(p.stem.split("_")[3]))
for fil in tqdm(sortedPs):
    with open(fil, "rb") as f:
        dets = pickle.load(f)
    detObjsNum.append([det.numchgs for det in dets])
    
filName = f"AD_data_lamb{dets[0].paras.lamb:.1E}_tuningNum.pkl"
with open(resDir/filName, "wb") as f:
    pickle.dump(detObjsNum, f)

In [11]:
detObjsNumC = []
ps = list(resDir.glob("Ctrl_*04_decimate5_tuning.pkl"))
sortedPs = sorted(ps, key=lambda p:int(p.stem.split("_")[3]))
for fil in tqdm(sortedPs):
    
    with open(fil, "rb") as f:
        dets = pickle.load(f)
    detObjsNumC.append([det.numchgs for det in dets])
    
filName = f"Ctrl_data_lamb{dets[0].paras.lamb:.1E}_tuningNum.pkl"
with open(resDir/filName, "wb") as f:
    pickle.dump(detObjsNumC, f)

100%|██████████| 92/92 [12:07<00:00,  7.91s/it]


In [13]:
numchgss = []
ps = list(resDir.glob("AD_*04_decimate5.pkl"))
sortedPs = sorted(ps, key=lambda p:int(p.stem.split("_")[3]))
for fil in tqdm(sortedPs):
    with open(fil, "rb") as f:
        det = pickle.load(f)
    numchgss.append(det.numchgs)
    
filName = f"AD_data_lamb{det.paras.lamb:.1E}_Num.pkl"
with open(resDir/filName, "wb") as f:
    pickle.dump(numchgss, f)

100%|██████████| 88/88 [02:38<00:00,  1.80s/it]


In [14]:
numchgss = []
ps = list(resDir.glob("Ctrl_*04_decimate5.pkl"))
sortedPs = sorted(ps, key=lambda p:int(p.stem.split("_")[3]))
for fil in tqdm(sortedPs):
    with open(fil, "rb") as f:
        det = pickle.load(f)
    numchgss.append(det.numchgs)
    
filName = f"Ctrl_data_lamb{det.paras.lamb:.1E}_Num.pkl"
with open(resDir/filName, "wb") as f:
    pickle.dump(numchgss, f)

100%|██████████| 92/92 [02:28<00:00,  1.61s/it]


## Find the optimal kappa by Ctrl group

- The kappa tuning relys on your prior belief about the number of switch points.
So you may tune `lowCpts` and `upCpts` if you like. 
- But I think `upCpts` does not matter a lot.
You may focus on `lowCpts`. 


In [None]:
fil = list(resDir.glob("Ctrl_*_tuningNum.pkl"))[0]
with open(fil, "rb") as f:
    nchgAll = pickle.load(f)

In [None]:
nchgsMeans = [np.array(detObjNum).mean(axis=0) for detObjNum in nchgAll]
nchgsVars = [np.array(detObjNum).var(axis=0) for detObjNum in nchgAll]

In [None]:
lowCpts = 3 # lower bound of number of switches
upCpts = 19 # upper bound of number of switches
kps = np.linspace(1, 4, 1000)
nchgsMM = np.array(nchgsMeans).mean(axis=0)
nchgsVarM = np.array(nchgsVars).mean(axis=0)
idxs = np.bitwise_and(nchgsMM >=lowCpts, nchgsMM <=upCpts)
optIdx = np.where(idxs)[0][0] + np.where((nchgsVarM[idxs].min() == nchgsVarM)[idxs])[0][-1]

In [None]:
# it is the kappa tuned
optKp = kps[optIdx]
print(optKp)

## Update the kappa

Plz notice the order of your results when loading them

### The number of switches 

In [None]:
ecptss = []
ps = list(resDir.glob("Ctrl_*04_decimate5.pkl"))
sortedPs = sorted(ps, key=lambda p:int(p.stem.split("_")[3]))
for fil in tqdm(sortedPs):
    with open(fil, "rb") as f:
        det = pickle.load(f)
    det.UpdateEcpts(det.numchgs[kps==optKp][0])
    ecptss.append(det.ecpts)
filName = f"Ctrl_data_lamb{det.paras.lamb:.1E}_ecpts.pkl"
with open(resDir/filName, "wb") as f:
    pickle.dump(ecptss, f)

In [None]:
ecptss = []
ps = list(resDir.glob("AD_*04_decimate5.pkl"))
sortedPs = sorted(ps, key=lambda p:int(p.stem.split("_")[3]))
for fil in tqdm(sortedPs):
    with open(fil, "rb") as f:
        det = pickle.load(f)
    det.UpdateEcpts(det.numchgs[kps==optKp][0])
    ecptss.append(det.ecpts)
filName = f"AD_data_lamb{det.paras.lamb:.1E}_ecpts.pkl"
with open(resDir/filName, "wb") as f:
    pickle.dump(ecptss, f)

### Eigen value and modes

In [None]:
def GetFeatures(det):
    """
    obtain the eigvals and eigvectors for current ecpts
    """
    if det.RecResCur is None:
        det.GetRecResCur()
    Ur = det.midRes.eigVecs[:, :det.paras.r]
        
    lamMs = []
    for idx, ecpt in enumerate(np.concatenate([[0], det.ecpts])):
        lamM = det.RecResCur.LamMs[:, int(ecpt)]
        lamMs.append(lamM)
    
    det.curEigVecs = Ur
    det.curEigVals = lamMs

In [None]:
eigVecss = []
eigValss = []
ps = list(resDir.glob("Ctrl_*04_decimate5.pkl"))
sortedPs = sorted(ps, key=lambda p:int(p.stem.split("_")[3]))
for fil in tqdm(sortedPs):
    with open(fil, "rb") as f:
        det = pickle.load(f)
    det.UpdateEcpts(det.numchgs[kps==optKp][0])
    GetFeatures(det)
    eigVecss.append(det.curEigVecs)
    eigValss.append(det.curEigVals)
filNameVecs = f"Ctrl_data_lamb{det.paras.lamb:.1E}_eigVecs.pkl"
filNameVals = f"Ctrl_data_lamb{det.paras.lamb:.1E}_eigVals.pkl"
with open(resDir/filNameVecs, "wb") as f:
    pickle.dump(eigVecss, f)
with open(resDir/filNameVals, "wb") as f:
    pickle.dump(eigValss, f)

In [None]:
eigVecss = []
eigValss = []
ps = list(resDir.glob("AD_*04_decimate5.pkl"))
sortedPs = sorted(ps, key=lambda p:int(p.stem.split("_")[3]))
for fil in tqdm(sortedPs):
    with open(fil, "rb") as f:
        det = pickle.load(f)
    det.UpdateEcpts(det.numchgs[kps==optKp][0])
    GetFeatures(det)
    eigVecss.append(det.curEigVecs)
    eigValss.append(det.curEigVals)
filNameVecs = f"AD_data_lamb{det.paras.lamb:.1E}_eigVecs.pkl"
filNameVals = f"AD_data_lamb{det.paras.lamb:.1E}_eigVals.pkl"
with open(resDir/filNameVecs, "wb") as f:
    pickle.dump(eigVecss, f)
with open(resDir/filNameVals, "wb") as f:
    pickle.dump(eigValss, f)