In [1]:
# Cell 1: Imports and Setup
# Import necessary standard libraries
import pickle
import urllib.request
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Literal, Optional, Any # Added Any

# Import Click for command-line interface (though we'll run it directly)
import click

# Import PyTorch and PyTorch Lightning
import torch
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.utilities import rank_zero_only
from tqdm import tqdm

# Import Boltz specific modules (assuming they are in the PYTHONPATH or same directory)
from boltz.data import const
from boltz.data.module.inference import BoltzInferenceDataModule
from boltz.data.msa.mmseqs2 import run_mmseqs2
from boltz.data.parse.a3m import parse_a3m
from boltz.data.parse.csv import parse_csv
from boltz.data.parse.fasta import parse_fasta
from boltz.data.parse.yaml import parse_yaml
from boltz.data.types import MSA, Manifest, Record
from boltz.data.write.writer import BoltzWriter
from boltz.model.model import Boltz1

# Import PyTorch Profiler
import torch.profiler
import os # For creating profiler output directory

# Define constants for URLs
CCD_URL = "https://huggingface.co/boltz-community/boltz-1/resolve/main/ccd.pkl"
MODEL_URL = (
    "https://huggingface.co/boltz-community/boltz-1/resolve/main/boltz1_conf.ckpt"
)


In [2]:
# Cell 2: Dataclass Definitions
@dataclass
class BoltzProcessedInput:
    """Processed input data."""

    manifest: Manifest
    targets_dir: Path
    msa_dir: Path


@dataclass
class BoltzDiffusionParams:
    """Diffusion process parameters."""

    gamma_0: float = 0.605
    gamma_min: float = 1.107
    noise_scale: float = 0.901
    rho: float = 8
    step_scale: float = 1.638
    sigma_min: float = 0.0004
    sigma_max: float = 160.0
    sigma_data: float = 16.0
    P_mean: float = -1.2
    P_std: float = 1.5
    coordinate_augmentation: bool = True
    alignment_reverse_diff: bool = True
    synchronize_sigmas: bool = True
    use_inference_model_cache: bool = True


In [3]:
# Cell 3: Download Function
@rank_zero_only
def download(cache: Path) -> None:
    """Download all the required data.

    Parameters
    ----------
    cache : Path
        The cache directory.

    """
    # Download CCD
    ccd = cache / "ccd.pkl"
    if not ccd.exists():
        click.echo(
            f"Downloading the CCD dictionary to {ccd}. You may "
            "change the cache directory with the --cache flag."
        )
        urllib.request.urlretrieve(CCD_URL, str(ccd))

    # Download model
    model = cache / "boltz1_conf.ckpt"
    if not model.exists():
        click.echo(
            f"Downloading the model weights to {model}. You may "
            "change the cache directory with the --cache flag."
        )
        urllib.request.urlretrieve(MODEL_URL, str(model))


In [4]:
# Cell 4: Input Checking Function
def check_inputs(
    data: Path,
    outdir: Path,
    override: bool = False,
) -> list[Path]:
    """Check the input data and output directory.

    If the input data is a directory, it will be expanded
    to all files in this directory. Then, we check if there
    are any existing predictions and remove them from the
    list of input data, unless the override flag is set.

    Parameters
    ----------
    data : Path
        The input data.
    outdir : Path
        The output directory.
    override: bool
        Whether to override existing predictions.

    Returns
    -------
    list[Path]
        The list of input data.

    """
    click.echo("Checking input data.")

    if data.is_dir():
        data_files: list[Path] = list(data.glob("*"))
        filtered_data = []
        for d_file in data_files:
            if d_file.suffix in (".fa", ".fas", ".fasta", ".yml", ".yaml"):
                filtered_data.append(d_file)
            elif d_file.is_dir():
                msg = f"Found directory {d_file} instead of .fasta or .yaml."
                raise RuntimeError(msg)
            else:
                msg = (
                    f"Unable to parse filetype {d_file.suffix}, "
                    "please provide a .fasta or .yaml file."
                )
                raise RuntimeError(msg)
        data_list = filtered_data
    else:
        data_list = [data]

    existing = (outdir / "predictions").rglob("*")
    existing_stems = {e.stem for e in existing if e.is_dir()} # Use stem for comparison

    if existing_stems and not override:
        original_count = len(data_list)
        data_list = [d for d in data_list if d.stem not in existing_stems]
        num_skipped = original_count - len(data_list)
        if num_skipped > 0:
            msg = (
                f"Found {num_skipped} existing predictions, "
                f"skipping and running only the missing ones ({len(data_list)}), "
                "if any. If you wish to override these existing "
                "predictions, please set the --override flag."
            )
            click.echo(msg)
    elif existing_stems and override:
        msg = f"Found {len(existing_stems)} existing predictions, will override."
        click.echo(msg)

    return data_list


