## Stochastic gradient descent-based inference for dynamic network models with attractors
## This script simulates a dynamic network with changing membership and fitting the extended CLSNA model to the simulated data for inference.

In [None]:
from congress_utils import congress_clsna, preprocess, make_ar_pair, member_dict, ClsnaModelCongress
from utils import visualize_membership, visualize
import numpy as np
import torch
import math
import matplotlib.pyplot as plt
from scipy.linalg import orthogonal_procrustes

In [None]:
import time

In [None]:
def read_parameters(file_path):
    params = {}
    with open(file_path, 'r') as file:
        for line in file:
            name, value = line.strip().split('=')
            params[name] = int(value)  # Convert the value to an integer
    return params

In [None]:
# Read the value of N from the file
parameters = read_parameters('parameters.txt')
N_LEAVE = parameters.get('N_LEAVE')

In [None]:
# Set global variables for the model
N = 1000
DIM = 2
T = 10
SIGMA = 1
TAU = 1
PHI = 1
ALPHA = 1
DELTA = 2
GAMMAW = 0.25
GAMMAB = 0.5

In [None]:
# Generate synthetic data for the model
start = time.time()
z,y,persist,Aw,Ab,leaves=congress_clsna(N=N, d=DIM, T=T, alpha=ALPHA, delta=DELTA, sigma=SIGMA, tau=TAU , phi=PHI, gammaw=GAMMAW, gammab=GAMMAB, n_leave=N_LEAVE)
end = time.time()
# print(end - start)

In [None]:
z = np.concatenate(z)
persist = np.concatenate(persist)

In [None]:
# visualize(z_hat=z,z_true=z,start=N*9,end=N*10)

