In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset
import json
import numpy as np
import os
from PIL import Image
from torchvision import transforms as T
import plotly.graph_objects as go
from tqdm.notebook import tqdm
from positional_encodings.torch_encodings import PositionalEncoding2D


In [None]:
def parse_bbox(file_path):
    with open(file_path, 'r') as f:
        lines = f.readlines()
    # Extract parameters from the file
    bbox = [float(x) for x in lines[0].split()]
    return bbox

def decompose_pose(pose_matrix):
    # Extract rotation (3x3) and translation (3x1)
    R = pose_matrix[:3, :3]  # Rotation matrix
    t = pose_matrix[:3, 3]   # Translation vector
    return R, t

def parse_intrinsics(file_path):
    with open(file_path, 'r') as f:
        lines = f.readlines()
    # Extract parameters from the file
    fx = fy = float(lines[0].split()[0])  # Assuming symmetrical focal length
    cx = float(lines[0].split()[1])
    cy = float(lines[0].split()[2])
    width = int(lines[-1].split()[0])
    height = int(lines[-1].split()[1])

    intrinsics = {
        "fx": fx,
        "fy": fy,
        "cx": cx,
        "cy": cy,
        "width": width,
        "height": height
    }
    return intrinsics



In [None]:
class CameraPoseDataset(Dataset):
    def __init__(self, image_dir, pose_dir):
        self.image_paths = sorted([os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith('.png')])
        self.pose_paths = sorted([os.path.join(pose_dir, f) for f in os.listdir(pose_dir) if f.endswith('.txt')])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load image
        image = torch.from_numpy(np.array(Image.open(self.image_paths[idx])).astype(np.float32) / 255.0)                
        image = image.permute(2, 0, 1)[0:3]  # Convert to CxHxW
        
        # Load pose
        with open(self.pose_paths[idx], 'r') as f:            
            pose = np.array([list(map(float, line.strip().split())) for line in f.readlines()]).astype(np.float32)            

        return image, pose

# Usage
ds = CameraPoseDataset('/mnt/raid/C1_ML_Analysis/nerf_data/Synthetic_NeRF/Lego/rgb', '/mnt/raid/C1_ML_Analysis/nerf_data/Synthetic_NeRF/Lego/pose')
dl = torch.utils.data.DataLoader(ds, batch_size=4, shuffle=True)

In [None]:
intrinsics = parse_intrinsics('/mnt/raid/C1_ML_Analysis/nerf_data/Synthetic_NeRF/Lego/intrinsics.txt')
bbox = parse_bbox('/mnt/raid/C1_ML_Analysis/nerf_data/Synthetic_NeRF/Lego/bbox.txt')



intrinsics['fx'] = 177.777
intrinsics['fy'] = 177.777
intrinsics['cx'] = 64
intrinsics['cy'] = 64
intrinsics['width'] = 128
intrinsics['height'] = 128

print(intrinsics, bbox)

In [None]:

batch = next(iter(dl))
images, poses = batch
print(images.shape, poses.shape)

