In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import skimage.io as skio
import torch.optim as optim
import skimage as sk
from tqdm import tqdm
import matplotlib.pyplot as plt

In [2]:
cuda_available = torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
data = np.load(f"lego_200x200.npz")
images_train = data["images_train"] / 255.0
c2ws_train = data["c2ws_train"]
images_val = data["images_val"] / 255.0
c2ws_val = data["c2ws_val"]
c2ws_test = data["c2ws_test"]
focal = data["focal"]

In [4]:
height = 200
width = 200
n_samples = 32

In [5]:
K = np.array([[focal,0,width/2],[0,focal,height/2],[0,0,1]])

In [6]:
def transform(c2w, x_c):
    #camera to world
    num_rows = len(x_c)
    ones_column = np.ones((num_rows, 1))
    x_c_with_one = np.concatenate((x_c, ones_column), axis=1)
    x = (c2w @ x_c_with_one.T).T
    return x[:,:-1]

In [7]:
def pixel_to_camera(K, uv,s):
    num_rows = len(uv)
    ones_column = np.ones((num_rows, 1))
    uv_with_one = np.concatenate((uv, ones_column), axis=1)
    result = (np.linalg.inv(K) @ uv_with_one.T).T
    return result

In [8]:
def pixel_to_ray(K, c2w, uv):
    zeros = np.array([[0,0,0]])
    origin = transform(c2w, zeros)
    depth_1_points = pixel_to_camera(K, uv, 1)
    world_depth_1_points = transform(c2w, depth_1_points)
    world_depth_1_points_direction = world_depth_1_points - origin
    norms = np.linalg.norm(world_depth_1_points_direction, axis=1, keepdims=True)
    directions = world_depth_1_points_direction/ norms

    return origin, directions

In [9]:
class RaysData(Dataset):
    def __init__(self, img_train, K, c2ws_train):
        self.img = img_train
        self.c2ws = c2ws_train
        self.K = K
        self.height = 200
        self.width = 200
        self.length = len(self.img) * self.height * self.width

    def __len__(self):
        return len(self.img) * self.height * self.width
    
    def __getitem__(self, idx):
        img_index = idx // (self.width*self.height)
        residual = idx % (self.width*self.height)
        temp_height = residual // self.height 
        temp_width = residual % self.width 
        c2w = self.c2ws[img_index]
        uv = np.array([[temp_height+0.5, temp_width+0.5]])
        ray_o, ray_d = pixel_to_ray(self.K, c2w, uv)
        pixel = self.img[img_index,temp_height,temp_width,:]
        sample = {'rays_o':ray_o[0],
                  "rays_d":ray_d[0],
                 "pixels":pixel}
        return sample

    def sample_rays(self, num_samples):
        rays_o = []
        rays_d = []
        pixels = []
        # random_numbers = [random.randint(0,self.length -1) for _ in range(num_samples)]
        random_numbers = np.random.randint(0, self.length - 1, size=num_samples) 
        for random_number in random_numbers:
            img_index = random_number // (self.width*self.height)
            residual = random_number % (self.width*self.height)
            temp_height = residual // self.height 
            temp_width = residual % self.width 
            c2w = self.c2ws[img_index]
            uv = np.array([[temp_height+0.5, temp_width+0.5]])
            ray_o, ray_d = pixel_to_ray(self.K, c2w, uv)
            rays_o.append(ray_o[0])
            rays_d.append(ray_d[0])
            pixels.append(self.img[img_index,temp_height,temp_width,:])
        return rays_o, rays_d,pixels

In [10]:
def sample_along_rays(rays_o, rays_d, n_samples=32, perturb=True):
    far = 6 
    near = 2
    
    t_values = torch.linspace(near, far, n_samples).to(device)
    ran_values = (torch.rand((len(rays_o), n_samples)) * (far - near) / n_samples).to(device)

    # Create 3D grid for rays_o and rays_d
    ray_o_grid = rays_o[:, None, :].repeat(1, n_samples, 1)
    ray_d_grid = rays_d[:, None, :].repeat(1, n_samples, 1)

    # Compute points without explicit loops
    p_t = t_values + ran_values
    p_t = p_t.to(device)
    points = ray_o_grid + ray_d_grid * p_t.unsqueeze(2)

    return points.view(-1, 3)

In [11]:
def volrend(sigmas, rgbs, step_size):
    sigmas = sigmas.to(device)
    rgbs = rgbs.to(device)
    size_to_prepend = (sigmas.size(0), 1, 1)

    zeros_to_prepend = torch.zeros(size_to_prepend, dtype=sigmas.dtype).to(device)
    
    tensor_with_zeros = torch.cat((zeros_to_prepend, sigmas), dim=1).to(device)

    
    
    cum_sigmas = torch.cumsum(tensor_with_zeros,dim=1)[:,:-1].to(device)
    T = torch.exp(-cum_sigmas*step_size).to(device)
    interval_sigmas = 1 - torch.exp(-sigmas*step_size).to(device)
    weights = T * interval_sigmas
    colors = rgbs * weights
    cum_colors = torch.sum(colors, dim=1).to(device)

    return cum_colors

In [12]:
class Residual_block(nn.Module):
    def __init__(self, dim):
        super(Residual_block, self).__init__()
        self.dim = dim
        self.linear_1 = nn.Linear(dim, dim)
        self.layer_norm_1 = nn.LayerNorm(normalized_shape=dim)
        self.relu = nn.ReLU()
        self.linear_2 = nn.Linear(dim, dim)
        self.layer_norm_2 = nn.LayerNorm(normalized_shape=dim)
        nn.init.kaiming_normal_(self.linear_1.weight, mode='fan_in', nonlinearity='relu')
        nn.init.kaiming_normal_(self.linear_2.weight, mode='fan_in', nonlinearity='relu')

    def forward(self,x):
        origin_x = x
        x = self.linear_1(x)
        x = self.layer_norm_1(x)
        x = self.relu(x)
        x = self.linear_2(x)
        x = self.layer_norm_2(x)
        x = self.relu(x + origin_x)
        return x

