In [1]:
import cv2
import matplotlib.pyplot as plt
import torch
from IPython import display
from constants import DTYPE_STATE


def stack_preprocess_frames_old(frames, device="cpu", mode="grayscale"):
    frames = [torch.tensor(preprocess_frame(frame, mode=mode), device=device, dtype=DTYPE_STATE) for frame in frames]
    stacked_frames_tensor = torch.stack(frames)
    if mode == "grayscale":
        stacked_frames_tensor = stacked_frames_tensor.unsqueeze(1)  # Add channel dimension for grayscale
    elif mode == "rgb":
        stacked_frames_tensor = stacked_frames_tensor.permute(0, 3, 1, 2).contiguous()  # Rearrange dimensions for RGB
    else:
        raise ValueError("Invalid mode: choose 'grayscale' or 'rgb'")
    return stacked_frames_tensor


def preprocess_frame(frame, mode="grayscale"):
    """
    Preprocesses a given frame by converting it to the specified mode and resizing it to 84x84.

    Parameters:
    - frame: The input frame in RGB format.
    - mode: 'grayscale' or 'rgb'.

    Returns:
    - The preprocessed frame.
    """
    if mode == "grayscale":
        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
        frame = cv2.resize(frame, (84, 84))
    elif mode == "rgb":
        frame = cv2.resize(frame, (84, 84))
    else:
        raise ValueError("Invalid mode: choose 'grayscale' or 'rgb'")
    return frame

In [15]:
import torch
import torchvision.transforms as T
import numpy as np

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")


def stack_preprocess_frames_new(frames, device="cpu", mode="grayscale"):
    """
    Efficiently preprocesses and stacks frames using torchvision transforms.

    Parameters:
    - frames: List of input frames in RGB format (as NumPy arrays or equivalent).
    - device: The device to place the tensors on ('cpu' or 'cuda').
    - mode: 'grayscale' or 'rgb'.
    - dtype: The desired output data type (default: uint8 for image data).

    Returns:
    - A tensor of preprocessed frames with the specified dtype.
    """
    # Define transformations
    if mode == "grayscale":
        transform = T.Compose([T.Resize((84, 84)), T.Grayscale(num_output_channels=1)])
    elif mode == "rgb":
        transform = T.Compose([T.Resize((84, 84))])
    else:
        raise ValueError("Invalid mode: choose 'grayscale' or 'rgb'")

    # Apply transformations to all frames and stack them
    frames = np.stack(frames)
    frames_tensor = torch.from_numpy(frames).to(device).permute(0, 3, 1, 2)
    frames_tensor = transform(frames_tensor)

    return frames_tensor

In [16]:
import torch.utils.benchmark as benchmark

import numpy as np

# Create sample frames (3 frames of 84x84 RGB images)
sample_frames = [np.random.randint(0, 256, (96, 96, 3), dtype=np.uint8) for _ in range(1000)]

# Define the setup code
setup_code = """
import torch
from __main__ import stack_preprocess_frames_old, stack_preprocess_frames_new, sample_frames
"""

# Define the statements to be timed
stmt_old = "stack_preprocess_frames_old(sample_frames, device='cuda:1', mode='grayscale')"
stmt_new = "stack_preprocess_frames_new(sample_frames, device='cuda:1', mode='grayscale')"

# Create Timer objects
timer_old = benchmark.Timer(stmt=stmt_old, setup=setup_code)
timer_new = benchmark.Timer(stmt=stmt_new, setup=setup_code)

# Run the benchmarks using adaptive_autorange
result_old = timer_old.adaptive_autorange()
result_new = timer_new.adaptive_autorange()

print(f"percentage improvement: {result_old.median / result_new.median * 100:.2f}%")

percentage improvement: 604.62%


In [20]:
new = stack_preprocess_frames_new(sample_frames, device="cuda:1", mode="grayscale")
old = stack_preprocess_frames_old(sample_frames, device="cuda:1", mode="grayscale")

torch.allclose(new, old, rtol=10)

True