In [None]:
def visualize_camera_positions_plotly(poses, bbox):
    cam_positions = []

    # Extract camera positions from the poses
    for pose in poses:
        R, t = decompose_pose(pose)  # Extract rotation and translation
        cam_pos = t.cpu().numpy()
        cam_positions.append(cam_pos)

    cam_positions = torch.stack([torch.tensor(pos) for pos in cam_positions]).numpy()

    # Create a 3D scatter plot for camera positions
    fig = go.Figure()
    fig.add_trace(go.Scatter3d(
        x=cam_positions[:, 0],
        y=cam_positions[:, 1],
        z=cam_positions[:, 2],
        mode='markers',
        marker=dict(size=5, color='blue', opacity=0.8),
        name="Camera Positions"
    ))

    # Add bounding box
    x_min, y_min, z_min, x_max, y_max, z_max, _ = bbox

    # Define the 8 corners of the bounding box
    corners = np.array([
        [x_min, y_min, z_min],
        [x_max, y_min, z_min],
        [x_max, y_max, z_min],
        [x_min, y_max, z_min],
        [x_min, y_min, z_max],
        [x_max, y_min, z_max],
        [x_max, y_max, z_max],
        [x_min, y_max, z_max]
    ])

    # Define edges (pairs of corners)
    edges = [
        (0, 1), (1, 2), (2, 3), (3, 0),  # Bottom face
        (4, 5), (5, 6), (6, 7), (7, 4),  # Top face
        (0, 4), (1, 5), (2, 6), (3, 7)   # Vertical edges
    ]

    # Draw edges of the bounding box
    for edge in edges:
        fig.add_trace(go.Scatter3d(
            x=[corners[edge[0], 0], corners[edge[1], 0]],
            y=[corners[edge[0], 1], corners[edge[1], 1]],
            z=[corners[edge[0], 2], corners[edge[1], 2]],
            mode='lines',
            line=dict(color='green', width=3),
            showlegend=False
        ))

    # Draw corner points
    fig.add_trace(go.Scatter3d(
        x=corners[:, 0],
        y=corners[:, 1],
        z=corners[:, 2],
        mode='markers',
        marker=dict(size=5, color='green'),
        name="Bounding Box Corners"
    ))

    return fig

def visualize_rays_plotly(ray_origins, ray_directions, poses, bbox, num_rays=50, ray_length=1.0, pts=None):
    """
    Visualizes camera rays and camera planes using Plotly.

    Args:
        ray_origins: Tensor of shape [batch_size, num_rays_total, 3]
        ray_directions: Tensor of shape [batch_size, num_rays_total, 3]
        poses: Tensor of shape [batch_size, 4, 4] containing camera poses.        
        num_rays: Number of rays to sample per camera
        ray_length: Length of the rays in visualization

    Returns:
        Interactive 3D Plotly figure
    """
    batch_size, num_rays_total, _ = ray_origins.shape

    fig = visualize_camera_positions_plotly(poses, bbox)

    # Process each camera in the batch
    for cam_idx in range(batch_size):
        # Extract camera rotation and translation
        R, t = poses[cam_idx, :3, :3], poses[cam_idx, :3, 3]
        cam_pos = t.cpu().numpy()  # Compute world-space camera position

        # Sample a subset of rays for visualization
        sampled_indices = torch.randperm(num_rays_total)[:num_rays]
        sampled_origins = ray_origins[cam_idx, sampled_indices].cpu().numpy()
        sampled_directions = ray_directions[cam_idx, sampled_indices].cpu().numpy()

        # Compute ray endpoints
        sampled_endpoints = sampled_origins + ray_length * sampled_directions

        # Add rays to the plot
        for i in range(num_rays):
            fig.add_trace(go.Scatter3d(
                x=[sampled_origins[i, 0], sampled_endpoints[i, 0]],
                y=[sampled_origins[i, 1], sampled_endpoints[i, 1]],
                z=[sampled_origins[i, 2], sampled_endpoints[i, 2]],
                mode='lines',
                line=dict(color='red', width=2),
                name=f"Ray {i}" if cam_idx == 0 and i == 0 else None  # Show legend only once
            ))

        if pts is not None:
            p = pts[cam_idx, sampled_indices].cpu().numpy()
            p = p.reshape(-1, 3)
            fig.add_trace(go.Scatter3d(
                x=p[:, 0],
                y=p[:, 1],
                z=p[:, 2],
                mode='markers',
                marker=dict(size=2, color='blue')
            ))

    # Update layout
    fig.update_layout(
        title="3D Visualization of Camera Rays & Planes",
        scene=dict(
            xaxis_title="X",
            yaxis_title="Y",
            zaxis_title="Z",
            aspectmode="auto"
        )
    )

    return fig


