In [None]:
from utils import simulate_clsna, visualize, visualize_membership, preprocess, ClsnaModel
import numpy as np
import torch
import math

In [None]:
# from torch.utils.tensorboard import SummaryWriter
# writer = SummaryWriter(log_dir='../runs/exp2', comment = "")

In [None]:
import time

In [None]:
N = 100
DIM = 2
T = 10
SIGMA = 1
TAU = 1
ALPHA = 1
DELTA = 2
GAMMAW = 0.25
GAMMAB = 0.5

In [None]:
start = time.time()
z,y,Aw,Ab=simulate_clsna(N=N,d=DIM,T=T,alpha=ALPHA,delta=DELTA,sigma=SIGMA, tau=TAU, gammaw=GAMMAW, gammab=GAMMAB)
end = time.time()
print(end - start)

In [None]:
z = np.concatenate(z)

In [None]:
visualize(z_hat=z,z_true=z,start=N*3,end=N*4)

In [None]:
membership = np.concatenate((np.ones(N//2),np.zeros(N//2)))
visualize_membership(z=z,membership=np.tile(membership,T),start=9*N,end=10*N)

In [None]:
label, persist, Aw, Ab, combination_N=preprocess(y, Aw, Ab, N, T)

In [None]:
_s = torch.arange(0,N*(T-1), requires_grad = False)
_t = torch.arange(N,N*T, requires_grad = False)
ar_pair = torch.stack((_s,_t), dim = 1)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"

In [None]:
combination_N = combination_N.to(device)
label = label.to(device)
persist = persist.to(device)

In [None]:
model = ClsnaModel(device,N,T,ar_pair,Aw,Ab).to(device)

In [None]:
with torch.no_grad():       
    model.z[:,:] = torch.from_numpy(z).to(device)
    #logsigma2
    model.para[0,0] = model.para[0,0].clip(min=2*math.log(SIGMA), max=2*math.log(SIGMA))
    #logtau2
    model.para[1,0] = model.para[1,0].clip(min=2*math.log(TAU), max=2*math.log(TAU))
    #gamma
    model.para[1,1] = model.para[1,1].clip(min=GAMMAW, max=GAMMAW)
    model.para[2,0] = model.para[2,0].clip(min=GAMMAB, max=GAMMAB)
    #alpha
    model.para[0,1] = model.para[0,1].clip(min=ALPHA, max=ALPHA)
    #delta
    model.para[2,1] = model.para[0,1].clip(min=DELTA, max=DELTA)

In [None]:
optimizer = torch.optim.SGD([model.z, model.para], lr=0.2, momentum = 0.95)

In [None]:
def train(optimizer):
    optimizer.zero_grad()   
    t_index=torch.arange(start=0,end=N*T,device=device,requires_grad=False)
    # take loss, calculate grad
    # use sign of the grad for global parameters, take SGD step
    loss = model.loss(device=device,label=label,persist=persist,sample_edge=combination_N,T_index=t_index)
    loss.backward()
    model.para.grad = 0.1*((model.para.grad>0).bool().float()-0.5)
    optimizer.step()
    return loss.item()

In [None]:
def run(optimizer):
    for epoch in range(1,1000000):
        loss = train(optimizer)
        if epoch%2000==0:
            optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr']-0.01
        if epoch%1111== 0:
            tt = epoch%T
            z_hat, p_hat = model()
            z_hat = z_hat.detach().cpu().numpy()
            p_hat = p_hat.detach().cpu().numpy().round(2)
            caption_dict = {'E':epoch,
                            'T':tt,
                            'a':p_hat[0,1],
                            'd':p_hat[2,1],
                            'gw':p_hat[1,1],
                            'gb':p_hat[2,0],
#                             's':round(math.exp(p_hat[0,0])**0.5,1),
#                             't':round(math.exp(p_hat[1,0])**0.5,1),
                           'lr':round(optimizer.param_groups[0]['lr'],2),
                           'loss':round(loss,1)}
            start = tt*N
            end = (tt+1)*N
            visualize(z_hat=z_hat,z_true=z[:,[1,0]],start=start,end=end,caption=str(caption_dict))
#             visualize(z_hat=z_hat,z_true=z,start=start,end=end,caption=str(caption_dict))
            
            
            
#             writer.add_scalar("Plot/logL", loss, epoch)
# #             writer.add_scalar("Plot/alpha", alpha, epoch)
#             writer.add_scalar("Plot/lr", optimizer.param_groups[0]['lr'], epoch)
            
            
            
            

In [None]:
run(optimizer)

In [None]:
def fix_train(optimizer,index,fixed):
    optimizer.zero_grad()   
    t_index=torch.arange(start=0,end=N*T,device=device,requires_grad=False)
    loss = model.loss(device=device,label=label,persist=persist,sample_edge=combination_N,T_index=t_index)
    loss.backward()
    model.para.grad = 0.2*((model.para.grad>0).bool().float()-0.5)
    optimizer.step()
    with torch.no_grad():
        model.para[index//2,index%2] = fixed
    return loss.item()

In [None]:
def fix_run(index,fixed,logL_df):
    #initialize model
    model = ClsnaModel(device,N,T,ar_pair,Aw,Ab).to(device)
    with torch.no_grad():       
        model.embedding[:,:] = embed_star.detach().clone()
        model.para[:,:] = para_star.detach().clone()    
    #create optimizer
    optimizer = torch.optim.SGD([model.embedding, model.para], lr=1e-1, momentum = 0.97)
    #initalize list
    logL = []
    for epoch in range(1,1000000):
        loss = fix_train(optimizer,index,fixed)
        if (epoch>4000) and (epoch%100 == 0):
            logL.append(loss)
    logL_df[fixed] = logL
    print(logL_df.mean().tolist())