# Alpamayo Live Inference + MJPEG Stream

This notebook reuses the same dataset loading, preprocessing and inference calls from `notebooks/inference.ipynb` and adds a simple real-time simulation loop plus a Flask MJPEG server on localhost:8000 that always serves the most recent rendered frame.

Notes:
- `model loading` happens in cell 3
- `dataset loading` happens in cell 5
- `inference` happens in cell 7
- `rendering` happens in cell 7
- `streaming` (Flask MJPEG background server) starts in cell 9

Run cells sequentially. The loop in cell 8 processes frames repeatedly to simulate a live feed; you can stop it by interrupting the kernel.

In [None]:
# Cell 2: Imports and small utilities
import time
import threading
import io
import itertools
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt
import torch
import cv2
from flask import Flask, Response, make_response

# Repo utilities (reuse existing preprocessing/inference utilities)
from alpamayo_r1.models.alpamayo_r1 import AlpamayoR1
from alpamayo_r1.load_physical_aiavdataset import load_physical_aiavdataset
from alpamayo_r1 import helper

# Shared variable to hold the latest JPEG frame bytes served by MJPEG stream
latest_frame = { 'jpg': None, 'lock': threading.Lock() }

## Model loading (observe model load time)

This cell loads the pretrained `AlpamayoR1` model and the `processor` used for preprocessing. Loading may be slow depending on model size and whether GPU is available.

In [None]:
# Cell 4: Load model (this may take time)
start = time.time()
# NOTE: uses the same call as the reference notebook
device = 'cuda' if torch.cuda.is_available() else 'cpu'
try:
    model = AlpamayoR1.from_pretrained('nvidia/Alpamayo-R1-10B', dtype=torch.bfloat16 if device == 'cuda' else torch.float32).to(device)
    print('Loaded model on', device)
except Exception as e:
    # If large model cannot be loaded in test environment, raise descriptive error
    raise RuntimeError('Model load failed; ensure weights are available and you have the right runtime. Error: ' + str(e))
processor = helper.get_processor(model.tokenizer)
print(f'Model load time: {time.time() - start:.1f}s')

## Dataset loading and buffer initialization

This cell loads the exact same clip and data used in the reference notebook. We keep a rolling temporal buffer of the camera frames (same `num_frames` as the loader, typically 4). The buffer will be advanced each step to simulate realtime playback.

In [None]:
# Cell 6: Load dataset (same footage as original notebook)
import pandas as pd
# Attempt to find clip ids; adjust path if needed
try:
    clip_ids = pd.read_parquet('clip_ids.parquet')["clip_id"].tolist()
    clip_id = clip_ids[774]
except Exception as e:
    # Fallback: use a default clip id or raise helpful message
    raise RuntimeError('Could not read clip_ids.parquet. Ensure the dataset is present. Error: ' + str(e))
print('Using clip_id:', clip_id)
# load data using the repo helper
data = load_physical_aiavdataset(clip_id)
# data keys: image_frames (N_cameras, num_frames, 3, H, W), ego_history_xyz, ego_future_xyz, etc.
print({k: (v.shape if hasattr(v, 'shape') else type(v)) for k, v in data.items() if k in ['image_frames','ego_history_xyz','ego_future_xyz']})

# Determine camera we want to display as the main camera (front wide is camera index 1 by loader mapping)
cam_indices = data['camera_indices']
try:
    main_cam_pos = (cam_indices == 1).nonzero(as_tuple=True)[0].item()
except Exception:
    # Fallback to first camera available
    main_cam_pos = 0
print('Main camera position in returned tensor:', main_cam_pos)

# Image buffer: copy the tensor to CPU numpy for easy drawing/cycling and convert to HWC uint8 per frame
image_frames = data['image_frames']  # (N_cameras, num_frames, 3, H, W)
N_cams, num_frames, C, H, W = image_frames.shape
print(f'Loaded images: N_cams={N_cams}, num_frames={num_frames}, H={H}, W={W}')
# Convert to numpy HWC uint8 for each camera/time: shape (N_cams, num_frames, H, W, 3)
image_frames_np = image_frames.cpu().numpy().transpose(0,1,3,4,2).astype('uint8')
# Create a simple circular iterator over time indices to simulate playback (we'll advance buffer each step)
time_idx_iter = itertools.cycle(range(num_frames))

