# Python API examples

## Setup ConfRover cache directory

In [None]:
from pathlib import Path
cache_dir = Path("./confrover_cache").resolve()
print("ConfRover cache dir:", cache_dir)

## Forward simulation

In [None]:
from confrover.model import ConfRover
model = ConfRover.from_pretrained(
    "confrover-base-20m-v1.0", 
    ckpt_dir=cache_dir/'confrover_ckpts'
).to("cuda")

model.generate(
    case_id="6j56_A",
    seqres="ARQREIEMNRQQRFFRIPFIRPADQYKDPQSKKKGWWYAHFDGPWIARQMELHPDKPPILLVAGKDDMEMCELNLEETGLTRKRGAEILPRQFEEIWERCGGIQYLQNAIESRQARPTYATAMLQSLLK",
    output_dir="./output/example_fwd/",
    task_mode="forward",
    n_replicates=1,
    n_frames=10, # total number of frames (including the starting frame)
    stride_in_10ps=120, # time interval between frames in the unit of 10 ps.
    conditions="6j56_A_start.pdb",
    cache_dir=cache_dir
)


## Independent ensemble sampling

In [None]:
from confrover.model import ConfRover
model = ConfRover.from_pretrained(
    "confrover-base-20m-v1.0", 
    ckpt_dir=cache_dir/'confrover_ckpts'
).to("cuda")

model.generate(
    case_id="6j56_A",
    seqres="ARQREIEMNRQQRFFRIPFIRPADQYKDPQSKKKGWWYAHFDGPWIARQMELHPDKPPILLVAGKDDMEMCELNLEETGLTRKRGAEILPRQFEEIWERCGGIQYLQNAIESRQARPTYATAMLQSLLK",
    output_dir="./output/example_iid/",
    task_mode="iid",
    n_replicates=10,
    cache_dir=cache_dir
)

## State interpolation

In [None]:
from confrover.model import ConfRover
model = ConfRover.from_pretrained(
    "confrover-interp-20m-v1.0", 
    ckpt_dir=cache_dir/'confrover_ckpts'
).to("cuda")

model.generate(
    case_id="6j56_A",
    seqres="ARQREIEMNRQQRFFRIPFIRPADQYKDPQSKKKGWWYAHFDGPWIARQMELHPDKPPILLVAGKDDMEMCELNLEETGLTRKRGAEILPRQFEEIWERCGGIQYLQNAIESRQARPTYATAMLQSLLK",
    output_dir="./output/example_interp",
    task_mode="interp",
    n_replicates=3,
    n_frames=9, # total number of frames (including the start/end frame)
    stride_in_10ps=256, # time interval between frames in the unit of 10 ps.
    conditions=[
        "6j56_A_start.pdb",
        "6j56_A_end.pdb"
    ],
    cache_dir=cache_dir
)
