In [1]:
from utils import simulate_clsna, preprocess, ClsnaModel
import numpy as np
import torch
from scipy.linalg import orthogonal_procrustes
import csv   

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"

In [3]:
def train(device,model,optimizer,N,T,SIGMA,TAU,label,persist,combination_N,index=None, fixed=None):
    t_index=torch.arange(start=0,end=N*T,device=device,requires_grad=False)
    optimizer.zero_grad()
    loss = model.loss(device=device,label=label,persist=persist,sample_edge=combination_N,T_index=t_index,ss=SIGMA,tt=TAU)
    loss.backward()
    model.para.grad = 0.1*((model.para.grad>0).bool().float()-0.5)
    optimizer.step()
    if index is not None:
        with torch.no_grad():
            model.para[index[0],index[1]] = fixed
    return loss.item()

In [4]:
def run(device,model,optimizer,N,T,SIGMA,TAU,label,persist,combination_N,index=None,fixed=None):
    for epoch in range(1,2000):
        loss = train(device,model,optimizer,N,T,SIGMA,TAU,label,persist,combination_N,index,fixed)
    return loss

# Step 1

In [5]:
def step1(device,N,T,ar_pair,Aw,Ab,LR,LR_P,MOM,SIGMA,TAU,label,persist,combination_N):
    model = ClsnaModel(device,N,T,ar_pair,Aw,Ab,D=3).to(device)
    
    optimizer = torch.optim.SGD([
    {'params': model.z, "momentum": MOM, "lr": LR},
    {'params': model.para, "momentum": 0.0, "lr":LR_P}
    ])
    
    run(device,model,optimizer,N,T,SIGMA,TAU,label,persist,combination_N)
    PCA_p = torch.pca_lowrank(model.z.cpu())[2][:,[0,1]]
    init_z = (model.z.cpu().detach()@PCA_p).detach().numpy()
    init_para = model.para.detach().cpu().numpy()
    return init_z, init_para

# Step 2

In [6]:
def step2(device,N,T,ar_pair,Aw,Ab,init_z,init_para,LR,LR_P,MOM,SIGMA,TAU,label,persist,combination_N):
    model = ClsnaModel(device,N,T,ar_pair,Aw,Ab,D=2).to(device)
    
    with torch.no_grad():       
        model.z[:,:] = torch.from_numpy(init_z).to(device)
        model.para[:,:] = torch.from_numpy(init_para).to(device)
    
    optimizer = torch.optim.SGD([
    {'params': model.z, "momentum": MOM, "lr": LR},
    {'params': model.para, "momentum": 0.0, "lr":LR_P}
    ])
    
    loss = run(device,model,optimizer,N,T,SIGMA,TAU,label,persist,combination_N)
    init_z = model.z.cpu().detach().numpy()
    init_para = model.para.detach().cpu().numpy()
    return init_z, init_para, loss

# Step 3

In [7]:
def step3(device,N,T,ar_pair,Aw,Ab,init_z,init_para,LR,LR_P,MOM,SIGMA,TAU,label,persist,combination_N):
    model = ClsnaModel(device,N,T,ar_pair,Aw,Ab,D=2).to(device)
    with torch.no_grad():       
        model.z[:,:] = torch.from_numpy(init_z).to(device)
        model.para[:,:] = torch.from_numpy(init_para).to(device)
    optimizer = torch.optim.SGD([
        {'params': model.z, "momentum": MOM, "lr": LR},
        {'params': model.para, "momentum": 0.0, "lr":LR_P}
        ])    
    logL = train(device,model,optimizer,N,T,SIGMA,TAU,label,persist,combination_N,index=None, fixed=None)

    
    delta_var = 0.3
    parad = {'alpha':(0,1),'delta':(2,1),'gw':(1,1),'gb':(2,0)}
    var_list = []
    for key, value in parad.items():
        model = ClsnaModel(device,N,T,ar_pair,Aw,Ab,D=2).to(device)
        with torch.no_grad():       
            model.z[:,:] = torch.from_numpy(init_z).to(device)
            model.para[:,:] = torch.from_numpy(init_para).to(device)
        optimizer = torch.optim.SGD([
        {'params': model.z, "momentum": MOM, "lr": LR},
        {'params': model.para, "momentum": 0.0, "lr":LR_P}
        ])
        newlogL = run(device,model,optimizer,N,T,SIGMA,TAU,label,persist,combination_N,index=value,fixed=init_para[value[0],value[1]]+delta_var)
        var_hat = delta_var/(newlogL-logL)**0.5/2**0.5
        var_list.append(round(var_hat,5))
    return var_list