# Rolling buffer state: for each camera keep the loaded frames (we'll rotate them to simulate streaming)
buffer = image_frames_np.copy()  # shape (N_cams, num_frames, H, W, 3)

# Keep a device-resident copy of the ego history used for inference
ego_history_xyz = data['ego_history_xyz']  # (1,1,T,3)
ego_history_rot = data['ego_history_rot']  # (1,1,T,3,3)

print('Buffer initialized. Ready to run inference loop.')

## Inference + rendering loop

This cell runs the realtime simulation loop. Each iteration it:
- advances the rolling buffer of image frames (cycles through available frames),
- builds `messages` using `helper.create_message` (same preprocessing as reference),
- tokenizes with `processor.apply_chat_template`,
- calls `model.sample_trajectories_from_data_with_vlm_rollout` to get predicted trajectories and reasoning traces,
- renders a composed image (left: current camera image, right: BEV visualization of predicted trajectory),
- updates `latest_frame` bytes for the MJPEG server to serve.

You can run this cell; it will loop until you interrupt the kernel. For demo purposes the loop cycles the small set of frames returned by the loader to simulate continuous playback.

In [None]:
# Cell 8: Inference + rendering loop (run to start simulated realtime processing)
from copy import deepcopy
import base64
from IPython.display import display, Image as IPImage, clear_output

# Toggle: show live video inside the notebook output (clears and updates the cell output)
SHOW_IN_NOTEBOOK = True

def frame_to_jpeg_bytes(img_np, quality=90):
    # img_np: H,W,3 uint8 (RGB) -> JPEG bytes
    is_success, buffer = cv2.imencode('.jpg', cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR), [int(cv2.IMWRITE_JPEG_QUALITY), quality])
    if not is_success:
        raise RuntimeError('JPEG encoding failed')
    return buffer.tobytes()

def render_bev(pred_xy, gt_xy=None, size=(H, W)):
    # Create a simple BEV panel as PIL image and draw the predicted trajectory (meters -> pixels scaled)
    bev_w, bev_h = size[1], size[0]
    img = Image.new('RGB', (bev_w, bev_h), (255,255,255))
    draw = ImageDraw.Draw(img)
    if pred_xy is None:
        return img
    pts = np.array(pred_xy)  # T,2
    # compute mins/maxs and elementwise span (avoid using Python max on arrays)
    mins = pts.min(axis=0)
    maxs = pts.max(axis=0)
    span = np.maximum(maxs - mins, 1e-3)
    mins = mins - 0.1 * span
    maxs = maxs + 0.1 * span
    span = maxs - mins
    def to_px(xy):
        rel = (xy - mins) / span
        px = np.stack([rel[:,0] * (bev_w-20) + 10, (1.0 - rel[:,1]) * (bev_h-20) + 10], axis=1)
        return px.astype(int)
    px = to_px(pts)
    for i in range(len(px)-1):
        draw.line([tuple(px[i]), tuple(px[i+1])], fill=(0,0,255), width=3)
    for p in px:
        draw.ellipse((p[0]-3,p[1]-3,p[0]+3,p[1]+3), fill=(0,0,255))
    if gt_xy is not None:
        gpx = to_px(np.array(gt_xy))
        for i in range(len(gpx)-1):
            draw.line([tuple(gpx[i]), tuple(gpx[i+1])], fill=(255,0,0), width=2)
    return img

