In [None]:
from utils import simulate_clsna, visualize, visualize_membership, preprocess, ClsnaModel
import numpy as np
import torch
import math
import matplotlib.pyplot as plt

In [None]:
# from torch.utils.tensorboard import SummaryWriter
# writer = SummaryWriter(log_dir='../runs/exp2', comment = "")

In [None]:
import time

In [None]:
N = 10000
DIM = 2
T = 10
SIGMA = 5
TAU = 1
ALPHA = -1
DELTA = 2
GAMMAW = 0.55
GAMMAB = -0.5

In [None]:
start = time.time()
z,y,Aw,Ab=simulate_clsna(N=N,d=DIM,T=T,alpha=ALPHA,delta=DELTA,sigma=SIGMA, tau=TAU, gammaw=GAMMAW, gammab=GAMMAB)
end = time.time()
print(end - start)

In [None]:
z[3].transpose()@z[4]

In [None]:
z = np.concatenate(z)

In [None]:
visualize(z_hat=z,z_true=z,start=N*3,end=N*4)

In [None]:
membership = np.concatenate((np.ones(N//2),np.zeros(N//2)))
visualize_membership(z=z,membership=np.tile(membership,T),start=9*N,end=10*N)

In [None]:
label, persist, Aw, Ab, combination_N=preprocess(y, Aw, Ab, N, T)

In [None]:
label.sum()/label.shape[0]

In [None]:
_s = torch.arange(0,N*(T-1), requires_grad = False)
_t = torch.arange(N,N*T, requires_grad = False)
ar_pair = torch.stack((_s,_t), dim = 1)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"

In [None]:
# combination_N = combination_N.to(device)
# label = label.to(device)
# persist = persist.to(device)

In [None]:
model = ClsnaModel(device,N,T,ar_pair,Aw,Ab).to(device)

In [None]:
FOLDS = 50
BS = label.size(0)//FOLDS

In [None]:
# with torch.no_grad():       
#     model.z[:,:] = torch.from_numpy(z).to(device)
#     #logsigma2
#     model.para[0,0] = model.para[0,0].clip(min=2*math.log(SIGMA), max=2*math.log(SIGMA))
#     #logtau2
#     model.para[1,0] = model.para[1,0].clip(min=2*math.log(TAU), max=2*math.log(TAU))
#     #gamma
#     model.para[1,1] = model.para[1,1].clip(min=GAMMAW, max=GAMMAW)
#     model.para[2,0] = model.para[2,0].clip(min=GAMMAB, max=GAMMAB)
#     #alpha
#     model.para[0,1] = model.para[0,1].clip(min=ALPHA, max=ALPHA)
#     #delta
#     model.para[2,1] = model.para[0,1].clip(min=DELTA, max=DELTA)

In [None]:
optimizer = torch.optim.SGD([
    {'params': model.z, "momentum": 0.995, "lr": 1e-3},
    {'params': model.para, "momentum": 0.0, "lr":1e-1}
    ])

In [None]:
def train(optimizer,update=False):
    perm_i = torch.randperm(label.size(0))
    t_index=torch.arange(start=0,end=N*T,device=device,requires_grad=False)
    for fold in range(FOLDS):
        print(fold)
        ii = perm_i[BS*fold:BS*(fold+1)]
        optimizer.zero_grad(set_to_none=True)
        loss = model.loss(device=device,label=label[ii].to(device),persist=persist[ii].to(device),sample_edge=combination_N[ii].to(device),T_index=t_index)
        loss.backward()
        model.para.grad = 0.1*((model.para.grad>0).bool().float()-0.5)

        if update and (fold == FOLDS-1):
            plt.hist(torch.log10((model.z.grad).pow(2)+1e-20).cpu().detach().numpy().flatten(), bins=100)
            z_hat1, p_hat1 = model()
            z_hat1 = z_hat1.detach().clone()
            p_hat1 = p_hat1.detach().clone()
            optimizer.step()
            z_hat2, p_hat2 = model()
            plt.hist(torch.log10((z_hat1-z_hat2).pow(2)+1e-20).cpu().detach().numpy().flatten(), bins=100)
            plt.show()
        else:
            optimizer.step()
    return loss.item()

In [None]:
optimizer.param_groups

In [None]:
def run(optimizer):
    for epoch in range(1,1000000):
        if optimizer.param_groups[0]['lr']<1e-10:
            break
        if True:
#             optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr']
            loss = train(optimizer,update=True)
            tt = epoch%T
            print(torch.pca_lowrank(model.z.cpu()))
            z_hat, p_hat = model()
            z_hat = z_hat.detach().cpu().numpy()
            p_hat = p_hat.detach().cpu().numpy().round(2)
            caption_dict = {'E':epoch,
                            'T':tt,
                            'a':p_hat[0,1],
                            'd':p_hat[2,1],
                            'gw':p_hat[1,1],
                            'gb':p_hat[2,0],
#                             's':round(math.exp(p_hat[0,0])**0.5,1),
#                             't':round(math.exp(p_hat[1,0])**0.5,1),
                           'lr':round(optimizer.param_groups[0]['lr'],3),
                           'loss':round(loss,1)}
            start = tt*N
            end = (tt+1)*N
            visualize(z_hat=z_hat[:,:2],z_true=z[:,[1,0]],start=start,end=end,caption=str(caption_dict))           
            
            
            

In [None]:
run(optimizer)

In [None]:
PCA_p = torch.pca_lowrank(model.z.cpu())[2][:,[0,1]]

In [None]:
torch.pca_lowrank(model.z.cpu())[1]

In [None]:
for i in range(T):
    visualize(z_hat=(model.z.cpu().detach()@PCA_p).detach().numpy(),z_true=z[:,[0,1]],start=N*i,end=N*(i+1))
    time.sleep(2)

In [None]:
zz=(model.z.cpu().detach()@PCA_p).detach().numpy()

In [None]:
for i in range(T-1):
    corr = zz[N*i:N*(i+1)].transpose()@zz[N*(i+1):N*(i+2)]
    time.sleep(1)
    print(corr.round(1))

In [None]:
def fix_train(optimizer,index,fixed):
    optimizer.zero_grad()   
    t_index=torch.arange(start=0,end=N*T,device=device,requires_grad=False)
    loss = model.loss(device=device,label=label,persist=persist,sample_edge=combination_N,T_index=t_index)
    loss.backward()
    model.para.grad = 0.2*((model.para.grad>0).bool().float()-0.5)
    optimizer.step()
    with torch.no_grad():
        model.para[index//2,index%2] = fixed
    return loss.item()

In [None]:
def fix_run(index,fixed,logL_df):
    #initialize model
    model = ClsnaModel(device,N,T,ar_pair,Aw,Ab).to(device)
    with torch.no_grad():       
        model.embedding[:,:] = embed_star.detach().clone()
        model.para[:,:] = para_star.detach().clone()    
    #create optimizer
    optimizer = torch.optim.SGD([model.embedding, model.para], lr=1e-1, momentum = 0.97)
    #initalize list
    logL = []
    for epoch in range(1,1000000):
        loss = fix_train(optimizer,index,fixed)
        if (epoch>4000) and (epoch%100 == 0):
            logL.append(loss)
    logL_df[fixed] = logL
    print(logL_df.mean().tolist())