In [5]:
# Cell 5: MSA Computation Function
def compute_msa(
    data: dict[str, str],
    target_id: str,
    msa_dir: Path,
    msa_server_url: str,
    msa_pairing_strategy: str,
) -> None:
    """Compute the MSA for the input data.

    Parameters
    ----------
    data : dict[str, str]
        The input protein sequences.
    target_id : str
        The target id.
    msa_dir : Path
        The msa directory.
    msa_server_url : str
        The MSA server URL.
    msa_pairing_strategy : str
        The MSA pairing strategy.

    """
    if len(data) > 1:
        paired_msas = run_mmseqs2(
            list(data.values()),
            msa_dir / f"{target_id}_paired_tmp",
            use_env=True,
            use_pairing=True,
            host_url=msa_server_url,
            pairing_strategy=msa_pairing_strategy,
        )
    else:
        paired_msas = [""] * len(data)

    unpaired_msa = run_mmseqs2(
        list(data.values()),
        msa_dir / f"{target_id}_unpaired_tmp",
        use_env=True,
        use_pairing=False,
        host_url=msa_server_url,
        pairing_strategy=msa_pairing_strategy,
    )

    for idx, name in enumerate(data):
        paired = paired_msas[idx].strip().splitlines()
        paired = paired[1::2]
        paired = paired[: const.max_paired_seqs]

        keys = [i for i, s in enumerate(paired) if s != "-" * len(s)]
        paired = [s for s in paired if s != "-" * len(s)]

        unpaired = unpaired_msa[idx].strip().splitlines()
        unpaired = unpaired[1::2]
        unpaired = unpaired[: (const.max_msa_seqs - len(paired))]
        if paired:
            unpaired = unpaired[1:]

        seqs = paired + unpaired
        keys = keys + [-1] * len(unpaired)

        csv_str = ["key,sequence"] + [f"{key},{seq}" for key, seq in zip(keys, seqs)]

        msa_path = msa_dir / f"{name}.csv"
        with msa_path.open("w") as f:
            f.write("\n".join(csv_str))

In [19]:
# Cell 6: Input Processing Function
@rank_zero_only
def process_inputs(
    data: list[Path],
    out_dir: Path,
    ccd_path: Path,
    msa_server_url: str,
    msa_pairing_strategy: str,
    max_msa_seqs: int = 4096,
    use_msa_server: bool = False,
) -> None:
    """Process the input data and output directory.
    (Function body remains the same as in your original script)
    """
    click.echo("Processing input data.")
    existing_records = None

    manifest_path = out_dir / "processed" / "manifest.json"
    if manifest_path.exists():
        click.echo(f"Found a manifest file at output directory: {out_dir}")
        manifest_data: Manifest = Manifest.load(manifest_path)
        input_ids = [d.stem for d in data]
        
        # Ensure manifest_data.records is not None before proceeding
        if manifest_data.records is None:
            manifest_data.records = []

        processed_records_tuples = [
            (record, record.id)
            for record in manifest_data.records
            if record.id in input_ids
        ]
        
        if processed_records_tuples:
            existing_records, processed_ids = zip(*processed_records_tuples)
            existing_records = list(existing_records)
        else:
            existing_records = []
            processed_ids = []


        missing = len(input_ids) - len(processed_ids)
        if not missing:
            click.echo("All examples in data are processed. Updating the manifest.")
            if existing_records is not None:
                 updated_manifest = Manifest(existing_records)
                 updated_manifest.dump(out_dir / "processed" / "manifest.json")
            return

        click.echo(f"{missing} missing ids. Preprocessing these ids.")
        missing_ids = list(set(input_ids).difference(set(processed_ids)))
        data = [d for d in data if d.stem in missing_ids]
        assert len(data) == len(missing_ids)

    msa_dir = out_dir / "msa"
    structure_dir = out_dir / "processed" / "structures"
    processed_msa_dir = out_dir / "processed" / "msa"
    predictions_dir = out_dir / "predictions"

    out_dir.mkdir(parents=True, exist_ok=True)
    msa_dir.mkdir(parents=True, exist_ok=True)
    structure_dir.mkdir(parents=True, exist_ok=True)
    processed_msa_dir.mkdir(parents=True, exist_ok=True)
    predictions_dir.mkdir(parents=True, exist_ok=True)

    with ccd_path.open("rb") as file:
        ccd = pickle.load(file)

    if existing_records:
        click.echo(f"Found {len(existing_records)} records. Adding them to records.")

    records: list[Record] = list(existing_records) if existing_records is not None else []
    for path in tqdm(data):
        try:
            if path.suffix in (".fa", ".fas", ".fasta"):
                target = parse_fasta(path, ccd)
            elif path.suffix in (".yml", ".yaml"):
                target = parse_yaml(path, ccd)
            else:
                msg = f"Unsupported file type: {path.suffix}"
                raise RuntimeError(msg)

            target_id = target.record.id
            to_generate = {}
            prot_id = const.chain_type_ids["PROTEIN"]
            for chain in target.record.chains:
                if (chain.mol_type == prot_id) and (chain.msa_id == 0):
                    entity_id = chain.entity_id
                    msa_id_name = f"{target_id}_{entity_id}"
                    to_generate[msa_id_name] = target.sequences[entity_id]
                    chain.msa_id = msa_dir / f"{msa_id_name}.csv"
                elif chain.msa_id == 0:
                    chain.msa_id = -1 # type: ignore

            if to_generate and not use_msa_server:
                msg = "Missing MSA's in input and --use_msa_server flag not set."
                raise RuntimeError(msg)

            if to_generate:
                click.echo(f"Generating MSA for {path} with {len(to_generate)} protein entities.")
                compute_msa(
                    data=to_generate,
                    target_id=target_id,
                    msa_dir=msa_dir,
                    msa_server_url=msa_server_url,
                    msa_pairing_strategy=msa_pairing_strategy,
                )

            msas_paths = sorted({c.msa_id for c in target.record.chains if c.msa_id != -1 and c.msa_id != 0}) # Ensure c.msa_id is Path-like
            msa_id_map = {}
            for msa_idx, msa_p_obj in enumerate(msas_paths):
                msa_path = Path(msa_p_obj) # Ensure it's a Path object
                if not msa_path.exists():
                    msg = f"MSA file {msa_path} not found."
                    raise FileNotFoundError(msg)

                processed = processed_msa_dir / f"{target_id}_{msa_idx}.npz"
                msa_id_map[msa_p_obj] = f"{target_id}_{msa_idx}" # type: ignore
                if not processed.exists():
                    if msa_path.suffix == ".a3m":
                        msa: MSA = parse_a3m(msa_path, taxonomy=None, max_seqs=max_msa_seqs)
                    elif msa_path.suffix == ".csv":
                        msa: MSA = parse_csv(msa_path, max_seqs=max_msa_seqs)
                    else:
                        msg = f"MSA file {msa_path} not supported, only a3m or csv."
                        raise RuntimeError(msg)
                    msa.dump(processed)

            for c in target.record.chains:
                if (c.msa_id != -1 and c.msa_id != 0) and (c.msa_id in msa_id_map):
                    c.msa_id = msa_id_map[c.msa_id] # type: ignore

            records.append(target.record)
            struct_path = structure_dir / f"{target.record.id}.npz"
            target.structure.dump(struct_path)

        except Exception as e:
            if len(data) > 1:
                print(f"Failed to process {path}. Skipping. Error: {e}.")
            else:
                raise e

    manifest = Manifest(records)
    manifest.dump(out_dir / "processed" / "manifest.json")


