In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import skimage.io as skio
import torch.optim as optim
import skimage as sk
from tqdm import tqdm
import matplotlib.pyplot as plt

In [2]:
cuda_available = torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
data = np.load(f"lego_200x200.npz")
images_train = data["images_train"] / 255.0
c2ws_train = data["c2ws_train"]
images_val = data["images_val"] / 255.0
c2ws_val = data["c2ws_val"]
c2ws_test = data["c2ws_test"]
focal = data["focal"]

In [4]:
height = 200
width = 200
n_samples = 32

In [5]:
K = np.array([[focal,0,width/2],[0,focal,height/2],[0,0,1]])

In [6]:
def transform(c2w, x_c):
    #camera to world
    num_rows = len(x_c)
    ones_column = np.ones((num_rows, 1))
    x_c_with_one = np.concatenate((x_c, ones_column), axis=1)
    x = (c2w @ x_c_with_one.T).T
    return x[:,:-1]

In [7]:
def pixel_to_camera(K, uv,s):
    num_rows = len(uv)
    ones_column = np.ones((num_rows, 1))
    uv_with_one = np.concatenate((uv, ones_column), axis=1)
    result = (np.linalg.inv(K) @ uv_with_one.T).T
    return result

In [8]:
def pixel_to_ray(K, c2w, uv):
    zeros = np.array([[0,0,0]])
    origin = transform(c2w, zeros)
    depth_1_points = pixel_to_camera(K, uv, 1)
    world_depth_1_points = transform(c2w, depth_1_points)
    world_depth_1_points_direction = world_depth_1_points - origin
    norms = np.linalg.norm(world_depth_1_points_direction, axis=1, keepdims=True)
    directions = world_depth_1_points_direction/ norms

    return origin, directions

In [9]:
class RaysData(Dataset):
    def __init__(self, img_train, K, c2ws_train):
        self.img = img_train
        self.c2ws = c2ws_train
        self.K = K
        self.height = 200
        self.width = 200
        self.length = len(self.img) * self.height * self.width

    def __len__(self):
        return len(self.img) * self.height * self.width
    
    def __getitem__(self, idx):
        img_index = idx // (self.width*self.height)
        residual = idx % (self.width*self.height)
        temp_height = residual // self.height 
        temp_width = residual % self.width 
        c2w = self.c2ws[img_index]
        uv = np.array([[temp_height+0.5, temp_width+0.5]])
        ray_o, ray_d = pixel_to_ray(self.K, c2w, uv)
        pixel = self.img[img_index,temp_height,temp_width,:]
        sample = {'rays_o':ray_o[0],
                  "rays_d":ray_d[0],
                 "pixels":pixel}
        return sample

    def sample_rays(self, num_samples):
        rays_o = []
        rays_d = []
        pixels = []
        # random_numbers = [random.randint(0,self.length -1) for _ in range(num_samples)]
        random_numbers = np.random.randint(0, self.length - 1, size=num_samples) 
        for random_number in random_numbers:
            img_index = random_number // (self.width*self.height)
            residual = random_number % (self.width*self.height)
            temp_height = residual // self.height 
            temp_width = residual % self.width 
            c2w = self.c2ws[img_index]
            uv = np.array([[temp_height+0.5, temp_width+0.5]])
            ray_o, ray_d = pixel_to_ray(self.K, c2w, uv)
            rays_o.append(ray_o[0])
            rays_d.append(ray_d[0])
            pixels.append(self.img[img_index,temp_height,temp_width,:])
        return rays_o, rays_d,pixels

In [10]:
def sample_along_rays(rays_o, rays_d, n_samples=32, perturb=True):
    far = 6 
    near = 2
    
    t_values = torch.linspace(near, far, n_samples).to(device)
    ran_values = (torch.rand((len(rays_o), n_samples)) * (far - near) / n_samples).to(device)

    # Create 3D grid for rays_o and rays_d
    ray_o_grid = rays_o[:, None, :].repeat(1, n_samples, 1)
    ray_d_grid = rays_d[:, None, :].repeat(1, n_samples, 1)

    # Compute points without explicit loops
    p_t = t_values + ran_values
    p_t = p_t.to(device)
    points = ray_o_grid + ray_d_grid * p_t.unsqueeze(2)

    return points.view(-1, 3)

