In [1]:
scene_num = 8 # {8: 1094,  0: 0061, 9: 1100}

## Setup

In [2]:
from nuscenes.nuscenes import NuScenes
from nuscenes.map_expansion.map_api import NuScenesMap

from nuscenes.utils.geometry_utils import box_in_image, view_points
from nuscenes.utils.data_classes import Box

import matplotlib.pyplot as plt
import numpy as np
import matplotlib.patches as patches
from pyquaternion import Quaternion
from PIL import Image
import cv2
import os
import torch


nusc = NuScenes(version='v1.0-mini', dataroot='/network/scratch/a/anthony.gosselin/nuscenes', verbose=True)
nusc_map = NuScenesMap(dataroot='/network/scratch/a/anthony.gosselin/nuscenes', map_name='singapore-hollandvillage')

Loading NuScenes tables for version v1.0-mini...
23 category,
8 attribute,
4 visibility,
911 instance,
12 sensor,
120 calibrated_sensor,
31206 ego_pose,
8 log,
10 scene,
404 sample,
31206 sample_data,
18538 sample_annotation,
4 map,
Done loading in 0.399 seconds.
Reverse indexing ...
Done reverse indexing in 0.1 seconds.


In [3]:
my_scene = nusc.scene[scene_num]  # {8: 1094,  0: 0061, 9: 1100}
first_sample_token = my_scene['first_sample_token']
my_sample = nusc.get('sample', first_sample_token)
cam_front_data = nusc.get('sample_data', my_sample['data']["CAM_FRONT"])
front_camera_sensor = nusc.get('calibrated_sensor', cam_front_data['calibrated_sensor_token'])

In [4]:
def get_random_height(category: str):
    heights_by_category = {
        "human.pedestrian.adult": {"mean": 1.76, "std": 0.12},
        "human.pedestrian.child":  {"mean": 1.37, "std": 0.06},
        "human.pedestrian.construction_worker": {"mean": 1.78, "std": 0.05},
        "human.pedestrian.personal_mobility": {"mean": 1.87, "std": 0.00},
        "human.pedestrian.police_officer": {"mean": 1.81, "std": 0.00},
        "movable_object.barrier": {"mean": 1.06, "std": 0.10},
        "movable_object.debris":  {"mean": 0.46, "std": 0.00},
        "movable_object.pushable_pullable": {"mean": 1.04, "std": 0.20},
        "movable_object.trafficcone": {"mean": 0.78, "std": 0.13},
        "static_object.bicycle_rack":  {"mean": 1.40, "std": 0.00},
        "vehicle.bicycle":  {"mean": 1.39, "std": 0.34},
        "vehicle.bus.bendy":  {"mean": 3.32, "std": 0.07},
        "vehicle.bus.rigid":  {"mean": 3.80, "std": 0.62},
        "vehicle.car": {"mean": 1.69, "std": 0.21},
        "vehicle.construction":  {"mean": 2.38, "std": 0.33},
        "vehicle.motorcycle":  {"mean": 1.47, "std": 0.20},
        "vehicle.trailer":  {"mean": 3.71, "std": 0.27},
        "vehicle.truck":  {"mean": 2.62, "std": 0.68},
    }

    stats = heights_by_category[category]
    height = np.random.normal(stats['mean'], stats['std'])

    return height

## 2D bbox sequence

### No interpolation

In [50]:
bboxes_2d_seq = []
ego_pose_seq = []
ind = 0
curr_sample = my_sample
next_sample_token = curr_sample['token']

MAX_FRAMES = 8