In [26]:
# Cell 7: Prediction and Profiling Function
def predict_and_profile( # Renamed for clarity
    data_path_str: str, # Changed 'data' to 'data_path_str' to avoid conflict
    out_dir_str: str,   # Changed 'out_dir' to 'out_dir_str'
    cache_str: str = "~/.boltz",
    checkpoint: Optional[str] = None,
    # devices: int = 1, # Simplified for single device profiling for now
    accelerator: str = "gpu", # "cpu" or "gpu"
    recycling_steps: int = 3,
    sampling_steps: int = 200, # Reduced for faster profiling example
    diffusion_samples: int = 1, # Reduced for faster profiling
    step_scale: float = 1.638,
    # write_full_pae: bool = False, # Commented out unused params for this example
    # write_full_pde: bool = False,
    # output_format: Literal["pdb", "mmcif"] = "mmcif",
    num_workers: int = 0, # Often 0 for notebooks to avoid issues with multiprocessing
    override: bool = False,
    seed_val: Optional[int] = None, # Renamed from 'seed'
    use_msa_server: bool = False,
    msa_server_url: str = "https://api.colabfold.com",
    msa_pairing_strategy: str = "greedy",
    profile_output_dir: str = "./profiler_output" # Directory for profiler traces
) -> None:
    """Run a single prediction step with Boltz-1 and profile memory usage."""
    
    # Determine device for PyTorch
    if accelerator == "cpu":
        torch_device = torch.device("cpu")
        profiler_activities = [torch.profiler.ProfilerActivity.CPU]
        click.echo("Running on CPU. Profiling CPU activity.")
    elif accelerator == "gpu" and torch.cuda.is_available():
        torch_device = torch.device("cuda:0") # Assuming first GPU for simplicity
        profiler_activities = [torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA]
        click.echo(f"Running on GPU: {torch_device}. Profiling CPU and CUDA activity.")
    else:
        if accelerator == "gpu":
            click.echo("CUDA not available, falling back to CPU.")
        torch_device = torch.device("cpu")
        profiler_activities = [torch.profiler.ProfilerActivity.CPU]
        accelerator = "cpu" # Ensure consistency
        click.echo("Running on CPU. Profiling CPU activity.")

    # Set no grad
    torch.set_grad_enabled(False)

    # Ignore matmul precision warning
    torch.set_float32_matmul_precision("highest")

    # Set seed if desired
    if seed_val is not None:
        seed_everything(int(seed_val))

    # Set cache path
    cache = Path(cache_str).expanduser()
    cache.mkdir(parents=True, exist_ok=True)

    # Create output directories
    data_path = Path(data_path_str).expanduser()
    out_dir = Path(out_dir_str).expanduser()
    # Use a more specific output directory name to avoid clashes if data_path_str is a dir
    name_stem = data_path.stem if data_path.is_file() else data_path.name
    out_dir = out_dir / f"boltz_results_{name_stem}"
    out_dir.mkdir(parents=True, exist_ok=True)
    
    # Create profiler output directory
    Path(profile_output_dir).mkdir(parents=True, exist_ok=True)


    # Download necessary data and model
    download(cache)

    # Validate inputs
    # For profiling a single forward pass, we expect a single input file
    # The check_inputs function might filter it out if results exist and override is False.
    # We'll ensure we get at least one item or handle it.
    input_files = check_inputs(data_path, out_dir, override)
    if not input_files:
        click.echo("No input files to process after checking existing outputs. Exiting.")
        # If you expect `check_inputs` to return the original path even if processed (for re-profiling)
        # you might need to adjust `check_inputs` or handle it here.
        # For now, we'll assume if `input_files` is empty, we can't proceed.
        if data_path.is_file() and not override:
             click.echo(f"The file {data_path} might have been processed. Use --override or clean outputs.")
        return
    
    # We will profile prediction for the first file in the list.
    # If `data_path_str` was a directory, `process_inputs` will handle all,
    # but for focused profiling, we'll get a batch from the datamodule based on this processing.
    # The current `process_inputs` is @rank_zero_only and processes all `input_files`.

    click.echo(f"Running predictions for {len(input_files)} structure(s)")
    if len(input_files) > 1:
        click.echo("Note: Profiling will focus on one batch from the first processed input.")

    # Process inputs (generates manifest.json, etc.)
    ccd_path = cache / "ccd.pkl"
    process_inputs( # This will process all files in input_files
        data=input_files, # Pass the potentially filtered list
        out_dir=out_dir,
        ccd_path=ccd_path,
        use_msa_server=use_msa_server,
        msa_server_url=msa_server_url,
        msa_pairing_strategy=msa_pairing_strategy,
    )

    # Load processed data
    processed_dir = out_dir / "processed"
    manifest_file = processed_dir / "manifest.json"
    if not manifest_file.exists():
        click.echo(f"Manifest file not found at {manifest_file} after processing. Cannot proceed with profiling.")
        return
        
    processed = BoltzProcessedInput(
        manifest=Manifest.load(manifest_file),
        targets_dir=processed_dir / "structures",
        msa_dir=processed_dir / "msa",
    )
    
    if not processed.manifest.records:
        click.echo("No records found in the manifest. Cannot create dataloader for profiling.")
        return

    # Create data module
    data_module = BoltzInferenceDataModule(
        manifest=processed.manifest,
        target_dir=processed.targets_dir,
        msa_dir=processed.msa_dir,
        num_workers=num_workers,
        # For profiling, usually batch_size = 1 is good for isolating single sample behavior
        # However, BoltzInferenceDataModule might have its own batch size logic.
        # We will take the default from the module.
    )
    data_module.setup(stage='predict') # Call setup to prepare dataloaders

    # Load model
    if checkpoint is None:
        checkpoint_path = cache / "boltz1_conf.ckpt"
    else:
        checkpoint_path = Path(checkpoint)

    predict_args = {
        "recycling_steps": recycling_steps,
        "sampling_steps": sampling_steps,
        "diffusion_samples": diffusion_samples,
        "write_confidence_summary": False, # Keep false for focused profiling
        "write_full_pae": False,
        "write_full_pde": False,
    }
    diffusion_params = BoltzDiffusionParams()
    diffusion_params.step_scale = step_scale
    
    model_module: Boltz1 = Boltz1.load_from_checkpoint(
        checkpoint_path,
        strict=True,
        predict_args=predict_args,
        map_location="cpu", # Load to CPU first, then move to device
        diffusion_process_args=asdict(diffusion_params),
        ema=False,
        steering_args={"fk_steering": False,
                       "guidance_update": False,
                       "num_particles": 3,
                       "fk_lambda": 4.0},
    )
    model_module.eval()
    model_module.to(torch_device) # Move model to the target device

    click.echo(f"Model loaded and moved to {torch_device}.")

    # Get a single batch from the dataloader
    try:
        dataloader = data_module.predict_dataloader()
        if not dataloader:
            click.echo("Predict dataloader is None or empty. Cannot get a batch for profiling.")
            return
        batch = next(iter(dataloader))
    except Exception as e:
        click.echo(f"Error getting a batch from the dataloader: {e}")
        click.echo("Ensure your input data can be processed into at least one batch.")
        return

    # Move batch to the target device
    # The batch structure depends on BoltzInferenceDataModule's collate_fn
    # Assuming batch is a dict, list, or tensor that needs to be moved
    if isinstance(batch, dict):
        batch = {k: v.to(torch_device) if hasattr(v, 'to') else v for k, v in batch.items()}
    elif isinstance(batch, (list, tuple)):
        batch = [b.to(torch_device) if hasattr(b, 'to') else b for b in batch]
    elif hasattr(batch, 'to'):
        batch = batch.to(torch_device)
    
    click.echo("Single batch obtained and moved to device.")

    # Define the profiler trace handler
    # trace_handler = torch.profiler.tensorboard_trace_handler(profile_output_dir)
    
    click.echo(f"Starting profiling for a single predict_step. Trace will be saved to {profile_output_dir}")

    # # Profile the model's predict_step (which includes the forward pass)
    # with torch.profiler.profile(
    #     activities=profiler_activities,
    #     record_shapes=True,      # Records tensor shapes
    #     profile_memory=True,     # Enables memory profiling
    #     with_stack=True,         # Records callstacks
    #     on_trace_ready=trace_handler # Handles the trace when profiling is done
    # ) as prof:
    #     with torch.no_grad(): # Ensure no gradients are computed during this specific step
    #         # The predict_step in LightningModules usually takes (self, batch, batch_idx, dataloader_idx=0)
    #         # We simulate a call for the first batch (batch_idx=0)
    torch.cuda.memory._record_memory_history()
    _ = model_module.predict_step(batch, batch_idx=0) 
    torch.cuda.memory._dump_snapshot(f"{profile_output_dir}/result.pickle") # Dump memory snapshot
    # We don't need the output for profiling, just the execution.

    click.echo("Profiling finished.")
    click.echo(f"Profiler trace saved in '{profile_output_dir}'.")
    click.echo("You can view it using TensorBoard: tensorboard --logdir=" + profile_output_dir)

    # # Print a summary of memory usage to the console
    # click.echo("\n--- CPU Memory Usage Summary (Self) ---")
    # click.echo(prof.key_averages(group_by_input_shape=True).table(sort_by="self_cpu_memory_usage", row_limit=15))
    
    # if torch.profiler.ProfilerActivity.CUDA in profiler_activities:
    #     click.echo("\n--- CUDA Memory Usage Summary (Self) ---")
    #     click.echo(prof.key_averages(group_by_input_shape=True).table(sort_by="self_cuda_memory_usage", row_limit=15))
        
    # click.echo("\n--- Operator Table (CPU time) ---")
    # click.echo(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