In [11]:
def volrend(sigmas, rgbs, step_size):
    sigmas = sigmas.to(device)
    rgbs = rgbs.to(device)
    size_to_prepend = (sigmas.size(0), 1, 1)

    zeros_to_prepend = torch.zeros(size_to_prepend, dtype=sigmas.dtype).to(device)
    
    tensor_with_zeros = torch.cat((zeros_to_prepend, sigmas), dim=1).to(device)

    
    
    cum_sigmas = torch.cumsum(tensor_with_zeros,dim=1)[:,:-1].to(device)
    T = torch.exp(-cum_sigmas*step_size).to(device)
    interval_sigmas = 1 - torch.exp(-sigmas*step_size).to(device)
    weights = T * interval_sigmas
    colors = rgbs * weights
    cum_colors = torch.sum(colors, dim=1).to(device)

    return cum_colors

In [39]:
class Residual_block(nn.Module):
    def __init__(self, dim):
        super(Residual_block, self).__init__()
        self.dim = dim
        self.linear_1 = nn.Linear(dim, dim)
        self.layer_norm_1 = nn.LayerNorm(normalized_shape=dim)
        self.relu = nn.ReLU()
        self.linear_2 = nn.Linear(dim, dim)
        self.layer_norm_2 = nn.LayerNorm(normalized_shape=dim)
        nn.init.kaiming_normal_(self.linear_1.weight, mode='fan_in', nonlinearity='relu')
        nn.init.kaiming_normal_(self.linear_2.weight, mode='fan_in', nonlinearity='relu')

    def forward(self,x):
        origin_x = x
        x = self.linear_1(x)
        x = self.layer_norm_1(x)
        x = self.relu(x)
        x = self.linear_2(x)
        x = self.layer_norm_2(x)
        x = self.relu(x + origin_x)
        return x