In [None]:
class Params:
    def __init__(self):
        self.fx = 177.777
        self.fy = 177.777
        self.cx = 64
        self.cy = 64
        self.width = 128
        self.height = 128

class Ray():
    def __init__(self):
        self.hparams = Params()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def generate_rays(self, c2w):
        """
        Generate rays for a given camera configuration.

        Args:
            c2w: Camera-to-world transformation matrix (4x4).

        Returns:
            rays_o: Ray origins (H*W, 3).
            rays_d: Ray directions (H*W, 3).
        """

        batch_size = c2w.shape[0]
        device = self.device  # Get the device of c2w
        focal = self.hparams.fx
        W = self.hparams.width
        H = self.hparams.height
        # print(type(H), type(W), type(focal), type(c2w))

        i, j = torch.meshgrid(
            torch.arange(W, dtype=torch.float32, device=device),
            torch.arange(H, dtype=torch.float32, device=device),
            indexing='xy'
        )
        dirs = torch.stack(
            [(i - W * .5) / focal, -(j - H * .5) / focal, -torch.ones_like(i, device = device)], -1
        ).unsqueeze(0).repeat(batch_size, 1, 1, 1)
        
        rays_d = torch.bmm(dirs.view(batch_size, -1, 3), c2w[:, :3, :3].transpose(-1, -2))
        
        rays_o = c2w[:, :3, -1].unsqueeze(1).expand(rays_d.shape)

        return rays_o, rays_d

# Generate rays

ray_origins, ray_directions = Ray().generate_rays(poses.cuda())

fig = visualize_rays_plotly(ray_origins, ray_directions, poses, bbox, num_rays=20, ray_length=1.0)
fig.show()

# torch.Size([2, 128, 128, 1, 3]) torch.Size([2, 3, 3])


In [None]:
class NeRF(nn.Module):
  def __init__(self, input_dim=3, pos_dim=60, view_dim=24, hidden=256) -> None:
      super().__init__()

      self.proj_pos = nn.Linear(input_dim, pos_dim)
      self.pos_enc_pos = PositionalEncoding2D(pos_dim)
      self.act = nn.ReLU()

      self.block1 = nn.Sequential(nn.Linear(pos_dim, hidden), 
      nn.ReLU(),
      nn.Linear(hidden, hidden),
      nn.ReLU(),
      nn.Linear(hidden, hidden),
      nn.ReLU(),
      nn.Linear(hidden, hidden),
      nn.ReLU())
      
      self.block2 = nn.Sequential(nn.Linear(pos_dim + hidden, hidden),
      nn.ReLU(),
      nn.Linear(hidden, hidden),
      nn.ReLU(),
      nn.Linear(hidden, hidden),
      nn.ReLU(),
      nn.Linear(hidden, hidden)) # No activation
      
      self.proj_view = nn.Linear(input_dim, view_dim)
      self.pos_enc_view = PositionalEncoding2D(view_dim)

      self.final_sigma = nn.Sequential(nn.Linear(view_dim + hidden, 1), 
      nn.ReLU(),
      nn.Linear(1, 1),
      nn.Softplus())

      self.final_rgb = nn.Sequential(nn.Linear(view_dim + hidden, hidden),
      nn.ReLU(),
      nn.Linear(hidden, 3),
      nn.Sigmoid())

  def forward(self, x_p, x_v):

      # parameters:
      # x_p: torch.Size([4, N_P, N_Samples, 3]) N_P is the number of points, N_Samples is the number of samples/bins
      # x_v: torch.Size([4, N_P, N_Samples, 3]) N_V is the number of view directions

      x_p_pos_enc = self.proj_pos(x_p)
      x_p_pos_enc = self.pos_enc_pos(x_p_pos_enc)
      x_p = self.act(x_p_pos_enc)

      x_p = self.block1(x_p)

      x_p = torch.cat([x_p, x_p_pos_enc], dim=-1)

      x_p = self.block2(x_p)

      x_v = self.proj_view(x_v)
      x_v = self.pos_enc_view(x_v)

      x = torch.cat([x_p, x_v], dim=-1)

      sigma = self.final_sigma(x)
      rgb = self.final_rgb(x)

      return rgb, sigma