In [27]:
# Cell 8: Main execution block (for script or notebook cell)
# if __name__ == "__main__":
# Create dummy input file for the example if it doesn't exist
# In a real scenario, replace this with your actual input data path
example_input_dir = Path("./inputs_prediction_profiling")
example_input_dir.mkdir(parents=True, exist_ok=True)
example_fasta_file = example_input_dir / "example.fasta"

if not example_fasta_file.exists():
    with open(example_fasta_file, "w") as f:
        f.write(">protein1\n")
        f.write("MILKADLINSLKNVFKSLENSESGSESSENSKENESGHSGSKRKRKPKSSSLLEARMELLLEKREKTKK\n") # Example sequence
    print(f"Created dummy input file: {example_fasta_file}")

In [28]:
# --- Parameters for the prediction and profiling ---
# Adjust these paths and parameters as needed for your setup

# Path to your input data (single .fasta, .yaml, or a directory)
# For this example, using the dummy file created above.
# IMPORTANT: For MSA generation to work without a server (use_msa_server=False),
# you would typically provide precomputed MSAs or point to a local MMseqs2 setup.
# The default `compute_msa` will try to call `run_mmseqs2`.
# If you don't have MMseqs2 setup and `use_msa_server=False`, `process_inputs` might fail
# during MSA generation. For simplicity, ensure your input YAML specifies MSA files,
# or use `use_msa_server=True` if ColabFold API is accessible and suitable.
#
# For this example to run somewhat out-of-the-box, let's try with use_msa_server=True.
# If you have local MSAs, adjust your input or set use_msa_server=False.

