# RITINI: Inferring Dynamic Regulatory Interaction Graphs from Time Series Data with Perturbations

Prerequisites:
- Trained MIOFlow and decoded trajectories back to gene space.

In this notebook we will:
- Run RITINI to infer gene dynamics in gene regulatory networks

In [1]:
# Standard library imports
import warnings
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import torch

from torch.optim.lr_scheduler import StepLR


# Suppress specific warnings
warnings.filterwarnings(
    "ignore",
    category=FutureWarning,
    message=".*unique with argument that is not not a Series.*"
)

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

seed = 3
torch.manual_seed(seed)
np.random.seed(seed)

# Load and Preprocess Dataset

Load the dataset (output from MIOFlow) of shape ((n_timepoints, n_trajectrories, n_genes))

In [3]:
import pickle
"""
Load MIOFlow inferred trajectories.
Shape (n_timepoints, n_trajectrories, n_genes)
"""

with open('../data/trajectories/traj_data.pkl', 'rb') as f:
    traj_data = pickle.load(f)
trajectories = traj_data
"""
Load gene names for plotting purposes
Shape (n_genes,)
"""

with open('../data/trajectories/genes_data.pkl', 'rb') as f:
    genes = pickle.load(f)
genes = pd.Series(genes).astype(str) \
        .str.replace(r'\s*\(ENSG[^\)]*\)', '', regex=True) \
        .to_numpy()
"""
Annotations are clusters of trajectories for a specific lineage. Here we assume just one set of trajectories.
All labeled as 0.
"""
annotations = np.zeros(trajectories.shape[0])

"""
For simplicity, we will just train on the mean trajectory.
"""
mean_trajectories = trajectories.mean(axis=1, keepdims=True)

traj_data = {
    'trajectories': mean_trajectories,
    'genes': genes,
    'annotations': annotations
}


with open('../data/trajectories/cancer_granger_prior_graph_nx_20.pkl', 'rb') as f:
    G = pickle.load(f)

with open('../data/trajectories/cancer_granger_prior_top_genes_20.pkl', 'rb') as f:
    top_genes = pickle.load(f)


In [4]:
"""
Create the adjacency matrix.
"""
adjacency_matrix = nx.adjacency_matrix(G)

edges = list(G.edges())
if len(edges) > 0:
    u, v = np.array(edges).T
    u = torch.tensor(u, dtype=torch.int32)
    v = torch.tensor(v, dtype=torch.int32)
else:
    u = v = torch.tensor([], dtype=torch.int32)

In [5]:
G

<networkx.classes.digraph.DiGraph at 0x13750ae90>

In [6]:
ref_pos = nx.spring_layout(G.to_undirected(), seed=seed)
top_genes_for_plotting = top_genes[:10]
# Color and label
for idx, node in enumerate(G.nodes()):
    G.nodes[node]['color'] = plt.get_cmap('viridis', len(top_genes))(idx)
    G.nodes[node]['label'] = top_genes[idx]

In [7]:
# """
# """
# # Mapping for edge ids
edge_ids = np.arange(G.number_of_edges())

# # Shuffle
edge_ids = np.random.permutation(edge_ids)

test_size_percent = 30
test_size_fraction = test_size_percent / 100

edge_test_size = int(len(edge_ids) * test_size_fraction)


In [8]:
top_genes = [gene.replace('_y', '') for gene in top_genes]

# Train RiTINI
- Saves plots of the ground truth dynamics vs predicted dynamics

In [9]:
seen = set()
gene_subset_indices = []
for i, k in enumerate(traj_data['genes']):
    if k in top_genes and k not in seen:
        seen.add(k)
        gene_subset_indices.append(i)
gene_subset_indices = np.array(gene_subset_indices)

cell_subset_indices = np.random.choice(traj_data['trajectories'].shape[1], traj_data['trajectories'].shape[1], replace=False)

In [10]:
traj_data['genes'].shape

(21465,)

In [11]:
trajs = traj_data['trajectories']
trajs = trajs[::3]
trajs = trajs[:, cell_subset_indices]
trajs = trajs[:, :, gene_subset_indices]
traj_f = trajs.reshape(-1, trajs.shape[2])

Here we construct pseudotime and cell type annotations for training the dynamic graph model RiTINI

In [12]:
pseudotimes = np.linspace(0, 1, trajs.shape[0])

In [13]:
annot_repeated = np.repeat(traj_data['annotations'][cell_subset_indices], trajs.shape[0])
pt_repeated = np.tile(pseudotimes, trajs.shape[1])
df = pd.DataFrame(traj_f, columns=top_genes, index=[f'cell_{i}' for i in range(traj_f.shape[0])])
df['pseudotime'] = pt_repeated

df['cell_types'] = [f'cell_type_{a}' for a in annot_repeated]
num_cell_types = len(df['cell_types'].unique())

In [14]:

df_train = df.sample(frac=85/100)
df_test = df.loc[~df.index.isin(df_train.index)]

In [15]:
n_cells_at_t = df['pseudotime'].value_counts()[0]

time_bins = np.sort(df.pseudotime.unique())
cell_types = np.sort(df.cell_types.unique())

t0, *_, tn = time_bins
time_tensor = torch.Tensor(time_bins)#.to(device)