In [13]:
class Nerf_model(nn.Module):
    def __init__(self,high_fre_level, high_fre_level_angle, hidden_dim):
        super(Nerf_model, self).__init__()
        self.high_fre_level = high_fre_level
        self.high_fre_level_angle = high_fre_level_angle
        self.pe_dim = 3+high_fre_level*6
        self.pe_dim_angle = 3 + 6 * high_fre_level_angle
        self.input_layer = nn.Linear(3+high_fre_level*6, hidden_dim)
        self.residual_block_1 = Residual_block(hidden_dim)
        self.residual_block_2 = Residual_block(hidden_dim)
        self.residual_block_3 = Residual_block(hidden_dim)
        
        self.hidden_layer_1 = nn.Linear(hidden_dim, hidden_dim)
        self.hidden_layer_2 = nn.Linear(hidden_dim, hidden_dim)
        
        self.hidden_layer_concat_angle = nn.Linear(hidden_dim + self.pe_dim_angle, hidden_dim//2)

        nn.init.kaiming_normal_(self.hidden_layer_concat_angle.weight, mode='fan_in', nonlinearity='relu')
        
        self.out = nn.Linear(hidden_dim//2, 3)
        nn.init.kaiming_normal_(self.out.weight, mode='fan_in', nonlinearity='relu')
        
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.density_layer = nn.Linear(hidden_dim, 1)

        power_terms_pos = (torch.exp2(torch.arange(0, high_fre_level))*3.14159).to(device)
        
        power_terms_angle = (torch.exp2(torch.arange(0, high_fre_level_angle))*3.14159).to(device)
        
        self.power_matrix_pos= torch.zeros(3, self.high_fre_level *3).to(device)
        for i in range(3):
            self.power_matrix_pos[i,i*high_fre_level:(i+1)*high_fre_level] = power_terms_pos

        self.power_matrix_angle = torch.zeros(3, self.high_fre_level_angle  *3).to(device)
        for i in range(3):
            self.power_matrix_angle[i,i*high_fre_level_angle :(i+1)*high_fre_level_angle ] = power_terms_angle


        self.middle_layer_norm = nn.LayerNorm(normalized_shape=hidden_dim) 
        
        self.last_layer_norm = nn.LayerNorm(normalized_shape=hidden_dim//2) 

    def positional_encoding(self, data, high_fre_level, power_matrix):

        powered_data = data @ power_matrix
        sin_matrix = torch.sin(powered_data).to(device)
        cos_matrix = torch.cos(powered_data).to(device)

        pe = torch.cat((data, sin_matrix, cos_matrix),1).to(device)

        return pe

    def forward(self,pos, angle):
                
        pos_pe = self.positional_encoding(pos, self.high_fre_level, self.power_matrix_pos)
        
        origin_x = pos_pe
        x = self.input_layer(origin_x)
        x = self.residual_block_1(x)
        x = self.residual_block_2(x)
        x = self.residual_block_3(x)
        x = self.hidden_layer_1(x)
        sigmas = self.density_layer(x)
        sigmas = self.relu(sigmas)

        # x = self.middle_layer_norm(x)
        x = self.hidden_layer_2(x)

        angle_pe = self.positional_encoding(angle, self.high_fre_level_angle, self.power_matrix_angle)
        # angle_input = torch.cat((angle,angle_pe),dim=1)
        concated_x = torch.cat((x,angle_pe), dim = 1).float()
        x = self.hidden_layer_concat_angle(concated_x)
        x = self.last_layer_norm(x)
        x = self.relu(x)
        x = self.out(x)
        x = self.sigmoid(x)
        return x, sigmas

In [14]:
# class Nerf_model(nn.Module):
#     def __init__(self,high_fre_level, high_fre_level_angle, hidden_dim):
#         super(Nerf_model, self).__init__()
#         self.high_fre_level = high_fre_level
#         self.high_fre_level_angle = high_fre_level_angle
#         self.pe_dim = 3+high_fre_level*6
#         self.pe_dim_angle = 3 + 6 * high_fre_level_angle
#         self.input_layer = nn.Linear(3+high_fre_level*6, hidden_dim)
#         # self.input_layer = nn.Linear(2, hidden_dim)
#         hidden_layer_list = []


        
#         for i in range(3):
#             hidden_layer_list.append(nn.Linear(hidden_dim, hidden_dim))
#             hidden_layer_list.append(nn.ReLU())
#         self.hidden_layer_1 = nn.Sequential(*hidden_layer_list)

#         self.concat_hidden_layer = nn.Linear(hidden_dim + self.pe_dim,hidden_dim)
        
#         hidden_layer_list = []
#         for i in range(2):
#             hidden_layer_list.append(nn.Linear(hidden_dim, hidden_dim))
#             hidden_layer_list.append(nn.ReLU())
#         self.hidden_layer_2 = nn.Sequential(*hidden_layer_list)

#         self.hidden_layer_3 = nn.Linear(hidden_dim, hidden_dim)
#         self.hidden_layer_4 = nn.Linear(hidden_dim, hidden_dim)
#         self.hidden_layer_5 = nn.Linear(hidden_dim, hidden_dim)
#         self.hidden_layer_concat_angle = nn.Linear(hidden_dim + self.pe_dim_angle, hidden_dim//2)

#         self.out = nn.Linear(hidden_dim//2, 3)
#         self.relu = nn.ReLU()
#         self.sigmoid = nn.Sigmoid()
#         self.density_layer = nn.Linear(hidden_dim, 1)
        
#         power_terms_pos = (torch.exp2(torch.arange(0, high_fre_level))*3.14159).to(device)
        
#         power_terms_angle = (torch.exp2(torch.arange(0, high_fre_level_angle))*3.14159).to(device)
        
#         self.power_matrix_pos= torch.zeros(3, self.high_fre_level *3).to(device)
#         for i in range(3):
#             self.power_matrix_pos[i,i*high_fre_level:(i+1)*high_fre_level] = power_terms_pos

#         self.power_matrix_angle = torch.zeros(3, self.high_fre_level_angle  *3).to(device)
#         for i in range(3):
#             self.power_matrix_angle[i,i*high_fre_level_angle :(i+1)*high_fre_level_angle ] = power_terms_angle

#         # self.layer_norm1 = nn.LayerNorm(normalized_shape=HIDDEN_UNITS)  
#         # self._initialize_weights()
#         self.last_layer_norm = nn.LayerNorm(normalized_shape=hidden_dim//2) 
        

#     def _initialize_weights(self):
#         for m in self.modules():
#             if isinstance(m, nn.Linear):
#                 nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
#                 if m.bias is not None:
#                     nn.init.zeros_(m.bias)


    


#     def positional_encoding(self, data, high_fre_level, power_matrix):

#         powered_data = data @ power_matrix
#         sin_matrix = torch.sin(powered_data).to(device)
#         cos_matrix = torch.cos(powered_data).to(device)

#         pe = torch.cat((data, sin_matrix, cos_matrix),1).to(device)

#         return pe




    
#     def forward_phase_1(self, origin_x):

#         x = self.input_layer(origin_x)
#         x = self.relu(x)
#         x = self.hidden_layer_1(x)
#         x = torch.cat((x,origin_x), dim = 1)
#         x = self.concat_hidden_layer(x)
#         x = self.relu(x)
#         x = self.hidden_layer_2(x)
#         x = self.hidden_layer_3(x)
#         return x


        
#     def forward(self, pos, angle):
        
#         pos_pe = self.positional_encoding(pos, self.high_fre_level, self.power_matrix_pos)
        
#         origin_x = pos_pe
#         x = self.forward_phase_1(origin_x)




        
#         sigmas = self.density_layer(x)
#         sigmas = self.relu(sigmas)
        
#         x = self.hidden_layer_4(x)
#         angle_pe = self.positional_encoding(angle, self.high_fre_level_angle, self.power_matrix_angle)
#         # angle_input = torch.cat((angle,angle_pe),dim=1)
#         concated_x = torch.cat((x,angle_pe), dim = 1).float()
#         x = self.hidden_layer_concat_angle(concated_x)
#         x = self.last_layer_norm(x)
#         x = self.relu(x)
#         x = self.out(x)
#         x = self.sigmoid(x)
#         return x, sigmas
        

In [15]:
model = Nerf_model(30,10,256).to(device)
step_size = (6-2)/n_samples

In [19]:
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [23]:
# optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
# criterion = PSNRWithMSELoss()
criterion = torch.nn.MSELoss()
number_epoch  = 40
model.train()
# scheduler = StepLR(optimizer, step_size=70, gamma=0.1)
# for i in range(number_iteration):

dataset = RaysData(images_train, K, c2ws_train)
dataloader = DataLoader(dataset, batch_size=10000,
                        shuffle=True)

for e in range(number_epoch):
    for i_batch, sample_batched in enumerate(tqdm(dataloader)):
        # print(len(sample_batched['pixels']))
        # rays_o, rays_d, pixels = dataset.sample_rays(10000)
        # rays_o, rays_d, pixels = dataset.sample_rays_one(1000)
        # t1 = time.time()
        rays_o = sample_batched['rays_o'].squeeze()
        rays_d = sample_batched['rays_d'].squeeze()
        pixels = sample_batched['pixels']
        rays_o = rays_o.float().to(device)
        rays_d = rays_d.float().to(device)
        # rays_d = np.array(rays_d)
        # rays_o = np.array(rays_o)
        # rays_d = torch.tensor(rays_d).to(device)
        # rays_o = torch.tensor(rays_o).to(device)
        points = sample_along_rays(rays_o, rays_d)
        
    
        points = points.float().to(device)
        # points = np.array(points)
        # rays_d = np.array(rays_d)
        # points = torch.tensor(points).to(device)
        # rays_d = torch.tensor(rays_d).to(device)
    
        rays_d = torch.unsqueeze(rays_d,1)
        rays_d = rays_d.repeat(1,n_samples,1)
        rays_d = rays_d.view(-1,3)
        
        # t2 = time.time()
        # print('p',t2-t1)
        # t1 = t2
        
        rgbs, sigmas = model(points, rays_d)
        
        # t2 = time.time()
        # print('forward',t2-t1)
        # t1 = t2
        
        rgbs = rgbs.to(device)
        sigmas = sigmas.to(device)
        # sigams = model.foward_dentisy()
        sigmas = sigmas.view(-1, n_samples, 1)
        rgbs = rgbs.view(-1, n_samples, 3)
        
    
    
        rendered_colors = volrend(sigmas, rgbs, step_size)

                
        # t2 = time.time()
        # print('render',t2-t1)
        # t1 = t2
        # print("render",rendered_colors[0])
        # print("pixels",pixels[0])
    
        pixels = pixels.float().to(device)
        # pixels = np.array(pixels)
        # pixels = torch.tensor(pixels).float().to(device)
    
        # print(rendered_colors)
        loss = criterion(rendered_colors, pixels)
        
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

                        
        # t2 = time.time()
        # print('backward',t2-t1)
        # t1 = t2
        # scheduler.step()
        if i_batch % 49 == 0:
            print(f'iteration [{i_batch}], Loss: {loss.item()}')
    # scheduler.step()
    # torch.save(model.state_dict(), 'weights_%s.pth'%e)

  0%|                                           | 1/400 [00:00<04:56,  1.34it/s]

iteration [0], Loss: 0.007410584483295679


 12%|█████▎                                    | 50/400 [00:23<02:51,  2.04it/s]

iteration [49], Loss: 0.007934566587209702


 25%|██████████▍                               | 99/400 [00:45<02:31,  1.98it/s]

iteration [98], Loss: 0.007425498683005571


 37%|███████████████▏                         | 148/400 [01:08<02:07,  1.97it/s]

iteration [147], Loss: 0.007331693544983864


 49%|████████████████████▏                    | 197/400 [01:30<01:44,  1.93it/s]

iteration [196], Loss: 0.007999908179044724


 62%|█████████████████████████▏               | 246/400 [01:54<01:26,  1.78it/s]

iteration [245], Loss: 0.0076196808367967606


 74%|██████████████████████████████▏          | 295/400 [02:19<00:59,  1.76it/s]

iteration [294], Loss: 0.008052819408476353


 86%|███████████████████████████████████▎     | 344/400 [02:42<00:27,  2.01it/s]

iteration [343], Loss: 0.0077311014756560326


 98%|████████████████████████████████████████▎| 393/400 [03:04<00:03,  1.96it/s]

iteration [392], Loss: 0.00731813907623291


100%|█████████████████████████████████████████| 400/400 [03:08<00:00,  2.12it/s]
  0%|                                           | 1/400 [00:00<04:37,  1.44it/s]

iteration [0], Loss: 0.007229409646242857


 12%|█████▎                                    | 50/400 [00:23<02:59,  1.95it/s]

iteration [49], Loss: 0.007748833391815424


 25%|██████████▍                               | 99/400 [00:46<02:35,  1.94it/s]

iteration [98], Loss: 0.007614610716700554


 37%|███████████████▏                         | 148/400 [01:09<02:20,  1.79it/s]

iteration [147], Loss: 0.007979323156177998


 49%|████████████████████▏                    | 197/400 [01:35<02:12,  1.53it/s]

iteration [196], Loss: 0.007725160103291273


 62%|█████████████████████████▏               | 246/400 [01:58<01:19,  1.95it/s]

iteration [245], Loss: 0.007604737766087055


 74%|██████████████████████████████▏          | 295/400 [02:21<00:54,  1.94it/s]

iteration [294], Loss: 0.007367763202637434


 86%|███████████████████████████████████▎     | 344/400 [02:45<00:29,  1.87it/s]

iteration [343], Loss: 0.006996949203312397


 98%|████████████████████████████████████████▎| 393/400 [03:08<00:03,  1.91it/s]

iteration [392], Loss: 0.007795737124979496


100%|█████████████████████████████████████████| 400/400 [03:11<00:00,  2.08it/s]
  0%|                                           | 1/400 [00:00<04:40,  1.42it/s]

iteration [0], Loss: 0.007061498239636421


 12%|█████▎                                    | 50/400 [00:23<03:06,  1.88it/s]

iteration [49], Loss: 0.007830890826880932


 25%|██████████▍                               | 99/400 [00:47<02:46,  1.81it/s]

iteration [98], Loss: 0.007280253805220127


 37%|███████████████▏                         | 148/400 [01:10<02:21,  1.78it/s]

iteration [147], Loss: 0.007702145259827375


 49%|████████████████████▏                    | 197/400 [01:33<01:43,  1.96it/s]

iteration [196], Loss: 0.007724258117377758


 62%|█████████████████████████▏               | 246/400 [01:57<01:20,  1.91it/s]

iteration [245], Loss: 0.007644336204975843


 74%|██████████████████████████████▏          | 295/400 [02:21<00:55,  1.89it/s]

iteration [294], Loss: 0.007827994413673878


 86%|███████████████████████████████████▎     | 344/400 [02:45<00:30,  1.85it/s]

iteration [343], Loss: 0.007389109581708908


 98%|████████████████████████████████████████▎| 393/400 [03:09<00:03,  1.81it/s]

iteration [392], Loss: 0.007359923329204321


100%|█████████████████████████████████████████| 400/400 [03:13<00:00,  2.07it/s]
  0%|                                           | 1/400 [00:00<04:46,  1.39it/s]

iteration [0], Loss: 0.007316230796277523


 12%|█████▎                                    | 50/400 [00:24<03:21,  1.73it/s]

iteration [49], Loss: 0.007451023440808058


 25%|██████████▍                               | 99/400 [00:48<02:44,  1.83it/s]

iteration [98], Loss: 0.00771652115508914


 37%|███████████████▏                         | 148/400 [01:13<02:14,  1.87it/s]

iteration [147], Loss: 0.007143587339669466


 49%|████████████████████▏                    | 197/400 [01:37<01:47,  1.90it/s]

iteration [196], Loss: 0.007182182278484106


 62%|█████████████████████████▏               | 246/400 [02:02<01:25,  1.80it/s]

iteration [245], Loss: 0.007268719375133514


 74%|██████████████████████████████▏          | 295/400 [02:26<00:57,  1.83it/s]

iteration [294], Loss: 0.007097249384969473


 86%|███████████████████████████████████▎     | 344/400 [02:50<00:32,  1.72it/s]

iteration [343], Loss: 0.007726332172751427


 98%|████████████████████████████████████████▎| 393/400 [03:14<00:03,  1.85it/s]

iteration [392], Loss: 0.007507207803428173


100%|█████████████████████████████████████████| 400/400 [03:17<00:00,  2.03it/s]
  0%|                                           | 1/400 [00:00<04:40,  1.42it/s]

iteration [0], Loss: 0.007494545076042414


 12%|█████▎                                    | 50/400 [00:24<03:01,  1.92it/s]

iteration [49], Loss: 0.0076312837190926075


 25%|██████████▍                               | 99/400 [00:48<02:42,  1.85it/s]

iteration [98], Loss: 0.007104309741407633


 37%|███████████████▏                         | 148/400 [01:12<02:16,  1.85it/s]

iteration [147], Loss: 0.007195927202701569


 49%|████████████████████▏                    | 197/400 [01:37<01:50,  1.83it/s]

iteration [196], Loss: 0.007151078898459673


 62%|█████████████████████████▏               | 246/400 [02:01<01:25,  1.80it/s]

iteration [245], Loss: 0.007922429591417313


 74%|██████████████████████████████▏          | 295/400 [02:25<00:58,  1.79it/s]

iteration [294], Loss: 0.007222754880785942


 86%|███████████████████████████████████▎     | 344/400 [02:49<00:28,  1.95it/s]

iteration [343], Loss: 0.0073536052368581295


 98%|████████████████████████████████████████▎| 393/400 [03:11<00:03,  1.96it/s]

iteration [392], Loss: 0.007340446580201387


100%|█████████████████████████████████████████| 400/400 [03:15<00:00,  2.05it/s]
  0%|                                           | 1/400 [00:00<04:37,  1.44it/s]

iteration [0], Loss: 0.007352232001721859


 12%|█████▎                                    | 50/400 [00:23<02:59,  1.95it/s]

iteration [49], Loss: 0.007140972185879946


 25%|██████████▍                               | 99/400 [00:46<02:34,  1.95it/s]

iteration [98], Loss: 0.007638930343091488


 37%|███████████████▏                         | 148/400 [01:10<02:14,  1.88it/s]

iteration [147], Loss: 0.007452782243490219


 49%|████████████████████▏                    | 197/400 [01:33<01:49,  1.85it/s]

iteration [196], Loss: 0.007389488164335489


 62%|█████████████████████████▏               | 246/400 [01:57<01:27,  1.76it/s]

iteration [245], Loss: 0.007679012604057789


 74%|██████████████████████████████▏          | 295/400 [02:21<00:54,  1.93it/s]

iteration [294], Loss: 0.007703888230025768


 86%|███████████████████████████████████▎     | 344/400 [02:45<00:29,  1.92it/s]

iteration [343], Loss: 0.007514822296798229


 98%|████████████████████████████████████████▎| 393/400 [03:08<00:03,  1.93it/s]

iteration [392], Loss: 0.007072217762470245


100%|█████████████████████████████████████████| 400/400 [03:12<00:00,  2.08it/s]
  0%|                                           | 1/400 [00:00<04:39,  1.43it/s]

iteration [0], Loss: 0.006426146719604731


 12%|█████▎                                    | 50/400 [00:24<03:08,  1.86it/s]

iteration [49], Loss: 0.0072994171641767025


 25%|██████████▍                               | 99/400 [00:48<02:45,  1.82it/s]

iteration [98], Loss: 0.007461740635335445


 37%|███████████████▏                         | 148/400 [01:12<02:24,  1.74it/s]

iteration [147], Loss: 0.0076539828442037106


 49%|████████████████████▏                    | 197/400 [01:36<01:53,  1.79it/s]

iteration [196], Loss: 0.007151829544454813


 62%|█████████████████████████▏               | 246/400 [01:59<01:19,  1.93it/s]

iteration [245], Loss: 0.00674979155883193


 74%|██████████████████████████████▏          | 295/400 [02:23<00:58,  1.78it/s]

iteration [294], Loss: 0.007170203607529402


 86%|███████████████████████████████████▎     | 344/400 [02:48<00:30,  1.83it/s]

iteration [343], Loss: 0.0069114877842366695


 98%|████████████████████████████████████████▎| 393/400 [03:12<00:03,  1.84it/s]

iteration [392], Loss: 0.007271016947925091


100%|█████████████████████████████████████████| 400/400 [03:15<00:00,  2.04it/s]
  0%|                                           | 1/400 [00:00<04:41,  1.42it/s]

iteration [0], Loss: 0.007125995587557554


 12%|█████▎                                    | 50/400 [00:24<03:08,  1.85it/s]

iteration [49], Loss: 0.006875541061162949


 25%|██████████▍                               | 99/400 [00:47<02:45,  1.82it/s]

iteration [98], Loss: 0.007397100329399109


 37%|███████████████▏                         | 148/400 [01:10<02:08,  1.97it/s]

iteration [147], Loss: 0.007413126528263092


 49%|████████████████████▏                    | 197/400 [01:33<01:44,  1.95it/s]

iteration [196], Loss: 0.007263763342052698


 62%|█████████████████████████▏               | 246/400 [01:56<01:19,  1.94it/s]

iteration [245], Loss: 0.0068305255845189095


 74%|██████████████████████████████▏          | 295/400 [02:19<00:54,  1.92it/s]

iteration [294], Loss: 0.0073247081600129604


 86%|███████████████████████████████████▎     | 344/400 [02:42<00:30,  1.86it/s]

iteration [343], Loss: 0.0068644327111542225


 98%|████████████████████████████████████████▎| 393/400 [03:06<00:03,  1.75it/s]

iteration [392], Loss: 0.007420359179377556


100%|█████████████████████████████████████████| 400/400 [03:10<00:00,  2.10it/s]
  0%|                                           | 1/400 [00:00<04:45,  1.40it/s]

iteration [0], Loss: 0.006786893587559462


 12%|█████▎                                    | 50/400 [00:24<03:04,  1.90it/s]

iteration [49], Loss: 0.006673604249954224


 25%|██████████▍                               | 99/400 [00:48<02:40,  1.87it/s]

iteration [98], Loss: 0.007311273366212845


 37%|███████████████▏                         | 148/400 [01:12<02:14,  1.88it/s]

iteration [147], Loss: 0.007047479040920734


 49%|████████████████████▏                    | 197/400 [01:35<01:47,  1.89it/s]

iteration [196], Loss: 0.007246239110827446


 62%|█████████████████████████▏               | 246/400 [01:59<01:28,  1.75it/s]

iteration [245], Loss: 0.0072416034527122974


 74%|██████████████████████████████▏          | 295/400 [02:24<00:59,  1.76it/s]

iteration [294], Loss: 0.006860150024294853


 86%|███████████████████████████████████▎     | 344/400 [02:48<00:30,  1.86it/s]

iteration [343], Loss: 0.007050371263176203


 98%|████████████████████████████████████████▎| 393/400 [03:13<00:03,  1.84it/s]

iteration [392], Loss: 0.007007188629359007


100%|█████████████████████████████████████████| 400/400 [03:17<00:00,  2.03it/s]
  0%|                                           | 1/400 [00:00<04:50,  1.37it/s]

iteration [0], Loss: 0.006996342446655035


 12%|█████▎                                    | 50/400 [00:24<03:12,  1.82it/s]

iteration [49], Loss: 0.007270782720297575


 25%|██████████▍                               | 99/400 [00:49<02:46,  1.80it/s]

iteration [98], Loss: 0.007086717989295721


 37%|███████████████▏                         | 148/400 [01:14<02:22,  1.76it/s]

iteration [147], Loss: 0.006968893110752106


 49%|████████████████████▏                    | 197/400 [01:39<01:55,  1.76it/s]

iteration [196], Loss: 0.006803616415709257


 62%|█████████████████████████▏               | 246/400 [02:02<01:25,  1.80it/s]

iteration [245], Loss: 0.0071273865178227425


 74%|██████████████████████████████▏          | 295/400 [02:25<00:54,  1.92it/s]

iteration [294], Loss: 0.007154104765504599


 86%|███████████████████████████████████▎     | 344/400 [02:49<00:29,  1.87it/s]

iteration [343], Loss: 0.007326112594455481


 98%|████████████████████████████████████████▎| 393/400 [03:13<00:03,  1.87it/s]

iteration [392], Loss: 0.0072638243436813354


100%|█████████████████████████████████████████| 400/400 [03:17<00:00,  2.03it/s]
  0%|                                           | 1/400 [00:00<04:43,  1.41it/s]

iteration [0], Loss: 0.007186399307101965


 12%|█████▎                                    | 50/400 [00:24<03:09,  1.85it/s]

iteration [49], Loss: 0.007119185756891966


 25%|██████████▍                               | 99/400 [00:49<02:45,  1.82it/s]

iteration [98], Loss: 0.006842239294201136


 37%|███████████████▏                         | 148/400 [01:13<02:23,  1.75it/s]

iteration [147], Loss: 0.006804965436458588


 49%|████████████████████▏                    | 197/400 [01:37<01:45,  1.92it/s]

iteration [196], Loss: 0.007036606315523386


 62%|█████████████████████████▏               | 246/400 [02:01<01:23,  1.85it/s]

iteration [245], Loss: 0.0067987809889018536


 74%|██████████████████████████████▏          | 295/400 [02:25<00:57,  1.84it/s]

iteration [294], Loss: 0.007312712259590626


 86%|███████████████████████████████████▎     | 344/400 [02:50<00:28,  1.95it/s]

iteration [343], Loss: 0.006878664251416922


 98%|████████████████████████████████████████▎| 393/400 [03:13<00:03,  1.94it/s]

iteration [392], Loss: 0.007185340393334627


100%|█████████████████████████████████████████| 400/400 [03:16<00:00,  2.03it/s]
  0%|                                           | 1/400 [00:00<04:39,  1.43it/s]

iteration [0], Loss: 0.0068885935470461845


 12%|█████▎                                    | 50/400 [00:24<03:26,  1.69it/s]

iteration [49], Loss: 0.006992565002292395


 25%|██████████▍                               | 99/400 [00:49<02:53,  1.73it/s]

iteration [98], Loss: 0.00709148682653904


 37%|███████████████▏                         | 148/400 [01:14<02:13,  1.89it/s]

iteration [147], Loss: 0.007114793173968792


 49%|████████████████████▏                    | 197/400 [01:37<01:43,  1.96it/s]

iteration [196], Loss: 0.006824399344623089


 62%|█████████████████████████▏               | 246/400 [02:00<01:19,  1.95it/s]

iteration [245], Loss: 0.006517413537949324


 74%|██████████████████████████████▏          | 295/400 [02:24<00:54,  1.91it/s]

iteration [294], Loss: 0.00670379726216197


 86%|███████████████████████████████████▎     | 344/400 [02:47<00:29,  1.92it/s]

iteration [343], Loss: 0.006722383666783571


 98%|████████████████████████████████████████▎| 393/400 [03:10<00:03,  1.85it/s]

iteration [392], Loss: 0.007091241423040628


100%|█████████████████████████████████████████| 400/400 [03:13<00:00,  2.07it/s]
  0%|                                           | 1/400 [00:00<06:08,  1.08it/s]

iteration [0], Loss: 0.0067479899153113365


 12%|█████▎                                    | 50/400 [00:25<03:07,  1.87it/s]

iteration [49], Loss: 0.007100452668964863


 25%|██████████▍                               | 99/400 [00:49<02:35,  1.94it/s]

iteration [98], Loss: 0.006938505917787552


 37%|███████████████▏                         | 148/400 [01:12<02:09,  1.94it/s]

iteration [147], Loss: 0.0073486873880028725


 49%|████████████████████▏                    | 197/400 [01:35<01:45,  1.93it/s]

iteration [196], Loss: 0.0067145549692213535


 62%|█████████████████████████▏               | 246/400 [02:00<01:23,  1.84it/s]

iteration [245], Loss: 0.006996646523475647


 74%|██████████████████████████████▏          | 295/400 [02:24<01:00,  1.74it/s]

iteration [294], Loss: 0.006746483966708183


 86%|███████████████████████████████████▎     | 344/400 [02:48<00:30,  1.81it/s]

iteration [343], Loss: 0.006980118807405233


 98%|████████████████████████████████████████▎| 393/400 [03:12<00:03,  1.82it/s]

iteration [392], Loss: 0.006928410846740007


100%|█████████████████████████████████████████| 400/400 [03:15<00:00,  2.04it/s]
  0%|                                           | 1/400 [00:00<04:52,  1.37it/s]

iteration [0], Loss: 0.006306170951575041


 12%|█████▎                                    | 50/400 [00:25<03:06,  1.88it/s]

iteration [49], Loss: 0.006796051748096943


 25%|██████████▍                               | 99/400 [00:49<02:44,  1.83it/s]

iteration [98], Loss: 0.006799801718443632


 37%|███████████████▏                         | 148/400 [01:14<02:22,  1.77it/s]

iteration [147], Loss: 0.006902963388711214


 49%|████████████████████▏                    | 197/400 [01:37<01:47,  1.89it/s]

iteration [196], Loss: 0.006882703863084316


 62%|█████████████████████████▏               | 246/400 [02:01<01:23,  1.84it/s]

iteration [245], Loss: 0.006831207312643528


 74%|██████████████████████████████▏          | 295/400 [02:24<00:53,  1.95it/s]

iteration [294], Loss: 0.007080285809934139


 86%|███████████████████████████████████▎     | 344/400 [02:47<00:29,  1.93it/s]

iteration [343], Loss: 0.007136027794331312


 98%|████████████████████████████████████████▎| 393/400 [03:11<00:03,  1.90it/s]

iteration [392], Loss: 0.006599412299692631


100%|█████████████████████████████████████████| 400/400 [03:14<00:00,  2.05it/s]
  0%|                                           | 1/400 [00:00<04:49,  1.38it/s]

iteration [0], Loss: 0.006894113961607218


 12%|█████▎                                    | 50/400 [00:24<03:06,  1.88it/s]

iteration [49], Loss: 0.0068382201716303825


 25%|██████████▍                               | 99/400 [00:48<02:45,  1.82it/s]

iteration [98], Loss: 0.006679829675704241


 37%|███████████████▏                         | 148/400 [01:12<02:15,  1.86it/s]

iteration [147], Loss: 0.006709489040076733


 49%|████████████████████▏                    | 197/400 [01:36<01:55,  1.75it/s]

iteration [196], Loss: 0.00685707526281476


 62%|█████████████████████████▏               | 246/400 [01:59<01:20,  1.92it/s]

iteration [245], Loss: 0.006532457657158375


 74%|██████████████████████████████▏          | 295/400 [02:23<00:55,  1.89it/s]

iteration [294], Loss: 0.00673647178336978


 86%|███████████████████████████████████▎     | 344/400 [02:47<00:29,  1.87it/s]

iteration [343], Loss: 0.006803066935390234


 98%|████████████████████████████████████████▎| 393/400 [03:11<00:03,  1.87it/s]

iteration [392], Loss: 0.0065615214407444


100%|█████████████████████████████████████████| 400/400 [03:15<00:00,  2.05it/s]
  0%|                                           | 1/400 [00:00<04:42,  1.41it/s]

iteration [0], Loss: 0.006631833966821432


 12%|█████▎                                    | 50/400 [00:25<03:07,  1.87it/s]

iteration [49], Loss: 0.00666501559317112


 25%|██████████▍                               | 99/400 [00:49<02:55,  1.72it/s]

iteration [98], Loss: 0.006827374454587698


 37%|███████████████▏                         | 148/400 [01:13<02:23,  1.75it/s]

iteration [147], Loss: 0.006965977605432272


 49%|████████████████████▏                    | 197/400 [01:36<01:48,  1.88it/s]

iteration [196], Loss: 0.006483483128249645


 62%|█████████████████████████▏               | 246/400 [02:02<01:25,  1.80it/s]

iteration [245], Loss: 0.006669400259852409


 74%|██████████████████████████████▏          | 295/400 [02:25<00:56,  1.86it/s]

iteration [294], Loss: 0.006688464432954788


 86%|███████████████████████████████████▎     | 344/400 [02:48<00:29,  1.89it/s]

iteration [343], Loss: 0.006877653766423464


 98%|████████████████████████████████████████▎| 393/400 [03:12<00:03,  1.84it/s]

iteration [392], Loss: 0.006773229222744703


100%|█████████████████████████████████████████| 400/400 [03:15<00:00,  2.05it/s]
  0%|                                           | 1/400 [00:00<04:48,  1.38it/s]

iteration [0], Loss: 0.006324845366179943


 12%|█████▎                                    | 50/400 [00:24<03:16,  1.78it/s]

iteration [49], Loss: 0.006412874441593885


 25%|██████████▍                               | 99/400 [00:48<02:42,  1.85it/s]

iteration [98], Loss: 0.006817615125328302


 37%|███████████████▏                         | 148/400 [01:11<02:18,  1.82it/s]

iteration [147], Loss: 0.006832933519035578


 49%|████████████████████▏                    | 197/400 [01:36<01:45,  1.92it/s]

iteration [196], Loss: 0.0068048145622015


 62%|█████████████████████████▏               | 246/400 [01:59<01:22,  1.87it/s]

iteration [245], Loss: 0.006492867134511471


 74%|██████████████████████████████▏          | 295/400 [02:23<00:56,  1.85it/s]

iteration [294], Loss: 0.006875301245599985


 86%|███████████████████████████████████▎     | 344/400 [02:46<00:31,  1.79it/s]

iteration [343], Loss: 0.006727386731654406


 98%|████████████████████████████████████████▎| 393/400 [03:09<00:03,  1.94it/s]

iteration [392], Loss: 0.0067956107668578625


100%|█████████████████████████████████████████| 400/400 [03:13<00:00,  2.07it/s]
  0%|                                           | 1/400 [00:00<04:41,  1.42it/s]

iteration [0], Loss: 0.0066348654218018055


 12%|█████▎                                    | 50/400 [00:24<03:02,  1.92it/s]

iteration [49], Loss: 0.006715047173202038


 25%|██████████▍                               | 99/400 [00:48<02:47,  1.80it/s]

iteration [98], Loss: 0.006891993340104818


 37%|███████████████▏                         | 148/400 [01:13<02:17,  1.84it/s]

iteration [147], Loss: 0.006537829525768757


 49%|████████████████████▏                    | 197/400 [01:37<01:51,  1.82it/s]

iteration [196], Loss: 0.006486142985522747


 62%|█████████████████████████▏               | 246/400 [02:02<01:27,  1.75it/s]

iteration [245], Loss: 0.00731813395395875


 74%|██████████████████████████████▏          | 295/400 [02:26<00:57,  1.81it/s]

iteration [294], Loss: 0.006989235989749432


 86%|███████████████████████████████████▎     | 344/400 [02:51<00:29,  1.90it/s]

iteration [343], Loss: 0.007139451336115599


 98%|████████████████████████████████████████▎| 393/400 [03:14<00:03,  1.92it/s]

iteration [392], Loss: 0.006596551742404699


100%|█████████████████████████████████████████| 400/400 [03:18<00:00,  2.01it/s]
  0%|                                           | 1/400 [00:00<04:44,  1.40it/s]

iteration [0], Loss: 0.0068042525090277195


 12%|█████▎                                    | 50/400 [00:25<03:09,  1.85it/s]

iteration [49], Loss: 0.006604148540645838


 25%|██████████▍                               | 99/400 [00:48<02:49,  1.78it/s]

iteration [98], Loss: 0.006482970435172319


 37%|███████████████▏                         | 148/400 [01:12<02:23,  1.76it/s]

iteration [147], Loss: 0.006494718138128519


 49%|████████████████████▏                    | 197/400 [01:35<01:51,  1.81it/s]

iteration [196], Loss: 0.006471249274909496


 62%|█████████████████████████▏               | 246/400 [01:59<01:18,  1.96it/s]

iteration [245], Loss: 0.006498234812170267


 74%|██████████████████████████████▏          | 295/400 [02:22<00:53,  1.97it/s]

iteration [294], Loss: 0.006821201182901859


 86%|███████████████████████████████████▎     | 344/400 [02:45<00:28,  1.93it/s]

iteration [343], Loss: 0.006401294842362404


 98%|████████████████████████████████████████▎| 393/400 [03:09<00:03,  1.93it/s]

iteration [392], Loss: 0.006711119785904884


100%|█████████████████████████████████████████| 400/400 [03:12<00:00,  2.08it/s]
  0%|                                           | 1/400 [00:00<04:41,  1.42it/s]

iteration [0], Loss: 0.0063773877918720245


 12%|█████▎                                    | 50/400 [00:24<03:07,  1.87it/s]

iteration [49], Loss: 0.006824841722846031


 25%|██████████▍                               | 99/400 [00:47<02:55,  1.72it/s]

iteration [98], Loss: 0.006719151046127081


 37%|███████████████▏                         | 148/400 [01:11<02:15,  1.86it/s]

iteration [147], Loss: 0.00661097839474678


 49%|████████████████████▏                    | 197/400 [01:36<01:49,  1.85it/s]

iteration [196], Loss: 0.006211737170815468


 62%|█████████████████████████▏               | 246/400 [02:00<01:23,  1.84it/s]

iteration [245], Loss: 0.006662752013653517


 74%|██████████████████████████████▏          | 295/400 [02:25<00:57,  1.84it/s]

iteration [294], Loss: 0.006750531028956175


 86%|███████████████████████████████████▎     | 344/400 [02:50<00:31,  1.77it/s]

iteration [343], Loss: 0.006473491434007883


 98%|████████████████████████████████████████▎| 393/400 [03:14<00:03,  1.88it/s]

iteration [392], Loss: 0.007000879850238562


100%|█████████████████████████████████████████| 400/400 [03:17<00:00,  2.02it/s]
  0%|                                           | 1/400 [00:00<04:40,  1.42it/s]

iteration [0], Loss: 0.006223329342901707


 12%|█████▎                                    | 50/400 [00:24<03:17,  1.78it/s]

iteration [49], Loss: 0.006446422543376684


 25%|██████████▍                               | 99/400 [00:47<02:34,  1.95it/s]

iteration [98], Loss: 0.006735878065228462


 37%|███████████████▏                         | 148/400 [01:11<02:12,  1.90it/s]

iteration [147], Loss: 0.006098803598433733


 49%|████████████████████▏                    | 197/400 [01:35<01:49,  1.85it/s]

iteration [196], Loss: 0.006595437880605459


 62%|█████████████████████████▏               | 246/400 [01:59<01:21,  1.89it/s]

iteration [245], Loss: 0.006549675017595291


 74%|██████████████████████████████▏          | 295/400 [02:22<00:56,  1.87it/s]

iteration [294], Loss: 0.006339050363749266


 86%|███████████████████████████████████▎     | 344/400 [02:46<00:30,  1.82it/s]

iteration [343], Loss: 0.006582935806363821


 98%|████████████████████████████████████████▎| 393/400 [03:09<00:03,  1.76it/s]

iteration [392], Loss: 0.006818955764174461


100%|█████████████████████████████████████████| 400/400 [03:12<00:00,  2.07it/s]
  0%|                                           | 1/400 [00:00<04:37,  1.44it/s]

iteration [0], Loss: 0.0063435290940105915


 12%|█████▎                                    | 50/400 [00:24<03:03,  1.90it/s]

iteration [49], Loss: 0.006463781930506229


 25%|██████████▍                               | 99/400 [00:47<02:35,  1.93it/s]

iteration [98], Loss: 0.006516380235552788


 37%|███████████████▏                         | 148/400 [01:11<02:13,  1.89it/s]

iteration [147], Loss: 0.006304625421762466


 49%|████████████████████▏                    | 197/400 [01:34<01:47,  1.89it/s]

iteration [196], Loss: 0.006120136007666588


 62%|█████████████████████████▏               | 246/400 [01:58<01:24,  1.82it/s]

iteration [245], Loss: 0.006441106088459492


 74%|██████████████████████████████▏          | 295/400 [02:21<00:57,  1.83it/s]

iteration [294], Loss: 0.006385373882949352


 86%|███████████████████████████████████▎     | 344/400 [02:44<00:28,  1.94it/s]

iteration [343], Loss: 0.006782786920666695


 98%|████████████████████████████████████████▎| 393/400 [03:08<00:03,  1.89it/s]

iteration [392], Loss: 0.006724074482917786


100%|█████████████████████████████████████████| 400/400 [03:11<00:00,  2.08it/s]
  0%|                                           | 1/400 [00:00<04:44,  1.40it/s]

iteration [0], Loss: 0.006389732006937265


 12%|█████▎                                    | 50/400 [00:24<03:01,  1.92it/s]

iteration [49], Loss: 0.006493333727121353


 25%|██████████▍                               | 99/400 [00:47<02:38,  1.90it/s]

iteration [98], Loss: 0.006413002964109182


 37%|███████████████▏                         | 148/400 [01:10<02:15,  1.86it/s]

iteration [147], Loss: 0.006271649617701769


 49%|████████████████████▏                    | 197/400 [01:34<01:49,  1.85it/s]

iteration [196], Loss: 0.006492236629128456


 62%|█████████████████████████▏               | 246/400 [01:58<01:24,  1.83it/s]

iteration [245], Loss: 0.006454160436987877


 74%|██████████████████████████████▏          | 295/400 [02:23<00:59,  1.77it/s]

iteration [294], Loss: 0.006042396649718285


 86%|███████████████████████████████████▎     | 344/400 [02:48<00:31,  1.79it/s]

iteration [343], Loss: 0.006162840407341719


 98%|████████████████████████████████████████▎| 393/400 [03:14<00:03,  1.77it/s]

iteration [392], Loss: 0.006208798848092556


100%|█████████████████████████████████████████| 400/400 [03:17<00:00,  2.02it/s]
  0%|                                           | 1/400 [00:00<04:50,  1.38it/s]

iteration [0], Loss: 0.006367339286953211


 12%|█████▎                                    | 50/400 [00:25<03:15,  1.79it/s]

iteration [49], Loss: 0.0066231326200068


 25%|██████████▍                               | 99/400 [00:49<02:39,  1.89it/s]

iteration [98], Loss: 0.006016012281179428


 37%|███████████████▏                         | 148/400 [01:13<02:29,  1.69it/s]

iteration [147], Loss: 0.006211102940142155


 49%|████████████████████▏                    | 197/400 [01:38<01:51,  1.82it/s]

iteration [196], Loss: 0.006187673192471266


 62%|█████████████████████████▏               | 246/400 [02:02<01:18,  1.96it/s]

iteration [245], Loss: 0.006253335624933243


 74%|██████████████████████████████▏          | 295/400 [02:26<00:56,  1.86it/s]

iteration [294], Loss: 0.006344610825181007


 86%|███████████████████████████████████▎     | 344/400 [02:50<00:31,  1.77it/s]

iteration [343], Loss: 0.006472984794527292


 98%|████████████████████████████████████████▎| 393/400 [03:15<00:03,  1.79it/s]

iteration [392], Loss: 0.006297153886407614


100%|█████████████████████████████████████████| 400/400 [03:18<00:00,  2.01it/s]
  0%|                                           | 1/400 [00:00<04:51,  1.37it/s]

iteration [0], Loss: 0.006012956146150827


 12%|█████▎                                    | 50/400 [00:24<03:24,  1.71it/s]

iteration [49], Loss: 0.00614289753139019


 25%|██████████▍                               | 99/400 [00:49<02:50,  1.77it/s]

iteration [98], Loss: 0.006382896099239588


 37%|███████████████▏                         | 148/400 [01:13<02:15,  1.86it/s]

iteration [147], Loss: 0.0064943693578243256


 49%|████████████████████▏                    | 197/400 [01:38<01:49,  1.86it/s]

iteration [196], Loss: 0.006319753360003233


 62%|█████████████████████████▏               | 246/400 [02:01<01:20,  1.91it/s]

iteration [245], Loss: 0.006568428594619036


 74%|██████████████████████████████▏          | 295/400 [02:24<00:55,  1.88it/s]

iteration [294], Loss: 0.006561018992215395


 86%|███████████████████████████████████▎     | 344/400 [02:48<00:29,  1.87it/s]

iteration [343], Loss: 0.006152417976409197


 98%|████████████████████████████████████████▎| 393/400 [03:11<00:03,  1.81it/s]

iteration [392], Loss: 0.006178099662065506


100%|█████████████████████████████████████████| 400/400 [03:14<00:00,  2.06it/s]
  0%|                                           | 1/400 [00:00<05:38,  1.18it/s]

iteration [0], Loss: 0.006636727135628462


 12%|█████▎                                    | 50/400 [00:24<03:00,  1.94it/s]

iteration [49], Loss: 0.006203705444931984


 25%|██████████▍                               | 99/400 [00:47<02:39,  1.88it/s]

iteration [98], Loss: 0.006176709663122892


 37%|███████████████▏                         | 148/400 [01:11<02:11,  1.92it/s]

iteration [147], Loss: 0.006333045195788145


 49%|████████████████████▏                    | 197/400 [01:35<01:47,  1.89it/s]

iteration [196], Loss: 0.005889314226806164


 62%|█████████████████████████▏               | 246/400 [01:58<01:23,  1.84it/s]

iteration [245], Loss: 0.0061921062879264355


 74%|██████████████████████████████▏          | 295/400 [02:22<00:59,  1.76it/s]

iteration [294], Loss: 0.005987337324768305


 86%|███████████████████████████████████▎     | 344/400 [02:45<00:28,  1.95it/s]

iteration [343], Loss: 0.006156736984848976


 98%|████████████████████████████████████████▎| 393/400 [03:09<00:03,  1.96it/s]

iteration [392], Loss: 0.006368695758283138


100%|█████████████████████████████████████████| 400/400 [03:12<00:00,  2.07it/s]
  0%|                                           | 1/400 [00:00<04:34,  1.45it/s]

iteration [0], Loss: 0.006302895024418831


 12%|█████▎                                    | 50/400 [00:23<03:13,  1.81it/s]

iteration [49], Loss: 0.006436687428504229


 25%|██████████▍                               | 99/400 [00:48<02:44,  1.84it/s]

iteration [98], Loss: 0.006450957152992487


 37%|███████████████▏                         | 148/400 [01:11<02:11,  1.91it/s]

iteration [147], Loss: 0.00665638130158186


 49%|████████████████████▏                    | 197/400 [01:35<01:53,  1.79it/s]

iteration [196], Loss: 0.006479126866906881


 62%|█████████████████████████▏               | 246/400 [01:58<01:19,  1.94it/s]

iteration [245], Loss: 0.006108568981289864


 74%|██████████████████████████████▏          | 295/400 [02:22<00:54,  1.93it/s]

iteration [294], Loss: 0.006440696772187948


 86%|███████████████████████████████████▎     | 344/400 [02:46<00:29,  1.88it/s]

iteration [343], Loss: 0.006500013638287783


 98%|████████████████████████████████████████▎| 393/400 [03:10<00:03,  1.89it/s]

iteration [392], Loss: 0.006438667885959148


100%|█████████████████████████████████████████| 400/400 [03:13<00:00,  2.07it/s]
  0%|                                           | 1/400 [00:00<04:41,  1.42it/s]

iteration [0], Loss: 0.006413136143237352


 12%|█████▎                                    | 50/400 [00:24<03:09,  1.85it/s]

iteration [49], Loss: 0.006106858607381582


 25%|██████████▍                               | 99/400 [00:48<02:44,  1.83it/s]

iteration [98], Loss: 0.0061151571571826935


 37%|███████████████▏                         | 148/400 [01:12<02:23,  1.76it/s]

iteration [147], Loss: 0.006164039019495249


 49%|████████████████████▏                    | 197/400 [01:36<01:44,  1.93it/s]

iteration [196], Loss: 0.0061837900429964066


 62%|█████████████████████████▏               | 246/400 [01:59<01:21,  1.89it/s]

iteration [245], Loss: 0.006459631025791168


 74%|██████████████████████████████▏          | 295/400 [02:23<00:55,  1.90it/s]

iteration [294], Loss: 0.006344257388263941


 86%|███████████████████████████████████▎     | 344/400 [02:46<00:29,  1.88it/s]

iteration [343], Loss: 0.006283781956881285


 98%|████████████████████████████████████████▎| 393/400 [03:10<00:03,  1.88it/s]

iteration [392], Loss: 0.006368444766849279


100%|█████████████████████████████████████████| 400/400 [03:13<00:00,  2.07it/s]
  0%|                                           | 1/400 [00:00<04:35,  1.45it/s]

iteration [0], Loss: 0.005783864296972752


 12%|█████▎                                    | 50/400 [00:24<03:21,  1.74it/s]

iteration [49], Loss: 0.00644346559420228


 25%|██████████▍                               | 99/400 [00:49<02:41,  1.86it/s]

iteration [98], Loss: 0.006327606271952391


 37%|███████████████▏                         | 148/400 [01:13<02:13,  1.89it/s]

iteration [147], Loss: 0.0061172302812337875


 49%|████████████████████▏                    | 197/400 [01:37<01:49,  1.86it/s]

iteration [196], Loss: 0.0061457292176783085


 62%|█████████████████████████▏               | 246/400 [02:01<01:23,  1.86it/s]

iteration [245], Loss: 0.006057244259864092


 74%|██████████████████████████████▏          | 295/400 [02:26<00:57,  1.82it/s]

iteration [294], Loss: 0.006243813782930374


 86%|███████████████████████████████████▎     | 344/400 [02:49<00:30,  1.84it/s]

iteration [343], Loss: 0.0064918906427919865


 98%|████████████████████████████████████████▎| 393/400 [03:13<00:04,  1.72it/s]

iteration [392], Loss: 0.006700810976326466


100%|█████████████████████████████████████████| 400/400 [03:17<00:00,  2.03it/s]
  0%|                                           | 1/400 [00:00<05:44,  1.16it/s]

iteration [0], Loss: 0.006240257993340492


 12%|█████▎                                    | 50/400 [00:25<03:13,  1.81it/s]

iteration [49], Loss: 0.005970018915832043


 25%|██████████▍                               | 99/400 [00:49<02:38,  1.90it/s]

iteration [98], Loss: 0.006068261805921793


 37%|███████████████▏                         | 148/400 [01:13<02:12,  1.90it/s]

iteration [147], Loss: 0.005909058731049299


 49%|████████████████████▏                    | 197/400 [01:38<01:57,  1.72it/s]

iteration [196], Loss: 0.006227156613022089


 62%|█████████████████████████▏               | 246/400 [02:02<01:23,  1.85it/s]

iteration [245], Loss: 0.006236984394490719


 74%|██████████████████████████████▏          | 295/400 [02:27<01:04,  1.63it/s]

iteration [294], Loss: 0.006566166412085295


 86%|███████████████████████████████████▎     | 344/400 [02:51<00:31,  1.80it/s]

iteration [343], Loss: 0.00596719142049551


 98%|████████████████████████████████████████▎| 393/400 [03:14<00:03,  1.96it/s]

iteration [392], Loss: 0.006366857793182135


100%|█████████████████████████████████████████| 400/400 [03:17<00:00,  2.02it/s]
  0%|                                           | 1/400 [00:00<04:41,  1.42it/s]

iteration [0], Loss: 0.006186351645737886


 12%|█████▎                                    | 50/400 [00:24<03:06,  1.87it/s]

iteration [49], Loss: 0.006159232929348946


 25%|██████████▍                               | 99/400 [00:47<02:37,  1.91it/s]

iteration [98], Loss: 0.006106516811996698


 37%|███████████████▏                         | 148/400 [01:10<02:12,  1.91it/s]

iteration [147], Loss: 0.006272933911532164


 49%|████████████████████▏                    | 197/400 [01:34<01:53,  1.78it/s]

iteration [196], Loss: 0.006314653437584639


 62%|█████████████████████████▏               | 246/400 [01:57<01:23,  1.85it/s]

iteration [245], Loss: 0.006166763603687286


 74%|██████████████████████████████▏          | 295/400 [02:21<00:54,  1.94it/s]

iteration [294], Loss: 0.006298112217336893


 86%|███████████████████████████████████▎     | 344/400 [02:44<00:29,  1.89it/s]

iteration [343], Loss: 0.006240364629775286


 98%|████████████████████████████████████████▎| 393/400 [03:08<00:03,  1.80it/s]

iteration [392], Loss: 0.006119653582572937


100%|█████████████████████████████████████████| 400/400 [03:11<00:00,  2.08it/s]
  0%|                                           | 1/400 [00:00<04:53,  1.36it/s]

iteration [0], Loss: 0.006143559701740742


 12%|█████▎                                    | 50/400 [00:25<02:58,  1.97it/s]

iteration [49], Loss: 0.005882592871785164


 25%|██████████▍                               | 99/400 [00:49<02:49,  1.77it/s]

iteration [98], Loss: 0.006028469186276197


 37%|███████████████▏                         | 148/400 [01:14<02:24,  1.75it/s]

iteration [147], Loss: 0.0059808399528265


 49%|████████████████████▏                    | 197/400 [01:37<01:42,  1.98it/s]

iteration [196], Loss: 0.006210966035723686


 62%|█████████████████████████▏               | 246/400 [02:00<01:18,  1.95it/s]

iteration [245], Loss: 0.006228812504559755


 74%|██████████████████████████████▏          | 295/400 [02:24<00:54,  1.93it/s]

iteration [294], Loss: 0.006054645404219627


 86%|███████████████████████████████████▎     | 344/400 [02:47<00:28,  1.96it/s]

iteration [343], Loss: 0.0062163835391402245


 98%|████████████████████████████████████████▎| 393/400 [03:10<00:03,  1.89it/s]

iteration [392], Loss: 0.006055046804249287


100%|█████████████████████████████████████████| 400/400 [03:14<00:00,  2.06it/s]
  0%|                                           | 1/400 [00:00<04:41,  1.42it/s]

iteration [0], Loss: 0.006092226132750511


 12%|█████▎                                    | 50/400 [00:24<03:06,  1.88it/s]

iteration [49], Loss: 0.006092081777751446


 25%|██████████▍                               | 99/400 [00:47<02:47,  1.79it/s]

iteration [98], Loss: 0.0063253967091441154


 37%|███████████████▏                         | 148/400 [01:10<02:11,  1.92it/s]

iteration [147], Loss: 0.006457846146076918


 49%|████████████████████▏                    | 197/400 [01:33<01:46,  1.91it/s]

iteration [196], Loss: 0.005854059010744095


 62%|█████████████████████████▏               | 246/400 [01:57<01:23,  1.84it/s]

iteration [245], Loss: 0.006180926691740751


 74%|██████████████████████████████▏          | 295/400 [02:22<00:58,  1.80it/s]

iteration [294], Loss: 0.006117685232311487


 86%|███████████████████████████████████▎     | 344/400 [02:47<00:31,  1.76it/s]

iteration [343], Loss: 0.006356190890073776


 98%|████████████████████████████████████████▎| 393/400 [03:11<00:03,  1.76it/s]

iteration [392], Loss: 0.006028407718986273


100%|█████████████████████████████████████████| 400/400 [03:14<00:00,  2.06it/s]
  0%|                                           | 1/400 [00:00<05:50,  1.14it/s]

iteration [0], Loss: 0.006073535420000553


 12%|█████▎                                    | 50/400 [00:25<03:21,  1.73it/s]

iteration [49], Loss: 0.0060497792437672615


 25%|██████████▍                               | 99/400 [00:49<02:41,  1.86it/s]

iteration [98], Loss: 0.006056047044694424


 37%|███████████████▏                         | 148/400 [01:13<02:17,  1.83it/s]

iteration [147], Loss: 0.006276226136833429


 49%|████████████████████▏                    | 197/400 [01:38<01:49,  1.86it/s]

iteration [196], Loss: 0.006054833065718412


 62%|█████████████████████████▏               | 246/400 [02:02<01:25,  1.81it/s]

iteration [245], Loss: 0.0063095698133111


 74%|██████████████████████████████▏          | 295/400 [02:26<00:58,  1.78it/s]

iteration [294], Loss: 0.006096535362303257


 86%|███████████████████████████████████▎     | 344/400 [02:51<00:32,  1.73it/s]

iteration [343], Loss: 0.00611021788790822


 98%|████████████████████████████████████████▎| 393/400 [03:15<00:03,  1.85it/s]

iteration [392], Loss: 0.00588356563821435


100%|█████████████████████████████████████████| 400/400 [03:19<00:00,  2.01it/s]
  0%|                                           | 1/400 [00:00<04:45,  1.40it/s]

iteration [0], Loss: 0.0058984835632145405


 12%|█████▎                                    | 50/400 [00:24<03:13,  1.81it/s]

iteration [49], Loss: 0.006199426483362913


 25%|██████████▍                               | 99/400 [00:49<02:46,  1.80it/s]

iteration [98], Loss: 0.00625060498714447


 37%|███████████████▏                         | 148/400 [01:13<02:16,  1.85it/s]

iteration [147], Loss: 0.006282494869083166


 49%|████████████████████▏                    | 197/400 [01:38<01:52,  1.81it/s]

iteration [196], Loss: 0.006383533589541912


 62%|█████████████████████████▏               | 246/400 [02:02<01:29,  1.72it/s]

iteration [245], Loss: 0.005750419571995735


 74%|██████████████████████████████▏          | 295/400 [02:25<00:54,  1.94it/s]

iteration [294], Loss: 0.006466447841376066


 86%|███████████████████████████████████▎     | 344/400 [02:49<00:29,  1.89it/s]

iteration [343], Loss: 0.006430764216929674


 98%|████████████████████████████████████████▎| 393/400 [03:13<00:03,  1.93it/s]

iteration [392], Loss: 0.006396578159183264


100%|█████████████████████████████████████████| 400/400 [03:16<00:00,  2.03it/s]
  0%|                                           | 1/400 [00:00<04:43,  1.41it/s]

iteration [0], Loss: 0.005760648753494024


 12%|█████▎                                    | 50/400 [00:24<03:06,  1.88it/s]

iteration [49], Loss: 0.006148628890514374


 25%|██████████▍                               | 99/400 [00:48<02:45,  1.82it/s]

iteration [98], Loss: 0.006197987124323845


 37%|███████████████▏                         | 148/400 [01:12<02:20,  1.80it/s]

iteration [147], Loss: 0.005722958594560623


 49%|████████████████████▏                    | 197/400 [01:36<01:48,  1.87it/s]

iteration [196], Loss: 0.006060248706489801


 62%|█████████████████████████▏               | 246/400 [02:00<01:20,  1.91it/s]

iteration [245], Loss: 0.006232595071196556


 74%|██████████████████████████████▏          | 295/400 [02:24<00:55,  1.88it/s]

iteration [294], Loss: 0.005987475626170635


 86%|███████████████████████████████████▎     | 344/400 [02:48<00:29,  1.88it/s]

iteration [343], Loss: 0.005667619872838259


 98%|████████████████████████████████████████▎| 393/400 [03:12<00:03,  1.87it/s]

iteration [392], Loss: 0.006048225332051516


100%|█████████████████████████████████████████| 400/400 [03:16<00:00,  2.04it/s]
  0%|                                           | 1/400 [00:00<04:48,  1.38it/s]

iteration [0], Loss: 0.00608826894313097


 12%|█████▎                                    | 50/400 [00:24<03:16,  1.78it/s]

iteration [49], Loss: 0.006008485797792673


 25%|██████████▍                               | 99/400 [00:48<02:51,  1.76it/s]

iteration [98], Loss: 0.005943145137280226


 37%|███████████████▏                         | 148/400 [01:12<02:13,  1.89it/s]

iteration [147], Loss: 0.006011574994772673


 49%|████████████████████▏                    | 197/400 [01:37<01:47,  1.88it/s]

iteration [196], Loss: 0.006040182895958424


 62%|█████████████████████████▏               | 246/400 [02:01<01:23,  1.85it/s]

iteration [245], Loss: 0.006293350365012884


 74%|██████████████████████████████▏          | 295/400 [02:25<00:56,  1.87it/s]

iteration [294], Loss: 0.005721669644117355


 86%|███████████████████████████████████▎     | 344/400 [02:50<00:31,  1.77it/s]

iteration [343], Loss: 0.006211589090526104


 98%|████████████████████████████████████████▎| 393/400 [03:14<00:03,  1.77it/s]

iteration [392], Loss: 0.00595530541613698


100%|█████████████████████████████████████████| 400/400 [03:18<00:00,  2.02it/s]
  0%|                                           | 1/400 [00:00<05:48,  1.15it/s]

iteration [0], Loss: 0.005911531858146191


 12%|█████▎                                    | 50/400 [00:25<03:10,  1.84it/s]

iteration [49], Loss: 0.0059252274222671986


 25%|██████████▍                               | 99/400 [00:50<02:40,  1.88it/s]

iteration [98], Loss: 0.00597411161288619


 37%|███████████████▏                         | 148/400 [01:14<02:17,  1.83it/s]

iteration [147], Loss: 0.006128179375082254


 49%|████████████████████▏                    | 197/400 [01:38<01:54,  1.78it/s]

iteration [196], Loss: 0.006153138354420662


 62%|█████████████████████████▏               | 246/400 [02:02<01:20,  1.91it/s]

iteration [245], Loss: 0.005956715904176235


 74%|██████████████████████████████▏          | 295/400 [02:25<00:54,  1.93it/s]

iteration [294], Loss: 0.006193230394273996


 86%|███████████████████████████████████▎     | 344/400 [02:49<00:32,  1.70it/s]

iteration [343], Loss: 0.005565815605223179


 98%|████████████████████████████████████████▎| 393/400 [03:13<00:04,  1.73it/s]

iteration [392], Loss: 0.005985285621136427


100%|█████████████████████████████████████████| 400/400 [03:16<00:00,  2.03it/s]
  0%|                                           | 1/400 [00:00<04:47,  1.39it/s]

iteration [0], Loss: 0.005973868537694216


 12%|█████▎                                    | 50/400 [00:24<03:04,  1.90it/s]

iteration [49], Loss: 0.00574469193816185


 25%|██████████▍                               | 99/400 [00:49<02:43,  1.84it/s]

iteration [98], Loss: 0.006255338899791241


 37%|███████████████▏                         | 148/400 [01:13<02:18,  1.83it/s]

iteration [147], Loss: 0.005865000654011965


 49%|████████████████████▏                    | 197/400 [01:37<01:47,  1.88it/s]

iteration [196], Loss: 0.00601186091080308


 62%|█████████████████████████▏               | 246/400 [02:01<01:26,  1.78it/s]

iteration [245], Loss: 0.005928500555455685


 74%|██████████████████████████████▏          | 295/400 [02:24<00:58,  1.78it/s]

iteration [294], Loss: 0.0057656546123325825


 86%|███████████████████████████████████▎     | 344/400 [02:48<00:29,  1.89it/s]

iteration [343], Loss: 0.006072668358683586


 98%|████████████████████████████████████████▎| 393/400 [03:12<00:03,  1.90it/s]

iteration [392], Loss: 0.006294204853475094


100%|█████████████████████████████████████████| 400/400 [03:15<00:00,  2.04it/s]
  0%|                                           | 1/400 [00:00<04:45,  1.40it/s]

iteration [0], Loss: 0.00574893644079566


 12%|█████▎                                    | 50/400 [00:24<03:04,  1.90it/s]

iteration [49], Loss: 0.006171909626573324


 25%|██████████▍                               | 99/400 [00:48<02:44,  1.83it/s]

iteration [98], Loss: 0.0061696982011199


 37%|███████████████▏                         | 148/400 [01:12<02:21,  1.78it/s]

iteration [147], Loss: 0.005702509544789791


 49%|████████████████████▏                    | 197/400 [01:36<01:51,  1.81it/s]

iteration [196], Loss: 0.00562923913821578


 62%|█████████████████████████▏               | 246/400 [01:59<01:19,  1.93it/s]

iteration [245], Loss: 0.005723692011088133


 74%|██████████████████████████████▏          | 295/400 [02:23<00:54,  1.93it/s]

iteration [294], Loss: 0.005831304006278515


 86%|███████████████████████████████████▎     | 344/400 [02:47<00:30,  1.85it/s]

iteration [343], Loss: 0.005950096528977156


 98%|████████████████████████████████████████▎| 393/400 [03:11<00:03,  1.81it/s]

iteration [392], Loss: 0.005805119872093201


100%|█████████████████████████████████████████| 400/400 [03:15<00:00,  2.05it/s]


In [24]:
torch.save(model.state_dict(), 'weights_arc2.pth')

In [18]:
weights_path = 'weights_arc2.pth'  # Provide the correct path to the saved weights file
model.load_state_dict(torch.load(weights_path))

<All keys matched successfully>

In [25]:
for index in range(len(c2ws_test)):
    print("saving: ", index)
    model.eval()
    eval_height = 200
    eval_width = 200
    # c2w_eval = c2ws_train[2]
    c2w_eval = c2ws_test[index]
    results = np.zeros((eval_height, eval_width,3))
    for i in range(eval_height):
        uv = []
        for j in range(eval_width):
            uv.append([i,j])
        ray_o, ray_d = pixel_to_ray(K, c2w_eval, uv)
        rays_o = np.repeat(ray_o, eval_height, axis=0)
        rays_d = list(ray_d)
    
        rays_o = torch.tensor(rays_o).float().to(device)
        rays_d = torch.tensor(rays_d).float().to(device)
        points = sample_along_rays(rays_o, rays_d)
        points = points.to(device)
        # rays_d = torch.tensor(rays_d).to(device)
        rays_d = torch.unsqueeze(rays_d,1)
        rays_d = rays_d.repeat(1,n_samples,1)
        rays_d = rays_d.view(-1,3)
        
        rgbs, sigmas = model(points, rays_d)
        rgbs = rgbs.to(device)
        sigmas = sigmas.to(device)
        # sigams = model.foward_dentisy()
        sigmas = sigmas.view(-1, n_samples, 1)
        rgbs = rgbs.view(-1, n_samples, 3)
        
        rendered_colors = volrend(sigmas, rgbs, step_size).detach().cpu().numpy()
    
        # print(rendered_colors.shape)
        results[i,:,:] = rendered_colors
    save_file_path = 'test_results/%s.png'%index  # Change the file extension as needed (e.g., .jpg, .png, .bmp, etc.)

    # Save the NumPy array as an image using skimage
    skio.imsave(save_file_path, results)

saving:  0


  rays_d = torch.tensor(rays_d).float().to(device)


saving:  1




saving:  2




saving:  3




saving:  4




saving:  5




saving:  6




saving:  7




saving:  8




saving:  9




saving:  10




saving:  11




saving:  12




saving:  13




saving:  14




saving:  15




saving:  16




saving:  17




saving:  18




saving:  19




saving:  20




saving:  21




saving:  22




saving:  23




saving:  24




saving:  25




saving:  26




saving:  27




saving:  28




saving:  29




saving:  30




saving:  31




saving:  32




saving:  33




saving:  34




saving:  35




saving:  36




saving:  37




saving:  38




saving:  39




saving:  40




saving:  41




saving:  42




saving:  43




saving:  44




saving:  45




saving:  46




saving:  47




saving:  48




saving:  49




saving:  50




saving:  51




saving:  52




saving:  53




saving:  54




saving:  55




saving:  56




saving:  57




saving:  58




saving:  59