In [None]:
# Visualize membership
membership = np.concatenate((np.ones(N//2),np.zeros(N//2)))
visualize_membership(z=z,membership=np.tile(membership,T),start=9*N,end=10*N)

In [None]:
label, persist, Aw, Ab, combination_N=preprocess(y, Aw, Ab, N, T, persist)

In [None]:
# label.sum()/label.size(0)

In [None]:
# Set device for computation (GPU if available)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = "cpu"

In [None]:
ar_pair = make_ar_pair(device,leaves,N,T)
new_at_t = member_dict(device,leaves,N,T)

In [None]:
combination_N = combination_N.to(device)
label = label.to(device)
persist = persist.to(device)

In [None]:
# Set learning rates
LR = 2e-3
MOM = 0.99
LR_P = 1e-2

# Step 1

In [None]:
print("Step 1: Fitting initial CLSNA model with higher-dimensional space...")

In [None]:
#train the model
def train(optimizer,index=None, fixed=None):
    t_index=torch.arange(start=0,end=N*T,device=device,requires_grad=False)
    optimizer.zero_grad()
    loss = model.loss(device=device,label=label,persist=persist,sample_edge=combination_N,T_index=t_index,ss=SIGMA,tt=TAU,pp=PHI)
    loss.backward()
    model.para.grad = 0.1*((model.para.grad>0).bool().float()-0.5)
    optimizer.step()
    if index is not None:
        with torch.no_grad():
            model.para[index[0],index[1]] = fixed
    return loss.item()

In [None]:
def check_alignment(zz):
    align = 1
    for i in range(T-1):
        d1=zz[N*i:N*(i+1)]
        d2=zz[N*(i+1):N*(i+2)]
        c1=d1-(d1.mean(axis=0))[np.newaxis,:]
        c2=d2-(d2.mean(axis=0))[np.newaxis,:]
        R,_ = orthogonal_procrustes(c1,c2)
        if not np.all(np.diag(R.round(1)) == 1):
            align = 0
            break
    return align

In [None]:
#run the optimization process
def run(optimizer):
    loss_history = []
    window = 300  # define the number of epochs to check for stabilization
    stable_range = 0.2  # define the range within which loss is considered stable
    stable_count = 0  # counter to track stable epochs

    for epoch in range(1, 10000):
        loss = train(optimizer)
        loss_history.append(loss)

        if len(loss_history) > window:
            recent_losses = loss_history[-window:]
            if max(recent_losses) - min(recent_losses) < stable_range:
                stable_count += 1
            else:
                stable_count = 0  # reset if loss is not within the range

            if stable_count >= window:  # check if stabilization period is reached
                print(f"Loss has stabilized for {stable_count} epochs within range of {stable_range}. Stopping training.")
                break

        if epoch % 111 == 0:
            print(f"Epoch {epoch}: Loss = {loss}")
#             print(f"Model parameters: {model.para}")


In [None]:
# Initialize and train the first model
model = ClsnaModelCongress(device,N,T,ar_pair,Aw,Ab,new_at_t,DIM+1).to(device)

In [None]:
optimizer = torch.optim.SGD([
    {'params': model.z, "momentum": 0.99, "lr": LR},
    {'params': model.para, "momentum": 0, "lr":LR_P}
    ])

In [None]:
run(optimizer)

In [None]:
model.para

In [None]:
# Perform PCA to reduce dimensionality
PCA_p = torch.pca_lowrank(model.z.cpu())[2][:,[0,1]]
zz=(model.z.cpu().detach()@PCA_p).detach().numpy()
init_z = zz
init_para = model.para.detach().cpu().numpy()

# Step 2

In [None]:
print("Step 2: Fitting CLSNA model with targeted dimension and estimating model parameters...")


In [None]:
model = ClsnaModelCongress(device,N,T,ar_pair,Aw,Ab,new_at_t,DIM).to(device)
with torch.no_grad():       
    model.z[:,:] = torch.from_numpy(init_z).to(device)
    model.para[:,:] = torch.from_numpy(init_para).to(device)
optimizer = torch.optim.SGD([
    {'params': model.z, "momentum": MOM, "lr": LR},
    {'params': model.para, "momentum": 0.0, "lr":LR_P}
    ])

In [None]:
def run(optimizer):
    loss_history = []
    window = 300  # define the number of epochs to check for stabilization
    stable_range = 0.2  # define the range within which loss is considered stable
    stable_count = 0  # counter to track stable epochs

    for epoch in range(1, 10000):
        loss = train(optimizer)
        loss_history.append(loss)

        if len(loss_history) > window:
            recent_losses = loss_history[-window:]
            if max(recent_losses) - min(recent_losses) < stable_range:
                stable_count += 1
            else:
                stable_count = 0  # reset if loss is not within the range

            if stable_count >= window:  # check if stabilization period is reached
                print(f"Loss has stabilized for {stable_count} epochs within range of {stable_range}. Stopping training.")
                break

        if epoch % 111 == 0:
            print(f"Epoch {epoch}: Loss = {loss}")
#             print(f"Model parameters: {model.para}")


In [None]:
run(optimizer)

In [None]:
optimizer = torch.optim.SGD([
    {'params': model.z, "momentum": MOM, "lr": LR/2},
    {'params': model.para, "momentum": 0.0, "lr":LR_P/2}
    ])

In [None]:
run(optimizer)

In [None]:
zz = model.z.cpu().detach().numpy()

In [None]:
for i in range(T):
    visualize(z_hat=zz,z_true=z[:,[1,0]],start=N*i,end=N*(i+1))

In [None]:
for i in range(T):
    visualize(z_hat=zz,z_true=z[:,[0,1]],start=N*i,end=N*(i+1))

In [None]:
init_z = zz
init_para = model.para.detach().cpu().numpy()

In [None]:
init_para

# Step 3

In [None]:
print("Step 3: Performing variance/covariance estimation for the parameters of interest...")

In [None]:
def run(optimizer,index,fixed):
    loss_history = []
    window = 300  # define the number of epochs to check for stabilization
    stable_range = 0.2  # define the range within which loss is considered stable
    stable_count = 0  # counter to track stable epochs

    for epoch in range(1, 10000):
        loss = train(optimizer=optimizer,index=index,fixed=fixed)
        loss_history.append(loss)

        if len(loss_history) > window:
            recent_losses = loss_history[-window:]
            if max(recent_losses) - min(recent_losses) < stable_range:
                stable_count += 1
            else:
                stable_count = 0  # reset if loss is not within the range

            if stable_count >= window:  # check if stabilization period is reached
                print(f"Loss has stabilized for {stable_count} epochs within range of {stable_range}. Stopping training.")
                break

        if epoch % 111 == 0:
            print(f"Epoch {epoch}: Loss = {loss}")
#             print(f"Model parameters: {model.para}")
    return loss


In [None]:
model = ClsnaModelCongress(device,N,T,ar_pair,Aw,Ab,new_at_t,DIM).to(device)
with torch.no_grad():       
    model.z[:,:] = torch.from_numpy(init_z).to(device)
    model.para[:,:] = torch.from_numpy(init_para).to(device)
optimizer = torch.optim.SGD([
    {'params': model.z, "momentum": MOM, "lr": LR/2},
    {'params': model.para, "momentum": 0.0, "lr":LR_P/2}
    ])    
logL = train(optimizer)

In [None]:
delta_var = 0.05/((N-N_LEAVE)/200)**0.5

In [None]:
# Estimate variance/covariance for each parameter
parad = {'alpha':(0,1),'delta':(2,1),'gw':(1,1),'gb':(2,0)}
var_list = {'alpha':0,'delta':0,'gw':0,'gb':0}
for key, value in parad.items():
    model = ClsnaModelCongress(device,N,T,ar_pair,Aw,Ab,new_at_t,DIM).to(device)
    with torch.no_grad():       
        model.z[:,:] = torch.from_numpy(init_z).to(device)
        model.para[:,:] = torch.from_numpy(init_para).to(device)
    optimizer = torch.optim.SGD([
    {'params': model.z, "momentum": MOM, "lr": LR},
    {'params': model.para, "momentum": 0.0, "lr":LR_P}
    ])
    newlogL=run(optimizer,value,init_para[value[0],value[1]]+delta_var)
    print("--------------------------------")
    print(key,delta_var/(newlogL-logL)**0.5/2**0.5)
    var_hat = delta_var/(newlogL-logL)**0.5/2**0.5
    print(var_hat)
    var_list[key] = (round(var_hat,4))

In [None]:
print("var: ",var_list)

In [None]:
init_para = init_para.round(3)
printdict = {'a':init_para[0,1],'d':init_para[2,1],'gw':init_para[1,1],'gb':init_para[2,0]}
print("point estimate: ",printdict)

In [None]:
with open('estvar', 'a') as file:
    # Convert dictionary to string and write it to the file
    file.write(str(var_list) + '\n')

In [None]:
with open('est', 'a') as file:
    # Convert dictionary to string and write it to the file
    file.write(str(printdict) + '\n')

In [1]:
# import csv   
# fields=list(var_list.values())
# with open('var001', 'a') as f:
#     writer = csv.writer(f)
#     writer.writerow(fields)

In [None]:
# import csv   
# fields=list(printdict.values())
# with open('theta001', 'a') as f:
#     writer = csv.writer(f)
#     writer.writerow(fields)