while True:
    curr_sample = nusc.get('sample', next_sample_token)
    bboxes_2d_seq.append([])
    
    # The sequence of ego poses determines the trajectory of the camera through the reconstructed scene
    # (This means we can move the camera as we please through the reconstructed scene)
    curr_cam_front_data = nusc.get('sample_data', curr_sample['data']["CAM_FRONT"])
    ego_pose_seq.append(nusc.get('ego_pose', curr_cam_front_data['ego_pose_token']))

    for ann_token in curr_sample['anns']:
        bbox_3d = nusc.get_box(ann_token)
        yaw, pitch, roll  = bbox_3d.orientation.yaw_pitch_roll
        annotation_data = nusc.get('sample_annotation', ann_token)
        bbox_2d = {'center': [bbox_3d.center[0], bbox_3d.center[1]], 'size': [bbox_3d.wlh[0], bbox_3d.wlh[1]], 'heading': yaw, 'category': bbox_3d.name, 'instance_token': annotation_data['instance_token'], 'height': bbox_3d.center[2]}
        bboxes_2d_seq[ind].append(bbox_2d)
    
    next_sample_token = curr_sample['next']
    if next_sample_token == '' or ind >= MAX_FRAMES:
        break
    ind += 1
    
print("Sequence length:", len(bboxes_2d_seq))

Sequence length: 9


### (1) With interpolation

In [6]:
def interpolate_data(prev_sample, next_sample, frequency):
    """
    Interpolates both ego poses and bounding boxes for a given sample token at a specified frequency.
    :param nusc: NuScenes instance.
    :param sample_token: Token of the sample to interpolate.
    :param frequency: Desired interpolation frequency in Hz.
    :return: List of tuples (interpolated ego pose, interpolated bounding boxes).
    """

    # Load ego poses for the previous and next samples
    prev_cam_front_data = nusc.get('sample_data', prev_sample['data']["CAM_FRONT"])
    next_cam_front_data = nusc.get('sample_data', next_sample['data']["CAM_FRONT"])

    prev_ego_pose = nusc.get('ego_pose', prev_cam_front_data['ego_pose_token'])
    next_ego_pose = nusc.get('ego_pose', next_cam_front_data['ego_pose_token'])

    # Load annotations for the previous and next samples
    prev_anns = {ann['instance_token']: ann for ann in map(lambda ann_token: nusc.get('sample_annotation', ann_token), prev_sample['anns'])}
    next_anns = {ann['instance_token']: ann for ann in map(lambda ann_token: nusc.get('sample_annotation', ann_token), next_sample['anns'])}

    # Calculate the time difference between samples
    dt = (next_sample['timestamp'] - prev_sample['timestamp']) / 1e6  # Convert microseconds to seconds
    steps = int(dt * frequency)  # Number of interpolation steps

    # Interpolate the data
    interpolated_boxes, interpolated_ego_poses, instance_tokens = [], [], []
    for i in range(steps):
        t = i / steps
        
        # Interpolate ego pose
        translation = (1 - t) * np.array(prev_ego_pose['translation']) + t * np.array(next_ego_pose['translation'])
        rotation = Quaternion.slerp(Quaternion(prev_ego_pose['rotation']), Quaternion(next_ego_pose['rotation']), t)
        interpolated_ego_pose = {
            'translation': translation,
            'rotation': rotation
        }
        interpolated_ego_poses.append(interpolated_ego_pose)
        
        # Interpolate bounding boxes
        interpolated_boxes.append([])
        instance_tokens.append([])
        for instance_token in set(prev_anns.keys()).union(next_anns.keys()):
            if instance_token in prev_anns and instance_token in next_anns:
                # Interpolate between the two annotations
                prev_ann = prev_anns[instance_token]
                next_ann = next_anns[instance_token]
                center = (1 - t) * np.array(prev_ann['translation']) + t * np.array(next_ann['translation'])
                size = (1 - t) * np.array(prev_ann['size']) + t * np.array(next_ann['size'])
                orientation = Quaternion.slerp(Quaternion(prev_ann['rotation']), Quaternion(next_ann['rotation']), t)
                box = Box(center=center, size=size, orientation=orientation, name=prev_ann['category_name'])
                interpolated_boxes[i].append(box)
                instance_tokens[i].append(instance_token)
            elif instance_token in prev_anns:
                # The bounding box disappears in the next sample
                prev_ann = prev_anns[instance_token]
                box = Box(center=prev_ann['translation'], size=prev_ann['size'], orientation=Quaternion(prev_ann['rotation']), name=prev_ann['category_name'])
                interpolated_boxes[i].append(box)
                instance_tokens[i].append(instance_token)
            elif instance_token in next_anns:
                # The bounding box appears in the next sample
                next_ann = next_anns[instance_token]
                box = Box(center=next_ann['translation'], size=next_ann['size'], orientation=Quaternion(next_ann['rotation']), name=next_ann['category_name'])
                interpolated_boxes[i].append(box)
                instance_tokens[i].append(instance_token)

    return interpolated_boxes, interpolated_ego_poses, instance_tokens

