# FK-RFdiffusion

! WARNING: we are still working on getting this notebook to work! Bear with us!

This notebook demonstrates Feynman-Kac guided protein design using FK-RFdiffusion.

Based on [this notebook](https://colab.research.google.com/github/sokrypton/ColabDesign/blob/main/rf/examples/diffusion.ipynb) but modified to use guided diffusion with reward functions.

## Example Usage

### Binder Design
Design a protein binder to a target structure with interface energy optimization:
```python
run_feynman_kac_design(
    contigs=["A1-50/0 20"],           # Target chain A (residues 1-50), then design 20-residue binder
    target_structure="target.pdb",    # Path to target PDB
    reward_function="interface_dG",   # Optimize binding energy
    n_particles=10,
    resampling_frequency=5,
    guidance_start_timestep=30
)
```

### Unconditional Design
Design a standalone protein with specific secondary structure:
```python
run_feynman_kac_design(
    contigs=["75"],                           # Design 75-residue protein
    reward_function="secondary_structure",    # Optimize secondary structure
    n_particles=10,
    resampling_frequency=5
)
```

In [None]:
#@title setup **FK-RFdiffusion** (~5min)
%%time
import os, time, signal
import sys, random, string, re

# Download RFdiffusion weights and dependencies
if not os.path.isdir("params"):
  os.system("apt-get install aria2")
  os.system("mkdir params")
  # send param download into background
  os.system("(\
  aria2c -q -x 16 https://files.ipd.uw.edu/krypton/schedules.zip; \
  aria2c -q -x 16 http://files.ipd.uw.edu/pub/RFdiffusion/6f5902ac237024bdd0c176cb93063dc4/Base_ckpt.pt; \
  aria2c -q -x 16 http://files.ipd.uw.edu/pub/RFdiffusion/e29311f6f1bf1af907f9ef9f44b8328b/Complex_base_ckpt.pt; \
  aria2c -q -x 16 http://files.ipd.uw.edu/pub/RFdiffusion/f572d396fae9206628714fb2ce00f72e/Complex_beta_ckpt.pt; \
  touch params/done.txt) &")

# Install FK-RFdiffusion from GitHub (includes RFdiffusion and ProteinMPNN as submodules)
if not os.path.isdir("FK-RFdiffusion"):
  print("Installing FK-RFdiffusion from GitHub...")
  os.system("git clone --recurse-submodules https://github.com/ErikHartman/FK-RFdiffusion.git")
  os.system("pip install jedi omegaconf hydra-core icecream pyrsistent pynvml decorator")
  os.system("pip install git+https://github.com/NVIDIA/dllogger#egg=dllogger")
  # 17Mar2024: adding --no-dependencies to avoid installing nvidia-cuda-* dependencies
  # 25Aug2025: updating dgl install to work with latest pytorch
  os.system("pip install --no-dependencies dgl -f https://data.dgl.ai/wheels/torch-2.4/cu124/repo.html")
  os.system("pip install --no-dependencies e3nn==0.5.5 opt_einsum_fx")
  os.system("cd FK-RFdiffusion/externals/RFdiffusion/env/SE3Transformer; pip install .")
  os.system("wget -qnc https://files.ipd.uw.edu/krypton/ananas")
  os.system("chmod +x ananas")
  
  # Install FK-RFdiffusion package
  os.system("cd FK-RFdiffusion && pip install -e .")
  print("FK-RFdiffusion installed successfully!")

# Move RFdiffusion models to the right location
if not os.path.isdir("FK-RFdiffusion/externals/RFdiffusion/models"):
  print("downloading RFdiffusion params...")
  os.system("mkdir FK-RFdiffusion/externals/RFdiffusion/models")
  models = ["Base_ckpt.pt","Complex_base_ckpt.pt","Complex_beta_ckpt.pt"]
  for m in models:
    while os.path.isfile(f"{m}.aria2"):
      time.sleep(5)
  os.system(f"mv {' '.join(models)} FK-RFdiffusion/externals/RFdiffusion/models")
  os.system("unzip schedules.zip; rm schedules.zip")

# Add to path FIRST, before importing
os.environ["DGLBACKEND"] = "pytorch"

# Add both FK-RFdiffusion and RFdiffusion to path
fk_path = os.path.abspath('FK-RFdiffusion')
rfd_path = os.path.abspath('FK-RFdiffusion/externals/RFdiffusion')

if fk_path not in sys.path:
  sys.path.insert(0, fk_path)
if rfd_path not in sys.path:
  sys.path.insert(0, rfd_path)

# Import necessary libraries
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import display, HTML
import ipywidgets as widgets

# Import FK-RFdiffusion modules (now should work)
from fk_rfdiffusion.run_inference_guided import run_feynman_kac_design

print("Setup complete!")

In [None]:
%%time
#@title run **FK-RFdiffusion** to generate a backbone with guided diffusion

#@markdown ### Basic Settings
name = "test" #@param {type:"string"}
contigs = "100" #@param {type:"string"}
pdb = "" #@param {type:"string"}
num_designs = 1 #@param ["1", "2", "4", "8"] {type:"raw"}

#@markdown ### FK Guidance Parameters
reward_function = "secondary_structure" #@param ["interface_dG", "interface_dSASA", "interface_dGdSASA", "secondary_structure", "positive_charge", "negative_charge"]
n_particles = 10 #@param {type:"integer"}
resampling_frequency = 5 #@param {type:"integer"}
guidance_start_timestep = 30 #@param {type:"integer"}
potential_mode = "difference" #@param ["immediate", "difference", "max", "sum", "blind"]
tau = None #@param {type:"raw"}
n_sequences = 1 #@param {type:"integer"}

#@markdown ### Advanced Settings
iterations = 50 #@param ["25", "50", "100", "150", "200"] {type:"raw"}
hotspot = "" #@param {type:"string"}
checkpoint = "base" #@param ["base", "beta"]
aggregation_mode = "mean" #@param ["mean", "max"]

#@markdown - `reward_function`: Which reward to optimize (interface_dG for binders, secondary_structure for unconditional)
#@markdown - `n_particles`: Number of particles in FK sampling (more = better but slower)
#@markdown - `resampling_frequency`: How often to resample particles (lower = more guidance)
#@markdown - `tau`: Temperature parameter (lower = stronger guidance, None = use default)
#@markdown - `potential_mode`: How to compute guidance potential
#@markdown - `n_sequences`: Number of ProteinMPNN sequences per evaluation

# Parse contigs to determine mode
contigs_list = [contigs] if isinstance(contigs, str) else contigs
target_structure = pdb if pdb and pdb.strip() else None

# Prepare hotspot if provided
hotspot_res = None
if hotspot and hotspot.strip():
    hotspot_res = [h.strip() for h in hotspot.replace(",", " ").split()]

# Determine output path
path = name
counter = 0
while os.path.exists(f"{path}_0.pdb"):
    counter += 1
    path = f"{name}_{counter}"

print(f"Running FK-RFdiffusion with:")
print(f"  Contigs: {contigs_list}")
print(f"  Reward: {reward_function}")
print(f"  Particles: {n_particles}")
print(f"  Resampling freq: {resampling_frequency}")
print(f"  Guidance start: {guidance_start_timestep}")
print(f"  Potential mode: {potential_mode}")
print(f"  Output: {path}")

# Run FK-guided design
run_feynman_kac_design(
    contigs=contigs_list,
    target_structure=target_structure,
    hotspot_res=hotspot_res,
    num_designs=num_designs,
    output_prefix=path,
    n_particles=n_particles,
    resampling_frequency=resampling_frequency,
    guidance_start_timestep=guidance_start_timestep,
    potential_mode=potential_mode,
    tau=tau,
    checkpoint=checkpoint,
    reward_function=reward_function,
    n_sequences=n_sequences,
    aggregation_mode=aggregation_mode,
    save_full_trajectory=False
)

print(f"\nDesign complete! Output saved to {path}_*.pdb")

In [None]:
#@title Visualize Reward Over Time

# Load the metadata CSV file
metadata_file = f"{path}_metadata.csv"

if os.path.exists(metadata_file):
    df = pd.read_csv(metadata_file)
    
    # Get unique designs
    designs = df['design_idx'].unique()
    
    # Create figure
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Plot 1: Reward trajectories for all particles
    ax1 = axes[0]
    for design_idx in designs:
        design_df = df[df['design_idx'] == design_idx]
        
        # Group by timestep and particle
        for particle in design_df['particle_name'].unique():
            particle_df = design_df[design_df['particle_name'] == particle]
            particle_df = particle_df.sort_values('timestep', ascending=False)
            
            ax1.plot(particle_df['timestep'], particle_df['reward'], 
                    alpha=0.3, linewidth=0.5, color='blue')
    
    ax1.set_xlabel('Timestep', fontsize=12)
    ax1.set_ylabel('Reward', fontsize=12)
    ax1.set_title('Reward Trajectories (All Particles)', fontsize=14)
    ax1.grid(True, alpha=0.3)
    ax1.invert_xaxis()  # Time goes backward in diffusion
    
    # Plot 2: Mean reward over time with confidence intervals
    ax2 = axes[1]
    for design_idx in designs:
        design_df = df[df['design_idx'] == design_idx]
        
        # Group by timestep and compute statistics
        timestep_stats = design_df.groupby('timestep')['reward'].agg(['mean', 'std', 'count'])
        timestep_stats = timestep_stats.sort_index(ascending=False)
        
        timesteps = timestep_stats.index
        means = timestep_stats['mean']
        stds = timestep_stats['std']
        
        ax2.plot(timesteps, means, label=f'Design {design_idx}', linewidth=2)
        ax2.fill_between(timesteps, means - stds, means + stds, alpha=0.2)
    
    ax2.set_xlabel('Timestep', fontsize=12)
    ax2.set_ylabel('Mean Reward', fontsize=12)
    ax2.set_title('Mean Reward Over Time (Â±1 std)', fontsize=14)
    ax2.grid(True, alpha=0.3)
    ax2.invert_xaxis()
    ax2.legend()
    
    plt.tight_layout()
    plt.show()
    
    # Print summary statistics
    print("\n=== Reward Summary ===")
    for design_idx in designs:
        design_df = df[df['design_idx'] == design_idx]
        final_rewards = design_df[design_df['timestep'] == design_df['timestep'].min()]['reward']
        
        print(f"\nDesign {design_idx}:")
        print(f"  Final reward (mean): {final_rewards.mean():.4f}")
        print(f"  Final reward (std):  {final_rewards.std():.4f}")
        print(f"  Final reward (max):  {final_rewards.max():.4f}")
        print(f"  Final reward (min):  {final_rewards.min():.4f}")
        
        # Show improvement
        initial_rewards = design_df[design_df['timestep'] == design_df['timestep'].max()]['reward']
        improvement = final_rewards.mean() - initial_rewards.mean()
        print(f"  Improvement: {improvement:.4f}")
        
else:
    print(f"Metadata file not found: {metadata_file}")
    print("Make sure to run the design first!")