In [None]:
def compute_accumulated_transmittance(alphas):
    accumulated_transmittance = torch.cumprod(alphas, 1)
    return torch.cat((torch.ones((accumulated_transmittance.shape[0], 1), device=alphas.device),
                      accumulated_transmittance[:, :-1]), dim=-1)

# def render_rays(network_fn, rays_o, rays_d, near, far, nb_bins, device, rand=False):

#     # Sampling
#     z_vals = torch.linspace(near, far, steps=nb_bins, device=device)

#     if rand:
#         z_vals += torch.rand(*z_vals.shape[:-1], nb_bins, device=rays_o.device) * (far - near) / nb_bins

#     pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None]

#     # Normalize view directions
#     view_dirs = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
#     view_dirs = view_dirs[..., None, :].expand(pts.shape)
    
#     rgb, sigma = network_fn(pts, view_dirs)
    
#     # Improved volume rendering
#     dists = z_vals[..., 1:] - z_vals[..., :-1]  # Shape: [batch, N_samples-1]
#     dists = torch.cat([dists, torch.tensor([1e10], device=device)], -1)
    
#     # No need to manually expand dists as broadcasting will handle it
#     alpha = 1. - torch.exp(-sigma.squeeze(-1) * dists)  # Shape: [batch, N_samples]
#     alpha = alpha.unsqueeze(-1)  # Shape: [batch, N_samples, 1]
    
#     # Computing transmittance
#     ones_shape = (alpha.shape[0], alpha.shape[1], 1, 1)
    
#     T = torch.cumprod(
#         torch.cat([
#             torch.ones(ones_shape, device=device),
#             1. - alpha + 1e-10
#         ], dim=2),
#         dim=2
#     )[:,:,:-1]  # Shape: [batch, N_samples, 1]

#     weights = alpha * T  # Shape: [batch, N_samples, 1]

#     # Compute final colors and depths
#     rgb_map = torch.sum(weights * rgb, dim=2)  # Sum along sample dimension
#     depth_map = torch.sum(weights.squeeze(-1) * z_vals, dim=-1)  # Shape: [batch]
#     acc_map = torch.sum(weights.squeeze(-1), dim=-1)  # Shape: [batch]

#     return rgb_map, depth_map, acc_map


def render_rays(nerf_model, rays_o, rays_d, rand=False, device='cpu', near=1.0, far=6.0, nb_bins=64):

    # Sampling z_vals is t in the paper
    z_vals = torch.linspace(near, far, steps=nb_bins, device=device)

    if rand:
        z_vals += torch.rand(*z_vals.shape[:-1], nb_bins, device=device) * (far - near) / nb_bins


    pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None]
    
    dirs = rays_d.unsqueeze(-2).expand(pts.shape)
    
    rgb, sigma = nerf_model(pts, dirs)

    delta = torch.diff(z_vals, dim=-1)
    delta = torch.cat([delta, torch.tensor([1e10], device=device)], dim=-1)
    
    alpha = 1.0 - torch.exp(-sigma.squeeze(-1) * delta)
    cumprod = torch.cumprod(1.0 - alpha + 1e-10, dim=2)
    exclusive_cumprod = torch.cat([torch.ones_like(alpha[..., :1]), cumprod[..., :-1]], dim=-1)
    
    weights = (alpha * exclusive_cumprod)

    rgb_map = (weights.unsqueeze(-1) * rgb).sum(dim=2)
    depth_map = (weights*z_vals).sum(dim=2)
    acc_map = weights.sum(dim=2)

    return rgb_map, depth_map, acc_map, pts