In [18]:
bboxes_2d_seq = []
ego_pose_seq = []
ind = 0
curr_sample = my_sample
next_sample_token = curr_sample['token']
interpolation_freq = 7  # Convert to 7Hz annotations

MAX_FRAMES = 30

while True:
    curr_sample = nusc.get('sample', next_sample_token)

    next_sample_token = curr_sample['next']
    if next_sample_token == '' or ind >= MAX_FRAMES:
        break 
    next_sample = nusc.get('sample', next_sample_token)
        
    interpolated_boxes, interpolated_ego_poses, instance_tokens = interpolate_data(curr_sample, next_sample, interpolation_freq)
    for bboxes_3d_t, ego_pose, instance_tokens_t in zip(interpolated_boxes, interpolated_ego_poses, instance_tokens):

        bboxes_2d_seq.append([])
        
        # The sequence of ego poses determines the trajectory of the camera through the reconstructed scene
        # (This means we can move the camera as we please through the reconstructed scene)
        ego_pose_seq.append(ego_pose)
        
        for bbox_idx, bbox_3d in enumerate(bboxes_3d_t):
            yaw, pitch, roll  = bbox_3d.orientation.yaw_pitch_roll
            
            instance_token = instance_tokens_t[bbox_idx]
            bbox_2d = {'center': [bbox_3d.center[0], bbox_3d.center[1]], 
                       'size': [bbox_3d.wlh[0], bbox_3d.wlh[1]], 
                       'heading': yaw, 
                       'category': bbox_3d.name, 
                       'instance_token': instance_token, 
                       'height': bbox_3d.center[2]}  # For testing...
            bboxes_2d_seq[ind].append(bbox_2d)
    
        ind += 1
    
print("Sequence length:", len(bboxes_2d_seq))

Sequence length: 117


## 3D bbox sequence

### Setup

In [9]:
from collections import defaultdict

class CVCOLORS:
    RED = (0,0,255)
    GREEN = (0,255,0)
    BLUE = (255,0,0)
    PURPLE = (247,44,200)
    ORANGE = (44,162,247)
    MINT = (239,255,66)
    YELLOW = (2,255,250)
    BROWN = (42,42,165)
    LIME=(51,255,153)
    GRAY=(128, 128, 128)
    LIGHTPINK = (222,209,255)
    LIGHTGREEN = (204,255,204)
    LIGHTBLUE = (255,235,207)
    LIGHTPURPLE = (255,153,204)
    LIGHTRED = (204,204,255)
    WHITE = (255,255,255)
    BLACK = (0,0,0)
    
    TRACKID_LOOKUP = defaultdict(lambda: (np.random.randint(50, 255), np.random.randint(50, 255), np.random.randint(50, 255)))
    TYPE_LOOKUP = [BLUE, WHITE, RED, YELLOW, PURPLE, BROWN, GREEN, ORANGE, LIGHTPURPLE, LIGHTRED, GRAY]
    REVERT_CHANNEL_F = lambda x: (x[2], x[1], x[0])

