In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import sys
sys.path.append("/content/drive/MyDrive/HyperNeRFGAN/src")

import pickle
import pathlib

import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from nerf.load_blender import pose_spherical
import torch
from tqdm import tqdm

In [None]:
batch_size = 1
num_images = 10
interp_examples = 5
interp_steps = 10

retrained_name = "carla_ns8_improved_200"
# retrained_name = "carla_ps8_improved_200"
# retrained_name = "carla_p0_improved_200"
# retrained_name = "carla_r1_improved_200"

retrained_path = pathlib.Path(f"/content/drive/MyDrive/HyperNeRFGAN/data/retrained/{retrained_name}.pkl")
img_path = pathlib.Path(f"/content/drive/MyDrive/HyperNeRFGAN/retrained_running/images/{retrained_name}")
interpolation_path = pathlib.Path(f"/content/drive/MyDrive/HyperNeRFGAN/retrained_running/interpolation/{retrained_name}")

img_path.mkdir(parents=True, exist_ok=True)
interpolation_path.mkdir(parents=True, exist_ok=True)

In [None]:
def get_interpolated_images(G, num_steps=10):
    z1 = torch.randn(1, 128)
    z2 = torch.randn(1, 128)

    interpolated_vectors = torch.zeros(num_steps, 128)

    for i in tqdm(range(num_steps)):
        alpha = i / (num_steps - 1)
        interpolated_vectors[i] = (1 - alpha) * z1 + alpha * z2

    images = G(
        z=interpolated_vectors,
        c=None,
        poses=[pose_spherical(theta=30, phi=-30, radius=4.0)] * num_steps,
        scale=False,
        crop=False,
        perturb=False,
        use_normal=False,
    )
    images = images.permute((0, 2, 3, 1))

    return images


def make_interpolation_examples(
    G, interpolation_path, interpolation_examples=interp_examples, num_steps=interp_steps
):
    for i in range(interpolation_examples):
        images = get_interpolated_images(G, num_steps=num_steps)
        plt.clf()
        fig, axs = plt.subplots(1, num_steps, figsize=(50, 5), tight_layout=True)

        for j, ax in enumerate(axs):
            ax.imshow(images[j])
            ax.axis("off")

        fig.savefig(interpolation_path.joinpath(f"{i}.png"))


def make_image_examples(G, num_images, batch_size, result_path):
    epochs = int(num_images / batch_size + 1)
    pose = pose_spherical(theta=30, phi=-30, radius=4.0)  # chọn 1 pose cố định
    for i in tqdm(range(epochs)):
        z = torch.randn(batch_size, 128)
        imgs = G(z=z, c=None, poses=[pose] * batch_size, scale=False, crop=False, perturb=False, use_normal=False)

        for j, img in enumerate(imgs):
            image_name = result_path / f"{i * batch_size + j + 1}.png"
            img = img.detach().cpu()
            img = torch.clamp(img, 0.0, 1.0)
            transforms.ToPILImage()(img).save(image_name)

In [None]:
import numpy as np

if not hasattr(np, "float"): np.float = float
if not hasattr(np, "int"): np.int = int
if not hasattr(np, "bool"): np.bool = bool
# if not hasattr(np, "object"): np.object = object

In [None]:
if __name__ == "__main__":
    with retrained_path.open("rb") as f:
        content = pickle.load(f)

    G = content["G_ema"].eval()

    print("Generating interpolation examples...")
    torch.manual_seed(0)
    make_interpolation_examples(G, interpolation_path)

    print("Generating images...")
    torch.manual_seed(0)
    make_image_examples(G, num_images, batch_size, img_path)

Generating interpolation examples...


100%|██████████| 10/10 [00:00<00:00, 8305.55it/s]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
100%|██████████| 10/10 [00:00<00:00, 11618.57it/s]
100%|██████████| 10/10 [00:00<00:00, 13281.52it/s]
100%|██████████| 10/10 [00:00<00:00, 10443.98it/s]
