In [2]:
import json
import os
from os import path
import queue
import threading

from internal import math, utils  # pylint: disable=g-multiple-import
import jax
import numpy as np
from PIL import Image

import cv2

In [5]:
Path_to_videos = '/media/pleasework/Storage/Nerf_Datasets/Datasets/surrey/02_dancer'


# only save the file with end with .mp4
names = os.listdir(Path_to_videos)

for name in names:
    if not name.endswith('.mp4'):
        names.remove(name)

names.sort()
print(names)

['cam_1.mp4', 'cam_2.mp4', 'cam_3.mp4', 'cam_4.mp4', 'cam_5.mp4', 'cam_6.mp4', 'cam_7.mp4', 'cam_8.mp4']


In [15]:
meta_path = r'/media/pleasework/Storage/regnerf/data/pop-sparse/metadata.json'

In [8]:
def downsample(img, factor, patch_size=-1, mode=cv2.INTER_AREA):
  """Area downsample img (factor must evenly divide img height and width)."""
  sh = img.shape
  max_fn = lambda x: max(x, patch_size)
  out_shape = (max_fn(sh[1] // factor), max_fn(sh[0] // factor))
  img = cv2.resize(img, out_shape, mode)
  return img

def sample_recon_scale(image_list, dist='uniform_scale'):
  """Samples a scale factor for the reconstruction loss."""
  if dist == 'uniform_scale':
    idx = np.random.randint(len(image_list))
  elif dist == 'uniform_size':
    n_img = np.array([i.shape[0] for i in image_list], dtype=np.float32)
    probs = n_img / np.sum(n_img)
    idx = np.random.choice(np.arange(len(image_list)), size=(), p=probs)
  return idx

def anneal_nearfar(d, it, near_final, far_final,
                   n_steps=2000, init_perc=0.2, mid_perc=0.5):
  """Anneals near and far plane."""
  mid = near_final + mid_perc * (far_final - near_final)

  near_init = mid + init_perc * (near_final - mid)
  far_init = mid + init_perc * (far_final - mid)

  weight = min(it * 1.0 / n_steps, 1.0)

  near_i = near_init + weight * (near_final - near_init)
  far_i = far_init + weight * (far_final - far_init)

  out_dict = {}
  for (k, v) in d.items():
    if 'rays' in k and isinstance(v, utils.Rays):
      ones = np.ones_like(v.origins[Ellipsis, :1])
      rays_out = utils.Rays(
          origins=v.origins, directions=v.directions,
          viewdirs=v.viewdirs, radii=v.radii,
          lossmult=v.lossmult, near=ones*near_i, far=ones*far_i)
      out_dict[k] = rays_out
    else:
      out_dict[k] = v
  return out_dict


def subsample_patches(images, patch_size, batch_size, batching='all_images'):
  """Subsamples patches."""
  n_patches = batch_size // (patch_size ** 2)

  scale = np.random.randint(0, len(images))
  images = images[scale]

  if isinstance(images, np.ndarray):
    shape = images.shape
  else:
    shape = images.origins.shape

  # Sample images
  if batching == 'all_images':
    idx_img = np.random.randint(0, shape[0], size=(n_patches, 1))
  elif batching == 'single_image':
    idx_img = np.random.randint(0, shape[0])
    idx_img = np.full((n_patches, 1), idx_img, dtype=int)
  else:
    raise ValueError('Not supported batching type!')

  # Sample start locations
  x0 = np.random.randint(0, shape[2] - patch_size + 1, size=(n_patches, 1, 1))
  y0 = np.random.randint(0, shape[1] - patch_size + 1, size=(n_patches, 1, 1))
  xy0 = np.concatenate([x0, y0], axis=-1)
  patch_idx = xy0 + np.stack(
      np.meshgrid(np.arange(patch_size), np.arange(patch_size), indexing='xy'),
      axis=-1).reshape(1, -1, 2)

  # Subsample images
  if isinstance(images, np.ndarray):
    out = images[idx_img, patch_idx[Ellipsis, 1], patch_idx[Ellipsis, 0]].reshape(-1, 3)
  else:
    out = utils.dataclass_map(
        lambda x: x[idx_img, patch_idx[Ellipsis, 1], patch_idx[Ellipsis, 0]].reshape(  # pylint: disable=g-long-lambda
            -1, x.shape[-1]), images)
  return out, np.ones((n_patches, 1), dtype=np.float32) * scale

In [11]:
class Dataset(threading.Thread):
  """Dataset Base Class."""

  def __init__(self, split, data_dir, config):
    super(Dataset, self).__init__()
    self.queue = queue.Queue(3)  # Set prefetch buffer to 3 batches.
    self.daemon = True
    self.use_tiffs = config.use_tiffs
    self.load_disps = config.compute_disp_metrics
    self.load_normals = config.compute_normal_metrics
    self.load_random_rays = config.load_random_rays
    self.load_random_fullimage_rays = config.dietnerf_loss_mult != 0.0
    self.load_masks = ((config.dataset_loader == 'dtu') and (split == 'test')
                       and (not config.dtu_no_mask_eval)
                       and (not config.render_path))

    self.split = split
    if config.dataset_loader == 'dtu':
      self.data_base_dir = data_dir
      data_dir = os.path.join(data_dir, config.dtu_scan)
    elif config.dataset_loader == 'llff':
      self.data_base_dir = data_dir
      data_dir = os.path.join(data_dir, config.llff_scan)
    elif config.dataset_loader == 'blender':
      self.data_base_dir = data_dir
      data_dir = os.path.join(data_dir, config.blender_scene)
    self.data_dir = data_dir
    self.near = config.near
    self.far = config.far
    self.near_origin = config.near_origin
    self.anneal_nearfar = config.anneal_nearfar
    self.anneal_nearfar_steps = config.anneal_nearfar_steps
    self.anneal_nearfar_perc = config.anneal_nearfar_perc
    self.anneal_mid_perc = config.anneal_mid_perc
    self.sample_reconscale_dist = config.sample_reconscale_dist

    if split == 'train':
      self._train_init(config)
    elif split == 'test' or split == 'path':
      self._test_init(config)
    else:
      raise ValueError(
          f'`split` should be \'train\' or \'test\', but is \'{split}\'.')
    self.batch_size = config.batch_size // jax.host_count()
    self.batch_size_random = config.batch_size_random // jax.host_count()
    print('Using following batch size', self.batch_size)
    self.patch_size = config.patch_size
    self.batching = config.batching
    self.batching_random = config.batching_random
    self.render_path = config.render_path
    self.render_train = config.render_train
    self.start()
    

  def __iter__(self):
    return self

  def __next__(self):
    """Get the next training batch or test example.

    Returns:
      batch: dict, has 'rgb' and 'rays'.
    """
    x = self.queue.get()
    if self.split == 'train':
      return utils.shard(x)
    else:
      return utils.to_device(x)

  def peek(self):
    """Peek at the next training batch or test example without dequeuing it.

    Returns:
      batch: dict, has 'rgb' and 'rays'.
    """
    x = self.queue.queue[0].copy()  # Make a copy of the front of the queue.
    if self.split == 'train':
      return utils.shard(x)
    else:
      return utils.to_device(x)

  def run(self):
    if self.split == 'train':
      next_func = self._next_train
    else:
      next_func = self._next_test
    while True:
      self.queue.put(next_func())
      print("Queue size", self.queue.qsize())

  @property
  def size(self):
    return self.n_examples

  def _train_init(self, config):
    """Initialize training."""
    self._load_renderings(config)
    self._generate_downsampled_images(config)
    self._generate_rays(config)
    self._generate_downsampled_rays(config)

    # Generate more rays / image patches for unobserved-view-based losses.
    if self.load_random_rays:
      self._generate_random_rays(config)
    if self.load_random_fullimage_rays:
      self._generate_random_fullimage_rays(config)
      self._load_renderings_featloss(config)

    self.it = 0
    self.images_noreshape = self.images[0]

    if config.batching == 'all_images':
      # flatten the ray and image dimension together.
      self.images = [i.reshape(-1, 3) for i in self.images]
      if self.load_disps:
        self.disp_images = self.disp_images.flatten()
      if self.load_normals:
        self.normal_images = self.normal_images.reshape([-1, 3])

      self.ray_noreshape = [self.rays]
      self.rays = [utils.dataclass_map(lambda r: r.reshape(  # pylint: disable=g-long-lambda
          [-1, r.shape[-1]]), i) for (i, res) in zip(
              self.rays, self.resolutions)]

  
    elif config.batching == 'single_image':
      print("image shape:",self.images_noreshape.shape)
      self.images = [i.reshape(
          [-1, r, 3]) for (i, r) in zip(self.images, self.resolutions)]
      print("shape after operation:",self.images[0].shape)
      if self.load_disps:
        self.disp_images = self.disp_images.reshape([-1, self.resolution])
      if self.load_normals:
        self.normal_images = self.normal_images.reshape(
            [-1, self.resolution, 3])
      
      self.ray_noreshape = [self.rays]
      # print("rays:",self.rays)    
      self.rays = [utils.dataclass_map(lambda r: r.reshape(  # pylint: disable=g-long-lambda
          [-1, res, r.shape[-1]]), i) for (i, res) in  # pylint: disable=cell-var-from-loop
                   zip(self.rays, self.resolutions)]
      print("self.resolutions:", self.resolutions)
      # print("rays:",self.rays)   
    else:
      raise NotImplementedError(
          f'{config.batching} batching strategy is not implemented.')
    # print("rays:",self.rays)

  def _test_init(self, config):
    self._load_renderings(config)
    if self.load_masks:
      self._load_masks(config)
    self._generate_rays(config)
    self.it = 0

  def _next_train(self):
    """Sample next training batch."""

    self.it = self.it + 1
    print("self.it:",self.it)
    return_dict = {}
    if self.batching == 'all_images':
      # sample scale
      idxs = sample_recon_scale(self.images, self.sample_reconscale_dist)
      ray_indices = np.random.randint(0, self.rays[idxs].origins.shape[0],
                                      (self.batch_size,))
      return_dict['rgb'] = self.images[idxs][ray_indices]
      return_dict['rays'] = utils.dataclass_map(lambda r: r[ray_indices],
                                                self.rays[idxs])
      if self.load_disps:
        return_dict['disps'] = self.disp_images[ray_indices]
      if self.load_normals:
        return_dict['normals'] = self.normal_images[ray_indices]

    elif self.batching == 'single_image':
      idxs = sample_recon_scale(self.images, self.sample_reconscale_dist)
      idxs = 0
      print("idxs:",idxs)
      image_index = np.random.randint(0, self.n_examples, ())
      ray_indices = np.random.randint(0, self.rays[idxs].origins[0].shape[0],
                                      (self.batch_size,))
      print("ray_indices:", ray_indices)
      print("image index:", image_index)
      return_dict['rgb'] = self.images[idxs][image_index][ray_indices]
      return_dict['rays'] = utils.dataclass_map(
          lambda r: r[image_index][ray_indices], self.rays[idxs])
      if self.load_disps:
        return_dict['disps'] = self.disp_images[image_index][ray_indices]
      if self.load_normals:
        return_dict['normals'] = self.normal_images[image_index][ray_indices]
    else:
      raise NotImplementedError(
          f'{self.batching} batching strategy is not implemented.')

    if self.load_random_rays:
      return_dict['rays_random'], return_dict['rays_random_scale'] = (
          subsample_patches(self.random_rays, self.patch_size,
                            self.batch_size_random,
                            batching=self.batching_random))
      return_dict['rays_random2'], return_dict['rays_random2_scale'] = (
          subsample_patches(
              self.random_rays, self.patch_size, self.batch_size_random,
              batching=self.batching_random))
    if self.load_random_fullimage_rays:
      idx_img = np.random.randint(self.random_fullimage_rays.origins.shape[0])
      return_dict['rays_feat'] = utils.dataclass_map(
          lambda x: x[idx_img].reshape(-1, x.shape[-1]),
          self.random_fullimage_rays)
      idx_img = np.random.randint(self.images_feat.shape[0])
      return_dict['image_feat'] = self.images_feat[idx_img].reshape(-1, 3)

    if self.anneal_nearfar:
      return_dict = anneal_nearfar(return_dict, self.it, self.near, self.far,
                                   self.anneal_nearfar_steps,
                                   self.anneal_nearfar_perc,
                                   self.anneal_mid_perc)

    return return_dict

  def _next_test(self):
    """Sample next test example."""

    return_dict = {}

    idx = self.it
    self.it = (self.it + 1) % self.n_examples

    if self.render_path:
      return_dict['rays'] = utils.dataclass_map(lambda r: r[idx],
                                                self.render_rays)
    else:
      return_dict['rgb'] = self.images[idx]
      return_dict['rays'] = utils.dataclass_map(lambda r: r[idx], self.rays)

    if self.load_masks:
      return_dict['mask'] = self.masks[idx]
    if self.load_disps:
      return_dict['disps'] = self.disp_images[idx]
    if self.load_normals:
      return_dict['normals'] = self.normal_images[idx]

    return return_dict

  def _generate_rays(self, config):

    print("using the generate rays from dataset classsssssssss")
    """Generating rays for all images."""
    del config  # Unused.
    x, y = np.meshgrid(  # pylint: disable=unbalanced-tuple-unpacking
        np.arange(self.width, dtype=np.float32),  # X-Axis (columns)
        np.arange(self.height, dtype=np.float32),  # Y-Axis (rows)
        indexing='xy')
    camera_dirs = np.stack(
        [(x - self.width * 0.5 + 0.5) / self.focal,
         -(y - self.height * 0.5 + 0.5) / self.focal, -np.ones_like(x)],
        axis=-1)
    directions = ((camera_dirs[None, Ellipsis, None, :] *
                   self.camtoworlds[:, None, None, :3, :3]).sum(axis=-1))
    origins = np.broadcast_to(self.camtoworlds[:, None, None, :3, -1],
                              directions.shape)
    viewdirs = directions / np.linalg.norm(directions, axis=-1, keepdims=True)

    # Distance from each unit-norm direction vector to its x-axis neighbor.
    dx = np.sqrt(
        np.sum((directions[:, :-1, :, :] - directions[:, 1:, :, :])**2, -1))
    dx = np.concatenate([dx, dx[:, -2:-1, :]], axis=1)
    # Cut the distance in half, multiply it to match the variance of a uniform
    # distribution the size of a pixel (1/12, see paper).
    radii = dx[Ellipsis, None] * 2 / np.sqrt(12)

    ones = np.ones_like(origins[Ellipsis, :1])
    self.rays = utils.Rays(
        origins=origins,
        directions=directions,
        viewdirs=viewdirs,
        lossmult=ones,
        radii=radii,
        near=ones * self.near,
        far=ones * self.far)
    self.render_rays = self.rays

  def _generate_random_poses(self, config):
    """Generates random poses."""
    if config.random_pose_type == 'allposes':
      random_poses = list(self.camtoworlds_all)
    elif config.random_pose_type == 'renderpath':
      def sample_on_sphere(n_samples, only_upper=True, radius=4.03112885717555):
        p = np.random.randn(n_samples, 3)
        if only_upper:
          p[:, -1] = abs(p[:, -1])
        p = p / np.linalg.norm(p, axis=-1, keepdims=True) * radius
        return p

      def create_look_at(eye, target=np.array([0, 0, 0]),
                         up=np.array([0, 0, 1]), dtype=np.float32):
        """Creates lookat matrix."""
        eye = eye.reshape(-1, 3).astype(dtype)
        target = target.reshape(-1, 3).astype(dtype)
        up = up.reshape(-1, 3).astype(dtype)

        def normalize_vec(x, eps=1e-9):
          return x / (np.linalg.norm(x, axis=-1, keepdims=True) + eps)

        forward = normalize_vec(target - eye)
        side = normalize_vec(np.cross(forward, up))
        up = normalize_vec(np.cross(side, forward))

        up = up * np.array([1., 1., 1.]).reshape(-1, 3)
        forward = forward * np.array([-1., -1., -1.]).reshape(-1, 3)

        rot = np.stack([side, up, forward], axis=-1).astype(dtype)
        return rot

      origins = sample_on_sphere(config.n_random_poses)
      rotations = create_look_at(origins)
      random_poses = np.concatenate([rotations, origins[:, :, None]], axis=-1)
    else:
      raise ValueError('Not supported random pose type.')
    self.random_poses = np.stack(random_poses, axis=0)

  def _generate_random_rays(self, config):
    """Generating rays for all images."""
    self._generate_random_poses(config)

    random_rays = []
    for sfactor in [2**i for i in range(config.random_scales_init,
                                        config.random_scales)]:
      w = self.width // sfactor
      h = self.height // sfactor
      f = self.focal / (sfactor * 1.0)
      x, y = np.meshgrid(  # pylint: disable=unbalanced-tuple-unpacking
          np.arange(w, dtype=np.float32),  # X-Axis (columns)
          np.arange(h, dtype=np.float32),  # Y-Axis (rows)
          indexing='xy')
      camera_dirs = np.stack(
          [(x - w * 0.5 + 0.5) / f,
           -(y - h * 0.5 + 0.5) / f, -np.ones_like(x)],
          axis=-1)
      directions = ((camera_dirs[None, Ellipsis, None, :] *
                     self.random_poses[:, None, None, :3, :3]).sum(axis=-1))
      origins = np.broadcast_to(self.random_poses[:, None, None, :3, -1],
                                directions.shape)
      viewdirs = directions / np.linalg.norm(directions, axis=-1, keepdims=True)

      # Distance from each unit-norm direction vector to its x-axis neighbor.
      dx = np.sqrt(
          np.sum((directions[:, :-1, :, :] - directions[:, 1:, :, :])**2, -1))
      dx = np.concatenate([dx, dx[:, -2:-1, :]], axis=1)
      # Cut the distance in half, multiply it to match the variance of a uniform
      # distribution the size of a pixel (1/12, see paper).
      radii = dx[Ellipsis, None] * 2 / np.sqrt(12)

      ones = np.ones_like(origins[Ellipsis, :1])
      rays = utils.Rays(
          origins=origins,
          directions=directions,
          viewdirs=viewdirs,
          radii=radii,
          lossmult=ones,
          near=ones * self.near,
          far=ones * self.far)
      random_rays.append(rays)
    self.random_rays = random_rays

  def _load_renderings_featloss(self, config):
    """Loades renderings for DietNeRF's feature loss."""
    images = self.images[0]
    res = config.dietnerf_loss_resolution
    images_feat = []
    for img in images:
      images_feat.append(cv2.resize(img, (res, res), cv2.INTER_AREA))
    self.images_feat = np.stack(images_feat)

  def _generate_random_fullimage_rays(self, config):
    """Generating random rays for full images."""
    self._generate_random_poses(config)

    width = config.dietnerf_loss_resolution
    height = config.dietnerf_loss_resolution
    f = self.focal / (self.width * 1.0 / width)

    x, y = np.meshgrid(  # pylint: disable=unbalanced-tuple-unpacking
        np.arange(width, dtype=np.float32) + .5,
        np.arange(height, dtype=np.float32) + .5,
        indexing='xy')

    camera_dirs = np.stack([(x - width * 0.5 + 0.5) / f,
                            -(y - height * 0.5 + 0.5) / f,
                            -np.ones_like(x)], axis=-1)
    directions = ((camera_dirs[None, Ellipsis, None, :] *
                   self.random_poses[:, None, None, :3, :3]).sum(axis=-1))
    origins = np.broadcast_to(self.random_poses[:, None, None, :3, -1],
                              directions.shape)
    viewdirs = directions / np.linalg.norm(directions, axis=-1, keepdims=True)

    # Distance from each unit-norm direction vector to its x-axis neighbor.
    dx = np.sqrt(
        np.sum((directions[:, :-1, :, :] - directions[:, 1:, :, :])**2, -1))
    dx = np.concatenate([dx, dx[:, -2:-1, :]], axis=1)
    # Cut the distance in half, multiply it to match the variance of a uniform
    # distribution the size of a pixel (1/12, see paper).
    radii = dx[Ellipsis, None] * 2 / np.sqrt(12)

    ones = np.ones_like(origins[Ellipsis, :1])
    self.random_fullimage_rays = utils.Rays(
        origins=origins,
        directions=directions,
        viewdirs=viewdirs,
        radii=radii,
        lossmult=ones,
        near=ones * self.near,
        far=ones * self.far)

  def _generate_downsampled_images(self, config):
    """Generating downsampled images."""
    images = []
    resolutions = []
    for sfactor in [2**i for i in range(config.recon_loss_scales)]:
      print("single image shape:",self.images[0].shape)
      imgi = np.stack([downsample(i, sfactor) for i in self.images])
      images.append(imgi)
      # print("single image shape:",imgi.shape)
      resolutions.append(imgi.shape[1] * imgi.shape[2])

    self.images = images
    self.resolutions = resolutions

  def _generate_downsampled_rays(self, config):
    """Generating downsampled images."""
    rays, height, width, focal = self.rays, self.height, self.width, self.focal
    ray_list = [rays]
    for sfactor in [2**i for i in range(1, config.recon_loss_scales)]:
      self.height = height // sfactor
      self.width = width // sfactor
      self.focal = focal * 1.0 / sfactor
      self._generate_rays(config)
      ray_list.append(self.rays)
    self.height = height
    self.width = width
    self.focal = focal
    self.rays = ray_list


In [27]:
import json
import cv2
# save each frame of each camera into the images, and also the camera poses will be listed correspondingly

class multicam_video(Dataset):

    def _load_renderings(self, config):
        if config.render_path:
            raise ValueError('render_path cannot be used for the Multicam dataset.')
        with utils.open_file(path.join(self.data_dir, 'metadata.json'), 'r') as fp:
            self.meta = json.load(fp)[self.split]
            print("self.split:", self.split)
        self.meta = {k: np.array(self.meta[k]) for k in self.meta}
        # Should now have ['pix2cam', 'cam2world', 'width', 'height'] in self.meta.
        # print(self.meta)

        # read the video sequence:
        for i in range(len())
        self.video = cv2.VideoCapture(path.join(self.data_dir, 'video.mp4'))


        images = []

        for fbase in self.meta['file_path']:
            print('file path:', fbase)
            fname = os.path.join(self.data_dir, fbase)
            with utils.open_file(fname, 'rb') as imgin:
                image = np.array(Image.open(imgin), dtype=np.float32) / 255.
            if config.white_background:
                image = image[Ellipsis, :3] * image[Ellipsis, -1:] + (1. - image[Ellipsis, -1:])
            images.append(image[Ellipsis, :3])

        self.images = np.stack(images, axis = 0)
        print("image shape after stack:", self.images.shape)
        self.n_examples = len(self.images)






In [35]:
## So right now we should distribute the video sequences into the different frames.

In [34]:
# Main train used for testing the output of the multicam video sequences.

import functools
import gc
import time

from absl import app
import flax
from flax.metrics import tensorboard
from flax.training import checkpoints
from internal import configs, datasets, math, models, utils, vis  # pylint: disable=g-multiple-import
import jax
from jax import random
import jax.numpy as jnp
import numpy as np
from skimage.metrics import structural_similarity

def main():
    pass

if __name__ == '__main__':
    main()

ImportError: cannot import name 'datasets' from 'internal' (unknown location)

In [None]:
# construct the camera poses,