nerf_model = NeRF().cuda()
rgb, depth, acc, pts = render_rays(nerf_model, ray_origins.cuda(), ray_directions.cuda(), rand=False, device='cuda:0', near=1.0, far=6.0, nb_bins=64) 

fig = visualize_rays_plotly(ray_origins, ray_directions, poses, bbox, num_rays=20, ray_length=1.0, pts=pts)
fig.show()

In [None]:
def train(nerf_model, optimizer, scheduler, data_loader, intrinsics, device='cpu', near=0, far=1, nb_epochs=int(1e5), nb_bins=64, NSamples=128*128):
    training_loss = []
    for _ in tqdm(range(nb_epochs)):
        for batch in data_loader:
            
            images, poses = batch
            images = images.to(device)
            poses = poses.to(device)
            
            ray_origins, ray_directions = generate_rays(poses, intrinsics)

            idx_samples = torch.randint(0, ray_origins.shape[1], (NSamples,), device=device)
            
            ray_origins = ray_origins[:, idx_samples]
            ray_directions = ray_directions[:, idx_samples]

            images = images.view(images.shape[0], images.shape[1], -1)            
            images = images[:, :, idx_samples]

            # rgb = []

            # for r_o, r_d in zip(torch.chunk(ray_origins, chunks=128, dim=1), torch.chunk(ray_directions, chunks=128, dim=1)):
            #     rgb_, _, _ = render_rays(nerf_model, r_o, r_d, device=device, near=near, far=far, nb_bins=nb_bins)            
            #     rgb.append(rgb_)

            # rgb = torch.cat(rgb, dim=1)

            rgb, _, _ = render_rays(nerf_model, ray_origins, ray_directions, device=device, near=near, far=far, nb_bins=nb_bins)            
            
            loss = ((images - rgb.permute(0, 2, 1)) ** 2).sum()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            training_loss.append(loss.item())
        scheduler.step()

        # for img_index in range(200):
        #     test(hn, hf, testing_dataset, img_index=img_index, nb_bins=nb_bins, H=H, W=W)

    return training_loss


nerf_model = NeRF().cuda()
optimizer = torch.optim.Adam(nerf_model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.5)

# train(nerf_model=nerf_model, optimizer=optimizer, scheduler=scheduler, data_loader=dl, intrinsics=intrinsics, device='cuda', near=0.8, far=4, nb_epochs=10, nb_bins=64, NSamples=128*128)

In [None]:

@torch.no_grad()
def test(hn, hf, dataset, chunk_size=10, img_index=0, nb_bins=192, H=400, W=400):
    ray_origins = dataset[img_index * H * W: (img_index + 1) * H * W, :3]
    ray_directions = dataset[img_index * H * W: (img_index + 1) * H * W, 3:6]

    data = []
    for i in range(int(np.ceil(H / chunk_size))):
        ray_origins_ = ray_origins[i * W * chunk_size: (i + 1) * W * chunk_size].to(device)
        ray_directions_ = ray_directions[i * W * chunk_size: (i + 1) * W * chunk_size].to(device)

        regenerated_px_values = render_rays(model, ray_origins_, ray_directions_, hn=hn, hf=hf, nb_bins=nb_bins)
        data.append(regenerated_px_values)
    img = torch.cat(data).data.cpu().numpy().reshape(H, W, 3)

    plt.figure()
    plt.imshow(img)
    plt.savefig(f'novel_views/img_{img_index}.png', bbox_inches='tight')
    plt.close()

In [None]:
pe2 = PositionalEncoding2D(128)

pe2(torch.rand(4, 1024, 64, 128)).shape

In [None]:
def encoding(x, L=10):
    rets = [x]
    for i in range(L):
        for fn in [torch.sin, torch.cos]:  # Use torch functions
            rets.append(fn(2. ** i * x))
    return torch.cat(rets, dim=-1)  # Concatenate along the last dimension