In [10]:
# Based on closest match to KITTI classes
NUSC_CLASS_TO_GROUP_IDS_KITTI = {
    "human.pedestrian.adult": 4,
    "human.pedestrian.child":  4,
    "human.pedestrian.construction_worker": 5,
    "human.pedestrian.personal_mobility": 4,
    "human.pedestrian.police_officer": 5,
    "movable_object.barrier": 8,
    "movable_object.debris":  8,
    "movable_object.pushable_pullable": 8,
    "movable_object.trafficcone": 8,
    "static_object.bicycle_rack":  8,
    "vehicle.bicycle":  6,
    "vehicle.bus.bendy":  3,
    "vehicle.bus.rigid":  3,
    "vehicle.car": 1,
    "vehicle.construction":  3,
    "vehicle.motorcycle":  6,  # NOTE: Not sure if best to classify as cyclist or car...
    "vehicle.trailer":  3,
    "vehicle.truck":  3,
}

# KITTI:
# IDS_CLASS_LOOKUP = {
#     1: 'Car',
#     2: 'Van',
#     3: 'Truck',
#     4: 'Pedestrian',
#     5: 'Person',
#     6: 'Cyclist',
#     7: 'Tram',
#     8: 'Misc',
#     9: 'DontCare'
# }

# Based on closest match to BDD100k classes
NUSC_CLASS_TO_GROUP_IDS_BDD = {
    "human.pedestrian.adult": 1,
    "human.pedestrian.child":  1,
    "human.pedestrian.construction_worker": 1,
    "human.pedestrian.personal_mobility": 1,
    "human.pedestrian.police_officer": 1,
    "movable_object.barrier": 10,
    "movable_object.debris":  10,
    "movable_object.pushable_pullable": 10,
    "movable_object.trafficcone": 10,
    "static_object.bicycle_rack":  10,
    "vehicle.bicycle":  8,
    "vehicle.bus.bendy":  5,
    "vehicle.bus.rigid":  5,
    "vehicle.car": 3,
    "vehicle.construction":  4,
    "vehicle.motorcycle":  7, 
    "vehicle.trailer":  4,
    "vehicle.truck":  4,
}

# BDD100k:
# IDS_CLASS_LOOKUP = {
#         1: 'pedestrian',
#         2: 'rider',
#         3: 'car',
#         4: 'truck',
#         5: 'bus',
#         6: 'train',
#         7: 'motorcycle',
#         8: 'bicycle',
#         9: 'traffic light',
#         10: 'traffic sign',
#     }

NUSC_CLASS_TO_GROUP_IDS = NUSC_CLASS_TO_GROUP_IDS_BDD


In [11]:
def render_box_3d_style_CV(box, img, canvas3d, canvas2d, view: np.ndarray = np.eye(3), normalize: bool = False, outline_color=(255, 0, 0), fill_color=(0, 255, 0), linewidth: float = 2, show_3d_bboxes=True, show_2d_bboxes=False) -> None:
    """
    Renders the box in the provided canvas
    """
    corners = view_points(box.corners(), view, normalize=normalize)[:2, :]
    corners = np.round(corners).astype(int)

    if show_3d_bboxes:
        def draw_rect(selected_corners, color):
            prev = selected_corners[-1]
            for corner in selected_corners:
                cv2.line(canvas3d, (prev[0], prev[1]), (corner[0], corner[1]), color=color, thickness=linewidth)
                prev = corner

        # Draw front (first 4 corners) and rear (last 4 corners) rectangles(3d)/lines(2d)
        draw_rect(corners.T[:4], outline_color)
        draw_rect(corners.T[4:], outline_color)

        # Draw the sides
        for i in range(4):
            cv2.line(canvas3d, (corners.T[i][0], corners.T[i][1]),
                    (corners.T[i + 4][0], corners.T[i + 4][1]),
                    color=outline_color, thickness=linewidth)
        
        # Draw x mark at the back of the object
        cv2.line(canvas3d, [corners.T[4][0], corners.T[4][1]],
                [corners.T[6][0], corners.T[6][1]],
                color=outline_color, thickness=1)
        cv2.line(canvas3d, [corners.T[5][0], corners.T[5][1]],
                [corners.T[7][0], corners.T[7][1]],
                color=outline_color, thickness=1)

    if show_2d_bboxes:
        # Calculate the bottom left corner of the rectangle
        bottom_left_x, top_right_x = np.min(corners.T[:, 0]), np.max(corners.T[:, 0])
        bottom_left_y, top_right_y = np.min(corners.T[:, 1]), np.max(corners.T[:, 1])

        # Create the rectangle
        cv2.rectangle(canvas2d, (bottom_left_x, bottom_left_y), (top_right_x, top_right_y), color=fill_color, thickness=cv2.FILLED)
        if not show_3d_bboxes:
            cv2.rectangle(canvas2d, (bottom_left_x, bottom_left_y), (top_right_x, top_right_y), color=outline_color, thickness=2)

    alpha_2dbbox = 0.75
    mask = canvas2d.astype(bool)
    img[mask] = cv2.addWeighted(canvas2d, alpha_2dbbox, img, 1-alpha_2dbbox, 0)[mask]
    mask = canvas3d.astype(bool)
    img[mask] = canvas3d[mask]
    
    return img
            
        