input_data_path = str(example_fasta_file) # Can be a dir too

output_directory = "./outputs_prediction_profiling"
cache_directory = "./rna-prediction-boltz-profiling/" # Separate cache for this run

# Profiler specific
profiler_trace_output_dir = "./profiler_trace_boltz"

# Model and run parameters (can be tuned)
# For profiling, it's often good to reduce steps for faster feedback initially
num_sampling_steps = 50  # Reduced from 200 for quicker profiling run
num_diffusion_samples = 1 # Usually 1 for structure prediction profiling

# Set override to True if you want to re-process inputs and re-predict
# Set to False to use existing processed data/predictions if available.
# For profiling, True is often useful to ensure a fresh run.
override_existing = True 

# Set a seed for reproducibility
random_seed = 42

# Choose accelerator: "cpu" or "gpu"
# If "gpu" is chosen, it will try to use CUDA.
# Ensure PyTorch with CUDA support is installed if you choose "gpu".
accelerator_type = "gpu" # Change to "gpu" if you have a compatible GPU

In [29]:
click.echo("--- Starting Prediction and Profiling ---")

# Ensure the dummy const module exists if your boltz.data.const is structured that way
# This is a placeholder for what might be in your actual `boltz.data.const`
if not hasattr(const, 'max_paired_seqs'):
    const.max_paired_seqs = 512 # Example value
if not hasattr(const, 'max_msa_seqs'):
    const.max_msa_seqs = 1024 # Example value
if not hasattr(const, 'chain_type_ids'):
    const.chain_type_ids = {"PROTEIN": 0, "RNA": 1, "DNA": 2, "LIGAND":3} # Example values

predict_and_profile(
    data_path_str=input_data_path,
    out_dir_str=output_directory,
    cache_str=cache_directory,
    # checkpoint=None, # Use default from cache
    accelerator=accelerator_type,
    recycling_steps=1, # Further reduce for speed
    sampling_steps=num_sampling_steps,
    diffusion_samples=num_diffusion_samples,
    # step_scale=1.638, # Default
    num_workers=0, # Safer for notebooks/profiling
    override=override_existing,
    seed_val=random_seed,
    use_msa_server=True, # TRY USING SERVER for easier setup, requires internet
                        # Set to False if you have local MMseqs2 and want to use it,
                        # or if your input YAMLs provide pre-computed MSAs.
    msa_server_url="https://api.colabfold.com", # Default
    msa_pairing_strategy="greedy", # Default
    profile_output_dir=profiler_trace_output_dir
)
click.echo("--- Prediction and Profiling Script Finished ---")

--- Starting Prediction and Profiling ---
Running on GPU: cuda:0. Profiling CPU and CUDA activity.


[rank: 0] Seed set to 42


