<h1>Infinite Hallway Generator</h1>

Generate infinite videos of hallways using stable diffusion and monocular depth estimation.

<h3>Example</h3>

<figure>
<center>
<img src="https://github.com/NealWadhwa/infinite-hallway/blob/main/example.gif">
</center>

<figcaption>Prompt: <em>A large cat in a hallway with snakes and vines in an hallway in a steampunk aztec temple made of gears</em></figcaption>
</figure>

<a target="_blank" href="https://colab.research.google.com/github/NealWadhwa/infinite-hallway/blob/main/infinite_hallway.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>



In [1]:
# Install dependencies
%pip install matplotlib diffusers torch torchvision torchgeometry
%pip install timm transformers ipywidgets accelerate opencv-python
%pip install -q mediapy

You should consider upgrading via the '/home/nwadhwa/infinite-hallway/.venv/bin/python -m pip install --upgrade pip' command.[0m[33m
[0mNote: you may need to restart the kernel to use updated packages.
You should consider upgrading via the '/home/nwadhwa/infinite-hallway/.venv/bin/python -m pip install --upgrade pip' command.[0m[33m
[0mNote: you may need to restart the kernel to use updated packages.
You should consider upgrading via the '/home/nwadhwa/infinite-hallway/.venv/bin/python -m pip install --upgrade pip' command.[0m[33m
[0mNote: you may need to restart the kernel to use updated packages.


In [2]:
# Import dependencies
import os
import matplotlib.pyplot as plt

import numpy as np
from PIL import Image

import cv2

import torch
import torchvision
import torchgeometry

In [3]:
# Login to hugging face
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [4]:
device = "cuda:0"

In [5]:
# Load MiDaS model for depth estimation.
model_type = "DPT_Large"
midas = torch.hub.load("intel-isl/MiDaS", model_type)
midas.to(device)
midas.eval()

Using cache found in /home/nwadhwa/.cache/torch/hub/intel-isl_MiDaS_master


DPTDepthModel(
  (pretrained): Module(
    (model): VisionTransformer(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
        (norm): Identity()
      )
      (pos_drop): Dropout(p=0.0, inplace=False)
      (norm_pre): Identity()
      (blocks): Sequential(
        (0): Block(
          (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=1024, out_features=3072, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=1024, out_features=1024, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (ls1): Identity()
          (drop_path1): Identity()
          (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=1024, out_features=4096, bias=True)
            (act): GELU(approximate='none')
            (d

In [6]:
# Load inpainting stable diffusion model
from diffusers import StableDiffusionInpaintPipeline, DiffusionPipeline, DPMSolverMultistepScheduler

pipe_inpaint = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-inpainting",
    torch_dtype=torch.float16,
    revision="fp16")
pipe_inpaint.scheduler = DPMSolverMultistepScheduler.from_config(
    pipe_inpaint.scheduler.config)
pipe_inpaint = pipe_inpaint.to(device)

Fetching 13 files:   0%|          | 0/13 [00:00<?, ?it/s]



In [7]:
# For speed, we break the image into multiple planes and warp each plane as if it is a constant depth.


@torch.no_grad()
def move_camera(image,
                inverse_depth,
                focal_length,
                rotation,
                translation,
                a=0.0,
                b=1.0,
                mpi_layers=16,
                device="cuda:0"):
    rotation = rotation.to(device)
    translation = translation.to(device)

    # Break the image into a MPI based on depth.
    max_depth = inverse_depth.max()
    min_depth = inverse_depth.min()
    normalized_depth = (inverse_depth - min_depth) / (max_depth - min_depth)

    # Hard cutoffs in the MPI
    mpi = torch.floor(normalized_depth * mpi_layers).to(torch.int32)
    depth = torch.arange(mpi_layers + 1).to(device) / mpi_layers
    depth = depth * (max_depth - min_depth) + min_depth
    masks = []
    layers = []
    _, _, height, width = image.shape
    dsize = (height, width)
    out_image = torch.zeros_like(image)
    out_mask = torch.zeros_like(inverse_depth)
    for i in range(mpi_layers + 1):
        mask = (mpi == i).to(torch.float)

        intrinsic1 = torch.FloatTensor(
            [[1.0 / focal_length, 0, -height / 2 / focal_length],
             [0, 1.0 / focal_length, -width / 2 / focal_length], [0, 0, 1.0],
             [0.0, 0.0, depth[i]]]).to(device)
        intrinsic2 = torch.FloatTensor([[focal_length, 0, height / 2],
                                        [0, focal_length, width / 2],
                                        [0, 0, 1.0]]).to(device)
        rt = torch.cat([rotation, translation], dim=1)

        ptrans = torch.matmul(rt, intrinsic1)
        perspective_transform = torch.matmul(intrinsic2, ptrans)
        warped_mask = torchgeometry.warp_perspective(mask[None, :, :, :],
                                                     perspective_transform,
                                                     dsize=dsize)
        warped_image = torchgeometry.warp_perspective(image,
                                                      perspective_transform,
                                                      dsize=dsize)

        warped_image = torch.clip(warped_image, 0, 1)
        warped_mask = torch.clip(warped_mask, 0, 1)
        warped_layer = warped_image * warped_mask

        out_image = warped_layer + out_mask * out_image * (1 - warped_mask[0])
        out_mask = warped_mask[0] + out_mask * (1 - warped_mask[0])
        out_image = out_image / (1e-16 + out_mask)

    binary_mask = (out_mask == 0).to(torch.float32)
    out_image = out_image * (1 - binary_mask)

    return out_image, (out_mask == 0).to(torch.float32)


def to_numpy(tensor):
    return tensor.detach().cpu().numpy().transpose(1, 2, 0)


def to_uint8(tensor):
    return (to_numpy(tensor) * 255).astype(np.uint8)

In [8]:
def generate_next_frame(image,
                        prompt,
                        focal_length,
                        rotation,
                        translation,
                        negative_prompt="",
                        a=0.0,
                        b=1.0,
                        mpi_layers=16,
                        device="cuda:0"):
    to_tensor = torchvision.transforms.ToTensor()
    # Compute the depth map of the image
    image = image.to(device)
    image = image[None, :, :, :]
    with torch.inference_mode():
        depth = midas(image)

    warped, mask = move_camera(image, depth, focal_length, rotation, translation,
                            a, b, mpi_layers, device)

    warped_scaled = torch.clip(2 * warped - 1, -1, 1)

    inpainted = pipe_inpaint(prompt=prompt,
                             negative_prompt=negative_prompt,
                             image=warped_scaled,
                             mask_image=mask).images[0]
    inpainted = to_tensor(inpainted)[None, :, :, :]
    inpainted = inpainted.to(device)
    inpainted = warped * (1 - mask) + inpainted * mask
    return inpainted[0, :, :, :]

In [9]:
prompt = "A large cat in a hallway with snakes and vines in an hallway in a steampunk aztec temple made of gears"
negative_prompt = "blurry, bad art, blurred, text, watermark"

NUM_FRAMES = 100
VIDEO_FILENAME = "output.mp4"

prompts = [prompt] * NUM_FRAMES

to_tensor = torchvision.transforms.ToTensor()

dummy_image = torch.ones((1, 3, 512, 512), dtype=torch.float32).to("cuda:0")
mask = torch.ones((1, 1, 512, 512), dtype=torch.float32).to("cuda:0")
image = pipe_inpaint(prompt=prompt,
                     negative_prompt=negative_prompt,
                     image=dummy_image,
                     mask_image=mask).images[0]
image = to_tensor(image)

fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(VIDEO_FILENAME, fourcc, 15.0, (512, 512))
out.write(to_uint8(image))

for prompt in prompts:
    image = generate_next_frame(image,
                                prompt,
                                negative_prompt=negative_prompt,
                                focal_length=30,
                                rotation=torch.eye(3),
                                translation=torch.FloatTensor([[0.0], [0.0],
                                                               [0.003]]))
    image_np = to_uint8(image)    
    out.write(image_np)

out.release()
cv2.destroyAllWindows()

  0%|          | 0/50 [00:00<?, ?it/s]

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

In [10]:
# Show the video
import mediapy as media

media.show_video(media.read_video("output.mp4"))

0
This browser does not support the video tag.
