In [1]:
from chemCPA.notebooks_utils import load_dataset, load_smiles
from chemCPA.embedding import get_chemical_representation
from chemCPA.model import ComPert
import yaml
import torch

In [2]:
def load_model(config, canon_smiles_unique_sorted, checkpoint_path):
    model_checkp = checkpoint_path

    embedding_model = config["model"]["embedding"]["model"]
    if embedding_model == "vanilla":
        embedding = None
    else:
        embedding = get_chemical_representation(
            smiles=canon_smiles_unique_sorted,
            embedding_model=config["model"]["embedding"]["model"],
            data_dir=config["model"]["embedding"]["directory"],
            device="cpu",
        )
    dumped_model = torch.load(model_checkp, map_location=torch.device('cpu'))
    if len(dumped_model) == 3:
        print("This model does not contain the covariate embeddings or adversaries.")
        state_dict, init_args, history = dumped_model
        COV_EMB_AVAILABLE = False
    elif len(dumped_model) == 4:
        print("This model does not contain the covariate embeddings.")
        state_dict, cov_adv_state_dicts, init_args, history = dumped_model
        COV_EMB_AVAILABLE = False
    elif len(dumped_model) == 5:
        (
            state_dict,
            cov_adv_state_dicts,
            cov_emb_state_dicts,
            init_args,
            history,
        ) = dumped_model
        COV_EMB_AVAILABLE = True
        assert len(cov_emb_state_dicts) == 1
    append_layer_width = (
        config["dataset"]["n_vars"]
        if (config["model"]["append_ae_layer"] and config["model"]["load_pretrained"])
        else None
    )

    if embedding_model != "vanilla":
        state_dict.pop("drug_embeddings.weight")
    model = ComPert(
        **init_args, drug_embeddings=embedding, append_layer_width=append_layer_width, device="cpu"
    )
    model = model.eval()
    if COV_EMB_AVAILABLE:
        for embedding_cov, state_dict_cov in zip(
            model.covariates_embeddings, cov_emb_state_dicts
        ):
            embedding_cov.load_state_dict(state_dict_cov)

    incomp_keys = model.load_state_dict(state_dict, strict=False)
    if embedding_model == "vanilla":
        assert (
            len(incomp_keys.unexpected_keys) == 0 and len(incomp_keys.missing_keys) == 0
        )
    else:
        # make sure we didn't accidentally load the embedding from the state_dict
        torch.testing.assert_allclose(model.drug_embeddings.weight, embedding.weight)
        assert (
            len(incomp_keys.missing_keys) == 1
            and "drug_embeddings.weight" in incomp_keys.missing_keys
        ), incomp_keys.missing_keys
        # assert len(incomp_keys.unexpected_keys) == 0, incomp_keys.unexpected_keys

    return model, embedding

In [3]:
with open(
    "/u/adr/Code/enlight/enlight/OT_analysis/chemCPA/configs/chemCPA_on_cmonge_sciplex_9drugs_ood.yaml",
    "r",
) as f:
    config = yaml.safe_load(f)

In [4]:
dataset, key_dict = load_dataset(config)
canon_smiles_unique_sorted, smiles_to_pathway_map, smiles_to_drug_map = load_smiles(config, dataset, key_dict, True)

In [None]:
model_pretrained, embedding_pretrained = load_model(config, 
                                                    canon_smiles_unique_sorted, 
                                                    checkpoint_path="/path/to/checkpint/chemCPA/chemCPA_on_cmonge_sciplex/9drugs_ood/model_checkpoint.pt")

[32m2025-02-21 10:51:04.435[0m | [1mINFO    [0m | [36mchemCPA.model[0m:[36m__init__[0m:[36m327[0m - [1m{'adversary_depth': 2, 'adversary_lr': 0.0011926173789223548, 'adversary_steps': 2, 'adversary_wd': 9.846738873614555e-06, 'adversary_width': 128, 'autoencoder_depth': 4, 'autoencoder_lr': 6.251373574521742e-07, 'autoencoder_wd': 1.5702970884055366e-06, 'autoencoder_width': 256, 'batch_size': 32, 'dim': 32, 'dosers_depth': 3, 'dosers_lr': 0.0015751320499779737, 'dosers_wd': 6.251373574521742e-07, 'dosers_width': 64, 'dropout': 0.262378, 'embedding_encoder_depth': 4, 'embedding_encoder_width': 128, 'penalty_adversary': 0.4550475813202185, 'reg_adversary': 9.100951626404369, 'reg_adversary_cov': 26.975154833351137, 'reg_multi_task': 10.675229346653245, 'step_size_lr': 50} is a dict[0m
[32m2025-02-21 10:51:04.440[0m | [1mINFO    [0m | [36mchemCPA.model[0m:[36m__init__[0m:[36m354[0m - [1mEncoder: MLP(
  (network): Sequential(
    (0): Linear(in_features=1000, out_fe

In [6]:
model_pretrained

ComPert(
  (encoder): MLP(
    (network): Sequential(
      (0): Linear(in_features=1000, out_features=256, bias=True)
      (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Linear(in_features=256, out_features=256, bias=True)
      (4): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
      (6): Linear(in_features=256, out_features=256, bias=True)
      (7): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): ReLU()
      (9): Linear(in_features=256, out_features=256, bias=True)
      (10): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (11): ReLU()
      (12): Linear(in_features=256, out_features=32, bias=True)
    )
  )
  (decoder): MLP(
    (relu): ReLU()
    (network): Sequential(
      (0): Linear(in_features=32, out_features=256, bias=True)
      (1): BatchNorm1d(256, eps=1e-05, momentu

In [7]:
pytorch_total_params = sum(p.numel() for p in model_pretrained.parameters())

In [8]:
pytorch_total_params

1368709

In [9]:
1374742 - pytorch_total_params

6033