In [8]:
def simulation_study():
    #initialize parameters
    N = 100
    DIM = 2
    T = 10
    SIGMA = 1
    TAU = 1
    ALPHA = 1
    DELTA = 2
    GAMMAW = 0.25
    GAMMAB = 0.5
    #simulate data
    z,y,Aw,Ab=simulate_clsna(N=N,d=DIM,T=T,alpha=ALPHA,delta=DELTA,sigma=SIGMA, tau=TAU, gammaw=GAMMAW, gammab=GAMMAB)
    #preprocess
    z = np.concatenate(z)
    label, persist, Aw, Ab, combination_N=preprocess(y, Aw, Ab, N, T)
    combination_N = combination_N.to(device)
    label = label.to(device)
    persist = persist.to(device)
    #ar_pair
    _s = torch.arange(0,N*(T-1), requires_grad = False)
    ar_pair = torch.stack((_s,_s+N), dim = 1)
    #set learning rate
    LR = 2e-3
    MOM = 0.99
    LR_P = 1e-2
    
    print('Round 1')
    init_z, init_para = step1(device,N,T,ar_pair,Aw,Ab,LR,LR_P,MOM,SIGMA,TAU,label,persist,combination_N)
    print(init_para)
    init_z, init_para, loss = step2(device,N,T,ar_pair,Aw,Ab,init_z,init_para,LR,LR_P,MOM,SIGMA,TAU,label,persist,combination_N)
    print(init_para)
    print(loss)
    for jj in range(2):
        print('Round ',jj+2)
        new_init_z, new_init_para = step1(device,N,T,ar_pair,Aw,Ab,LR,LR_P,MOM,SIGMA,TAU,label,persist,combination_N)
        print(new_init_para)
        new_init_z, new_init_para, new_loss = step2(device,N,T,ar_pair,Aw,Ab,new_init_z,new_init_para,LR,LR_P,MOM,SIGMA,TAU,label,persist,combination_N)
        print(new_init_para)
        print(new_loss)
        if new_loss<loss:
            init_z = new_init_z
            init_para = new_init_para
            loss=new_loss
    
    zz=init_z
    for i in range(T-1):
        d1=zz[N*i:N*(i+1)]
        d2=zz[N*(i+1):N*(i+2)]
        c1=d1-(d1.mean(axis=0))[np.newaxis,:]
        c2=d2-(d2.mean(axis=0))[np.newaxis,:]
        R,_ = orthogonal_procrustes(c1,c2)
        with open('rotate001', 'a') as f:
            writer = csv.writer(f)
            writer.writerow(np.diag(R).round(2))
    
    var_list = step3(device,N,T,ar_pair,Aw,Ab,init_z,init_para,LR,LR_P,MOM,SIGMA,TAU,label,persist,combination_N)
    print(var_list)
    
    init_para = init_para.round(3)
    printdict = {'a':init_para[0,1],'d':init_para[2,1],'gw':init_para[1,1],'gb':init_para[2,0]}
    fields=var_list
    with open('var001', 'a') as f:
        writer = csv.writer(f)
        writer.writerow(fields)
    fields=list(printdict.values())
    with open('theta001', 'a') as f:
        writer = csv.writer(f)
        writer.writerow(fields)

In [None]:
for study in range(100):
    simulation_study()

Round 1
[[0.9995187  0.9995187 ]
 [0.9995187  0.17549953]
 [0.43649617 0.9995187 ]]
[[1.9989463  0.84001124]
 [1.9989463  0.16599965]
 [0.48399556 1.9569494 ]]
25785.514938643333
Round  2
[[0.9995187  0.9995187 ]
 [0.9995187  0.17549953]
 [0.43249622 0.9995187 ]]
[[1.9989463  0.84001124]
 [1.9989463  0.16599965]
 [0.48299557 1.9569494 ]]
25785.513513836
Round  3
[[0.9995187  0.9995187 ]
 [0.9995187  0.16849962]
 [0.4254963  0.9995187 ]]
[[1.9989463  0.84001124]
 [1.9989463  0.15699977]
 [0.4889955  1.9569494 ]]
25786.03436832583
[0.02485, 0.02474, 0.08707, 0.09074]
Round 1
[[0.9995187  0.9995187 ]
 [0.9995187  0.14249995]
 [0.5014954  0.9995187 ]]
[[1.9989463  0.81000984]
 [1.9989463  0.19199932]
 [0.4839956  1.9419504 ]]
26004.93213007297
Round  2
[[0.9995187  0.9995187 ]
 [0.9995187  0.1385    ]
 [0.49249545 0.9995187 ]]
[[1.9989463  0.81000984]
 [1.9989463  0.19699925]
 [0.48299557 1.9419504 ]]
