In [1]:
import os
import sys

import torch
import torch.optim as optim
import os.path as osp

sys.path.append(os.path.abspath(".."))

In [2]:
from CellGraphX.configs.config import Config

In [3]:
from CellGraphX.training.trainer import Trainer

In [4]:
from CellGraphX.configs.config import Config
from CellGraphX.data.data_loader import data_loader, edge_weight_dict_loader
from CellGraphX.evaluation.tsne_vis import TSNE_vis
from CellGraphX.models.model import model_builder
from CellGraphX.training.optimizer import build_optimizer

  from scipy.sparse import csr_matrix, issparse


In [5]:
# Load configuration
print("Loading configuration...", flush=True)

config = Config()
# Initialize the model
print("Initializing model...", flush=True)
model = model_builder(config=config.model)

Loading configuration...
Initializing model...


In [6]:
model

HeteroGNN(
  (convs): ModuleList(
    (0-1): 2 x HeteroConv(num_relations=4)
  )
  (lin): Linear(64, 24, bias=True)
)

In [7]:
config

<CellGraphX.configs.config.Config at 0x7f44b13e7230>

In [8]:
data_loader

<function CellGraphX.data.data_loader.data_loader(config)>

In [9]:
# Prepare the dataset and dataloaders
print("Loading data...")

Loading data...


In [10]:
data = data_loader(config=config.data)

Loading data from /nfs/research/irene/ysong/RESULTS/GeneCellType/GeneSpectra/MTG/CellGraphX/CellGraphX/data/mtg_all_sp_wilcox_data_with_og_ct_name.pt ...


  data = torch.load(data_path)


Split train val and test species
This is the HeteroData you are working with:
HeteroData(
  species_origin_index={
    0='H.sapiens',
    1='H.sapiens',
    2='H.sapiens',
    3='H.sapiens',
    4='H.sapiens',
    5='H.sapiens',
    6='H.sapiens',
    7='H.sapiens',
    8='H.sapiens',
    9='H.sapiens',
    10='H.sapiens',
    11='H.sapiens',
    12='H.sapiens',
    13='H.sapiens',
    14='H.sapiens',
    15='H.sapiens',
    16='H.sapiens',
    17='H.sapiens',
    18='H.sapiens',
    19='H.sapiens',
    20='H.sapiens',
    21='H.sapiens',
    22='H.sapiens',
    23='H.sapiens',
    24='M.mulatta',
    25='M.mulatta',
    26='M.mulatta',
    27='M.mulatta',
    28='M.mulatta',
    29='M.mulatta',
    30='M.mulatta',
    31='M.mulatta',
    32='M.mulatta',
    33='M.mulatta',
    34='M.mulatta',
    35='M.mulatta',
    36='M.mulatta',
    37='M.mulatta',
    38='M.mulatta',
    39='M.mulatta',
    40='M.mulatta',
    41='M.mulatta',
    42='M.mulatta',
    43='M.mulatta',
    44='M.mulat

In [11]:
data['cell_type'].val_mask

tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False])

In [12]:
edge_weight_dict = edge_weight_dict_loader(data) if config.model.edge_weight else None

# Optimizer and learning rate scheduler
optimizer = build_optimizer(model, config.training)

# Scheduler (optional, for learning rate decay)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
# Initialize the trainer with the model, data, optimizer, and config

In [13]:
edge_weight_dict

{('gene',
  'is_wilcox_marker_of',
  'cell_type'): tensor([4.0142, 4.0404, 4.0512,  ..., 6.7956, 7.1142, 7.1893]),
 ('cell_type',
  'rev_is_wilcox_marker_of',
  'gene'): tensor([4.0142, 4.0404, 4.0512,  ..., 6.7956, 7.1142, 7.1893])}

In [14]:
optimizer

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 0.001
    lr: 0.001
    maximize: False
    weight_decay: 1e-05
)

In [15]:
scheduler

<torch.optim.lr_scheduler.StepLR at 0x7f44b18ba7e0>

In [16]:
print("Initializing trainer...", flush=True)
trainer = Trainer(
    model=model,
    data=data,
    scheduler=scheduler,
    config=config,
    edge_weight_dict=edge_weight_dict,
    # optimizer, loss func, logdir build from config
)

Initializing trainer...


In [17]:
trainer

<CellGraphX.training.trainer.Trainer at 0x7f45a675b290>

In [None]:
# Start the training
print("Starting training...", flush=True)
trainer.train(epochs=config.training.num_epochs)

print("Training completed with model saved.", flush=True)

Starting training...


2025-08-27 14:26:45,328 - INFO - Epoch 1/10
2025-08-27 14:27:28,208 - INFO - Train Loss: 118.9376 | Train Acc: 0.0417 | Val Acc: 0.0417 | Test Acc: 0.0417
2025-08-27 14:27:28,738 - INFO - Checkpoint saved to ./logs/best.pt
2025-08-27 14:27:29,253 - INFO - Checkpoint saved to ./logs/latest.pt
2025-08-27 14:27:29,255 - INFO - Epoch 2/10
2025-08-27 14:28:11,654 - INFO - Train Loss: 62.0057 | Train Acc: 0.0833 | Val Acc: 0.0417 | Test Acc: 0.0417
2025-08-27 14:28:12,124 - INFO - Checkpoint saved to ./logs/latest.pt
2025-08-27 14:28:12,126 - INFO - Epoch 3/10
2025-08-27 14:28:54,666 - INFO - Train Loss: 41.9153 | Train Acc: 0.1111 | Val Acc: 0.0417 | Test Acc: 0.0833
2025-08-27 14:28:55,185 - INFO - Checkpoint saved to ./logs/latest.pt
2025-08-27 14:28:55,186 - INFO - Epoch 4/10
2025-08-27 14:29:38,252 - INFO - Train Loss: 29.0866 | Train Acc: 0.2083 | Val Acc: 0.1667 | Test Acc: 0.1250
2025-08-27 14:29:38,733 - INFO - Checkpoint saved to ./logs/best.pt
2025-08-27 14:29:39,246 - INFO - Chec

In [None]:
# t-SNE visualization
vis = TSNE_vis(model_path=f"{trainer.log_dir}/best.pt", data=data)

In [None]:
vis.model_path

In [None]:
config = Config()
print("Initializing model...", flush=True)
model = model_builder(config=config.model)

In [None]:
model

In [None]:
print("Loading trainned model state dict...", flush=True)
if osp.isfile(vis.model_path):
    print(f"Loading model from {vis.model_path}")
else:
    raise FileNotFoundError(f"No model file found at {vis.model_path}")

In [None]:
model.load_state_dict(
    torch.load(vis.model_path, map_location=torch.device("cpu"))["model_state_dict"]
)

In [None]:
model.convs[0].parameters()

In [None]:
vis.plot_tsne(
        vis.get_tsne_embed(vis.load_model()),
        out_path=f"{trainer.log_dir}/tsne_plot.png",
    )