in_feats = cell_types.size * n_cells_at_t
out_feats = cell_types.size * n_cells_at_t
df_train[top_genes] = StandardScaler().fit_transform(df_train[top_genes])


We now initialize the RITINI module and define the training hyperparameters, which can be adjusted based on experimental needs.

# Setting up RiTINI

In [None]:
config = {
    # Model configuration
    'model_type': 'gcn',  # or 'gat', 'gde'
    'in_feats': in_feats,  # cell_types.size *  n_cells_at_t
    'out_feats': out_feats,  # cell_types.size * n_cells_at_t
    'device': 'cpu',  # or 'cuda' if available

    # Training parameters
    'epochs': 100,
    'learning_rate': 0.1,
    'weight_decay': 5e-4,
    'step_size': 350,
    'gamma': 0.1,

    # Training loop parameters
    'steps': 100,
    'verbose_step': 1,
    'link_step': 2,
    'add_n': 5,
    'del_n': 5,
    'sample_size': 10,
    'lambda_l1': 10,

    # Data parameters
    'n_cells_at_t': n_cells_at_t,
    'num_cell_types': num_cell_types,
    'test_size_percent': 30,

    # Paths
    'data_dir': 'data/trajectories',
    'output_path': 'outputs'
}

In [34]:
import RiTINI
model = RiTINI.RiTINI(config=config)

In [32]:
# Train the model
model.train(
    data_path="single_cell_data.csv",
    model_type="gat",
    epochs=100,
    learning_rate=0.001
)

2025-09-26 13:41:21,826 - RiTINI.ritini - INFO - Starting training with model: gat
2025-09-26 13:41:21,828 - RiTINI.ritini - INFO - Loading training data...
2025-09-26 13:41:21,829 - RiTINI.ritini - INFO - Initializing gat model...
2025-09-26 13:41:21,829 - RiTINI.ritini - INFO - Starting training...
2025-09-26 13:41:21,830 - RiTINI.ritini - INFO - Training completed successfully!


In [25]:
import RiTINI 
# Create RITINI instance first
graph_trainer = RiTINI.RiTINI(G, in_feats, out_feats, device)
# Initialize the model
model = graph_trainer.model
device = 'cpu'
model = model.to(device)
# Call the train_test method on the instance
train_g = graph_trainer.train_test(
    edge_ids, edge_test_size
)

TypeError: RiTINI.__init__() takes from 1 to 2 positional arguments but 5 were given

In [20]:
"""
Hyperparameters for training
"""
optimizer = torch.optim.AdamW(model.parameters(), lr=0.1, weight_decay=5e-4)
scheduler = StepLR(optimizer, step_size=350, gamma=0.1)
criterion = torch.nn.MSELoss()

steps = 100
verbose_step = 1

lambda_l1 = 10
add_n = 5
del_n = 5
link_step = 2
sample_size = 10

Load the model and train RITINI, which will generate plots of predicted vs. ground truth gene expression dynamics, saved in the Results/ folder.

In [21]:
graph_trainer.train_loop(
        model, optimizer, scheduler, criterion, top_genes,
        train_g, df_train, n_cells_at_t, time_bins, steps, link_step, add_n, del_n,
        verbose_step, num_cell_types, cell_types, ref_pos, ref_g, DATA_DIR='/content')

Dynamics Prediction loss: 0.27714046835899353
Dynamics Prediction loss: 0.27679741382598877
Dynamics Prediction loss: 0.27669191360473633
Dynamics Prediction loss: 0.2766408920288086
Dynamics Prediction loss: 0.27660971879959106
Dynamics Prediction loss: 0.2765827476978302
Dynamics Prediction loss: 0.2765650153160095
Dynamics Prediction loss: 0.27655354142189026
Dynamics Prediction loss: 0.2765420079231262
Dynamics Prediction loss: 0.27653568983078003
Dynamics Prediction loss: 0.2765243649482727
Dynamics Prediction loss: 0.27652519941329956
Dynamics Prediction loss: 0.2765128016471863
Dynamics Prediction loss: 0.27651914954185486
Dynamics Prediction loss: 0.2765093445777893
Dynamics Prediction loss: 0.2765124440193176
Dynamics Prediction loss: 0.27650386095046997
Dynamics Prediction loss: 0.27650687098503113
Dynamics Prediction loss: 0.27649930119514465
Dynamics Prediction loss: 0.2765020728111267
Dynamics Prediction loss: 0.2764952778816223
Dynamics Prediction loss: 0.2764977812767029

In [22]:
graph_trainer.__dir__()

['g',
 'in_feats',
 'out_feats',
 'device',
 'model',
 '__module__',
 '__init__',
 '_build_model',
 'train_test',
 'train_loop',
 'test_loop',
 '__dict__',
 '__weakref__',
 '__doc__',
 '__new__',
 '__repr__',
 '__hash__',
 '__str__',
 '__getattribute__',
 '__setattr__',
 '__delattr__',
 '__lt__',
 '__le__',
 '__eq__',
 '__ne__',
 '__gt__',
 '__ge__',
 '__reduce_ex__',
 '__reduce__',
 '__getstate__',
 '__subclasshook__',
 '__init_subclass__',
 '__format__',
 '__sizeof__',
 '__dir__',
 '__class__']