Checking input data.
Running predictions for 1 structure(s)
Processing input data.
Found a manifest file at output directory: outputs_prediction_profiling/boltz_results_example
All examples in data are processed. Updating the manifest.
Model loaded and moved to cuda:0.
Single batch obtained and moved to device.
Starting profiling for a single predict_step. Trace will be saved to ./profiler_trace_boltz
Profiling finished.
Profiler trace saved in './profiler_trace_boltz'.
You can view it using TensorBoard: tensorboard --logdir=./profiler_trace_boltz
--- Prediction and Profiling Script Finished ---


In [30]:
import torch
import torch.cuda.memory # For CUDA memory utilities
from pathlib import Path
import click # Or use regular function arguments if not using Click for this part
from pytorch_lightning import seed_everything
# ... other Boltz imports ...
from boltz.model.model import Boltz1
from boltz.data.module.inference import BoltzInferenceDataModule
from boltz.data.types import Manifest # Ensure Manifest is imported for BoltzProcessedInput


In [31]:


# New function for profiling a training step using CUDA memory history
def profile_training_step_cuda_memory(
    data_path_str: str,
    out_dir_str: str,
    cache_str: str = "~/.boltz",
    checkpoint: Optional[str] = None,
    accelerator: str = "gpu", # Must be "gpu" for this function
    num_workers: int = 0,
    override: bool = False,
    seed_val: Optional[int] = None,
    use_msa_server: bool = False,
    msa_server_url: str = "https://api.colabfold.com",
    msa_pairing_strategy: str = "greedy",
    profile_output_dir: str = "./cuda_memory_profile_output", # Directory for the snapshot pickle
    # Boltz1 model specific params (can be defaults or passed if needed for model loading)
    recycling_steps: int = 1, # Reduced for speed, adjust if necessary
    sampling_steps: int = 50,
    diffusion_samples: int = 1,
    step_scale: float = 1.638,
) -> None:
    """
    Simulates a training step (forward + backward) and profiles CUDA memory
    using torch.cuda.memory._record_memory_history and _dump_snapshot.
    Reports peak CUDA memory usage.
    """
    click.echo("--- Starting Training Step CUDA Memory Profiling ---")

    # This profiler is CUDA-specific
    if accelerator != "gpu":
        click.secho("Error: CUDA memory profiling requires accelerator='gpu'.", fg="red")
        return
    if not torch.cuda.is_available():
        click.secho("Error: CUDA is not available. Cannot perform CUDA memory profiling.", fg="red")
        return
    
    torch_device = torch.device("cuda:0") # Assuming first GPU, or allow selection
    click.echo(f"Using device: {torch_device}")

    # Enable gradients for training
    torch.set_grad_enabled(True)

    if seed_val is not None:
        seed_everything(int(seed_val))
        click.echo(f"Random seed set to: {seed_val}")

    cache = Path(cache_str).expanduser()
    cache.mkdir(parents=True, exist_ok=True)

    data_path = Path(data_path_str).expanduser()
    # Use a more specific output directory name
    name_stem = data_path.stem if data_path.is_file() else data_path.name
    # Main output directory for general results (if any beyond profile)
    out_dir = Path(out_dir_str).expanduser() / f"boltz_training_profile_results_{name_stem}"
    out_dir.mkdir(parents=True, exist_ok=True)
    
    # Specific directory for the memory snapshot
    profile_out_path = Path(profile_output_dir).resolve()
    profile_out_path.mkdir(parents=True, exist_ok=True)
    click.echo(f"Memory snapshot will be saved to: {profile_out_path}")

    download(cache)
    input_files = check_inputs(data_path, out_dir, override)
    if not input_files:
        click.echo("No input files to process. Exiting.")
        return

    click.echo(f"Processing inputs for {len(input_files)} structure(s)...")
    ccd_path = cache / "ccd.pkl"
    process_inputs(
        data=input_files,
        out_dir=out_dir,
        ccd_path=ccd_path,
        use_msa_server=use_msa_server,
        msa_server_url=msa_server_url,
        msa_pairing_strategy=msa_pairing_strategy,
    )

    processed_dir = out_dir / "processed"
    manifest_file = processed_dir / "manifest.json"
    if not manifest_file.exists():
        click.secho(f"Manifest file not found at {manifest_file}. Cannot proceed.", fg="red")
        return
        
    processed_input_data = BoltzProcessedInput(
        manifest=Manifest.load(manifest_file),
        targets_dir=processed_dir / "structures",
        msa_dir=processed_dir / "msa",
    )
    
    if not processed_input_data.manifest.records:
        click.secho("No records found in the manifest. Cannot create dataloader.", fg="red")
        return

    data_module = BoltzInferenceDataModule( # Using inference datamodule to get a batch
        manifest=processed_input_data.manifest,
        target_dir=processed_input_data.targets_dir,
        msa_dir=processed_input_data.msa_dir,
        num_workers=num_workers,
    )
    data_module.setup(stage='predict') 

    if checkpoint is None:
        checkpoint_path = cache / "boltz1_conf.ckpt"
    else:
        checkpoint_path = Path(checkpoint)
    
    click.echo(f"Loading model from checkpoint: {checkpoint_path}")
    # Predict_args for Boltz1 loading - may not be strictly necessary for fwd/bwd
    # but kept for consistency with how model might expect to be loaded.
    # These relate to *what* predict_step computes/returns.
    predict_args_for_model_load = {
        "recycling_steps": recycling_steps,
        "sampling_steps": sampling_steps,
        "diffusion_samples": diffusion_samples,
        "write_confidence_summary": True, # To ensure predict_step returns enough data
        "write_full_pae": True,
        "write_full_pde": True,
    }
    diffusion_params = BoltzDiffusionParams()
    diffusion_params.step_scale = step_scale

    model_module: Boltz1 = Boltz1.load_from_checkpoint(
        checkpoint_path,
        strict=True, # Be strict about loading
        predict_args=predict_args_for_model_load,
        map_location="cpu", # Load to CPU first
        diffusion_process_args=asdict(diffusion_params),
        ema=False, # Assuming no EMA for typical training profiling
         steering_args={"fk_steering": False,
                       "guidance_update": False,
                       "num_particles": 3,
                       "fk_lambda": 4.0},
    )
    model_module.to(torch_device)
    model_module.train() # Set model to training mode
    click.echo("Model loaded, moved to device, and set to train mode.")

    # Get a single batch
    try:
        dataloader = data_module.predict_dataloader()
        if not dataloader: # Should not happen if setup was okay
             click.secho("Predict dataloader is empty. Cannot get a batch.", fg="red")
             return
        batch = next(iter(dataloader))
    except Exception as e:
        click.secho(f"Error getting a batch: {e}", fg="red")
        return

    # Move batch to the target device (assuming batch is a dict or list/tuple of tensors)
    if isinstance(batch, dict):
        batch = {k: v.to(torch_device) if hasattr(v, 'to') else v for k, v in batch.items()}
    elif isinstance(batch, (list, tuple)):
        batch = [b.to(torch_device) if hasattr(b, 'to') else b for b in batch]
    elif hasattr(batch, 'to'): # Single tensor batch
        batch = batch.to(torch_device)
    click.echo("Single batch obtained and moved to device.")

    # Dummy optimizer (required for `loss.backward()` if parameters are involved)
    # Filter for parameters that require gradients
    trainable_params = filter(lambda p: p.requires_grad, model_module.parameters())
    optimizer = torch.optim.Adam(trainable_params, lr=1e-4) # lr doesn't matter for this
    optimizer.zero_grad() # Zero out gradients

    click.echo("Starting CUDA memory recording for simulated training step...")
    torch.cuda.synchronize(torch_device) # Wait for all kernels to finish before starting
    torch.cuda.reset_peak_memory_stats(torch_device) # Reset peak memory counter
    
    # Start recording memory history
    # This is a global recorder, affecting the specified device or default CUDA device.
    torch.cuda.memory._record_memory_history(enabled=True, device=torch_device)

    # --- Forward Pass ---
    # Use predict_step as it's a known interface for Boltz1 to get outputs.
    # These outputs will be used to form a dummy loss.
    # predict_step should be run with torch.enable_grad() if it's not already on.
    # (We did torch.set_grad_enabled(True) globally earlier)
    predictions_dict = model_module.predict_step(batch, batch_idx=0)
    
    # Create a dummy loss from one of the output tensors
    loss = None
    if isinstance(predictions_dict, dict):
        for key, value in predictions_dict.items():
            if isinstance(value, torch.Tensor) and value.is_floating_point():
                # Ensure the tensor requires grad if it's an output of ops involving weights
                # If value itself doesn't have grad_fn, its sum might not create a graph to backprop.
                # A simple mean() should be fine if it's a result of model computations.
                try:
                    loss = value.abs().mean() # Take mean of absolute values to ensure it's scalar and positive
                    click.echo(f"Using output tensor from key '{key}' for dummy loss (shape: {value.shape}).")
                    break
                except Exception as e:
                    click.echo(f"Could not use tensor from key {key} for loss: {e}")
        if loss is None:
            click.secho("Error: Could not find a suitable tensor in predict_step output to form a dummy loss.", fg="red")
            torch.cuda.memory._record_memory_history(enabled=False, device=torch_device)
            return
    elif isinstance(predictions_dict, torch.Tensor) and predictions_dict.is_floating_point():
        loss = predictions_dict.abs().mean()
        click.echo(f"Using the direct tensor output from predict_step for dummy loss (shape: {predictions_dict.shape}).")
    else:
        click.secho("Error: predict_step output is not a dict or tensor suitable for dummy loss.", fg="red")
        torch.cuda.memory._record_memory_history(enabled=False, device=torch_device)
        return

    click.echo(f"Dummy loss computed: {loss.item()}")

    # --- Backward Pass ---
    loss.backward()
    click.echo("Backward pass completed.")

    # (Optional: optimizer.step() would go here if you wanted to profile it too)
    # optimizer.step() 
    # click.echo("Optimizer step completed.")

    torch.cuda.synchronize(torch_device) # Ensure all ops are done before dumping/stopping

    # --- Dump Snapshot and Stop Recording ---
    snapshot_file_name = "training_memory_snapshot.pickle"
    snapshot_path = profile_out_path / snapshot_file_name
    try:
        torch.cuda.memory._dump_snapshot(str(snapshot_path))
        click.echo(f"CUDA memory snapshot saved to: {snapshot_path}")
    except Exception as e:
        click.secho(f"Error dumping snapshot: {e}", fg="red")
    finally:
        # Stop recording memory history
        torch.cuda.memory._record_memory_history(enabled=False, device=torch_device)
        click.echo("CUDA memory recording stopped.")

    # --- Report Peak Memory ---
    peak_memory_bytes = torch.cuda.max_memory_allocated(torch_device)
    click.echo(f"Peak CUDA memory allocated during the step: {peak_memory_bytes / (1024**2):.2f} MB")
    
    # Clean up (optional, helps if running multiple times in a session)
    del batch, loss, predictions_dict, model_module, data_module, optimizer
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    click.echo("--- Training Step CUDA Memory Profiling Finished ---")


