In [1]:
import torch
import pandas as pd
import numpy as np
from wave.model import WAVE
from wave.utils import morgan_fp
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

def load_wave_model(wave_model_path, genevae_model_path, device, **kwargs):
    model = WAVE().to(device)
    model.load_state_dict(torch.load(wave_model_path, map_location=device))
    model.eval()
    return model

def predict_expression(unpert_expr, smiles_list, model, device="cpu"):
    model = model.to(device)
    model.eval()

    with torch.no_grad():
        unpert_expr_tensor = torch.tensor(unpert_expr, dtype=torch.float32).to(device)
        drug_fps = [torch.tensor(morgan_fp(s), dtype=torch.float32) for s in smiles_list]
        drug_fps_tensor = torch.stack(drug_fps).to(device)

        batch = {"unpert_expr": unpert_expr_tensor, "drug_fp": drug_fps_tensor}
        predicted_expression = model(batch).cpu().numpy()

    return pd.DataFrame(predicted_expression, columns=[f"Gene_{i+1}" for i in range(predicted_expression.shape[1])])

# 加载模型

model = load_wave_model(wave_model_path = "D:/git_down/wave/best_model.pth", genevae_model_path='"D:/git_down/vae_model.pth"', device='cuda')

# 准备数据
smiles_list = [
    'CC(=O)OC1=CC=CC=C1C(=O)O',
    'CC(C)CC(C)C(=O)O'
]
unpert_expr = np.random.rand(2, 978)  # 示例数据，替换为你的实际数据

# 进行预测
predicted_expression = predict_expression(unpert_expr, smiles_list, model, device='cpu')
output_dir='D:/git_down/wave/predicted_expression.csv'
predicted_expression.to_csv(output_dir)



ModuleNotFoundError: No module named 'torch'