# REINVENT4 LibInvent - NVIDIA DGX Spark
## DENV NS2B-NS3 Protease Inhibitor Generation

**Hardware**: NVIDIA DGX Spark (128GB unified memory, GB10 GPU)

**Target**: Generate ~2.5M molecules in 8-12 hours

**Scaffold**: Pyrrolidine dual aromatic

---

### Quick Start
1. 确保所有文件已复制到正确位置
2. 点击 `Run All` 即可开始
3. 训练过程中可以在TensorBoard查看实时进度

---

In [None]:
# Cell 1: 环境检查和导入
import os
import sys
import shutil
import subprocess
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime

print("=" * 80)
print("REINVENT4 LibInvent - DGX Spark Setup")
print("=" * 80)
print(f"\n时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Python版本: {sys.version}")
print(f"工作目录: {os.getcwd()}")

# 检查CUDA
try:
    import torch
    print(f"\nPyTorch版本: {torch.__version__}")
    print(f"CUDA可用: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"CUDA版本: {torch.version.cuda}")
        print(f"GPU设备: {torch.cuda.get_device_name(0)}")
        print(f"GPU内存: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
except Exception as e:
    print(f"\n⚠️ PyTorch检查失败: {e}")

print("\n✅ 环境检查完成")

In [None]:
# Cell 2: 路径配置
# 根据你的实际路径修改这些变量
REINVENT_HOME = Path.home() / "aidd" / "REINVENT4"
EXP_NAME = "spark_run1"
EXP_DIR = REINVENT_HOME / "experiments" / "runs" / EXP_NAME

print("路径配置:")
print(f"  REINVENT根目录: {REINVENT_HOME}")
print(f"  实验目录: {EXP_DIR}")

# 创建实验目录
EXP_DIR.mkdir(parents=True, exist_ok=True)
print(f"\n✅ 实验目录已创建: {EXP_DIR}")

In [None]:
# Cell 3: 文件完整性检查
print("检查必需文件...\n")

required_files = {
    "Prior模型": REINVENT_HOME / "priors" / "libinvent.prior",
    "Agent模型": REINVENT_HOME / "priors" / "denv_libinvent_model.model",
    "Scaffold文件": REINVENT_HOME / "data" / "pyrrolidine_dual_aryl.smi",
    "QSAR模型": REINVENT_HOME / "models" / "random_forest_champion.joblib",
    "QSAR插件": REINVENT_HOME / "reinvent" / "reinvent_plugins" / "components" / "comp_qsar_scorer.py",
}

all_exists = True
for name, path in required_files.items():
    exists = path.exists()
    status = "✅" if exists else "❌"
    print(f"{status} {name}: {path}")
    if not exists:
        all_exists = False
        print(f"   ⚠️ 文件不存在，请手动复制！")

if all_exists:
    print("\n✅ 所有必需文件检查通过")
else:
    print("\n❌ 有文件缺失，请先复制文件再继续")
    print("\n停止执行。请复制文件后重新运行。")
    raise FileNotFoundError("缺少必需文件")

In [None]:
# Cell 4: 生成配置文件
config_content = f"""# REINVENT4 LibInvent Configuration - NVIDIA DGX Spark Optimized
# Generated at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}

run_type = "staged_learning"
device = "cuda:0"
tb_logdir = "experiments/runs/{EXP_NAME}/tensorboard"
json_out_config = "experiments/runs/{EXP_NAME}/_config.json"

[parameters]
summary_csv_prefix = "experiments/runs/{EXP_NAME}/results"
use_checkpoint = true
purge_memories = false

prior_file = "priors/libinvent.prior"
agent_file = "priors/denv_libinvent_model.model"
smiles_file = "data/pyrrolidine_dual_aryl.smi"

# DGX Spark优化
batch_size = 512
unique_sequences = true
randomize_smiles = true

[learning_strategy]
type = "dap"
sigma = 60
rate = 0.00015

[diversity_filter]
type = "IdenticalMurckoScaffold"
bucket_size = 15
minscore = 0.6

[inception]
smiles_file = "data/pyrrolidine_dual_aryl.smi"
memory_size = 100
sample_size = 20

[[stage]]
max_score = 0.85
min_steps = 1000
max_steps = 5000
chkpt_file = "experiments/runs/{EXP_NAME}/checkpoint_stage1.chkpt"

[stage.scoring]
type = "arithmetic_mean"

# QSAR Activity Prediction
[[stage.scoring.component]]
[stage.scoring.component.QSARScorer]
[[stage.scoring.component.QSARScorer.endpoint]]
name = "DENV_Activity"
weight = 3.0
[stage.scoring.component.QSARScorer.endpoint.params]
model_path = "models/random_forest_champion.joblib"
[stage.scoring.component.QSARScorer.endpoint.transform]
type = "double_sigmoid"
low = 4.5
high = 9.0
coef_div = 9.0
coef_si = 8.0
coef_se = 8.0

# Stereochemistry
[[stage.scoring.component]]
[stage.scoring.component.NumAtomStereoCenters]
[[stage.scoring.component.NumAtomStereoCenters.endpoint]]
name = "Stereo_Centers"
weight = 0.6
[stage.scoring.component.NumAtomStereoCenters.endpoint.transform]
type = "reverse_sigmoid"
low = 0
high = 3
k = 1.0

# Chemical Stability Alerts
[[stage.scoring.component]]
[stage.scoring.component.CustomAlerts]
[[stage.scoring.component.CustomAlerts.endpoint]]
name = "Stability_Alerts"
weight = 1.2
[stage.scoring.component.CustomAlerts.endpoint.params]
smarts = [
    "[*;r8]", "[*;r9]", "[*;r10]", "[*;r11]",
    "[#8][#8]", "[#6;+]", "C#C", "[NX3][NX3]",
    "[SH]", "[N+](=O)[O-]", "S(=O)(=O)Cl",
    "[F,Cl,Br,I][C,c][F,Cl,Br,I]",
]

# Drug-likeness
[[stage.scoring.component]]
[stage.scoring.component.Qed]
[[stage.scoring.component.Qed.endpoint]]
name = "QED"
weight = 0.4

[[stage.scoring.component]]
[stage.scoring.component.SAScore]
[[stage.scoring.component.SAScore.endpoint]]
name = "SA"
weight = 0.5
[stage.scoring.component.SAScore.endpoint.transform]
type = "reverse_sigmoid"
low = 1.0
high = 5.0
k = 0.8

# Molecular Properties
[[stage.scoring.component]]
[stage.scoring.component.MolecularWeight]
[[stage.scoring.component.MolecularWeight.endpoint]]
name = "MW"
weight = 0.5
[stage.scoring.component.MolecularWeight.endpoint.transform]
type = "double_sigmoid"
low = 280.0
high = 550.0
coef_div = 550.0
coef_si = 15.0
coef_se = 15.0

[[stage.scoring.component]]
[stage.scoring.component.SlogP]
[[stage.scoring.component.SlogP.endpoint]]
name = "LogP"
weight = 0.4
[stage.scoring.component.SlogP.endpoint.transform]
type = "double_sigmoid"
low = 1.0
high = 4.5
coef_div = 4.5
coef_si = 10.0
coef_se = 10.0
"""

config_path = EXP_DIR / "config.toml"
with open(config_path, 'w') as f:
    f.write(config_content)

print(f"✅ 配置文件已生成: {config_path}")
print(f"\n配置摘要:")
print(f"  - Batch size: 512")
print(f"  - Max steps: 5000")
print(f"  - 预期分子数: ~2,560,000")
print(f"  - Learning rate: 0.00015")
print(f"  - Sigma: 60")

In [None]:
# Cell 5: 启动REINVENT4训练
print("=" * 80)
print("启动REINVENT4训练")
print("=" * 80)
print(f"\n开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"\n预计运行时间: 8-12 小时")
print(f"预计生成分子: ~2,560,000 个\n")

# 切换到REINVENT目录
os.chdir(REINVENT_HOME)
print(f"当前目录: {os.getcwd()}\n")

# 准备命令
config_rel_path = config_path.relative_to(REINVENT_HOME)
log_path = EXP_DIR / "training.log"

cmd = f"reinvent {config_rel_path}"
print(f"执行命令: {cmd}")
print(f"日志文件: {log_path}\n")
print("=" * 80)
print("训练开始...")
print("=" * 80)
print()

# 执行训练（实时输出）
with open(log_path, 'w') as log_file:
    process = subprocess.Popen(
        cmd,
        shell=True,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        universal_newlines=True,
        bufsize=1
    )
    
    for line in process.stdout:
        print(line, end='')
        log_file.write(line)
        log_file.flush()
    
    process.wait()
    
    if process.returncode == 0:
        print("\n" + "=" * 80)
        print("✅ 训练成功完成！")
        print("=" * 80)
    else:
        print("\n" + "=" * 80)
        print(f"❌ 训练失败 (返回码: {process.returncode})")
        print("=" * 80)

print(f"\n结束时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

In [None]:
# Cell 6: 结果分析 - 加载数据
print("=" * 80)
print("结果分析")
print("=" * 80)

# 查找结果CSV文件
csv_files = list(EXP_DIR.glob("results_*.csv"))
if not csv_files:
    print("\n❌ 未找到结果CSV文件")
else:
    csv_file = csv_files[0]
    print(f"\n加载结果文件: {csv_file.name}")
    
    df = pd.read_csv(csv_file)
    print(f"\n数据维度: {df.shape}")
    print(f"总生成分子数: {len(df):,}")
    
    # 基本统计
    print("\n" + "=" * 80)
    print("基本统计")
    print("=" * 80)
    
    valid = df[df['SMILES_state'] == 1]
    invalid = df[df['SMILES_state'] == 0]
    duplicates = df[df['SMILES_state'] == 2]
    
    print(f"\nValid SMILES: {len(valid):,} ({len(valid)/len(df)*100:.2f}%)")
    print(f"Invalid SMILES: {len(invalid):,} ({len(invalid)/len(df)*100:.2f}%)")
    print(f"Batch duplicates: {len(duplicates):,} ({len(duplicates)/len(df)*100:.2f}%)")
    
    # Unique分子
    unique_smiles = df['SMILES'].nunique()
    print(f"\nUnique SMILES: {unique_smiles:,} ({unique_smiles/len(df)*100:.2f}%)")
    
    # 分数统计
    print("\n" + "=" * 80)
    print("分数统计")
    print("=" * 80)
    
    if 'total_score' in df.columns:
        print(f"\nTotal Score:")
        print(f"  Mean: {df['total_score'].mean():.4f}")
        print(f"  Std: {df['total_score'].std():.4f}")
        print(f"  Min: {df['total_score'].min():.4f}")
        print(f"  Max: {df['total_score'].max():.4f}")
    
    if 'DENV_Activity' in df.columns:
        print(f"\nDENV Activity (pIC50):")
        print(f"  Mean: {df['DENV_Activity'].mean():.4f}")
        print(f"  Std: {df['DENV_Activity'].std():.4f}")
        print(f"  Min: {df['DENV_Activity'].min():.4f}")
        print(f"  Max: {df['DENV_Activity'].max():.4f}")

In [None]:
# Cell 7: 金标准候选分子提取
print("=" * 80)
print("金标准候选分子提取")
print("=" * 80)

# 定义金标准阈值
thresholds = {
    'pIC50': 7.5,        # IC50 < ~32 nM
    'QED': 0.6,
    'SA': 5.0,
    'total_score': 0.7,
}

print(f"\n金标准阈值:")
for k, v in thresholds.items():
    print(f"  {k}: {v}")

# 筛选金标准分子
gold_mask = (
    (df['SMILES_state'] == 1) &
    (df['DENV_Activity'] >= thresholds['pIC50']) &
    (df['QED'] >= thresholds['QED']) &
    (df['SA'] <= thresholds['SA']) &
    (df['total_score'] >= thresholds['total_score'])
)

gold_df = df[gold_mask].copy()
print(f"\n金标准候选数: {len(gold_df):,}")

# 去重
gold_df_unique = gold_df.drop_duplicates(subset=['SMILES'])
print(f"去重后: {len(gold_df_unique):,}")

# 保存金标准分子
if len(gold_df_unique) > 0:
    gold_output = EXP_DIR / "gold_standard_candidates.csv"
    gold_df_unique.to_csv(gold_output, index=False)
    print(f"\n✅ 金标准候选已保存: {gold_output.name}")
    
    # 显示前10个
    print("\nTop 10金标准候选 (按pIC50排序):")
    top10 = gold_df_unique.nlargest(10, 'DENV_Activity')
    print(top10[['SMILES', 'DENV_Activity', 'QED', 'SA', 'total_score']].to_string())
else:
    print("\n⚠️ 未找到符合金标准的分子")

In [None]:
# Cell 8: 可视化分析
print("=" * 80)
print("可视化分析")
print("=" * 80)

# 设置绘图样式
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (16, 10)

fig, axes = plt.subplots(2, 3, figsize=(18, 12))
fig.suptitle('REINVENT4 Generation Results - DGX Spark', fontsize=16, fontweight='bold')

# 1. Total Score分布
axes[0, 0].hist(df['total_score'], bins=50, edgecolor='black', alpha=0.7)
axes[0, 0].axvline(df['total_score'].mean(), color='red', linestyle='--', label=f'Mean: {df["total_score"].mean():.3f}')
axes[0, 0].set_xlabel('Total Score')
axes[0, 0].set_ylabel('Frequency')
axes[0, 0].set_title('Total Score Distribution')
axes[0, 0].legend()

# 2. DENV Activity (pIC50)分布
axes[0, 1].hist(df['DENV_Activity'], bins=50, edgecolor='black', alpha=0.7, color='green')
axes[0, 1].axvline(7.5, color='red', linestyle='--', label='Gold threshold (7.5)')
axes[0, 1].set_xlabel('Predicted pIC50')
axes[0, 1].set_ylabel('Frequency')
axes[0, 1].set_title('DENV Activity Distribution')
axes[0, 1].legend()

# 3. QED分布
axes[0, 2].hist(df['QED'], bins=50, edgecolor='black', alpha=0.7, color='orange')
axes[0, 2].axvline(df['QED'].mean(), color='red', linestyle='--', label=f'Mean: {df["QED"].mean():.3f}')
axes[0, 2].set_xlabel('QED')
axes[0, 2].set_ylabel('Frequency')
axes[0, 2].set_title('QED Distribution')
axes[0, 2].legend()

# 4. SA Score分布
axes[1, 0].hist(df['SA'], bins=50, edgecolor='black', alpha=0.7, color='purple')
axes[1, 0].axvline(df['SA'].mean(), color='red', linestyle='--', label=f'Mean: {df["SA"].mean():.3f}')
axes[1, 0].set_xlabel('SA Score')
axes[1, 0].set_ylabel('Frequency')
axes[1, 0].set_title('Synthetic Accessibility Distribution')
axes[1, 0].legend()

# 5. pIC50 vs QED散点图
axes[1, 1].scatter(df['DENV_Activity'], df['QED'], alpha=0.3, s=1)
axes[1, 1].axhline(0.6, color='red', linestyle='--', alpha=0.5)
axes[1, 1].axvline(7.5, color='red', linestyle='--', alpha=0.5)
axes[1, 1].set_xlabel('Predicted pIC50')
axes[1, 1].set_ylabel('QED')
axes[1, 1].set_title('Activity vs Drug-likeness')

# 6. Step进度
step_stats = df.groupby('step')['total_score'].mean()
axes[1, 2].plot(step_stats.index, step_stats.values, linewidth=2)
axes[1, 2].set_xlabel('Training Step')
axes[1, 2].set_ylabel('Mean Total Score')
axes[1, 2].set_title('Training Progress')
axes[1, 2].grid(True, alpha=0.3)

plt.tight_layout()
plot_path = EXP_DIR / "analysis_plots.png"
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
print(f"\n✅ 分析图表已保存: {plot_path.name}")
plt.show()

In [None]:
# Cell 9: 生成最终报告
print("=" * 80)
print("生成最终报告")
print("=" * 80)

report = f"""
REINVENT4 LibInvent Generation Report
========================================
Experiment: {EXP_NAME}
Hardware: NVIDIA DGX Spark (128GB, GB10 GPU)
Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}

Configuration
-------------
Batch Size: 512
Max Steps: 5000
Learning Rate: 0.00015
Sigma: 60
Diversity Filter Bucket Size: 15

Generation Statistics
--------------------
Total Molecules Generated: {len(df):,}
Valid SMILES: {len(valid):,} ({len(valid)/len(df)*100:.2f}%)
Invalid SMILES: {len(invalid):,} ({len(invalid)/len(df)*100:.2f}%)
Batch Duplicates: {len(duplicates):,} ({len(duplicates)/len(df)*100:.2f}%)
Unique Molecules: {unique_smiles:,} ({unique_smiles/len(df)*100:.2f}%)

Score Statistics
----------------
Total Score: {df['total_score'].mean():.4f} ± {df['total_score'].std():.4f}
DENV Activity (pIC50): {df['DENV_Activity'].mean():.4f} ± {df['DENV_Activity'].std():.4f}
QED: {df['QED'].mean():.4f} ± {df['QED'].std():.4f}
SA Score: {df['SA'].mean():.4f} ± {df['SA'].std():.4f}

Gold Standard Candidates
------------------------
Total Gold Candidates: {len(gold_df_unique):,}
Thresholds:
  - pIC50 ≥ 7.5 (IC50 < ~32 nM)
  - QED ≥ 0.6
  - SA ≤ 5.0
  - Total Score ≥ 0.7

Output Files
------------
1. {csv_file.name}
2. gold_standard_candidates.csv
3. analysis_plots.png
4. training.log
5. config.toml

Next Steps
----------
1. Review gold standard candidates
2. Run virtual screening with LigUnity
3. Select candidates for synthesis
4. Experimental validation

========================================
"""

report_path = EXP_DIR / "REPORT.txt"
with open(report_path, 'w') as f:
    f.write(report)

print(report)
print(f"\n✅ 报告已保存: {report_path.name}")
print("\n" + "=" * 80)
print("分析完成！")
print("=" * 80)

In [None]:
# Cell 10: TensorBoard启动指南
print("=" * 80)
print("TensorBoard可视化")
print("=" * 80)
print("\n要启动TensorBoard实时监控，请在终端运行：\n")
print(f"cd {REINVENT_HOME}")
print(f"tensorboard --logdir experiments/runs/{EXP_NAME}/tensorboard --bind_all\n")
print("然后在浏览器打开: http://localhost:6006\n")
print("=" * 80)