## CIS 580, Machine Perception, Spring 2023
### Homework 5
#### Due: Thursday April 27th 2023, 11:59pm ET

Instructions: Create a folder in your Google Drive and place inside this .ipynb file. Open the jupyter notebook with Google Colab. Refrain from using a GPU during implementing and testing the whole thing. You should switch to a GPU runtime only when performing the final training (of the 2D image or the NeRF) to avoid GPU usage runouts.

### Part 1: Fitting a 2D Image

In [None]:
import numpy as np
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import imageio.v2 as imageio
import time
import gdown

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

We first download the image from the web. We normalize the image so the pixels are in between the range of [0,1].

In [None]:
url = "https://drive.google.com/file/d/1-Cugk9WiFX2CPjWG5taX3868Gdd0PEVT/view?usp=share_link"
gdown.download(url=url, output='starry_night.jpg', quiet=False, fuzzy=True)

# Load painting image
painting = imageio.imread("starry_night.jpg")
painting = torch.from_numpy(np.array(painting, dtype=np.float32)/255.).to(device)
height_painting, width_painting = painting.shape[:2]

In [None]:
def positional_encoding(x, num_frequencies=6, incl_input=True):
    """
    Apply positional encoding to the input.

    Args:
    x (torch.Tensor): Input tensor to be positionally encoded.
      The dimension of x is [N, D], where N is the number of input coordinates,
      and D is the dimension of the input coordinate.
    num_frequencies (optional, int): The number of frequencies used in
     the positional encoding (default: 6).
    incl_input (optional, bool): If True, concatenate the input with the
        computed positional encoding (default: True).

    Returns:
    (torch.Tensor): Positional encoding of the input tensor.
    """

    results = []
    D = x.shape[-1]
    if incl_input:
        results.append(x)
    #############################  TODO 1(a) BEGIN  ############################
    # encode input tensor and append the encoded tensor to the list of results.
    for i in range(num_frequencies):
        # for j in range(D):
        sin = torch.sin((2 ** i) * torch.tensor(np.pi) * x)
        cos = torch.cos((2 ** i) * torch.tensor(np.pi) * x)

        results.append(sin)
        results.append(cos)

    #############################  TODO 1(a) END  ##############################
    return torch.cat(results, dim=-1)


class model_2d(nn.Module):

    """
    Define a 2D model comprising of three fully connected layers,
    two relu activations and one sigmoid activation.
    """

    def __init__(self, filter_size=128, num_frequencies=6):
        super().__init__()
        #############################  TODO 1(b) BEGIN  ############################
        input_dimension = 2 + 2 * num_frequencies * 2
        # Output_dimension = 3
        self.layer1 = nn.Linear(input_dimension, filter_size)
        self.layer2 = nn.Linear(filter_size, filter_size)
        self.layer3 = nn.Linear(filter_size, 3)
        #
        # self.Relu = nn.ReLU()
        # self.sigmoid = nn.Sigmoid()

        #############################  TODO 1(b) END  ##############################

    def forward(self, x):
        #############################  TODO 1(b) BEGIN  ############################

        print("Input shape: ", x.shape)
        x = x.view(x.shape[0], -1)
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = F.sigmoid(self.layer3(x))

        #############################  TODO 1(b) END  ##############################
        return x


