In [None]:
import numpy as np
import torch
import torch.nn as nn
import math
import sympy as sp
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from torcheval.metrics.functional import r2_score
import architectures as archit
import data as da
import stable as sta

### graph and data

In [None]:
num_vertices = 300 # number of vertices for two circulant graphs
num_graph = 2 # two circulant graphs
not_moving_probabilities_vector = [0.05,0.05] # weight p for the diagonal element of graph 1 and 2
jump_sizes_vector = [1,30] # location of jump l
n_samples = 1000 # number of data pair

# generate the circulant graph tuple
ts = da.cycle_operator_tuple(num_vertices=num_vertices,
                    not_moving_probabilities_vector=not_moving_probabilities_vector,
                    jump_sizes_vector=jump_sizes_vector)
operator_tuple = ts

# normalized the graph tuple
scale_all = 250
ts_normalized = []
for ind_g, g in enumerate(operator_tuple):
  # scale the norm to closer to 1 for faster convergence
  g = g/num_vertices * scale_all
  ts_normalized.append(g)
  print(torch.linalg.matrix_norm(g,ord = 2))
  if torch.linalg.matrix_norm(g,ord = 2) > 1:
    g_normalized = g / torch.linalg.matrix_norm(g,ord = 2)
    print("Graph {:d} has norm {:.4f} bigger than 1. This graph is normalized".format(ind_g,np.array(torch.linalg.matrix_norm(g,ord = 2))))
    ts_normalized[ind_g] = g_normalized
operator_tuple_normalized = (ts_normalized[0],ts_normalized[1])

# generate data using the circulant graph tuple
# y = 0.76*t1@t0@x + 0.33*t0@t1@x + 0.3*t0@t0@t0@x + noise
x, y = da.dataLab_cycles(num_vertices=num_vertices,
                      not_moving_probabilities_vector=not_moving_probabilities_vector,
                      jump_sizes_vector=jump_sizes_vector,
                      noise_stdev = 0.1,
                      n_samples = n_samples)

# train test split
test_data_ratio = 0.2
X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=test_data_ratio)
print(X_train.shape) # (number of data) * (number of features) * (number of veritces)

In [None]:
# create small graph tuple for training
def create_small_tuple(operator_tuple, num_vertices_small, scale_all = 250):
    num_graph = len(operator_tuple)
    num_vertices = operator_tuple[0].shape[0]

    ts_small = []
    for i in range(num_graph):
        ts_small.append(torch.zeros(num_vertices_small,num_vertices_small))

    for ind_row in range(num_vertices):
        for ind_col in range(num_vertices):
            for ind_g, g in enumerate(operator_tuple):
                if g[ind_row,ind_col] > 1e-8:
                    ind_small_row_low = math.floor(ind_row / num_vertices * num_vertices_small)
                    ind_small_col_low = math.floor(ind_col / num_vertices * num_vertices_small)

                    line_small_row_high = (ind_small_row_low+1) / num_vertices_small
                    line_row_low =  ind_row / num_vertices
                    line_row_high =  (ind_row+1) / num_vertices

                    line_small_col_high = (ind_small_col_low+1) / num_vertices_small
                    line_col_low =  ind_col / num_vertices
                    line_col_high =  (ind_col+1) / num_vertices

                    if line_small_row_high < line_row_high:
                        if line_small_col_high < line_col_high:
                            ## line cut in the middle for row and column
                            ts_small[ind_g][ind_small_row_low,ind_small_col_low] += (line_small_row_high - line_row_low) * (line_small_col_high - line_col_low) * g[ind_row,ind_col] * num_vertices_small * num_vertices_small
                            ts_small[ind_g][ind_small_row_low+1,ind_small_col_low] += (line_row_high - line_small_row_high) * (line_small_col_high - line_col_low) * g[ind_row,ind_col] * num_vertices_small * num_vertices_small
                            ts_small[ind_g][ind_small_row_low,ind_small_col_low+1] += (line_small_row_high - line_row_low) * (line_col_high - line_small_col_high) * g[ind_row,ind_col] * num_vertices_small * num_vertices_small
                            ts_small[ind_g][ind_small_row_low+1,ind_small_col_low+1] += (line_row_high - line_small_row_high) * (line_col_high - line_small_col_high) * g[ind_row,ind_col] * num_vertices_small * num_vertices_small
                        else:
                            ## line cut in the middle for row
                            ts_small[ind_g][ind_small_row_low,ind_small_col_low] += (line_small_row_high - line_row_low) * (1/num_vertices) * g[ind_row,ind_col] * num_vertices_small * num_vertices_small
                            ts_small[ind_g][ind_small_row_low+1,ind_small_col_low] += (line_row_high - line_small_row_high) * (1/num_vertices) * g[ind_row,ind_col] * num_vertices_small * num_vertices_small
                    else:
                        if line_small_col_high < line_col_high:
                            ## line cut in the middle for column
                            ts_small[ind_g][ind_small_row_low,ind_small_col_low] += (1/num_vertices) * (line_small_col_high - line_col_low) * g[ind_row,ind_col] * num_vertices_small * num_vertices_small
                            ts_small[ind_g][ind_small_row_low,ind_small_col_low+1] += (1/num_vertices) * (line_col_high - line_small_col_high) * g[ind_row,ind_col] * num_vertices_small * num_vertices_small
                        else:
                            ## entirely in a small graph grid
                            ts_small[ind_g][ind_small_row_low,ind_small_col_low] += (1/num_vertices) * (1/num_vertices) * g[ind_row,ind_col] * num_vertices_small * num_vertices_small

    # normalized the graph tuple
    ts_small_normalized = []
    for ind_g, g_small in enumerate(ts_small):
        g_small = g_small/num_vertices_small * scale_all
        ts_small_normalized.append(g_small)
        print(torch.linalg.matrix_norm(g_small,ord = 2))
        if torch.linalg.matrix_norm(g_small,ord = 2) > 1:
            g_small_normalized = g_small / torch.linalg.matrix_norm(g_small,ord = 2)
            print("Graph {:d} has norm {:.4f} bigger than 1. This graph is normalized".format(ind_g,np.array(torch.linalg.matrix_norm(g_small,ord = 2))))
            ts_small_normalized[ind_g] = g_small_normalized

    operator_tuple_small = (ts_small_normalized[0],ts_small_normalized[1])
   
    return operator_tuple_small


