In [None]:
import numpy as np
from scipy.stats import ortho_group as og
from scipy.linalg import inv
from package.Metric import shift_matrix


In [None]:
seed = 123 
np.random.seed(seed)
matrix = np.load('eg_matrix.npz')
W_ca= matrix['W']
A_ca = matrix['A']


K = 100  # number of subjects
J = 12      # number of features 
R = 4       # tensor rank 

# orthonormal basis
H = og.rvs(dim=R).T

# Create V and W matrices.
V = 5 + 5 * np.random.rand(J, R)
W = 5 + 5 * np.random.rand(K, R) 

# Initialize Q, U, and X as lists 

Q = [None] * K
U = [None] * K
X = [None] * K
S = [None] * K
U_g = [None] * K
U_ground = [None] * K
U_lag = [None] * K


# Adjust V matrix 

V[3:J, 0] = 0
V[0:3, 1] = 0
V[6:J, 1] = 0
V[0:6, 2] = 0
V[9:J, 2] = 0
V[0:9, 3] = 0


#  Irregular Tensor 
for k in range(K):
    I_k = np.random.randint(10, 21) # meeting size

    Mi = shift_matrix(I_k, 1)  # Shift matrix for lag of 2
    
    col_Q_k = np.random.randint(1, R+1, size=I_k)
    Q_k = np.zeros((I_k, R))

    for r in range(R):
        Q_k[:, r] = (col_Q_k == (r + 1)).astype(float)
    
    # Normalize each column of Q_k
    norms = np.linalg.norm(Q_k, axis=0, keepdims=True)
    # Prevent division by zero if a column is all zeros.
    norms[norms == 0] = 1

    # Q 
    Q_k = Q_k / norms
    Q[k] = Q_k

    # U_k = Q_k * H
    U_k = Q_k @ H
    U[k] = U_k

    data = U[k] @ np.diag(W[k, :]) # US
    data1 = Mi @ data # time lag of US
   
    X_k_p = data @ inv(np.eye(4) - W_ca) + data1 @ A_ca # add causal
    X_k = X_k_p @ V.T 
    U_lag_data = Mi @ X_k_p

    S[k] = np.diag(W[k, :])
    U_g[k] = X_k_p
    U_lag[k] = U_lag_data
    U_ground[k] = X_k_p @ inv(S[k])

    # # Add noise
    gNoise = 0.1 * np.random.rand(I_k, J)
    X[k] = X_k 
    

X_arr = np.empty(len(X), dtype=object)
S_arr = np.empty(len(S), dtype=object)
U_arr = np.empty(len(U), dtype=object)
U_ground_arr = np.empty(len(U_ground), dtype=object)

for k in range(len(X)):
    X_arr[k] = X[k]
    S_arr[k] = S[k]
    U_arr[k] = U[k]
    U_ground_arr[k] = U_ground[k]


# store the data X, W_1, A_1 
np.savez('data_eg.npz', X=X_arr, W_1=W_ca, A_1=A_ca, V = V, S=S_arr, U=U_arr, U_ground=U_ground_arr,H = H)