# def compose_frame(current_rgb, bev_img, text_lines=None, debug_stamp=None):
#     left = Image.fromarray(current_rgb)
#     bev_resized = bev_img.resize((left.width, left.height))
#     out = Image.new('RGB', (left.width + bev_resized.width, left.height))
#     out.paste(left, (0,0))
#     out.paste(bev_resized, (left.width,0))
#     draw = ImageDraw.Draw(out)
#     if text_lines:
#         font = ImageFont.load_default()
#         x = 6
#         y = 6
#         for line in text_lines[:6]:
#             draw.text((x,y), line, fill=(0,0,0), font=font)
#             y += 12
#     if debug_stamp:
#         draw.text((out.width-150, out.height-20), debug_stamp, fill=(0,0,0))
#     return np.array(out)


def compose_frame(current_rgb, bev_img, text_lines=None, debug_stamp=None):
    # Convert main frame to PIL
    left = Image.fromarray(current_rgb)

    # Rotate BEV image 90Â° counterclockwise
    bev_rotated = bev_img.rotate(90, expand=True)

    # Scale BEV to 1/5 of main image size
    overlay_w = left.width // 5
    overlay_h = left.height // 2
    bev_small = bev_rotated.resize((overlay_w, overlay_h))

    # Create overlay background (white) slightly bigger than BEV
    padding = 10
    bg_w = overlay_w + padding * 2
    bg_h = overlay_h + padding * 2 + 20  # extra space for title
    overlay_bg = Image.new('RGB', (bg_w, bg_h), (255, 255, 255))

    # Paste BEV onto background
    overlay_bg.paste(bev_small, (padding, padding + 20))

    
    # Add title
    draw_overlay = ImageDraw.Draw(overlay_bg)
    font = ImageFont.load_default(size = 28)
    title = "Projected path"
    
    # Use textbbox to measure text size
    bbox = draw_overlay.textbbox((0, 0), title, font=font)
    text_w = bbox[2] - bbox[0]
    text_h = bbox[3] - bbox[1]
    
    draw_overlay.text(((bg_w - text_w) // 2, 5), title, fill=(0, 0, 0), font=font)

    # Compose final image: start with main frame
    out = left.copy()

    # Paste overlay in top-left corner
    out.paste(overlay_bg, (10, 10))

    # Draw text lines on main image if provided
    draw_main = ImageDraw.Draw(out)
    if text_lines:
        x = 6
        y = out.height - 12 * min(len(text_lines), 6) - 6
        for line in text_lines[:6]:
            draw_main.text((x, y), line, fill=(0, 0, 0), font=font)
            y += 12

    # Debug stamp
    if debug_stamp:
        draw_main.text((out.width - 150, out.height - 20), debug_stamp, fill=(0, 0, 0))

    return np.array(out)


# Main loop: each iteration advances buffer (cycle) and runs inference
running = True
step = 0

print(f"running is {running}")

# If showing in-notebook, display a placeholder (this will be updated via clear_output)
if SHOW_IN_NOTEBOOK:
    display(IPImage(data=frame_to_jpeg_bytes(np.zeros((H, W*2, 3), dtype='uint8'))))

try:
    while running:
        start = time.time()
        print(f"step {step}")
        # Advance buffer: pop leftmost time index and append next frame from loaded frames to simulate new arrival
        next_t = next(time_idx_iter)
        buffer = np.roll(buffer, -1, axis=1)  # shift time dimension left
        buffer[:, -1, ...] = image_frames_np[:, next_t, ...]  # append the frame at next_t

        # Create messages from flattened frames (same call as reference notebook)
        frames_for_message = torch.from_numpy(buffer).permute(0,1,4,2,3).contiguous()  # N_cams, num_frames, C, H, W
        messages = helper.create_message(frames_for_message.flatten(0,1))

        inputs = processor.apply_chat_template(
            messages,
            tokenize=True,
            add_generation_prompt=False,
            continue_final_message=True,
            return_dict=True,
            return_tensors='pt',
        )

        model_inputs = {
            'tokenized_data': inputs,
            'ego_history_xyz': ego_history_xyz.to(device),
            'ego_history_rot': ego_history_rot.to(device),
        }
        model_inputs = helper.to_device(model_inputs, device)

        # Run model inference (may be slow); this uses the same sampling call as the reference notebook
        if device == 'cuda':
            torch.cuda.manual_seed_all(42)
        with torch.autocast(device if device=='cuda' else 'cpu', dtype=torch.bfloat16 if device=='cuda' else torch.float32):
            with torch.no_grad():
                pred_xyz, pred_rot, extra = model.sample_trajectories_from_data_with_vlm_rollout(
                    data=deepcopy(model_inputs),
                    top_p=0.98,
                    temperature=0.6,
                    num_traj_samples=1,
                    max_generation_length=256,
                    return_extra=True,
                )

        pred_xy = pred_xyz.cpu().numpy()[0,0,0,:, :2]  # (T,2)

        reasoning_text = None
        if isinstance(extra, dict) and 'cot' in extra and extra['cot'] is not None:
            try:
                reasoning_text = extra['cot'][0]
            except Exception:
                reasoning_text = None

        current_rgb = buffer[main_cam_pos, -1]  # H,W,3 uint8 (RGB)

        bev_img = render_bev(pred_xy, gt_xy=None, size=(H, W))
        text_lines = None
        if reasoning_text is not None:
            if isinstance(reasoning_text, (list, tuple)):
                text_lines = [str(x) for x in reasoning_text]
            else:
                text_lines = str(reasoning_text).split('\n')

        composed = compose_frame(current_rgb, bev_img, text_lines=text_lines, debug_stamp=f'step:{step}')

        jpg = frame_to_jpeg_bytes(composed)
        with latest_frame['lock']:
            latest_frame['jpg'] = jpg

        # Update the in-notebook display (clears previous output and shows the latest frame)
        if SHOW_IN_NOTEBOOK:
            try:
                clear_output(wait=True)
                display(IPImage(data=jpg))
            except Exception:
                # If display fails for any reason, continue without crashing the loop
                pass
                
        print(f'Step procesing time: {time.time() - start:.1f}s')
        step += 1
        # time.sleep(0.2)

except KeyboardInterrupt:
    print('Interrupted by user, stopping loop')
except Exception as e:
    print('Loop error:', e)
finally:
    running = False
    print('Inference loop finished')

## Start Flask MJPEG server in background thread

This cell starts a minimal Flask app that serves `/video` as an MJPEG stream. The app reads from `latest_frame['jpg']` updated by the inference loop. The Flask server runs in a background thread so the notebook remains responsive.

In [None]:
# Cell 10: Start Flask server (run this cell after you've started the inference loop)
app = Flask('alpamayo_stream')

def mjpeg_generator():
    boundary = '--frame'
    while True:
        with latest_frame['lock']:
            jpg = latest_frame.get('jpg')
        if jpg is None:
            time.sleep(0.05)
            continue
        yield (b'%s
Content-Type: image/jpeg
Content-Length: %d

' % (boundary.encode(), len(jpg))) + jpg + b'
'

@app.route('/video')
def video_route():
    return Response(mjpeg_generator(), mimetype='multipart/x-mixed-replace; boundary=--frame')

def run_flask():
    app.run(host='0.0.0.0', port=8000, threaded=True)

server_thread = threading.Thread(target=run_flask, daemon=True)
server_thread.start()
print('Flask MJPEG server started at http://localhost:8000/video')

## Notes and how to view
- Start the model load cell (4), dataset load (6), then run the inference loop (8) to begin producing frames. Then run the Flask server cell (10).
- The MJPEG endpoint is: http://localhost:8000/video
- To SSH-tunnel port 8000 from a remote machine, forward local port 8000 (example):

```
# on your workstation (example SSH command):
ssh -L 8000:localhost:8000 user@remote-host
```

Troubleshooting: If model weights are large and you cannot load them, you can still run the rendering loop by mocking `pred_xy` (e.g., simple circle) and setting `latest_frame['jpg']` yourself. The notebook intentionally keeps calls in clear cells so you can replace or step through them.