# Predict Inference (Multi-GPU)
* to generate one single inference request use torchrun on cmd line directly
* the input file can also specify multiple sample requests in a jsonl file (new line delimited json/json per line)

    torchrun --nproc_per_node=8 examples/inference.py \
    -i inference_paramerters.json --checkpoint-path checkpoint.pt \
    --experiment predict2_lora_training_2b_cosmos_nemo_assets

* to repeatedly call inference on a torchrun model the model needs to wait/block on incoming request and outputs
* we use a helper class to  ModelServer/ModelWorker that sits on top of the model that is handling the basic synchronization

In [None]:
import os
from pathlib import Path

output_dir = Path("outputs")

if not (Path.cwd() / "cosmos_predict2").is_dir():
    os.chdir(Path.cwd().parent.parent)  # Change working directory to root
    assert (Path.cwd() / "cosmos_predict2").is_dir(), "Working directory change failed."

os.environ["PYTHONPATH"] = str(os.getcwd())

to use pre-trained checkpoints just specify the model

In [None]:
from cosmos_predict2.config import SetupArguments

setup_args = SetupArguments(
    context_parallel_size=8,
    output_dir=output_dir,
    model="2B/pre-trained",
    keep_going=True,
    experiment="predict2_lora_training_2b_cosmos_nemo_assets",
)

to use a local checkpoint from post-training specify the checkpoint

In [None]:
from cosmos_predict2.config import SetupArguments

checkpoint_path = "checkpoints/nvidia/Cosmos-Predict2.5-2B/consolidated/model.pt"

setup_args = SetupArguments(
    checkpoint_path=checkpoint_path,
    context_parallel_size=8,
    output_dir=output_dir,
    model="2B/pre-trained",
    keep_going=True,
    experiment="predict2_lora_training_2b_cosmos_nemo_assets",
)

Create simple server/worker:
* server will create worker processes with torchrun
* woker will wait for input request from server

In [None]:
from cosmos_gradio.model_ipc.model_server import ModelServer

from cosmos_predict2.gradio.video2world_worker import save_setup_args

# save argumentes will be picked up by torchrun worker processes
save_setup_args(setup_args)

server = ModelServer(
    num_gpus=setup_args.context_parallel_size,
    factory_module="cosmos_predict2.gradio.video2world_worker",
    factory_function="create_worker",
)

# inference

In [None]:
from cosmos_predict2.config import InferenceArguments

asset_dir = "datasets"
prompt = "A nighttime city bus terminal gradually shifts from stillness to subtle movement. At first, multiple double-decker buses are parked under the glow of overhead lights, with a central bus labeled '87D' facing forward and stationary. As the video progresses, the bus in the middle moves ahead slowly, its headlights brightening the surrounding area and casting reflections onto adjacent vehicles. The motion creates space in the lineup, signaling activity within the otherwise quiet station. It then comes to a smooth stop, resuming its position in line. Overhead signage in Chinese characters remains illuminated, enhancing the vibrant, urban night scene."

args = {
    "inference_type": "image2world",
    "name": "bus_terminal",
    "input_path": os.path.join(asset_dir, "base/bus_terminal.jpg"),
    "prompt": prompt,
}


validated_args = InferenceArguments(**args)
status = server.infer(validated_args.model_dump(mode="json"))

In [None]:
from IPython.display import Image, Video, display

result = status.get("result", status)
output_videos = result.get("videos", None)
output_images = result.get("images", None)

if output_images:
    for image in output_images:
        display(Image(image))
if output_videos:
    for video in output_videos:
        display(Video(video, embed=True))

# alternatively use a json file with arguments

In [None]:
output_videos = server.infer({"args_json_file": "datasets/base/bus_terminal.json"})