In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import pickle
import pathlib

import sys
sys.path.append("/content/drive/MyDrive/HyperNeRFGan/src")

import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from nerf.load_blender import pose_spherical
import torch
from tqdm import tqdm

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
batch_size = 1
num_images = 20
interp_examples = 5
interp_steps = 10

pickle_name = "carla_improved_1600"
pickle_path = pathlib.Path(f"/content/drive/MyDrive/HyperNeRFGan/data/pickles/{pickle_name}.pkl")
img_path = pathlib.Path(f"/content/drive/MyDrive/HyperNeRFGan/pretrained_running/samples/images/{pickle_name}")
interpolation_path = pathlib.Path(f"/content/drive/MyDrive/HyperNeRFGan/pretrained_running/samples/interpolation/{pickle_name}")

img_path.mkdir(parents=True, exist_ok=True)
interpolation_path.mkdir(parents=True, exist_ok=True)

In [None]:
def get_interpolated_images(G, num_steps=8):
    z1 = torch.randn(1, 128, device=device)
    z2 = torch.randn(1, 128, device=device)

    alphas = torch.linspace(0, 1, steps=num_steps, device=device).unsqueeze(1)
    interpolated_vectors = (1 - alphas) * z1 + alphas * z2  # (S,128)

    poses = [pose_spherical(theta=30, phi=-30, radius=4.0)] * num_steps

    with torch.inference_mode():
        with torch.cuda.amp.autocast(enabled=(device.type == "cuda"), dtype=torch.float16):
            images = G(
                z=interpolated_vectors,
                c=None,
                poses=poses,
                scale=False,
                crop=False,
                perturb=False,
            )

    images = images.float().permute((0, 2, 3, 1)).cpu()
    return images


def make_interpolation_examples(G, interpolation_path, interpolation_examples=4, num_steps=8):
    for i in range(interpolation_examples):
        images = get_interpolated_images(G, num_steps=num_steps)

        plt.clf()
        fig, axs = plt.subplots(1, num_steps, figsize=(2.5 * num_steps, 3), tight_layout=True)

        for j, ax in enumerate(axs):
            ax.imshow(images[j])
            ax.axis("off")

        fig.savefig(interpolation_path.joinpath(f"{i}.png"), dpi=120)
        plt.close(fig)

        del images
        if device.type == "cuda":
            torch.cuda.empty_cache()


def make_image_examples(G, num_images, batch_size, result_path):
    to_pil = transforms.ToPILImage()

    epochs = (num_images + batch_size - 1) // batch_size
    counter = 0

    for i in tqdm(range(epochs)):
        cur_bs = min(batch_size, num_images - counter)
        z = torch.randn(cur_bs, 128, device=device)

        with torch.inference_mode():
            with torch.cuda.amp.autocast(enabled=(device.type == "cuda"), dtype=torch.float16):
                imgs = G(z=z, c=None, scale=False, crop=False, perturb=False)

        imgs = imgs.float().cpu()

        for j, img in enumerate(imgs):
            counter += 1
            image_name = result_path / f"{counter}.png"
            to_pil(img).save(image_name)

        del imgs, z
        if device.type == "cuda":
            torch.cuda.empty_cache()

In [None]:
if __name__ == "__main__":
    with pickle_path.open("rb") as f:
        content = pickle.load(f)

    G = content["G_ema"].eval().to(device)

    print("Generating interpolation examples...")
    torch.manual_seed(0)
    make_interpolation_examples(G, interpolation_path, interpolation_examples=interp_examples, num_steps=interp_steps)

    print("Generating images...")
    torch.manual_seed(0)
    make_image_examples(G, num_images, batch_size, img_path)