def train_2d_model(test_img, num_frequencies, device, model=model_2d, positional_encoding=positional_encoding,
                   show=True):
    # Optimizer parameters
    lr = 5e-4
    iterations = 10000
    height, width = test_img.shape[:2]
    # print(test_img.shape)

    # Number of iters after which stats are displayed
    display = 2000

    # Define the model and initialize its weights.
    model2d = model(num_frequencies=num_frequencies)
    model2d.to(device)

    def weights_init(m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight)

    model2d.apply(weights_init)

    #############################  TODO 1(c) BEGIN  ############################
    # Define the optimizer
    optimizer = optim.Adam(model2d.parameters(), lr=lr)

    #############################  TODO 1(c) END  ############################

    # Seed RNG, for repeatability
    seed = 5670
    torch.manual_seed(seed)
    np.random.seed(seed)

    # Lists to log metrics etc.
    psnrs = []
    iternums = []

    t = time.time()
    t0 = time.time()

    #############################  TODO 1(c) BEGIN  ############################
    # Create the 2D normalized coordinates, and apply positional encoding to them

    X = torch.zeros((width_painting*height_painting, 2))
    w = torch.linspace(0, 1, width_painting)
    h = torch.linspace(0, 1, height_painting)

    for i in range(height_painting):
        for j in range(width_painting):
           x_coor = i*width_painting + j
           X[x_coor] = torch.tensor([w[j], h[i]])

    # vector_2d = torch.tensor(X, dtype=torch.float)

    # print(vector_2d.shape)

    PE_coordinates = positional_encoding(X, num_frequencies=6, incl_input=True)
    PE_coordinates = PE_coordinates.to(device)
    print(PE_coordinates.shape)

    # pred = model2d.forward(PE_coordinates).view(height, width, 3)
    # print(pred.shape)

    #############################  TODO 1(c) END  ############################

    for i in range(iterations + 1):
        optimizer.zero_grad()
        #############################  TODO 1(c) BEGIN  ############################
        # Run one iteration
        pred = model2d.forward(PE_coordinates).view(height_painting, width_painting, 3)

        print(pred.shape)
        print("Testimage",test_img.shape)
        # print(pred.shape)

        # Compute mean-squared error between the predicted and target images. Backprop!

        loss = F.mse_loss(pred,test_img)

        loss.backward()
        optimizer.step()

        #############################  TODO 1(c) END  ############################

        # Display images/plots/stats
        if i % display == 0 and show:
            #############################  TODO 1(c) BEGIN  ############################
            # Calculate psnr
            psnr = 10 * torch.log10(1/loss)
            #############################  TODO 1(c) END  ############################

            print("Iteration %d " % i, "Loss: %.4f " % loss.item(), "PSNR: %.2f" % psnr.item(), \
                  "Time: %.2f secs per iter" % ((time.time() - t) / display), "%.2f secs in total" % (time.time() - t0))
            t = time.time()

            psnrs.append(psnr.item())
            iternums.append(i)

            plt.figure(figsize=(13, 4))
            plt.subplot(131)
            plt.imshow(pred.detach().cpu().numpy())
            plt.title(f"Iteration {i}")
            plt.subplot(132)
            plt.imshow(test_img.cpu().numpy())
            plt.title("Target image")
            plt.subplot(133)
            plt.plot(iternums, psnrs)
            plt.title("PSNR")
            plt.show()

    print('Done!')
    return pred.detach().cpu()

In [None]:
_ = train_2d_model(test_img=painting, num_frequencies=6, device=device)

1.1 Complete the function positional_encoding()

1.2 Complete the class model_2d() that will be used to fit the 2D image.


You need to complete 1.1 and 1.2 first before completing the train_2d_model function. Don't forget to transfer the completed functions from 1.1 and 1.2 to the part1.py file and upload it to the autograder.

Fill the gaps in the train_2d_model() function to train the model to fit the 2D image.

Train the model to fit the given image without applying positional encoding to the input, and by applying positional encoding of two different frequencies to the input; L = 2 and L = 6.

### Part 2: Fitting a 3D Image

In [None]:
import os
import gdown
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.nn.functional as F
import time

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
url = "https://drive.google.com/file/d/15W2EK8LooxTMfD0v5vo2BnBMse5ZzlVj/view?usp=share_link"
gdown.download(url=url, output='lego_data.npz', quiet=False, fuzzy=True)

Here, we load the data that is comprised by the images, the R and T matrices of each camera position with respect to the world coordinates and the intrinsics parameters K of the camera.

In [None]:
# Load input images, poses, and intrinsics
data = np.load("lego_data.npz")

# Images
images = data["images"]

# Height and width of each image
height, width = images.shape[1:3]

# Camera extrinsics (poses)
poses = data["poses"]
poses = torch.from_numpy(poses).to(device)

# Camera intrinsics
intrinsics = data["intrinsics"]
intrinsics = torch.from_numpy(intrinsics).to(device)

# Hold one image out (for test).
test_image, test_pose = images[101], poses[101]
test_image = torch.from_numpy(test_image).to(device)

# Map images to device
images = torch.from_numpy(images[:100, ..., :3]).to(device)

plt.imshow(test_image.detach().cpu().numpy())
plt.show()

2.1 Complete the following function that calculates the rays that pass through all the pixels of an HxW image

In [None]:

def get_rays(height, width, intrinsics, Rcw, Tcw):  # Rwc, Twc

    """
    Compute the origin and direction of rays passing through all pixels of an image (one ray per pixel).

    Args:
    height: the height of an image.
    width: the width of an image.
    intrinsics: camera intrinsics matrix of shape (3, 3).
    Rcw: Rotation matrix of shape (3,3) from camera to world coordinates.
    Tcw: Translation vector of shape (3,1) that transforms

    Returns:
    ray_origins (torch.Tensor): A tensor of shape (height, width, 3) denoting the centers of
      each ray. Note that desipte that all ray share the same origin, here we ask you to return
      the ray origin for each ray as (height, width, 3).
    ray_directions (torch.Tensor): A tensor of shape (height, width, 3) denoting the
      direction of each ray.
    """

    device = intrinsics.device
    ray_directions = torch.zeros((height, width, 3), device=device)  # placeholder
    ray_origins = torch.zeros((height, width, 3), device=device)  # placeholder

    #############################  TODO 2.1 BEGIN  ##########################
    k_inverse = torch.inverse(intrinsics)
    u, v = torch.meshgrid(torch.arange(width, device=device), torch.arange(height, device=device))
    homogenous_pixels = torch.stack([v.float(), u.float(), torch.ones_like(u)], dim=-1)
    print(k_inverse.shape, homogenous_pixels.shape)

    ray_dir = torch.einsum('ij,zyj->zyi', k_inverse, homogenous_pixels)
    print(Rcw.shape, ray_dir.shape)
    ray_directions = torch.einsum('ij,zyj->zyi', Rcw, ray_dir)

    ray_origins = ray_origins + Tcw

    #############################  TODO 2.1 END  ############################
    return ray_origins, ray_directions


def stratified_sampling(ray_origins, ray_directions, near, far, samples):
    """
    Sample 3D points on the given rays. The near and far variables indicate the bounds of sampling range.

    Args:
    ray_origins: Origin of each ray in the "bundle" as returned by the
      get_rays() function. Shape: (height, width, 3).
    ray_directions: Direction of each ray in the "bundle" as returned by the
      get_rays() function. Shape: (height, width, 3).
    near: The 'near' extent of the bounding volume.
    far:  The 'far' extent of the bounding volume.
    samples: Number of samples to be drawn along each ray.

    Returns:
    ray_points: Query 3D points along each ray. Shape: (height, width, samples, 3).
    depth_points: Sampled depth values along each ray. Shape: (height, width, samples).
    """

    #############################  TODO 2.2 BEGIN  ############################

    height, width, _ = ray_origins.shape

    # Create a tensor with shape (height, width, samples)
    i = torch.arange(samples, dtype=torch.float32).view(1, 1, samples)

    # Calculate depth points (ti) using the given formula
    depth_points = near + (i - 1) * (far - near) / samples

    # Tile depth_points to match the shape (height, width, samples)
    depth_points = depth_points.repeat(height, width, 1)
    depth_points = depth_points.to(ray_origins.device)


    # Calculate the sampled 3D points along each ray (ray_points)
    t_expanded = depth_points.unsqueeze(-1)
    ray_points = ray_origins.unsqueeze(2) + ray_directions.unsqueeze(2) * t_expanded

    #############################  TODO 2.2 END  ############################
    return ray_points, depth_points