def my_render_3d_style_CV(nusc, boxes_3d, camera_sensor, ego_pose, data_path=None, transform=True, background=False, show_3d_bboxes=True, show_2d_bboxes=False) -> None:
    """
    Bboxes are to be in global coordinate frame, and will be projected to the specified camera
    """
    
    im_size = (1600, 900)
    img = torch.zeros((3, im_size[1], im_size[0]))
    img = img.permute((1, 2, 0)).detach().cpu().numpy().copy()*255
    img = img.astype(np.uint8)
    canvas3d = np.zeros_like(img)
    canvas2d = np.zeros_like(img)
    
    # Camera extrinsic and intrinsic parameters
    camera_intrinsic = np.array(camera_sensor['camera_intrinsic'])

    for ind, box_3d in enumerate(boxes_3d):
        
        if transform:
            # Move box to ego vehicle coord system.
            box_3d.translate(-np.array(ego_pose['translation']))
            box_3d.rotate(Quaternion(ego_pose['rotation']).inverse)
            #  Move box to sensor coord system.
            box_3d.translate(-np.array(camera_sensor['translation']))
            box_3d.rotate(Quaternion(camera_sensor['rotation']).inverse)

        # Only render bboxes that fit in image frame
        if not box_in_image(box_3d, camera_intrinsic, im_size, vis_level=1):
            continue
            
        # outline_color = np.array(CVCOLORS.REVERT_CHANNEL_F(CVCOLORS.TYPE_LOOKUP[NUSC_CLASS_TO_GROUP_IDS[box_3d.name]])) / 255.0
        outline_color = CVCOLORS.REVERT_CHANNEL_F(CVCOLORS.TYPE_LOOKUP[NUSC_CLASS_TO_GROUP_IDS[box_3d.name]])
        instance_token = box_3d.token
        if not transform:
            annotation_data = nusc.get('sample_annotation', box_3d.token)
            instance_token = annotation_data['instance_token']
        # fill_color = np.array(CVCOLORS.REVERT_CHANNEL_F(CVCOLORS.TRACKID_LOOKUP[instance_token])) / 255.0
        fill_color = CVCOLORS.REVERT_CHANNEL_F(CVCOLORS.TRACKID_LOOKUP[instance_token])
        
        img = render_box_3d_style_CV(box_3d, img, canvas3d, canvas2d, view=camera_intrinsic, normalize=True, outline_color=outline_color, fill_color=fill_color, show_3d_bboxes=show_3d_bboxes, show_2d_bboxes=show_2d_bboxes)
        # if "personal_mobility" in box_3d.name:
        #     print(box_3d.name, instance_token)
    
    return img

### (2) Create video

In [19]:
from torchvision import transforms
import torch

ego_height = get_random_height('vehicle.car')
front_cam_token = my_sample['data']['CAM_FRONT']
agent_heights = {}

# Parameters for the video
video_filename = f"video_out/{my_scene['name']}_style.avi"
FPS = 7 # NOTE Real time is 2 fps
frame_size = (512, 320)