In [None]:
# create small size data for training
def create_small_data(X_train, y_train, num_vertices_small):
    num_vertices = X_train.shape[2]

    X_train_small = torch.zeros(X_train.shape[0], X_train.shape[1], num_vertices_small)
    y_train_small = torch.zeros(y_train.shape[0], y_train.shape[1], num_vertices_small)

    for ind_row in range(num_vertices):
        ind_small_row_low = math.floor(ind_row / num_vertices * num_vertices_small)

        line_small_row_high = (ind_small_row_low+1) / num_vertices_small
        line_row_low =  ind_row / num_vertices
        line_row_high =  (ind_row+1) / num_vertices

        if line_small_row_high < line_row_high:
            ## line cut in the middle
            X_train_small[:,:,ind_small_row_low] += (line_small_row_high - line_row_low) * X_train[:,:,ind_row] * num_vertices_small
            X_train_small[:,:,ind_small_row_low+1] += (line_row_high - line_small_row_high) * X_train[:,:,ind_row] * num_vertices_small
            y_train_small[:,:,ind_small_row_low] += (line_small_row_high - line_row_low) * y_train[:,:,ind_row] * num_vertices_small
            y_train_small[:,:,ind_small_row_low+1] += (line_row_high - line_small_row_high) * y_train[:,:,ind_row] * num_vertices_small
        else:
            ## entirely in a small size graph grid
            X_train_small[:,:,ind_small_row_low] += (1/num_vertices) * X_train[:,:,ind_row] * num_vertices_small
            y_train_small[:,:,ind_small_row_low] += (1/num_vertices) * y_train[:,:,ind_row] * num_vertices_small

    return X_train_small, y_train_small

In [None]:
# use GPU if available
USE_GPU = True

if USE_GPU:
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
else:
    device = torch.device("cpu")

X_test = X_test.to(device)
y_test = y_test.to(device)
operator_tuple_device = (operator_tuple_normalized[0].to(device),operator_tuple_normalized[1].to(device))

### training function

