## Stochastic gradient descent-based inference for dynamic network models with attractors
## This scripts plots the trajectory of the mean latent positions of the members of each party, comparing the model fitting result from the reduced dataset in `207.ipynb` and the full dataset in `505.ipynb`.

In [None]:
import numpy as np
from scipy import linalg
import matplotlib.pyplot as plt

In [None]:
node_names = np.load("nodes_names.npy")

In [None]:
names = []
for i in range(1,12):
    names.append(list(np.load('name'+str(i)+'.npy')))

In [None]:
z1 = np.load("z1.npy")
z2 = np.load("z2.npy")

In [None]:
membership = np.load("pi.npy")
membership = membership!='D'

In [None]:
time_point = []
start_idx = 0
for arr in names:
    time_point.append(np.arange(start_idx, start_idx + len(arr)))
    start_idx += len(arr)

In [None]:
def find_match(l1, l2):
    l1_dict = {l1[i]: i for i in range(len(l1))}
    l2_dict = {l2[i]: i for i in range(len(l2))}
    match = []
    for key in l1_dict:
        if key in l2_dict:
            match.append([l1_dict[key], l2_dict[key]])
    return match

In [None]:
select_index = []
for i in range(11):
    match = find_match(list(node_names),list(names[i]))
    _,pos = np.array(match).T
    assert(len(match)==207)
    index_t = time_point[i][pos]
    select_index.append(time_point[i][pos])

In [None]:
select = np.concatenate(select_index)

In [None]:
z2_compare = z2[select]

In [None]:
def soft_align(z_hat,z_true):
    R, sca = linalg.orthogonal_procrustes(z_hat,z_true)
    return z_hat@R

In [None]:
def normalize(z):
    return z/np.linalg.norm(z)

In [None]:
np.save('z2_compare.npy',z2_compare)

In [None]:
z1 = z1-np.mean(z1,axis=0)
z2_compare = z2_compare-np.mean(z2_compare,axis=0)
z2_compare= normalize(soft_align(z2_compare,z1))
z1 = normalize(z1)

In [None]:
mean1=[]
mean2=[]
for i in range(11):
    mean1.append(np.sum(z1[i*207:(i+1)*207][membership],axis=0))
    mean2.append(np.sum(z2_compare[i*207:(i+1)*207][membership],axis=0))

In [None]:
mean3=[]
mean4=[]
for i in range(11):
    mean3.append(np.sum(z1[i*207:(i+1)*207][membership==False],axis=0))
    mean4.append(np.sum(z2_compare[i*207:(i+1)*207][membership==False],axis=0))

In [None]:
fig, ax = plt.subplots(figsize=(8, 8))

data = np.array(mean1)
head_width = min(np.var(data)**0.5/25, np.var(data)**0.5/25)
ax.scatter(data[:, 0], data[:, 1], color='mistyrose', marker='o', label=r'$\bar{Z}(Republican)$ Reduced Dataset', alpha=0.6)

for i in range(len(data) - 1):
    ax.arrow(data[i][0], data[i][1], data[i+1][0] - data[i][0], data[i+1][1] - data[i][1],
             width=0.00001, alpha=0.5, linestyle=':', color='black', linewidth=1, head_width=head_width, length_includes_head=True)

data = np.array(mean2)
ax.scatter(data[:, 0], data[:, 1], color='red', marker='x', label=r'$\bar{Z}(Republican)$ Full Dataset')

for i in range(len(data)):
    if i % 2 == 0:
        ax.annotate(str(10+i), (data[i][0], data[i][1]))

for i in range(len(data) - 1):
    ax.arrow(data[i][0], data[i][1], data[i+1][0] - data[i][0], data[i+1][1] - data[i][1],
             width=0.00001, alpha=0.5, color='black', linewidth=1, head_width=head_width, length_includes_head=True)

data = np.array(mean3)
ax.scatter(data[:, 0], data[:, 1], color='skyblue', marker='o', label=r'$\bar{Z}(Democrat)$ Reduced Dataset', alpha=0.6)

for i in range(len(data) - 1):
    ax.arrow(data[i][0], data[i][1], data[i+1][0] - data[i][0], data[i+1][1] - data[i][1],
             width=0.00001, alpha=0.5, linestyle=':', color='darkgray', linewidth=1, head_width=head_width, length_includes_head=True)

data = np.array(mean4)
ax.scatter(data[:, 0], data[:, 1], color='blue', marker='x', label=r'$\bar{Z}(Democrat)$ Full Dataset')

for i in range(len(data)):
    if i % 2 == 0:
        ax.annotate(str(10+i), (data[i][0], data[i][1]))

for i in range(len(data) - 1):
    ax.arrow(data[i][0], data[i][1], data[i+1][0] - data[i][0], data[i+1][1] - data[i][1],
             width=0.00001, alpha=0.5, color='darkgray', linewidth=1, head_width=head_width, length_includes_head=True)
    
plt.legend()
plt.grid(True, alpha=0.2)

# Add labels and title
ax.set_xlabel(r'$Z_1$')
ax.set_ylabel(r'$Z_2$')
# ax.set_title(title)

# Display the plot
plt.show()