transform=transforms.Compose([
                    transforms.Resize((frame_size[1], frame_size[0])),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), # map from [0,1] to [-1,1]
                 ])

# Create a VideoWriter object
fourcc = cv2.VideoWriter_fourcc(*'XVID')
video_writer_out = cv2.VideoWriter(video_filename, fourcc, FPS, frame_size)

img_seq = torch.zeros([len(bboxes_2d_seq), 3, frame_size[1], frame_size[0]])

for t, bboxes_2d_t in enumerate(bboxes_2d_seq):
    bboxes_3d_out = []
    for ind, bbox_2d in enumerate(bboxes_2d_t):
        center = [bbox_2d['center'][0], bbox_2d['center'][1], ego_height/2] # Adjust height because sensor is mounted on ego (And we don't have z-height information)
        
        agent_height = agent_heights.get(bbox_2d['instance_token'])  # OR Gt height information: bbox_2d['height'] 
        if agent_height is None:
            agent_heights[bbox_2d['instance_token']] = get_random_height(bbox_2d['category'])
            agent_height = agent_heights.get(bbox_2d['instance_token'])
                       
        size = [bbox_2d['size'][0], bbox_2d['size'][1], agent_height]
        orientation = Quaternion._from_axis_angle(np.array([0, 0, 1]), bbox_2d['heading'])
        bbox_3d = Box(center, size, orientation, name=bbox_2d['category'], token=bbox_2d['instance_token'])
        bboxes_3d_out.append(bbox_3d)

    # Update ego pose (else the POV will remain fixed in the scene)
    curr_ego_pose = ego_pose_seq[t]

    img = my_render_3d_style_CV(nusc, bboxes_3d_out, front_camera_sensor, curr_ego_pose, data_path=None, background=True, show_2d_bboxes=True, show_3d_bboxes=False)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # Save to tensor buffer
    pil_image = Image.fromarray(img)
    img_seq[t] = transform(pil_image)

    img = cv2.resize(img, frame_size)
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    video_writer_out.write(img)
    
video_writer_out.release()
print(f"Video saved: {video_filename}")

Video saved: video_out/scene-1094_style.avi


### Create GT video

In [14]:
# from torchvision import transforms
# import torch

# ego_height = get_random_height('vehicle.car')
# front_cam_token = my_sample['data']['CAM_FRONT']
# agent_heights = {}

# # Parameters for the video
# video_filename = f"video_out/{my_scene['name']}_style_gt.avi"
# FPS = 2  # NOTE Real time is 2 fps
# frame_size = (512, 320)

# transform=transforms.Compose([
#                     transforms.Resize((frame_size[1], frame_size[0])),
#                     transforms.ToTensor(),
#                     transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), # map from [0,1] to [-1,1]
#                  ])

# # Create a VideoWriter object
# fourcc = cv2.VideoWriter_fourcc(*'XVID')
# video_writer_out = cv2.VideoWriter(video_filename, fourcc, FPS, frame_size)
# fig, ax = plt.subplots(1, 1, figsize=(9, 9))
# plt.tight_layout(pad=0)

# # img_seq_gt = torch.zeros([len(bboxes_2d_seq), 3, frame_size[1], frame_size[0]])

# curr_sample = my_sample
# next_sample_token = curr_sample['token']

# t = 0
# while True:
#     curr_sample = nusc.get('sample', next_sample_token)

#     # fig, ax = plt.subplots(1, 1, figsize=(9, 9))
#     # plt.tight_layout(pad=0)

#     ax.clear()
#     data_path, bboxes_3d_local, camera_intrinsic = nusc.get_sample_data(curr_sample['data']['CAM_FRONT'], selected_anntokens=curr_sample['anns'])
#     img = my_render_3d_style_CV(nusc, bboxes_3d_local, front_camera_sensor, ego_pose, data_path=data_path, transform=False, background=True, show_2d_bboxes=True, show_3d_bboxes=False)

#     img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#     img = cv2.resize(img, frame_size)