In [49]:
class Nerf_model(nn.Module):
    def __init__(self,high_fre_level, high_fre_level_angle, hidden_dim):
        super(Nerf_model, self).__init__()
        self.high_fre_level = high_fre_level
        self.high_fre_level_angle = high_fre_level_angle
        self.pe_dim = 3+high_fre_level*6
        self.pe_dim_angle = 3 + 6 * high_fre_level_angle
        self.input_layer = nn.Linear(3+high_fre_level*6, hidden_dim)
        self.residual_block_1 = Residual_block(hidden_dim)
        self.residual_block_2 = Residual_block(hidden_dim)
        self.residual_block_3 = Residual_block(hidden_dim)
        
        self.hidden_layer_1 = nn.Linear(hidden_dim, hidden_dim)
        self.hidden_layer_2 = nn.Linear(hidden_dim, hidden_dim)
        
        self.hidden_layer_concat_angle = nn.Linear(hidden_dim + self.pe_dim_angle, hidden_dim//2)

        nn.init.kaiming_normal_(self.hidden_layer_concat_angle.weight, mode='fan_in', nonlinearity='relu')
        
        self.out = nn.Linear(hidden_dim//2, 3)
        nn.init.kaiming_normal_(self.out.weight, mode='fan_in', nonlinearity='relu')
        
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.density_layer = nn.Linear(hidden_dim, 1)

        power_terms_pos = (torch.exp2(torch.arange(0, high_fre_level))*3.14159).to(device)
        
        power_terms_angle = (torch.exp2(torch.arange(0, high_fre_level_angle))*3.14159).to(device)
        
        self.power_matrix_pos= torch.zeros(3, self.high_fre_level *3).to(device)
        for i in range(3):
            self.power_matrix_pos[i,i*high_fre_level:(i+1)*high_fre_level] = power_terms_pos

        self.power_matrix_angle = torch.zeros(3, self.high_fre_level_angle  *3).to(device)
        for i in range(3):
            self.power_matrix_angle[i,i*high_fre_level_angle :(i+1)*high_fre_level_angle ] = power_terms_angle


        self.middle_layer_norm = nn.LayerNorm(normalized_shape=hidden_dim) 
        
        self.last_layer_norm = nn.LayerNorm(normalized_shape=hidden_dim//2) 

    def positional_encoding(self, data, high_fre_level, power_matrix):

        powered_data = data @ power_matrix
        sin_matrix = torch.sin(powered_data).to(device)
        cos_matrix = torch.cos(powered_data).to(device)

        pe = torch.cat((data, sin_matrix, cos_matrix),1).to(device)

        return pe

    def forward(self,pos, angle):
                
        pos_pe = self.positional_encoding(pos, self.high_fre_level, self.power_matrix_pos)
        
        origin_x = pos_pe
        x = self.input_layer(origin_x)
        x = self.residual_block_1(x)
        x = self.residual_block_2(x)
        x = self.residual_block_3(x)
        x = self.hidden_layer_1(x)
        sigmas = self.density_layer(x)
        sigmas = self.relu(sigmas)

        # x = self.middle_layer_norm(x)
        x = self.hidden_layer_2(x)

        angle_pe = self.positional_encoding(angle, self.high_fre_level_angle, self.power_matrix_angle)
        # angle_input = torch.cat((angle,angle_pe),dim=1)
        concated_x = torch.cat((x,angle_pe), dim = 1).float()
        x = self.hidden_layer_concat_angle(concated_x)
        x = self.last_layer_norm(x)
        x = self.relu(x)
        x = self.out(x)
        x = self.sigmoid(x)
        return x, sigmas

In [50]:
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [14]:
# class Nerf_model(nn.Module):
#     def __init__(self,high_fre_level, high_fre_level_angle, hidden_dim):
#         super(Nerf_model, self).__init__()
#         self.high_fre_level = high_fre_level
#         self.high_fre_level_angle = high_fre_level_angle
#         self.pe_dim = 3+high_fre_level*6
#         self.pe_dim_angle = 3 + 6 * high_fre_level_angle
#         self.input_layer = nn.Linear(3+high_fre_level*6, hidden_dim)
#         # self.input_layer = nn.Linear(2, hidden_dim)
#         hidden_layer_list = []


        
#         for i in range(3):
#             hidden_layer_list.append(nn.Linear(hidden_dim, hidden_dim))
#             hidden_layer_list.append(nn.ReLU())
#         self.hidden_layer_1 = nn.Sequential(*hidden_layer_list)

#         self.concat_hidden_layer = nn.Linear(hidden_dim + self.pe_dim,hidden_dim)
        
#         hidden_layer_list = []
#         for i in range(2):
#             hidden_layer_list.append(nn.Linear(hidden_dim, hidden_dim))
#             hidden_layer_list.append(nn.ReLU())
#         self.hidden_layer_2 = nn.Sequential(*hidden_layer_list)

#         self.hidden_layer_3 = nn.Linear(hidden_dim, hidden_dim)
#         self.hidden_layer_4 = nn.Linear(hidden_dim, hidden_dim)
#         self.hidden_layer_5 = nn.Linear(hidden_dim, hidden_dim)
#         self.hidden_layer_concat_angle = nn.Linear(hidden_dim + self.pe_dim_angle, hidden_dim//2)

#         self.out = nn.Linear(hidden_dim//2, 3)
#         self.relu = nn.ReLU()
#         self.sigmoid = nn.Sigmoid()
#         self.density_layer = nn.Linear(hidden_dim, 1)
        
#         power_terms_pos = (torch.exp2(torch.arange(0, high_fre_level))*3.14159).to(device)
        
#         power_terms_angle = (torch.exp2(torch.arange(0, high_fre_level_angle))*3.14159).to(device)
        
#         self.power_matrix_pos= torch.zeros(3, self.high_fre_level *3).to(device)
#         for i in range(3):
#             self.power_matrix_pos[i,i*high_fre_level:(i+1)*high_fre_level] = power_terms_pos

#         self.power_matrix_angle = torch.zeros(3, self.high_fre_level_angle  *3).to(device)
#         for i in range(3):
#             self.power_matrix_angle[i,i*high_fre_level_angle :(i+1)*high_fre_level_angle ] = power_terms_angle

#         # self.layer_norm1 = nn.LayerNorm(normalized_shape=HIDDEN_UNITS)  
#         # self._initialize_weights()
#         self.last_layer_norm = nn.LayerNorm(normalized_shape=hidden_dim//2) 
        

#     def _initialize_weights(self):
#         for m in self.modules():
#             if isinstance(m, nn.Linear):
#                 nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
#                 if m.bias is not None:
#                     nn.init.zeros_(m.bias)


    


#     def positional_encoding(self, data, high_fre_level, power_matrix):

#         powered_data = data @ power_matrix
#         sin_matrix = torch.sin(powered_data).to(device)
#         cos_matrix = torch.cos(powered_data).to(device)

#         pe = torch.cat((data, sin_matrix, cos_matrix),1).to(device)

#         return pe




    
#     def forward_phase_1(self, origin_x):

#         x = self.input_layer(origin_x)
#         x = self.relu(x)
#         x = self.hidden_layer_1(x)
#         x = torch.cat((x,origin_x), dim = 1)
#         x = self.concat_hidden_layer(x)
#         x = self.relu(x)
#         x = self.hidden_layer_2(x)
#         x = self.hidden_layer_3(x)
#         return x


        
#     def forward(self, pos, angle):
        
#         pos_pe = self.positional_encoding(pos, self.high_fre_level, self.power_matrix_pos)
        
#         origin_x = pos_pe
#         x = self.forward_phase_1(origin_x)




        
#         sigmas = self.density_layer(x)
#         sigmas = self.relu(sigmas)
        
#         x = self.hidden_layer_4(x)
#         angle_pe = self.positional_encoding(angle, self.high_fre_level_angle, self.power_matrix_angle)
#         # angle_input = torch.cat((angle,angle_pe),dim=1)
#         concated_x = torch.cat((x,angle_pe), dim = 1).float()
#         x = self.hidden_layer_concat_angle(concated_x)
#         x = self.last_layer_norm(x)
#         x = self.relu(x)
#         x = self.out(x)
#         x = self.sigmoid(x)
#         return x, sigmas
        

In [51]:
model = Nerf_model(30,10,256).to(device)
step_size = (6-2)/n_samples

In [52]:
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [54]:
# optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
# criterion = PSNRWithMSELoss()
criterion = torch.nn.MSELoss()
number_epoch  = 2
model.train()
# scheduler = StepLR(optimizer, step_size=70, gamma=0.1)
# for i in range(number_iteration):

dataset = RaysData(images_train, K, c2ws_train)
dataloader = DataLoader(dataset, batch_size=10000,
                        shuffle=True)

for e in range(number_epoch):
    for i_batch, sample_batched in enumerate(tqdm(dataloader)):
        # print(len(sample_batched['pixels']))
        # rays_o, rays_d, pixels = dataset.sample_rays(10000)
        # rays_o, rays_d, pixels = dataset.sample_rays_one(1000)
        # t1 = time.time()
        rays_o = sample_batched['rays_o'].squeeze()
        rays_d = sample_batched['rays_d'].squeeze()
        pixels = sample_batched['pixels']
        rays_o = rays_o.float().to(device)
        rays_d = rays_d.float().to(device)
        # rays_d = np.array(rays_d)
        # rays_o = np.array(rays_o)
        # rays_d = torch.tensor(rays_d).to(device)
        # rays_o = torch.tensor(rays_o).to(device)
        points = sample_along_rays(rays_o, rays_d)
        
    
        points = points.float().to(device)
        # points = np.array(points)
        # rays_d = np.array(rays_d)
        # points = torch.tensor(points).to(device)
        # rays_d = torch.tensor(rays_d).to(device)
    
        rays_d = torch.unsqueeze(rays_d,1)
        rays_d = rays_d.repeat(1,n_samples,1)
        rays_d = rays_d.view(-1,3)
        
        # t2 = time.time()
        # print('p',t2-t1)
        # t1 = t2
        
        rgbs, sigmas = model(points, rays_d)
        
        # t2 = time.time()
        # print('forward',t2-t1)
        # t1 = t2
        
        rgbs = rgbs.to(device)
        sigmas = sigmas.to(device)
        # sigams = model.foward_dentisy()
        sigmas = sigmas.view(-1, n_samples, 1)
        rgbs = rgbs.view(-1, n_samples, 3)
        
    
    
        rendered_colors = volrend(sigmas, rgbs, step_size)

                
        # t2 = time.time()
        # print('render',t2-t1)
        # t1 = t2
        # print("render",rendered_colors[0])
        # print("pixels",pixels[0])
    
        pixels = pixels.float().to(device)
        # pixels = np.array(pixels)
        # pixels = torch.tensor(pixels).float().to(device)
    
        # print(rendered_colors)
        loss = criterion(rendered_colors, pixels)
        
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

                        
        # t2 = time.time()
        # print('backward',t2-t1)
        # t1 = t2
        # scheduler.step()
        if i_batch % 49 == 0:
            print(f'iteration [{i_batch}], Loss: {loss.item()}')
    # scheduler.step()
    # torch.save(model.state_dict(), 'weights_%s.pth'%e)

  0%|                                           | 1/400 [00:00<04:51,  1.37it/s]

iteration [0], Loss: 0.035258326679468155


 12%|█████▎                                    | 50/400 [00:23<03:18,  1.76it/s]

iteration [49], Loss: 0.03232917934656143


 25%|██████████▍                               | 99/400 [00:48<02:43,  1.84it/s]

iteration [98], Loss: 0.03199579939246178


 37%|███████████████▏                         | 148/400 [01:13<02:16,  1.84it/s]

iteration [147], Loss: 0.02739843539893627


 49%|████████████████████▏                    | 197/400 [01:38<01:48,  1.87it/s]

iteration [196], Loss: 0.025908321142196655


 62%|█████████████████████████▏               | 246/400 [02:01<01:20,  1.92it/s]

iteration [245], Loss: 0.0235576331615448


 74%|██████████████████████████████▏          | 295/400 [02:25<00:56,  1.87it/s]

iteration [294], Loss: 0.022329488769173622


 86%|███████████████████████████████████▎     | 344/400 [02:48<00:29,  1.91it/s]

iteration [343], Loss: 0.02249150536954403


 98%|████████████████████████████████████████▎| 393/400 [03:12<00:03,  1.80it/s]

iteration [392], Loss: 0.019733626395463943


100%|█████████████████████████████████████████| 400/400 [03:15<00:00,  2.04it/s]
  0%|                                           | 1/400 [00:00<05:55,  1.12it/s]

iteration [0], Loss: 0.01948956958949566


 12%|█████▎                                    | 50/400 [00:24<03:05,  1.88it/s]

iteration [49], Loss: 0.01855948381125927


 25%|██████████▍                               | 99/400 [00:48<02:44,  1.83it/s]

iteration [98], Loss: 0.018138885498046875


 37%|███████████████▏                         | 148/400 [01:13<02:21,  1.78it/s]

iteration [147], Loss: 0.017201313748955727


 49%|████████████████████▏                    | 197/400 [01:36<01:47,  1.89it/s]

iteration [196], Loss: 0.01672467216849327


 62%|█████████████████████████▏               | 246/400 [02:00<01:26,  1.79it/s]

iteration [245], Loss: 0.016931967809796333


 74%|██████████████████████████████▏          | 295/400 [02:24<00:55,  1.88it/s]

iteration [294], Loss: 0.01626514084637165


 86%|███████████████████████████████████▎     | 344/400 [02:48<00:31,  1.76it/s]

iteration [343], Loss: 0.015142186544835567


 98%|████████████████████████████████████████▎| 393/400 [03:12<00:03,  1.89it/s]

iteration [392], Loss: 0.01447840966284275


100%|█████████████████████████████████████████| 400/400 [03:15<00:00,  2.04it/s]


In [55]:
torch.save(model.state_dict(), 'weights_arc2.pth')

In [None]:
weights_path = 'weights_arc2.pth'  # Provide the correct path to the saved weights file
model.load_state_dict(torch.load(weights_path))