In [None]:
def train_network(model, operator_tuple_train, operator_tuple_test, X_train, y_train, X_test, y_test, 
                  Penalty_lam = 0, Stable_Penalty = False, constrain_C_Flag = True, constrain_Cj_Flag = True, target_Upr_Cj_vec = 1, target_Upr_C = 1,
                  n_epochs = 5000, lr = 0.01, verbose = True, Plot_loss = True):

  loss_fcn = nn.MSELoss()
  M = model.monomial_word_support

  optimizer = torch.optim.Adam(model.parameters(), lr)

  epoch_tr_loss = []
  epoch_ts_loss = []
  R2_ts=[]
  
  for epoch in range(n_epochs):
    # training
    # evaluate monomial for training graph tuple
    M.evaluate_at_operator_tuple(operator_tuple = operator_tuple_train)
    model.change_monomial_word_support(M)
    model.train()
    outs_train = model.forward(X_train)
    loss = loss_fcn(outs_train, y_train)

    # stability penalty
    if Stable_Penalty:
      loss = loss + Penalty_lam * sta.compute_penalty(model, target_Upr_Cj_vec, Upr_C = target_Upr_C,
                                                constrain_C_Flag = constrain_C_Flag, constrain_Cj_Flag = constrain_Cj_Flag)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    epoch_tr_loss.append(loss.item())
    y_train_reshape = torch.reshape(y_train,(y_train.shape[0],y_train.shape[1]*y_train.shape[2]))
    outs_train_reshape = torch.reshape(outs_train,(outs_train.shape[0],outs_train.shape[1]*outs_train.shape[2]))
    #R2Score
    R2_train = r2_score(outs_train_reshape, y_train_reshape)


    # testing
    model.eval()
    with torch.no_grad():
      # evaluate monomial for testing graph tuple
      M.evaluate_at_operator_tuple(operator_tuple = operator_tuple_test)
      model.change_monomial_word_support(M)
      outs_test = model.forward(X_test)
      y_test_reshape = torch.reshape(y_test,(y_test.shape[0],y_test.shape[1]*y_test.shape[2]))
      outs_test_reshape = torch.reshape(outs_test,(outs_test.shape[0],outs_test.shape[1]*outs_test.shape[2]))
      R2_test = r2_score(outs_test_reshape, y_test_reshape)
      GNN_test_loss = loss_fcn(outs_test, y_test)
      Penalty_test_loss = 0
      # stability penalty
      if Stable_Penalty:
        Penalty_test_loss = sta.compute_penalty(model, target_Upr_Cj_vec, Upr_C = target_Upr_C,
                                            constrain_C_Flag = constrain_C_Flag, constrain_Cj_Flag = constrain_Cj_Flag)
      Test_loss = GNN_test_loss + Penalty_lam * Penalty_test_loss
      epoch_ts_loss.append(Test_loss)
      R2_ts.append(R2_test)


    if verbose and (epoch % 10 == 0):
      print("Epoch {:05d} | Train Loss {:.4f} | Train_R^2 {:.4f} | Test Loss {:.4f} | Test_R^2 {:.4f} | Test GNN loss {:.4f} | Test penalty loss {:.4f} |".
            format(epoch,  epoch_tr_loss[epoch], R2_train, epoch_ts_loss[epoch], R2_test, GNN_test_loss, Penalty_test_loss))


  epoch_ts_loss = torch.stack(epoch_ts_loss).cpu()
  R2_ts = torch.stack(R2_ts).cpu()

  if Plot_loss:
    fig = plt.figure()
    epoch_seq=np.arange(100, len(epoch_tr_loss) + 1)
    plt.plot(epoch_seq, epoch_tr_loss[99:],'.-')
    plt.plot(epoch_seq, epoch_ts_loss[99:],'r.-')
    plt.grid()
    plt.xlabel('epochs')
    plt.ylabel('Loss')
    plt.title('Training Curve')
    plt.legend(['Training', 'Test'])
    plt.show()

  return model, R2_test, epoch_ts_loss, R2_ts

### create monomial

In [None]:
num_variable = 2
allowed_degree = 3
M = archit.MonomialWordSupport(num_variables=num_variable, allowed_degree = allowed_degree, device = device)

### 1-layer GtNN

In [None]:
n_epochs = 200
num_features_in = 1
num_features_out = 1
Stable_Penalty = False

