In [1]:
import torch
import pandas as pd
import numpy as np
import scanpy as sc
import sys
sys.path.append("/path/to/wave")
from model import WAVE  
from utils import morgan_fp  

In [2]:
class WAVEPredictor:
    def __init__(self, model_path, device="cpu"):
        self.device = device
        self.model = self._load_model(model_path, device)

    def _load_model(self, model_path, device):
        """Load the pretrained WAVE model."""
        model = WAVE().to(device)
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.eval()
        return model

    def predict(self, unpert_expr, smiles_list):
        self.model.eval()

        with torch.no_grad():
            # Convert inputs to tensors
            unpert_expr_tensor = torch.tensor(unpert_expr, dtype=torch.float32).to(self.device)
            drug_fps = [torch.tensor(morgan_fp(s), dtype=torch.float32) for s in smiles_list]
            drug_fps_tensor = torch.stack(drug_fps).to(self.device)

            # Forward pass
            final_expr, _, _, delta_expr = self.model(unpert_expr_tensor, drug_fps_tensor)

            # Convert to numpy
            final_expr = final_expr.cpu().numpy()
            delta_expr = delta_expr.cpu().numpy()

        # Create DataFrames
        genes = [f"Gene_{i+1}" for i in range(final_expr.shape[1])]
        df_final = pd.DataFrame(final_expr, columns=genes)
        df_delta = pd.DataFrame(delta_expr, columns=genes)

        return df_final, df_delta

In [3]:
predictor = WAVEPredictor(model_path="/path/to/best_model.pth", device="cpu")

In [4]:
adata = sc.read_h5ad("test_expand_cp.h5ad")

In [8]:
adata = adata[adata.obs['smiles'] != 'restricted']

In [9]:
unpert_expr = adata.layers['unpert_expr']
smiles_list = adata.obs['smiles'].tolist()

In [10]:
final_expr, delta_expr = predictor.predict(unpert_expr, smiles_list)

In [11]:
adata.layers['pred'] = final_expr

  adata.layers['pred'] = final_expr


In [12]:
adata.write_h5ad("test_expand_cp_pred.h5ad")