In [1]:
from utils import simulate_clsna, preprocess, ClsnaModel
import numpy as np
import torch
from scipy.linalg import orthogonal_procrustes

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = "cpu"

In [3]:
N = 100
DIM = 2
T = 10
SIGMA = 1
TAU = 1
ALPHA = 1
DELTA = 2
GAMMAW = 0.25
GAMMAB = 0.5

In [4]:
z,y,Aw,Ab=simulate_clsna(N=N,d=DIM,T=T,alpha=ALPHA,delta=DELTA,sigma=SIGMA, tau=TAU, gammaw=GAMMAW, gammab=GAMMAB)

In [5]:
z = np.concatenate(z)
label, persist, Aw, Ab, combination_N=preprocess(y, Aw, Ab, N, T)

In [6]:
_s = torch.arange(0,N*(T-1), requires_grad = False)
ar_pair = torch.stack((_s,_s+N), dim = 1)

In [7]:
combination_N = combination_N.to(device)
label = label.to(device)
persist = persist.to(device)

In [8]:
LR = 2e-3
MOM = 0.99
LR_P = 1e-2

In [9]:
def train(optimizer, index=None, fixed=None):
    t_index=torch.arange(start=0,end=N*T,device=device,requires_grad=False)
    optimizer.zero_grad()
    loss = model.loss(device=device,label=label,persist=persist,sample_edge=combination_N,T_index=t_index,ss=SIGMA,tt=TAU)
    loss.backward()
    model.para.grad = 0.1*((model.para.grad>0).bool().float()-0.5)
    optimizer.step()
    if index is not None:
        with torch.no_grad():
            model.para[index[0],index[1]] = fixed
    return loss.item()

In [10]:
def run(optimizer,index=None,fixed=None):
    for epoch in range(1,3000):
        loss = train(optimizer=optimizer,index=index,fixed=fixed)
    return loss

# Step 1

In [11]:
model = ClsnaModel(device,N,T,ar_pair,Aw,Ab,D=3).to(device)

In [12]:
optimizer = torch.optim.SGD([
    {'params': model.z, "momentum": MOM, "lr": LR*4},
    {'params': model.para, "momentum": 0.0, "lr":LR_P}
    ])

In [13]:
run(optimizer)

28556.388548057832

In [14]:
PCA_p = torch.pca_lowrank(model.z.cpu())[2][:,[0,1]]
init_z = (model.z.cpu().detach()@PCA_p).detach().numpy()
init_para = model.para.detach().cpu().numpy()

# Step 2

In [15]:
model = ClsnaModel(device,N,T,ar_pair,Aw,Ab,D=2).to(device)
with torch.no_grad():       
    model.z[:,:] = torch.from_numpy(init_z).to(device)
    model.para[:,:] = torch.from_numpy(init_para).to(device)

In [16]:
optimizer = torch.optim.SGD([
    {'params': model.z, "momentum": MOM, "lr": LR/4},
    {'params': model.para, "momentum": 0.0, "lr":LR_P}
    ])

In [17]:
run(optimizer)

26720.70103666426

In [18]:
init_z = model.z.cpu().detach().numpy()
init_para = model.para.detach().cpu().numpy()

# Step 3

In [19]:
model = ClsnaModel(device,N,T,ar_pair,Aw,Ab,D=2).to(device)
with torch.no_grad():       
    model.z[:,:] = torch.from_numpy(init_z).to(device)
    model.para[:,:] = torch.from_numpy(init_para).to(device)
optimizer = torch.optim.SGD([
    {'params': model.z, "momentum": MOM, "lr": LR},
    {'params': model.para, "momentum": 0.0, "lr":LR_P}
    ])    
logL = train(optimizer)

In [20]:
delta_var = 0.1

In [21]:
parad = {'alpha':(0,1),'delta':(2,1),'gw':(1,1),'gb':(2,0)}
var_list = []

for key, value in parad.items():
    model = ClsnaModel(device,N,T,ar_pair,Aw,Ab,D=2).to(device)
    with torch.no_grad():       
        model.z[:,:] = torch.from_numpy(init_z).to(device)
        model.para[:,:] = torch.from_numpy(init_para).to(device)
    optimizer = torch.optim.SGD([
    {'params': model.z, "momentum": MOM, "lr": LR},
    {'params': model.para, "momentum": 0.0, "lr":LR_P}
    ])
    newlogL=run(optimizer,value,init_para[value[0],value[1]]+delta_var)
    
    var_hat = delta_var/(newlogL-logL)**0.5/2**0.5
    var_list.append(round(var_hat,5))

TypeError: type complex doesn't define __round__ method

In [None]:
init_para = init_para.round(3)
printdict = {'a':init_para[0,1],'d':init_para[2,1],'gw':init_para[1,1],'gb':init_para[2,0]}

In [None]:
for i in range(T-1):
    d1=zz[N*i:N*(i+1)]
    d2=zz[N*(i+1):N*(i+2)]
    c1=d1-(d1.mean(axis=0))[np.newaxis,:]
    c2=d2-(d2.mean(axis=0))[np.newaxis,:]
    R,_ = orthogonal_procrustes(c1,c2)
    print(np.diag(R).round(2))

In [None]:
import csv   
fields=var_list
with open('var001', 'a') as f:
    writer = csv.writer(f)
    writer.writerow(fields)

In [None]:
import csv   
fields=list(printdict.values())
with open('theta001', 'a') as f:
    writer = csv.writer(f)
    writer.writerow(fields)