class NeRF_v2(nn.Module):

  def __init__(self, pos_enc_dim = 63, view_enc_dim = 27, hidden = 256) -> None:
     super().__init__()

     self.linear1 = nn.Sequential(nn.Linear(pos_enc_dim,hidden),nn.ReLU())

     self.pre_skip_linear = nn.Sequential()
     for _ in range(4):
      self.pre_skip_linear.append(nn.Linear(hidden,hidden))
      self.pre_skip_linear.append(nn.ReLU())


     self.linear_skip = nn.Sequential(nn.Linear(hidden+pos_enc_dim,hidden),nn.ReLU())

     self.post_skip_linear = nn.Sequential()
     for _ in range(2):
      self.post_skip_linear.append(nn.Linear(hidden,hidden))
      self.post_skip_linear.append(nn.ReLU())

     self.density_layer = nn.Sequential(nn.Linear(hidden,1), nn.ReLU())

     self.linear2 = nn.Linear(hidden,hidden)

     self.color_linear1 = nn.Sequential(nn.Linear(hidden+view_enc_dim,hidden//2),nn.ReLU())
     self.color_linear2 = nn.Sequential(nn.Linear(hidden//2, 3),nn.Sigmoid())

     self.relu = nn.ReLU()
     self.sigmoid = nn.Sigmoid()


  def forward(self,input):

    # Extract pos and view dirs
    positions = input[..., :3]
    view_dirs = input[...,3:]

    # Encode
    pos_enc = encoding(positions,L=10)
    view_enc = encoding(view_dirs, L=4)

    x = self.linear1(pos_enc)
    x = self.pre_skip_linear(x)

    # Skip connection
    x = torch.cat([x, pos_enc], dim=-1)
    x = self.linear_skip(x)

    x = self.post_skip_linear(x)

    # Density prediction
    sigma = self.density_layer(x)

    x = self.linear2(x)

    # Incoroporate view encoding
    x = torch.cat([x, view_enc],dim=-1)
    x = self.color_linear1(x)

    # Color Prediction
    rgb = self.color_linear2(x)

    return torch.cat([sigma, rgb], dim=-1)

In [None]:
model = NeRF_v2()
model(torch.rand(1, 64, 6)).shape

In [None]:
def encoding(x, L=10):
    # x is assumed to be a tensor with shape (..., D)
    device, dtype = x.device, x.dtype
    # Precompute the frequency factors: shape (L,)
    freqs = (2 ** torch.arange(L, device=device, dtype=dtype)) * torch.pi
    # Expand dimensions of x to (..., D, 1) for broadcasting
    x_expanded = x.unsqueeze(-1)
    # Compute the arguments for sin and cos: shape (..., D, L)
    args = x_expanded * freqs
    # Compute sin and cos encodings: shape (..., D, L) each
    sin_enc = torch.sin(args)
    cos_enc = torch.cos(args)
    # Interleave the sin and cos encodings:
    # First stack to get shape (..., D, L, 2) then flatten the last two dimensions to (..., D, 2*L)
    sin_cos = torch.stack((sin_enc, cos_enc), dim=-1).view(*x.shape[:-1], -1)
    # Concatenate the original input with the positional encodings along the last dimension
    return torch.cat((x, sin_cos), dim=-1)


x = torch.rand(4, 4, 4, 3)
x_1 = encoding(x)

def encoding_v2(x, L=10):
    res = [x]
    for i in range(L):
        for fn in [torch.sin, torch.cos]:
            res.append(fn(2 ** i * torch.pi * x))

    print(len(res), res[0].shape, res[1].shape, res[2].shape)
    return torch.cat(res,dim=-1)

x_2 = encoding_v2(x)

(x_1 - x_2).abs().max() 


In [None]:
# data = np.load('/mnt/raid/C1_ML_Analysis/nerf_data/tiny_nerf_data.npz')
# images = data['images']
# poses = data['poses']
# focal = data['focal']
# H, W = images.shape[1:3]
# print(images.shape, poses.shape, focal)