#     # Save to tensor buffer
#     pil_image = Image.fromarray(img)
#     img_seq[t] = transform(pil_image)

#     # img_seq_gt[t] = transform(Image.fromarray(img))
#     # t += 1

#     img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
#     video_writer_out.write(img)

#     # plt.show(fig)

#     next_sample_token = curr_sample['next']
#     if next_sample_token == '':
#         break

    
# video_writer_out.release()
# print(f"Video saved: {video_filename}")
# plt.close(fig)

## Ctrl-V

### Setup

In [15]:
import warnings
import numpy as np
import torch
torch.cuda.empty_cache()
import torch.utils.checkpoint

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    from ctrlv.utils import export_to_video
    from ctrlv.pipelines import StableVideoControlPipeline

OUT_DIR = "/network/scratch/x/xuolga/Results/sd3d/bdd100k_ctrlv_240511_200727/" #kitti_ctrlv_240510_141159/" # This one exists: kitti_ctrlv_240513_195113/ 
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")

# set_seed(args.seed)
generator = None #torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None

In [16]:
# Load pipelines
from ctrlv.models import UNetSpatioTemporalConditionModel, ControlNetModel
ctrlnet = ControlNetModel.from_pretrained(OUT_DIR, subfolder="controlnet")
unet = UNetSpatioTemporalConditionModel.from_pretrained(OUT_DIR, subfolder="unet")
pipeline = StableVideoControlPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", controlnet = ctrlnet, unet = unet,)
pipeline = pipeline.to(device)
pipeline.set_progress_bar_config(disable=True)

Loading pipeline components...: 100%|██████████| 5/5 [00:00<00:00,  7.54it/s]


### (3) Generate video

In [20]:
# Get initial image for conditioning
image_init_path = os.path.join(nusc.dataroot, cam_front_data['filename'])
image_init = Image.open(image_init_path)
sample_out = {'bbox_img': img_seq, 'image_init': image_init}

# sample_out = {'bbox_img': img_seq_gt, 'image_init': image_init} 


CLIP_LENGTH = sample_out['bbox_img'].shape[0]


def run_inference_with_pipeline(pipeline, demo_samples):
    for sample_i, sample in enumerate(demo_samples):
        frames = pipeline(sample['image_init'], 
                        cond_images=sample['bbox_img'].unsqueeze(0),
                        height=frame_size[1], width=frame_size[0], 
                        decode_chunk_size=8, motion_bucket_id=127, fps=FPS, 
                        num_inference_steps=30,
                        num_frames=CLIP_LENGTH,
                        control_condition_scale=1.0,
                        min_guidance_scale=1.0,
                        max_guidance_scale=3.0,
                        noise_aug_strength=0.01,
                        generator=generator, output_type='pt').frames[0]
        #frames = F.interpolate(frames, (dataset.orig_H, dataset.orig_W)).detach().cpu().numpy()*255
        frames = frames.detach().cpu().numpy()*255
        frames = frames.astype(np.uint8)

        tmp = np.moveaxis(np.transpose(frames, (0, 2, 3, 1)), 0, 0)
        output_video_path = f"video_out/generated_ctrl_{my_scene['name']}.mp4"
        export_to_video(tmp, output_video_path, fps=FPS)
        print(f"Video saved:", output_video_path)
        # log_dict = {}
        # log_dict["generated_videos"] = wandb.Video(frames, fps=args.fps)
        # log_dict["gt_bbox_frames"] = wandb.Video(sample['bbox_img_np'], fps=args.fps)
        # log_dict["gt_videos"] = wandb.Video(sample['gt_clip_np'], fps=args.fps)
        # frame_bboxes = wandb_frames_with_bbox(frames, sample['objects_tensors'], (dataset.orig_W, dataset.orig_H))
        # log_dict["frames_with_bboxes_{}".format(sample_i)] = frame_bboxes

        
print("Start inference...")
sample_out['bbox_img'].to(device)
run_inference_with_pipeline(pipeline, [sample_out])

Start inference...