num_vertices_small_all = [100,150,200,250,300]
model_GNN_small_all = []
result_GNN_small_all = []
ts_loss_GNN_small_all = []
ts_R2_GNN_small_all = []
for num_vertices_small in num_vertices_small_all:
    # create small size training graph and data
    print(num_vertices_small)
    operator_tuple_small = create_small_tuple(operator_tuple, num_vertices_small, scale_all = scale_all)
    X_train_small, y_train_small = create_small_data(X_train, y_train, num_vertices_small)
    X_train_small = X_train_small.to(device)
    y_train_small = y_train_small.to(device)
    operator_tuple_small_device = (operator_tuple_small[0].to(device),operator_tuple_small[1].to(device))

    # initial evaluate monomial for training graph tuple
    M.evaluate_at_operator_tuple(operator_tuple=operator_tuple_small_device)

    # 1-layer GtNN model
    model_GNN_small = archit.OperatorFilterLayer(num_features_in = num_features_in,
                                                num_features_out = num_features_out, monomial_word_support = M)
    model_GNN_small.to(device)

    # training
    model_GNN_small, result_GNN_small, ts_loss_GNN_small, R2_ts_GNN_small = train_network(
        model_GNN_small,operator_tuple_small_device, operator_tuple_device, X_train_small, y_train_small, X_test, y_test, 
        Stable_Penalty = Stable_Penalty, n_epochs = n_epochs)
    
    model_GNN_small_all.append(model_GNN_small)
    result_GNN_small_all.append(result_GNN_small)
    ts_loss_GNN_small_all.append(ts_loss_GNN_small)
    ts_R2_GNN_small_all.append(R2_ts_GNN_small)
    

In [None]:
# save data
torch.save(result_GNN_small_all, 'result_GNN_small_all.pt')
torch.save(ts_loss_GNN_small_all, 'ts_loss_GNN_small_all.pt')
torch.save(ts_R2_GNN_small_all, 'ts_R2_GNN_small_all.pt')

### 1-layer stable GtNN

In [None]:
n_epochs = 200
Penalty_lam = 10
Stable_Penalty = True

model_stable_small_all = []
result_stable_small_all = []
ts_loss_stable_small_all = []
ts_R2_stable_small_all = []

for ind_size, num_vertices_small in enumerate(num_vertices_small_all):
  # create small size training graph and data
  print(num_vertices_small)
  operator_tuple_small = create_small_tuple(operator_tuple, num_vertices_small, scale_all = scale_all)
  X_train_small, y_train_small = create_small_data(X_train, y_train, num_vertices_small)
  X_train_small = X_train_small.to(device)
  y_train_small = y_train_small.to(device)
  operator_tuple_small_device = (operator_tuple_small[0].to(device),operator_tuple_small[1].to(device))

  # Compute expansion constant for 1-layer GtNN
  C_max_sum_small, C_j_max_sum_small = sta.compute_constrain_param(model_GNN_small_all[ind_size])
  print(C_max_sum_small)
  print(C_j_max_sum_small)
  # Set the target expansion constant for stable GtNN as 1/2 of the GtNN
  target_Upr_C_small = []
  target_Upr_Cj_vec_small = []
  for C_max_sum_each_layer in C_max_sum_small:
    target_Upr_C_small.append(C_max_sum_each_layer.data/2)
  for C_j_max_sum_each_layer in C_j_max_sum_small:
    target_Upr_Cj_vec_small.append(C_j_max_sum_each_layer.data/2)
  print("target:")
  print(target_Upr_C_small)
  print(target_Upr_Cj_vec_small)


  # initial evaluate monomial for training graph tuple
  M.evaluate_at_operator_tuple(operator_tuple=operator_tuple_small_device)

  # 1-layer stable GtNN model
  model_stable_small = archit.OperatorFilterLayer(num_features_in = num_features_in,
                                              num_features_out = num_features_out, monomial_word_support = M)
  model_stable_small.to(device)

  # training
  model_stable_small, result_stable_small, ts_loss_stable_small, R2_ts_stable_small = train_network(
      model_stable_small, operator_tuple_small_device, operator_tuple_device, X_train_small, y_train_small, X_test, y_test, 
      Penalty_lam = Penalty_lam, Stable_Penalty = Stable_Penalty, target_Upr_Cj_vec = target_Upr_Cj_vec_small, target_Upr_C = target_Upr_C_small,
      n_epochs = n_epochs)

  model_stable_small_all.append(model_stable_small)
  result_stable_small_all.append(result_stable_small)
  ts_loss_stable_small_all.append(ts_loss_stable_small)
  ts_R2_stable_small_all.append(R2_ts_stable_small)

  # print expanion constant to check if satisfy constraints
  C_max_sum_small_stable, C_j_max_sum_small_stable = sta.compute_constrain_param(model_stable_small)
  print(C_max_sum_small_stable)
  print(C_j_max_sum_small_stable)


