In [35]:
import os
import rootutils

rootutils.setup_root(os.getcwd(), indicator=".project-root", pythonpath=True)

from torchvision import transforms
import torch
from src.data.components.graphs_datamodules import (
    IMCBaseDictTransform,
    PickleDataset,
    PatchAugmentations,
)
from src.data.imc_datamodule import add_channel
import src.data.components.graphs_datamodules as gd
from torchvision.datasets import MNIST
from torch.utils.data import ConcatDataset, Dataset, random_split

In [36]:
from pathlib import Path

In [7]:
from src.models.components.plot import restore_tensor
import matplotlib.pyplot as plt

base_transforms = IMCBaseDictTransform()

aug_transforms_train = gd.PatchAugmentations(
    prob=1.0,
    size=13,
    patch_size=1,
)

aug_transforms_val = gd.PatchAugmentations(
    prob=1.0,
    size=13,
    patch_size=1,
    is_validation=True,
)

dual_transforms_train = gd.DualOutputTransform(base_transforms, aug_transforms_train)

dual_transforms_val = gd.DualOutputTransform(base_transforms, aug_transforms_val)

train_path = Path("../data") / 'IMC-sample' / 'train.h5'
test_path = Path("../data") / 'IMC-sample' / 'test.h5'
trainset = PickleDataset(train_path, transform=dual_transforms_train)
testset = PickleDataset(train_path, transform=dual_transforms_val)
train_ratio, val_ratio, test_ratio, leftover_ratio = [3600, 1044, 0, 0]
size_testset = len(testset)
size_trainset = len(trainset)
data_train, _ = random_split(
    dataset=trainset,
    lengths=[train_ratio, size_trainset - train_ratio],
    generator=torch.Generator().manual_seed(42),
)
# dataset = ConcatDataset(datasets=[trainset, testset])
data_val, data_test, _ = random_split(
    dataset=testset,
    lengths=[val_ratio, test_ratio, size_testset - val_ratio - test_ratio],
    generator=torch.Generator().manual_seed(42),
)

train_dataset = gd.GridGraphDataset(grid_size=13, dataset=data_train, channels=list(range(10)))

train_loader = gd.DenseGraphDataLoader(
    dataset=train_dataset,
    batch_size=8,
    num_workers=7,
    pin_memory=False,
    persistent_workers=7 > 0,
)

In [8]:
for el in train_loader:
    print(el)
    break

<src.data.components.graphs_datamodules.DenseGraphBatch object at 0x1108be120>


In [9]:
el.node_features.shape

torch.Size([64, 169, 10])

In [1]:
# Import necessary modules for PLGraphAE
import os
import rootutils

rootutils.setup_root(os.getcwd(), indicator=".project-root", pythonpath=True)

import torch
import lightning as L
from omegaconf import DictConfig, OmegaConf
from src.models.pigvae_auto_module import PLGraphAE
from src.models.components.modules import GraphAE
from src.models.components.schedulers import TemperatureScheduler, EntropyWeightScheduler
from src.models.components.model import Critic
from src.data.components.graphs_datamodules import DenseGraphBatch


In [28]:
import yaml

# Path to your YAML config file
config_path = "../configs/model/model.yaml"

# Load the YAML file
with open(config_path, "r") as f:
    model_config_raw = yaml.safe_load(f)

print("Converting configurations to DictConfig...")

# Register custom OmegaConf resolvers for multiply and divide
from omegaconf import OmegaConf

# Register multiply resolver if not already registered
if not OmegaConf.has_resolver("multiply"):
    OmegaConf.register_new_resolver("multiply", lambda x, y: x * y)

# Register divide resolver if not already registered
if not OmegaConf.has_resolver("divide"):
    OmegaConf.register_new_resolver("divide", lambda x, y: x // y)

# Wrap the config under 'model' and 'data' keys so interpolations work
# Also provide the missing data.hparams.num_aug_per_sample value
full_config = OmegaConf.create({
    "model": model_config_raw,
    "data": {
        "hparams": {
            "num_aug_per_sample": 8,
            "batch_size": 8
        }
    },
    "trainer": {
        "max_epochs": 500,  # Set these to match your training config
        "min_epochs": 1
    }
})

# Now resolve all interpolations
OmegaConf.resolve(full_config)

print("Creating GraphAE from resolved config...")
graph_ae = GraphAE(hparams=full_config.model.graph_ae.hparams)
print("GraphAE created successfully!")

print("Creating schedulers...")
temperature_scheduler = TemperatureScheduler(hparams=full_config.model.temperature_scheduler.hparams)
entropy_weight_scheduler = EntropyWeightScheduler(hparams=full_config.model.entropy_weight_scheduler.hparams)

print("Creating critic...")
critic = Critic(hparams=full_config.model.critic.hparams)

print("Creating optimizer (dummy for loading)...")
# Create a dummy optimizer for loading (it will be overridden by the checkpoint)
optimizer = torch.optim.Adam(graph_ae.parameters(), lr=0.0001)

print("Creating scheduler (dummy for loading)...")
# Create a dummy scheduler for loading
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)

print("All components created successfully!")


