In [6]:
import torch
import os
from torchvision import transforms, datasets
import torch.nn as nn 
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import json
import imageio
# Make sure to switch runtime to the GPU
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    print("Running on the GPU")
else:
    device = torch.device("cpu")
    print("Running on the CPU")




ModuleNotFoundError: No module named 'imageio'

In [1]:
# Write a class to read in the image frame and store all the data related to it 
class FrameManager:
    def __init__(self):
        self.test_frames = []
        self.train_frames = []
        self.val_frames = []
        self.cam_angle = 0
        self.f = None
        self.H = None
        self.W = None
    
    def read_frames(self, path):
        cfgs = ['transforms_test.json', 'transforms_train.json', 'transforms_val.json']
        data_cfg = {}
        for i, cfg in enumerate(cfgs):
            with open(os.path.join(path, cfg)) as json_file:
                data_cfg[i] = json.load(json_file)
        
        for i in range(3):
            data = data_cfg[i]
            frms = []
            for frame in data["frames"]:
                img_file = os.path.join(basedir, frame['file_path'] + '.png')
                self.cam_angle = data_cfg[0]['camera_angle_x']
                img = imageio.imread(img_file)
                self.H, self.W = img.shape[0], img.shape[1]                  
                '''It's basic geometry: you have a right angle triangle, with half the FOV as one of the angles (a), and half your image size as the opposite side (A). To calculate the focal length F, use tan(a) = A/F,
which gives F = A/tan(a)'''
                self.f = (0.5 * self.W)/(np.tan(0.5 * self.cam_angle))
                new_frame = Frame(img, np.array(frame['transform_matrix'], self.f))
                frms.append(new_frame)
            if i == 0:
                self.test_frames = frms
            elif i == 1:
                self.train_frames = frms
            else:
                self.val_frames = frms


class Frame:
    def __init__(self, image, pose, f):
        self.img = image
        self.pose = pose
        self.H, self.W = image.shape[0], image.shape[1]
        self.f = f
        self.samples = None
        self.rays_o = None
        self.rays_d = None
        self.depth_values = None

    def copy_to_device(self, device):
        self.img = torch.from_numpy(self.img).to(device)
        self.pose = torch.from_numpy(self.pose).to(device)
    # function to get the rays from the image through every pixel of the Camera (Using Pytorch) on GPU
    # Assuming a pinhole camera model
    def get_rays(self, device):
        self.copy_to_device(device)

        i, j = torch.meshgrid(torch.arange(self.H).to(device), torch.arange(self.W).to(device), indexing='ij')
        i, j = i.transpose(1, 0), j.transpose(1, 0)
        dirs = torch.stack([(i-self.W*0.5)/self.f, -(j-self.H*0.5)/self.f, -torch.ones_like(i)], -1)

        rays_d = torch.sum(dirs[..., None, :] * self.pose[:3, :3], -1)
        rays_o = torch.broadcast_to(self.pose[:3, -1], rays_d.shape)
        self.rays_o = rays_o
        self.rays_d = rays_d
        
        return rays_o, rays_d
        
        

        


IndentationError: expected an indented block (3026848341.py, line 13)

In [None]:
# takes a list of frames and samples all the rays for the frames uniformly and returns the rays and the depth values
# could make it a method of the FrameManager class
def sample_frame(frames, num_samples, near, far, dev = 'cuda:0'):
    
    sample_space = torch.linspace(0., 1., num_samples, device=dev)
    depth = near*(1.-sample_space) + far*sample_space
    mid_depth = (depth[1:] + depth[:-1])/2
    rand_sampling = torch.rand([num_samples], device=dev)
    upper_sample = torch.cat([mid_depth, depth[-1:]], dim=-1)
    lower_sample = torch.cat([depth[:1], mid_depth], dim=-1)
    depth_value = lower_sample + rand_sampling * (upper_sample - lower_sample)
    depth_value = depth_value.expand(list(frames.rays_o.shape[:-1]) +[num_samples])
    #pts are the points on the ray in the format (width, height, n_samples, 3)
    pts = frames.rays_o[..., None, :] + frames.rays_d[..., None, :] * depth_value[..., None]
    frames.samples = pts
    frames.depth_values = depth_value
    return pts, depth_value

    