26003.830592638813
Round  3
[[0.9995187  0.9995187 ]
 [0.9995187  0.13750002]
 [0.4934954

[[1.9989463  0.82501054]
 [1.9989463  0.09500012]
 [0.5439974  1.9379507 ]]
25749.061220313015
Round  3
[[0.9995187  0.9995187 ]
 [0.9995187  0.07250007]
 [0.4994954  0.9995187 ]]
[[1.9989463  0.8240105 ]
 [1.9989463  0.09100011]
 [0.5419973  1.9379507 ]]
25753.541131641563
[0.02499, 0.02494, 0.08396, 0.08361]
Round 1
[[0.9995187  0.9995187 ]
 [0.9995187  0.15949973]
 [0.47649565 0.9995187 ]]
[[1.9989463  0.8110099 ]
 [1.9989463  0.16599965]
 [0.53399694 1.9539496 ]]
25976.059042395016
Round  2
[[0.9995187  0.9995187 ]
 [0.9995187  0.13949999]
 [0.5014954  0.9995187 ]]
[[1.9989463  0.8110099 ]
 [1.9989463  0.16699964]
 [0.53399694 1.9529496 ]]
25976.061596405612
Round  3
[[0.9995187  0.9995187 ]
 [0.9995187  0.16749963]
 [0.47449568 0.9995187 ]]
[[1.9989463  0.80400956]
 [1.9989463  0.19099933]
 [0.5379971  1.9519497 ]]
25988.47979917063
[0.0246, 0.02453, 0.09188, 0.08956]
Round 1
[[0.9995187  0.9995187 ]
 [0.9995187  0.23549876]
 [0.41249648 0.9995187 ]]
[[1.9989463  0.81000984]
 [1.9

[[1.9989463  0.8150101 ]
 [1.9989463  0.24699861]
 [0.4329962  1.9159523 ]]
25859.862898282612
[0.02473, 0.02474, 0.09965, 0.09269]
Round 1
[[0.9995187  0.9995187 ]
 [0.9995187  0.18049946]
 [0.43049625 0.9995187 ]]
[[1.9989463  0.8160101 ]
 [1.9989463  0.22499889]
 [0.45099598 1.9229518 ]]
25882.036100622077
Round  2
[[0.9995187 0.9995187]
 [0.9995187 0.1934993]
 [0.4254963 0.9995187]]
[[1.9989463  0.8160101 ]
 [1.9989463  0.22499889]
 [0.45199597 1.9229518 ]]
25881.33035955194
Round  3
[[0.9995187  0.9995187 ]
 [0.9995187  0.18449941]
 [0.4344962  0.9995187 ]]
[[1.9989463  0.8150101 ]
 [1.9989463  0.22299892]
 [0.45099598 1.9219519 ]]
25882.40571664469
[0.02473, 0.02453, 0.08869, 0.08806]
Round 1
[[0.9995187  0.9995187 ]
 [0.9995187  0.19749925]
 [0.4034966  0.9995187 ]]
[[1.9989463  0.82301044]
 [1.9989463  0.19599926]
 [0.45499593 1.9429504 ]]
25826.12305908271
Round  2
[[0.9995187 0.9995187]
 [0.9995187 0.1464999]
 [0.4574959 0.9995187]]
[[1.9989463  0.8070097 ]
 [1.9989463  0.183

[[1.9989463  0.7770083 ]
 [1.9989463  0.23499876]
 [0.48599553 1.9239517 ]]
25931.521139696313
Round  2
[[0.9995187  0.9995187 ]
 [0.9995187  0.24049869]
 [0.38249686 0.9995187 ]]
[[1.9989463  0.77000797]
 [1.9989463  0.26899832]
 [0.45599592 1.9249517 ]]
25948.41520247042
Round  3
[[0.9995187  0.9995187 ]
 [0.9995187  0.24849859]
 [0.38349685 0.9995187 ]]
[[1.9989463  0.7690079 ]
 [1.9989463  0.21999896]
 [0.46699578 1.920952  ]]
25969.098708595193


In [None]:
# for i in range(T-1):
#     d1=zz[N*i:N*(i+1)]
#     d2=zz[N*(i+1):N*(i+2)]
#     c1=d1-(d1.mean(axis=0))[np.newaxis,:]
#     c2=d2-(d2.mean(axis=0))[np.newaxis,:]
#     R,_ = orthogonal_procrustes(c1,c2)
#     print(np.diag(R).round(2))

In [None]:
# import csv   
# fields=var_list
# with open('var001', 'a') as f:
#     writer = csv.writer(f)
#     writer.writerow(fields)

In [None]:
# import csv   
# fields=list(printdict.values())
# with open('theta001', 'a') as f:
#     writer = csv.writer(f)
#     writer.writerow(fields)