Converting configurations to DictConfig...
Creating GraphAE from resolved config...
GraphAE created successfully!
Creating schedulers...
Creating critic...
Creating optimizer (dummy for loading)...
Creating scheduler (dummy for loading)...
All components created successfully!


In [29]:
# Create the PLGraphAE model
print("Creating PLGraphAE model...")
model = PLGraphAE(
    graph_ae=graph_ae,
    critic=critic,
    temperature_scheduler=temperature_scheduler,
    entropy_weight_scheduler=entropy_weight_scheduler,
    optimizer=optimizer,
    scheduler=scheduler,
    compile=False
)

print("PLGraphAE model created successfully!")


Creating PLGraphAE model...


/Users/tomasznocon/Documents/MIM/Repositories/Master thesis/immuvis/.venv/lib/python3.12/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'critic' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['critic'])`.
/Users/tomasznocon/Documents/MIM/Repositories/Master thesis/immuvis/.venv/lib/python3.12/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'temperature_scheduler' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['temperature_scheduler'])`.
/Users/tomasznocon/Documents/MIM/Repositories/Master thesis/immuvis/.venv/lib/python3.12/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'entropy_weight_scheduler' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparame

PLGraphAE model created successfully!


/Users/tomasznocon/Documents/MIM/Repositories/Master thesis/immuvis/.venv/lib/python3.12/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'graph_ae' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['graph_ae'])`.


In [31]:
# Load the checkpoint
checkpoint_path = "../logs/train/runs/2025-10-14_23-57-01/checkpoints/last.ckpt"



print(f"Loading checkpoint from: {checkpoint_path}")
print("This may take a moment...")

# Load the checkpoint with weights_only=False since this is a Lightning checkpoint
# that contains more than just weights (it's safe since this is our own trained model)
checkpoint = torch.load(checkpoint_path, map_location='mps', weights_only=False)
print("Checkpoint loaded successfully!")

# Load the state dict into the model
model.load_state_dict(checkpoint['state_dict'])
print("Model weights loaded successfully!")

# Set model to evaluation mode
model.eval()
print("Model is ready for inference!")


Loading checkpoint from: ../logs/train/runs/2025-10-14_23-57-01/checkpoints/last.ckpt
This may take a moment...
Checkpoint loaded successfully!
Model weights loaded successfully!
Model is ready for inference!


In [32]:
# Optional: Move model to GPU if available
device = 'mps'
model = model.to(device)
print(f"Model moved to device: {device}")

# Display model summary
print("\nModel loaded successfully!")
print(f"- Checkpoint epoch: {checkpoint.get('epoch', 'Unknown')}")
print(f"- Global step: {checkpoint.get('global_step', 'Unknown')}")
print(f"- Model device: {next(model.parameters()).device}")
print(f"- Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"- Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")


Model moved to device: mps

Model loaded successfully!
- Checkpoint epoch: 499
- Global step: 510
- Model device: mps:0
- Model parameters: 18,023,049
- Trainable parameters: 18,022,921


In [38]:
from src.models.components.plot import restore_tensor
import matplotlib.pyplot as plt

base_transforms = IMCBaseDictTransform()

aug_transforms_train = gd.PatchAugmentations(
    prob=1.0,
    size=13,
    patch_size=1,
)

aug_transforms_val = gd.PatchAugmentations(
    prob=1.0,
    size=13,
    patch_size=1,
    is_validation=True,
)

dual_transforms_train = gd.DualOutputTransform(base_transforms, aug_transforms_train)

dual_transforms_val = gd.DualOutputTransform(base_transforms, aug_transforms_val)

train_path = Path("../data") / 'IMC-sample' / 'train.h5'
test_path = Path("../data") / 'IMC-sample' / 'test.h5'
trainset = PickleDataset(train_path, transform=dual_transforms_train)
testset = PickleDataset(train_path, transform=dual_transforms_val)
train_ratio, val_ratio, test_ratio, leftover_ratio = [3600, 1044, 0, 0]
size_testset = len(testset)
size_trainset = len(trainset)
data_train, _ = random_split(
    dataset=trainset,
    lengths=[train_ratio, size_trainset - train_ratio],
    generator=torch.Generator().manual_seed(42),
)
# dataset = ConcatDataset(datasets=[trainset, testset])
data_val, data_test, _ = random_split(
    dataset=testset,
    lengths=[val_ratio, test_ratio, size_testset - val_ratio - test_ratio],
    generator=torch.Generator().manual_seed(42),
)

train_dataset = gd.GridGraphDataset(grid_size=13, dataset=data_train, channels=list(range(10)))

train_loader = gd.DenseGraphDataLoader(
    dataset=train_dataset,
    batch_size=8,
    num_workers=7,
    pin_memory=False,
    persistent_workers=7 > 0,
)

In [40]:
for el in train_loader:
    break

In [41]:
el = el.to('mps')

In [42]:
model.eval()
with torch.no_grad():
    graph_emb, graph_pred, soft_probs, perm, mu, logvar = model(el, training=False, tau=1.0)

RuntimeError: linear(): input and weight.T shapes cannot be multiplied (169x10 and 1x512)