In [33]:
# Cell for Main execution block (example)
# Create dummy input file for the example
example_input_dir = Path("./inputs_boltz_profiling")
example_input_dir.mkdir(parents=True, exist_ok=True)
example_fasta_file = example_input_dir / "example_protein.fasta"

if not example_fasta_file.exists():
    with open(example_fasta_file, "w") as f:
        # Using a header format that should trigger MSA generation if parse_fasta sets msa_id=None -> ChainInfo.msa_id=0
        f.write(">CHAIN_X|protein\n") # MSA_ID field omitted to trigger generation via default
        f.write("MGKVKVGVNGFGRIGRLVTRAAFNSGKVDIVAINDPFIDLNYMVYMFQYDSTHGKFHGTVK\n")
        f.write("AENGKLVINGMPTGIILLTEPVEDRAMAKAKAEMTGKEIKAAQNIIPSSTGAAKAVGKVLP\n")
        f.write("ELGKLTGMAFRVPTANVSVVDLTCRLEKPAKYDDIKKVVKQASEGPLKGILGYTEHQVVSS\n")
        f.write("DFNSDTHSSTFDAGAGIALNDHFVKLISWYDNEFGYSNRVVDLMAHMASKEALGGENGLYL\n")
        f.write("IHGSNVTANYLPADDRVRYTLYTIAALLGLSLFKGAKVGILNVSADCGLTDAFHQLDSLLG\n")
        f.write("GRRALKNIVIPTSTGAAKAHEIVLKAGQHAA\n") # Example: GAPDH_HUMAN sequence (335 AA)
    print(f"Created dummy input file: {example_fasta_file}")

