In [2]:
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

import torch

from src.dynamics.masks import make_signed_mask
from src.dynamics.ground_truth import build_gt_W
from src.dynamics.simulate import simulate_trajectories, make_transition_dataset
from src.models.masked_aann import LinearAANN, SigmoidAANN
# from src.training.trainer import train_model
# from src.evaluation.metrics import mse, corr_per_protein

In [3]:
S = make_signed_mask(d=12)
print(S)
print("Num nonzeros per row:", S.count_nonzero(dim=1))
print("Unique values:", S.unique())


tensor([[ 0.,  0.,  0., -1.,  0., -1.,  0.,  0., -1.,  0.,  0.,  0.],
        [ 0.,  1.,  1.,  0., -1.,  0.,  0.,  0.,  1.,  0.,  1.,  0.],
        [ 1.,  0.,  0.,  0.,  0.,  0., -1.,  0., -1.,  0.,  0.,  0.],
        [ 0.,  0.,  1.,  1.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0., -1.,  0., -1.,  0., -1.,  0.,  0.,  0., -1., -1.],
        [ 0.,  0.,  0.,  1.,  0.,  0., -1., -1.,  0.,  0.,  0.,  1.],
        [ 0., -1.,  0.,  0.,  0., -1.,  0.,  0., -1.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  1.,  0., -1.,  0.,  0.,  0.,  0., -1.,  0.],
        [ 0.,  1.,  0.,  1., -1.,  0.,  1.,  1.,  0.,  0.,  0.,  0.],
        [ 0.,  1., -1.,  0.,  1.,  0.,  0.,  1.,  1.,  0.,  0.,  0.],
        [ 0.,  0., -1.,  1.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.],
        [ 1., -1.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0., -1.]])
Num nonzeros per row: tensor([3, 5, 3, 3, 5, 4, 3, 3, 5, 5, 3, 4])
Unique values: tensor([-1.,  0.,  1.])


In [4]:
W_eff = build_gt_W(S)
print(W_eff)

# get spectral radius
eigenvalues = torch.linalg.eigvals(W_eff)
spectral_radius = torch.max(torch.abs(eigenvalues)).item()
print("Spectral radius of W_eff:", spectral_radius)

tensor([[ 0.0000,  0.0000,  0.0000, -0.5615,  0.0000, -0.2235,  0.0000,  0.0000,
         -0.2828,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.5848,  0.2882,  0.0000, -0.1836,  0.0000,  0.0000,  0.0000,
          0.4655,  0.0000,  0.2294,  0.0000],
        [ 0.4415,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.1230,  0.0000,
         -0.4032,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.3651,  0.6046,  0.0000,  0.0000,  0.1755,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.9234,  0.0000, -0.8280,  0.0000, -0.1765,  0.0000,
          0.0000,  0.0000, -0.1593, -0.6380],
        [ 0.0000,  0.0000,  0.0000,  0.2303,  0.0000,  0.0000, -0.3180, -0.4730,
          0.0000,  0.0000,  0.0000,  0.1084],
        [ 0.0000, -0.2523,  0.0000,  0.0000,  0.0000, -0.1758,  0.0000,  0.0000,
         -0.2243,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  1.0938,  0.0000, -0.7532,  0.0000,  0.0000,
          0.0000,  0.0000, -0.

In [5]:
# simulate trajectories
trajectories = simulate_trajectories(W_eff, T=10, n_seqs=2)  # small T and n_seqs for quick test
print("Simulated trajectories (list of tensors):")
for i, traj in enumerate(trajectories):
    print(f"Trajectory {i} shape:", traj.shape)

# build transition dataset
datasets = make_transition_dataset(trajectories)
(train_X, train_Y), (val_X, val_Y), (test_X, test_Y) = datasets
print("Train set shapes:", train_X.shape, train_Y.shape)
print("Val set shapes:", val_X.shape, val_Y.shape)
print("Test set shapes:", test_X.shape, test_Y.shape)

Simulated trajectories (list of tensors):
Trajectory 0 shape: torch.Size([11, 12])
Trajectory 1 shape: torch.Size([11, 12])
Train set shapes: torch.Size([14, 12]) torch.Size([14, 12])
Val set shapes: torch.Size([3, 12]) torch.Size([3, 12])
Test set shapes: torch.Size([3, 12]) torch.Size([3, 12])


In [6]:
state_dim = 12
M = make_signed_mask(d=state_dim)
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

linear_model = LinearAANN(state_dim=state_dim, mask=M, device=device)
sigmoid_model = SigmoidAANN(state_dim=state_dim, mask=M, device=device)