In [None]:
from congress_utils import congress_clsna, get_always_in
from utils import preprocess, ClsnaModel
from utils import visualize_membership, visualize

import numpy as np
import torch
import math
import matplotlib.pyplot as plt
from scipy.linalg import orthogonal_procrustes

In [None]:
import time

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"

In [None]:
N = 550
DIM = 2
T = 10
SIGMA = 1
TAU = 1
PHI = 1
ALPHA = 1
DELTA = 2
GAMMAW = 0.25
GAMMAB = 0.5
N_LEAVE = 50

In [None]:
start = time.time()
z,y,persist,Aw,Ab,leaves=congress_clsna(N=N, d=DIM, T=T, alpha=ALPHA, delta=DELTA, sigma=SIGMA, tau=TAU , phi=PHI, gammaw=GAMMAW, gammab=GAMMAB, n_leave=N_LEAVE)
end = time.time()
print(end - start)

In [None]:
z, y, persist, Aw, Ab = get_always_in(N, T, z, y, persist, Aw, Ab, leaves)

In [None]:
N = z[0].shape[0]
DIM = 2
T = 10
SIGMA = 1
TAU = 1

In [None]:
z = np.concatenate(z)
label, persist, Aw, Ab, combination_N=preprocess(y, Aw, Ab, N, T)

In [None]:
_s = torch.arange(0,N*(T-1), requires_grad = False, device = device)
ar_pair = torch.stack((_s,_s+N), dim = 1)

In [None]:
combination_N = combination_N.to(device)
label = label.to(device)
persist = persist.to(device)

In [None]:
LR = 2e-3
MOM = 0.99
LR_P = 1e-2

In [None]:
def train(optimizer, index=None, fixed=None):
    t_index=torch.arange(start=0,end=N*T,device=device,requires_grad=False)
    optimizer.zero_grad()
    loss = model.loss(device=device,label=label,persist=persist,sample_edge=combination_N,T_index=t_index,ss=SIGMA,tt=TAU)
    loss.backward()
    model.para.grad = 0.1*((model.para.grad>0).bool().float()-0.5)
    optimizer.step()
    if index is not None:
        with torch.no_grad():
            model.para[index[0],index[1]] = fixed
    return loss.item()

In [None]:
def run(optimizer,index=None,fixed=None):
    for epoch in range(1,3000):
        loss = train(optimizer=optimizer,index=index,fixed=fixed)
    return loss

# Step 1

In [None]:
model = ClsnaModel(device,N,T,ar_pair,Aw,Ab,D=3).to(device)

In [None]:
optimizer = torch.optim.SGD([
    {'params': model.z, "momentum": MOM, "lr": LR*4},
    {'params': model.para, "momentum": 0.0, "lr":LR_P}
    ])

In [None]:
run(optimizer)

In [None]:
PCA_p = torch.pca_lowrank(model.z.cpu())[2][:,[0,1]]
init_z = (model.z.cpu().detach()@PCA_p).detach().numpy()
init_para = model.para.detach().cpu().numpy()

# Step 2

In [None]:
model = ClsnaModel(device,N,T,ar_pair,Aw,Ab,D=2).to(device)
with torch.no_grad():       
    model.z[:,:] = torch.from_numpy(init_z).to(device)
    model.para[:,:] = torch.from_numpy(init_para).to(device)

In [None]:
optimizer = torch.optim.SGD([
    {'params': model.z, "momentum": MOM, "lr": LR/4},
    {'params': model.para, "momentum": 0.0, "lr":LR_P}
    ])

In [None]:
run(optimizer)

In [None]:
init_z = model.z.cpu().detach().numpy()
init_para = model.para.detach().cpu().numpy()

In [None]:
zz = init_z
for i in range(T):
    visualize(z_hat=zz,z_true=z[:,[0,1]],start=N*i,end=N*(i+1))

In [None]:
zz = init_z
for i in range(T):
    visualize(z_hat=zz,z_true=z[:,[1,0]],start=N*i,end=N*(i+1))

# Step 3

In [None]:
model = ClsnaModel(device,N,T,ar_pair,Aw,Ab,D=2).to(device)
with torch.no_grad():       
    model.z[:,:] = torch.from_numpy(init_z).to(device)
    model.para[:,:] = torch.from_numpy(init_para).to(device)
optimizer = torch.optim.SGD([
    {'params': model.z, "momentum": MOM, "lr": LR},
    {'params': model.para, "momentum": 0.0, "lr":LR_P}
    ])    
logL = train(optimizer)

In [None]:
delta_var = 0.1

In [None]:
parad = {'alpha':(0,1),'delta':(2,1),'gw':(1,1),'gb':(2,0)}
var_list = []

for key, value in parad.items():
    model = ClsnaModel(device,N,T,ar_pair,Aw,Ab,D=2).to(device)
    with torch.no_grad():       
        model.z[:,:] = torch.from_numpy(init_z).to(device)
        model.para[:,:] = torch.from_numpy(init_para).to(device)
    optimizer = torch.optim.SGD([
    {'params': model.z, "momentum": MOM, "lr": LR},
    {'params': model.para, "momentum": 0.0, "lr":LR_P}
    ])
    newlogL=run(optimizer,value,init_para[value[0],value[1]]+delta_var)
    
    var_hat = delta_var/(newlogL-logL)**0.5/2**0.5
    print(key,var_hat)
    var_list.append(round(var_hat,5))

In [None]:
init_para = init_para.round(3)
printdict = {'a':init_para[0,1],'d':init_para[2,1],'gw':init_para[1,1],'gb':init_para[2,0]}

In [None]:
for i in range(T-1):
    d1=zz[N*i:N*(i+1)]
    d2=zz[N*(i+1):N*(i+2)]
    c1=d1-(d1.mean(axis=0))[np.newaxis,:]
    c2=d2-(d2.mean(axis=0))[np.newaxis,:]
    R,_ = orthogonal_procrustes(c1,c2)
    print(np.diag(R).round(2))

In [None]:
import csv   
fields=var_list
with open('var001', 'a') as f:
    writer = csv.writer(f)
    writer.writerow(fields)

In [None]:
import csv   
fields=list(printdict.values())
with open('theta001', 'a') as f:
    writer = csv.writer(f)
    writer.writerow(fields)