In [None]:
# save data
torch.save(result_stable_small_all, 'result_stable_small_all.pt')
torch.save(ts_loss_stable_small_all, 'ts_loss_stable_small_all.pt')
torch.save(ts_R2_stable_small_all, 'ts_R2_stable_small_all.pt')

### plot figure

In [None]:
start_point = 0
epoch_seq_all = np.arange(start_point,n_epochs)
fig = plt.figure()
plt.rc('font',size=15)
for ts_loss in ts_loss_GNN_small_all:
    plt.plot(epoch_seq_all, ts_loss[start_point:],'-')
for ts_loss in ts_loss_stable_small_all:
    plt.plot(epoch_seq_all, ts_loss[start_point:],'--')
plt.grid()
plt.xlabel('epochs')
plt.ylabel('Testing MSE')
plt.ylim([0,0.3])
plt.legend(['GtNN:100', 'GtNN:150', 'GtNN:200', 'GtNN:250', 'GtNN:300','ST:100', 'ST:150', 'ST:200', 'ST:250', 'ST:300'])
plt.show()

fig = plt.figure()
plt.rc('font',size=15)
plt.plot(epoch_seq_all, ts_loss_GNN_small_all[0][start_point:],'-')
plt.plot(epoch_seq_all, ts_loss_stable_small_all[0][start_point:],'-')
plt.grid()
plt.xlabel('epochs')
plt.ylabel('Testing MSE')
plt.ylim([0.04,0.2])
plt.legend(['GtNN:100','Stable:100'])
plt.show()

In [None]:
start_point = 0
epoch_seq_all = np.arange(start_point,n_epochs)
fig = plt.figure()
plt.rc('font',size=15)
for ts_loss in ts_loss_GNN_small_all:
    plt.plot(epoch_seq_all, ts_loss[start_point:],'-')
plt.grid()
plt.xlabel('epochs')
plt.ylabel('Testing MSE')
plt.ylim([0.005,0.2])
plt.legend(['GtNN:100', 'GtNN:150', 'GtNN:200', 'GtNN:250', 'GtNN:300'])
plt.show()

start_point = 0
epoch_seq_all = np.arange(start_point,n_epochs)
fig = plt.figure()
plt.rc('font',size=15)
for ts_loss in ts_loss_stable_small_all:
    plt.plot(epoch_seq_all, ts_loss[start_point:],'--')
plt.grid()
plt.xlabel('epochs')
plt.ylabel('Testing MSE')
plt.ylim([0.005,0.2])
plt.legend(['Stable:100', 'Stable:150', 'Stable:200', 'Stable:250', 'Stable:300'],loc='upper left', bbox_to_anchor=(0.08, 1.0))
plt.show()

In [None]:
start_point = 100
epoch_seq_all = np.arange(start_point,n_epochs)
fig = plt.figure()
for R2_ts in ts_R2_GNN_small_all:
    plt.plot(epoch_seq_all, R2_ts[start_point:],'-')
plt.grid()
plt.xlabel('epochs')
plt.ylabel('Loss')
plt.title('Testing $R^2$')
plt.legend(['GtNN:100', 'GtNN:150', 'GtNN:200', 'GtNN:250', 'GtNN:300'])
plt.show()

start_point = 100
epoch_seq_all = np.arange(start_point,n_epochs)
fig = plt.figure()
for R2_ts in ts_R2_stable_small_all:
    plt.plot(epoch_seq_all, R2_ts[start_point:],'--')
plt.grid()
plt.xlabel('epochs')
plt.ylabel('Loss')
plt.title('Testing $R^2$')
plt.legend(['Stable:100', 'Stable:150', 'Stable:200', 'Stable:250', 'Stable:300'])
plt.show()


In [None]:
fig = plt.figure()
plt.plot(epoch_seq_all, ts_R2_GNN_small_all[0][start_point:],'-')
plt.plot(epoch_seq_all, ts_R2_stable_small_all[0][start_point:],'-')
plt.grid()
plt.xlabel('epochs')
plt.ylabel('Loss')
plt.title('Testing $R^2$')
plt.legend(['GtNN:100','Stable:100'])
plt.show()