class nerf_model(nn.Module):

    def __init__(self, filter_size=256, num_x_frequencies=6, num_d_frequencies=3):
        super().__init__()

        self.input_layer = nn.Linear(3 * num_x_frequencies * 2 + 3, filter_size)
        self.layers2 = nn.Linear(filter_size, filter_size)
        self.layers3 = nn.Linear(filter_size, filter_size)
        self.layers4 = nn.Linear(filter_size, filter_size)
        self.layers5 = nn.Linear(filter_size, filter_size)
        self.layers6 = nn.Linear(filter_size + 3 * num_x_frequencies * 2 + 3, filter_size)
        self.layers7 = nn.Linear(filter_size, filter_size)
        self.layers8 = nn.Linear(filter_size, filter_size)
        self.sigma_layer = nn.Linear(filter_size, 1)
        self.feature_layer = nn.Linear(filter_size, filter_size)
        self.direction_layer = nn.Linear(filter_size + 3 * num_d_frequencies * 2 + 3, filter_size // 2)
        self.Output_layer = nn.Linear(filter_size // 2, 3)

        #############################  TODO 2.3 END  ############################

    def forward(self, x, d):
        #############################  TODO 2.3 BEGIN  ############################

        X = F.relu(self.input_layer(x))
        x1 = F.relu(self.layers2(X))
        x2 = F.relu(self.layers3(x1))
        x3 = F.relu(self.layers4(x2))
        x4 = F.relu(self.layers5(x3))

        x5 = F.relu(self.layers6(torch.cat([x4, x], dim=-1)))
        x6 = F.relu(self.layers7(x5))
        x7 = F.relu(self.layers8(x6))

        sigma = self.sigma_layer(x7)

        x11 = F.relu(self.direction_layer(torch.cat([self.feature_layer(x7), d], dim=-1)))
        rgb = torch.sigmoid(self.Output_layer(x11))

        return rgb, sigma


def get_batches(ray_points, ray_directions, num_x_frequencies, num_d_frequencies):
    def positional_encoding(x, num_frequencies=6, incl_input=True):
        results = []
        if incl_input:
            results.append(x)
        for i in range(num_frequencies):
            sin = torch.sin((2 ** i) * torch.pi * x)
            cos = torch.cos((2 ** i) * torch.pi * x)

            results.append(sin)
            results.append(cos)
        return torch.cat(results, dim=-1)

    def get_chunks(inputs, chunksize=2 ** 15):
        return [inputs[i:i + chunksize] for i in range(0, inputs.shape[0], chunksize)]

    def normalize(v):
        return v / torch.norm(v, dim=-1, keepdim=True)

    # Normalize ray_directions
    ray_directions_normalized = normalize(ray_directions)

    # Repeat ray_directions_normalized for each sample along the ray
    nsamples = ray_points.shape[2]
    ray_directions_normalized = ray_directions_normalized.unsqueeze(2).repeat(1, 1, nsamples, 1)

    # Flatten the ray_points and ray_directions_normalized tensors
    ray_points_flattened = ray_points.view(-1, 3)
    ray_directions_flattened = ray_directions_normalized.view(-1, 3)

    # Apply positional encoding
    ray_points_encoded = positional_encoding(ray_points_flattened, num_x_frequencies)
    ray_directions_encoded = positional_encoding(ray_directions_flattened, num_d_frequencies)

    # Call get_chunks() for the encoded ray_points and ray_directions
    ray_points_batches = get_chunks(ray_points_encoded)
    ray_directions_batches = get_chunks(ray_directions_encoded)

    return ray_points_batches, ray_directions_batches


def volumetric_rendering(rgb, sigma, depth_points):
    """
    Differentiably renders a radiance field, given the origin of each ray in the
    "bundle", and the sampled depth values along them.

    Args:
    rgb: RGB color at each query location (X, Y, Z). Shape: (height, width, samples, 3).
    sigma: Volume density at each query location (X, Y, Z). Shape: (height, width, samples).
    depth_points: Sampled depth values along each ray. Shape: (height, width, samples).

    Returns:
    rec_image: The reconstructed image after applying the volumetric rendering to every pixel.
    Shape: (height, width, 3)
    """
    device = rgb.device

    # Calculate delta
    Delta = torch.ones_like(depth_points).to(device) * 1e9
    Delta[..., :-1] = torch.diff(depth_points, dim=-1)

    # Calculate T
    Ti = torch.cumprod(torch.exp(-(F.relu(sigma)) * (Delta.reshape_as(sigma))), dim=-1)
    T_i = torch.roll(Ti, shifts=1, dims=-1)

    # Calculate A and B
    A = T_i * (1 - torch.exp(-(F.relu(sigma)) * (Delta.reshape_as(sigma))))

    # Calculate the reconstructed image
    rec_image = (A[..., None] * rgb).sum(dim=-2)

    return rec_image

def one_forward_pass(height, width, intrinsics, pose, near, far, samples, model, num_x_frequencies, num_d_frequencies):
    #############################  TODO 2.5 BEGIN  ############################

    # compute all the rays from the image

    # sample the points from the rays

    # divide data into batches to avoid memory errors

    # forward pass the batches and concatenate the outputs at the end

    # Apply volumetric rendering to obtain the reconstructed image

    R_cw=pose[:3, :3].to(device)
    T_cw=pose[:3, -1].to(device)

    ro,rd=get_rays(height, width, intrinsics, R_cw, T_cw)
    ro=ro.to(device)
    rd=rd.to(device)
    rp, dp=stratified_sampling(ro,rd,near, far, samples)
    rpb, rdb=get_batches(rp,rd, num_x_frequencies, num_d_frequencies)

    all_rgb=[]
    all_sigma=[]

    for i in range(len(rpb)):
        rp1=rpb[i].float()
        rd1=rdb[i].float()
        rgbi, sigmai=model(rp1, rd1)
        all_rgb.append(rgbi)
        all_sigma.append(sigmai)

    rgb=torch.concat(all_rgb).reshape((height, width, samples,3))
    sigma=torch.concat(all_sigma).reshape((height, width, samples))


    # Apply volumetric rendering to obtain the reconstructed image
    rec_image=volumetric_rendering(rgb, sigma, dp)

    #############################  TODO 2.5 END  ############################

    return rec_image


Complete the next function to visualize how is the dataset created. You will be able to see from which point of view each image has been captured for the 3D object. What we want to achieve here, is to being able to interpolate between these given views and synthesize new realistic views of the 3D object.

2.2 Complete the following function to implement the sampling of points along a given ray.

2.3 Define the network architecture of NeRF along with a function that divided data into chunks to avoid memory leaks during training.

2.4 Compute the compositing weights of samples on camera ray and then complete the volumetric rendering procedure to reconstruct a whole RGB image from the sampled points and the outputs of the neural network.

2.5 Combine everything together. Given the pose position of a camera, compute the camera rays and sample the 3D points along these rays. Divide those points into batches and feed them to the neural network. Concatenate them and use them for the volumetric rendering to reconstructed the final image.

If you manage to pass the autograder for all the previous functions, then it is time to train a NeRF! We provide the hyperparameters for you, we initialize the NeRF model and its weights, and we define a couple lists that will be needed to store results.

In [None]:
num_x_frequencies = 10
num_d_frequencies = 4
learning_rate  = 5e-4
iterations = 3000
samples = 64
display = 25
near = 0.667
far = 2

model = nerf_model(num_x_frequencies=num_x_frequencies,num_d_frequencies=num_d_frequencies).to(device)

def weights_init(m):
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
model.apply(weights_init)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

psnrs = []
iternums = []

t = time.time()
t0 = time.time()

In [None]:
for i in range(iterations+1):

    #############################  TODO 2.6 BEGIN  ############################
    #choose randomly a picture for the forward pass
    # torch.cuda.empty_cache()
    idx=torch.randint(low=0,high=100,size=(1,))
    pose=poses[idx].squeeze(0).float()
    img=images[idx].squeeze(0).float()
    height, width=img.shape[:2]
    test_recw=one_forward_pass(height, width, intrinsics, pose, near, far, samples, model, num_x_frequencies, num_d_frequencies)
    loss = F.mse_loss(test_recw, img)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    #############################  TODO 2.6 END  ############################

    # Display images/plots/stats
    if i % display == 0:
        with torch.no_grad():
        #############################  TODO 2.6 BEGIN  ############################
            # Render the held-out view
            # R=pred.max()
            # R_MSE=(R**2)/loss
            # psnr=10 * torch.log10(R_MSE)
            test_rec_image=one_forward_pass(height, width, intrinsics, test_pose, near, far, samples, model, num_x_frequencies, num_d_frequencies)
            test_loss=F.mse_loss(test_rec_image, test_image)
            psnr=10*torch.log10(1/test_loss)

        #calculate the loss and the psnr between the original test image and the reconstructed one.


        #############################  TODO 2.6 END  ############################

        print("Iteration %d " % i, "Loss: %.4f " % loss.item(), "PSNR: %.2f " % psnr.item(), \
                "Time: %.2f secs per iter, " % ((time.time() - t) / display), "%.2f mins in total" % ((time.time() - t0)/60))

        t = time.time()
        psnrs.append(psnr.item())
        iternums.append(i)

        plt.figure(figsize=(16, 4))
        plt.subplot(141)
        plt.imshow(test_rec_image.detach().cpu().numpy())
        plt.title(f"Iteration {i}")
        plt.subplot(142)
        plt.imshow(test_image.detach().cpu().numpy())
        plt.title("Target image")
        plt.subplot(143)
        plt.plot(iternums, psnrs)
        plt.title("PSNR")
        plt.show()

print('Done!')