# ▲ This is an exmple for generating sing-cell data conditioned on EHR embbeding

In [1]:
import os,sys
root = os.path.dirname(os.getcwd())
sys.path.insert(0, root)
import yaml
import numpy as np
import torch
from src.diffusion_src.models.gaussian_diffusion import GaussianDiffusion
from src.diffusion_src.models.scAttNet import SingleCellAN

## Generating data by denoising process

In [2]:
def sample_cells_chunked(model, cd_dict, gaussian_diffusion, device,
                         num_cells_total, cell_num_per_sample,
                         feature_num, output_dir):
    """
    Generate synthetic single–cell measurements in chunks for each donor,
    saving each donor’s array to a .npy file.
    
    Args:
        model: Trained SingleCellAN model.
        cd_dict: Mapping donor_id -> precomputed EHR embedding tensor.
        gaussian_diffusion: GaussianDiffusion sampler.
        loader: DataLoader over a subset of the dataset (to get donor IDs).
        device: torch.device.
        num_cells_total: Total number of cells to generate per donor.
        cell_num_per_sample: Cells generated per call to diffusion.sample().
        feature_num: Dimensionality of each cell feature vector.
        output_dir: Directory in which to save .npy files.
    """
    os.makedirs(output_dir, exist_ok=True)
    model.eval()
    seen = set()
    print("Genaration process Begin")
    with torch.no_grad():
        for did in cd_dict.keys():
            if did in seen: continue
            seen.add(did)
            ehr_emb = cd_dict[did].unsqueeze(1).to(device)
            mask    = torch.ones(ehr_emb.size(0), 1, dtype=torch.bool, device=device)
            total = []
            while len(total) < num_cells_total:
                gen = gaussian_diffusion.sample(
                    model=model,
                    batch_size=1,
                    cell_num=cell_num_per_sample,
                    dims=feature_num,
                    cd=(ehr_emb, mask)
                )
                total.extend(gen.squeeze(0).cpu().tolist())
            arr = np.array(total[:num_cells_total], dtype=np.float32)
            np.save(os.path.join(output_dir, f"{did}.npy"), arr)
            print(f"generated {did}")
        print("Genaration process Finish!")

## Load the EHR condition embbeding from contrasive pre-training

In [3]:
cd_dict = torch.load("../data/cd_dict.pt")
cd_dict_v = [v for k,v in cd_dict.items()]
emb_dim= cd_dict_v[0].shape[1]

  cd_dict = torch.load("../data/cd_dict.pt")


## Load the trained model

In [4]:
def load_cfg(path: str) -> dict:
    with open(path, 'r', encoding='utf-8') as f:
        return yaml.safe_load(f)

def strip_ddp_prefix(state_dict: dict) -> dict:
    """
    Remove the 'module.' prefix inserted by DDP from each key, so that
    the weights can be loaded into a plain nn.Module.
    Args:
        state_dict: The raw state_dict, possibly with 'module.' prefixes.
    Returns:
        A new state_dict without the 'module.' prefixes.
    """
    new_state = {}
    for k, v in state_dict.items():
        if k.startswith("module."):
            new_state[k[7:]] = v
        else:
            new_state[k] = v
    return new_state

cfg = load_cfg(os.path.join("../configs", "diffusion.yaml"))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

diff_model = SingleCellAN(
    feature_dims=cfg["model"]["feature_dims"],
    EHR_embdims=emb_dim,
    model_dims=cfg["model"]["model_dims"],
    dims_mult=tuple(cfg["model"]["dims_mult"]),
    num_res_blocks=cfg["model"]["num_res_blocks"],
    attention_resolutions=tuple(cfg["model"]["attention_resolutions"]),
    dropout=cfg["model"]["dropout"],
    dropoutAtt=cfg["model"]["dropout_att"],
    num_heads=cfg["model"]["num_heads"],
).to(device)
raw = torch.load(os.path.join("../checkpoints/diffusion_ckpt", "best_diff_model.pth"),
                 map_location=device)
diff_model.load_state_dict(strip_ddp_prefix(raw))
diff_model.eval()

  raw = torch.load(os.path.join("../checkpoints/diffusion_ckpt", "best_diff_model.pth"),


SingleCellAN(
  (proteinEmb): Sequential(
    (0): Linear(in_features=1, out_features=128, bias=True)
    (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (2): SiLU()
    (3): Linear(in_features=128, out_features=128, bias=True)
  )
  (InitEmb): Sequential(
    (0): Linear(in_features=128, out_features=128, bias=True)
    (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (2): SiLU()
    (3): Linear(in_features=128, out_features=128, bias=True)
  )
  (down_blocks): ModuleList(
    (0): TimestepEmbedSequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
    )
    (1-2): 2 x TimestepEmbedSequential(
      (0): ResidualBlock(
        (Linear1): Sequential(
          (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (1): SiLU()
          (2): Linear(in_features=128, out_features=128, bias=True)
        )
        (time_emb): Sequential(
          (0): SiLU()
          (1): Linear(in_features=512, out_features=128, bias=True)


## Generation process (The generation data file is in the output_dir)

In [5]:
gd = GaussianDiffusion()
sample_cells_chunked(
    model=diff_model,
    cd_dict=cd_dict,
    gaussian_diffusion=gd,
    device=device,
    num_cells_total=cfg["evaluation"]["num_cells_total"],
    cell_num_per_sample=cfg["evaluation"]["cell_num_per_sample"],
    feature_num=cfg["model"]["feature_dims"],
    output_dir=cfg["evaluation"]["sample_dir"]
)

Genaration process Begin
generated HPAP-050
generated HPAP-129
generated HPAP-049
generated HPAP-044
generated HPAP-131
generated HPAP-072
generated HPAP-146
generated HPAP-045
generated HPAP-056
generated HPAP-139
generated HPAP-047
generated HPAP-114
generated HPAP-135
generated HPAP-140
generated HPAP-138
generated HPAP-087
generated HPAP-130
generated HPAP-107
generated HPAP-055
generated HPAP-136
generated HPAP-137
generated HPAP-113
generated HPAP-064
generated HPAP-132
generated HPAP-122
generated HPAP-092
generated HPAP-123
generated HPAP-043
Genaration process Finish!
