In [2]:
import torch
import pathlib
from models.graph_learning import HiPoNet, MLPAutoEncoder
from utils.read_data import load_data

In [3]:
slurm_job_id = "1779231"
weights_loc = pathlib.Path(f"model_weights/{slurm_job_id}")

In [4]:
device = "cpu"
PC_gene, PC_spatial = load_data("data/sea", "")
model_spatial = (
    HiPoNet(
        dimension=PC_spatial[0].shape[1],
        n_weights=1,
        threshold=0.15,
        K=1,
        J=3,
        device="cpu",
        sigma=10,
        pooling=False,
    )
    .to(device)
    .float()
)
model_gene = (
    HiPoNet(
        dimension=PC_gene[0].shape[1],
        n_weights=1,
        threshold=0.15,
        K=1,
        J=3,
        device="cpu",
        sigma=10,
        pooling=False,
    )
    .to(device)
    .float()
)
with torch.no_grad():
    input_dim = (
        model_spatial(
            PC_spatial[0][:5].unsqueeze(0).to(device),
            torch.zeros((1, 5), dtype=torch.bool).to(device),
        ).shape[1]
        + model_gene(
            PC_gene[0][:5].unsqueeze(0).to(device),
            torch.zeros((1, 5), dtype=torch.bool).to(device),
        ).shape[1]
    )
autoencoder = MLPAutoEncoder(
    input_dim, 256, 4, 3, bn=False
)

In [9]:
train_test_split = torch.load(weights_loc / "split_idx.pt")
(
    autoencoder.load_state_dict(
        torch.load(weights_loc / "autoenc.pt", map_location=device)
    ),
    model_gene.load_state_dict(
        torch.load(weights_loc / "model_gene.pt", map_location=device)
    ),
    model_spatial.load_state_dict(
        torch.load(weights_loc / "model_spatial.pt", map_location=device)
    ),
)

(<All keys matched successfully>,
 <All keys matched successfully>,
 <All keys matched successfully>)