# --- Parameters for the training step profiling ---
input_data = str(example_fasta_file)
output_base_dir = "./outputs_boltz_training_profile" # General output for processed data etc.
profiler_snapshot_dir = "./cuda_memory_snapshots" # Specific for memory .pickle file
model_cache_dir = "./boltz_model_cache_training_profile/"

# Ensure const module has placeholder values if not fully set up
# This is for the `process_inputs` part if `boltz.data.const` is not fully available
if not hasattr(const, 'chain_type_ids'): const.chain_type_ids = {"PROTEIN": 0, "RNA": 1} # Simplified
if not hasattr(const, 'max_paired_seqs'): const.max_paired_seqs = 128
if not hasattr(const, 'max_msa_seqs'): const.max_msa_seqs = 256


profile_training_step_cuda_memory(
    data_path_str=input_data,
    out_dir_str=output_base_dir,
    cache_str=model_cache_dir,
    # checkpoint=None, # Uses default from cache
    accelerator="gpu", # MUST BE GPU
    num_workers=0,
    override=True, # Re-process inputs each time for consistency in profiling
    seed_val=42,
    use_msa_server=True, # For dummy example, set to true if MMseqs2 not local
                            # Ensure your FASTA header and parse_fasta setup allows this
    profile_output_dir=profiler_snapshot_dir,
    # Boltz1 specific params, ensure they are sensible for your model
    recycling_steps=1, 
    sampling_steps=10, # Reduced for faster test, may not be relevant for fwd/bwd of training
    diffusion_samples=1,
)

--- Starting Training Step CUDA Memory Profiling ---
Using device: cuda:0


[rank: 0] Seed set to 42


Random seed set to: 42
Memory snapshot will be saved to: /insomnia001/depts/edu/COMSE6998/yy3448/hpml_of3/cuda_memory_snapshots
Checking input data.
Processing inputs for 1 structure(s)...
Processing input data.


  0%|          | 0/1 [00:00<?, ?it/s]

Generating MSA for inputs_boltz_profiling/example_protein.fasta with 1 protein entities.


Sleeping for 10s. Reason: PENDING
Sleeping for 5s. Reason: RUNNING
Sleeping for 5s. Reason: RUNNING
Sleeping for 10s. Reason: RUNNING
Sleeping for 7s. Reason: RUNNING
Sleeping for 6s. Reason: RUNNING
Sleeping for 6s. Reason: RUNNING
Sleeping for 6s. Reason: RUNNING
Sleeping for 10s. Reason: RUNNING
Sleeping for 5s. Reason: RUNNING
Sleeping for 10s. Reason: RUNNING
Sleeping for 10s. Reason: RUNNING
COMPLETE: 100%|██████████| 150/150 [elapsed: 01:36 remaining: 00:00]
100%|██████████| 1/1 [01:36<00:00, 96.80s/it]


Loading model from checkpoint: boltz_model_cache_training_profile/boltz1_conf.ckpt
Model loaded, moved to device, and set to train mode.
Single batch obtained and moved to device.
Starting CUDA memory recording for simulated training step...




Using output tensor from key 'masks' for dummy loss (shape: torch.Size([1, 2528])).
Dummy loss computed: 0.9988133311271667


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn