# Nerf with GCN version
## New adding file:
1. models/graphConv.py: putting the GCN layer into this file, and building the adj matrix, normailze this matrix.

In [2]:
class GraphConvolution(Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """

    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
        support = torch.mm(input, self.weight)
        ### if the adj matrix is small, it can use the mm directly.
        # output = torch.spmm(adj, support)
        output = torch.mm(adj, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'

### 1. Building adj matrix which using the square distance function, setting the threshold, if large than threshold setting as 0, if smaller than threshold, setting as 1.
### 2. Based on the degree of every node, normalize the adj matrix.

In [None]:
def design_adj_matrix(rays_o, threshold):
    """
    input: rays_o : 3D coordinators. N_rays,
                    Actually, it only 2D. Z is fixed.

    Output:
       return : adjacency matrix (N_rays, N_rays)
    """
    dim = rays_o.shape[0]
    # create the sparse adj matrix

    y = torch.zeros(dim, dim).cuda()
    x = torch.ones(dim, dim).cuda()

    # generate the distance matrix, (dim, dim), return A + I it's a bug.
    adj = torch.where(torch.cdist(rays_o, rays_o, p=2) <= threshold, x, y)
#    logging.info(sum(adj,1))
    return adj

def normalize(mx): # all column will be 1.0
    """Row-normalize sparse matrix"""
    import scipy.sparse as sp
    mx = mx.cpu()
    rowsum = np.array(mx.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx)
    mx = torch.tensor(mx).cuda()
    return mx


def adj_normalized(adj):
    """
    Input: adj: adj matrix with (N_rays, N_rays), A+I

      Avoid the information focus on some nodes which have too many neighbors. So do normalized operation.
      degree_matrix^(-1/2) * A * degree_matrix^(-1/2), degree_ii = sum_j(Aij)

    Output:
       co-adj: (N_rays, N_rays)
    """
    #degree = torch.diag(torch.sum(adj, 1) ** (-1/2))  # return the degree of every nodes
    #return degree @ adj @ degree  # return the adj, after normalized.
    return normalize(adj)

### 3. Adding this adj functions into models/rendering.py, because (x,y,z) already map by ndc() coordinator.

In [None]:
def render_rays(models,
                embeddings,
                rays,
                N_samples=64,
                use_disp=False,
                perturb=0,
                noise_std=1,
                N_importance=0,
                chunk=1024*32,
                white_back=False,
                test_time=False
                ):
    def inference(model, embedding_xyz, xyz_, dir_, dir_embedded, z_vals, adj, weights_only=False):  ## add adj as new parameters
        N_samples_ = xyz_.shape[1]
        N_rays = xyz_.shape[0]
        if not weights_only:
            dir_embedded = torch.repeat_interleave(dir_embedded, repeats=N_samples_, dim=0)
                           # (N_rays*N_samples_, embed_dir_channels)
            ## rollback the shapre into (N_rays, N_samples, Dimension), because I want to input the MLP with the points with similar Z values.
            ## if sample is 64, it will divide into 64 chunk of data, then input to MLP.
            dir_embedded = dir_embedded.view(N_rays, N_samples_, 27)  
        
        out_chunks = []
        for i in range(0, N_samples_):  ## putting into data points every layer.
            xyz_embedded = embedding_xyz(xyz_[:, i, :]) # N_rays, 63
            if not weights_only:
                xyzdir_embedded = torch.cat([xyz_embedded,
                                             dir_embedded[:, i, :]], 1) # N_rays, 27
            else:
                xyzdir_embedded = xyz_embedded
            out_chunks += [model(xyzdir_embedded, adj, sigma_only=weights_only)]  ## put data into MLP.
            
        out = torch.cat(out_chunks, 0)
        if weights_only:
            # sigmas = out.view(N_rays, N_samples_)
            sigmas = out.view(N_samples_, N_rays).permute(1,0)
        else:
            # rgbsigma = out.view(N_rays, N_samples_, 4)
            rgbsigma = out.view(N_samples_, N_rays, 4).permute(1,0,2)   ### reshape the data.
            rgbs = rgbsigma[..., :3] # (N_rays, N_samples_, 3)
            sigmas = rgbsigma[..., 3] # (N_rays, N_samples_)
        ...
    
    model_coarse = models[0]
    embedding_xyz = embeddings[0]
    embedding_dir = embeddings[1]

    # Decompose the inputs
    N_rays = rays.shape[0]
    rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3)
    near, far = rays[:, 6:7], rays[:, 7:8] # both (N_rays, 1)

    import math
    adj = design_adj_matrix(rays_o, 1/math.sqrt(75))  # the threshold need to optimize.
    co_adj = adj_normalized(adj)  ## fixed

    # Embed direction
    dir_embedded = embedding_dir(rays_d) # (N_rays, embed_dir_channels)
    
    ...
    return result

## Changing the data method into single image training.

### 1. Changing the datasets/llff.py. Every data comes from one image with 4096 points.

In [None]:
class LLFFDataset(Dataset):
    ...
    def read_meta(self):
        ...
         if self.split == 'train': 
                ...
                self.all_rays = []
                self.all_rgbs = []
                rays_dict = {}
                rgbs_dict = {}

                img_paths = self.image_paths[:val_idx] + self.image_paths[val_idx+1:]  ### removing the val image from the image sets.
                img_paths = np.array(list(img_paths) * 20)                                                     ### repeat the img names 20 times
                np.random.shuffle(img_paths)                                                                           ### random shuffle the images
                for i, image_path in enumerate(img_paths):
                    if image_path not in rays_dict.keys():                                                            ### store the rays and rgbs for images, if re-choose the image, not re-calculate
                        idx = list(self.image_paths).index(image_path)
                        c2w = torch.FloatTensor(self.poses[idx])                                                  ### this idx of poses determined by the name of images.
                        img = Image.open(image_path).convert('RGB')
                            assert img.size[1]*self.img_wh[0] == img.size[0]*self.img_wh[1], \
                                f'''{image_path} has different aspect ratio than img_wh, 
                                    please check your data!'''
                        img = img.resize(self.img_wh, Image.LANCZOS)
                        img = self.transform(img) # (3, h, w)
                        img = img.view(3, -1).permute(1, 0)
                        
                        self.all_rgbs += [img] # store the rgbs
                        rgbs_dict[image_path] = img
                        
                        rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
                        if not self.spheric_poses:
                            near, far = 0, 1
                            rays_o, rays_d = get_ndc_rays(self.img_wh[1], self.img_wh[0],
                                                      self.focal, 1.0, rays_o, rays_d)
                                         # near plane is always at 1.0
                                         # near and far in NDC are always 0 and 1
                                         # See https://github.com/bmild/nerf/issues/34
                        else:
                            near = self.bounds.min()
                            far = min(8 * near, self.bounds.max()) # focus on central object only

                        self.all_rays += [torch.cat([rays_o, rays_d, 
                                                 near*torch.ones_like(rays_o[:, :1]),
                                                 far*torch.ones_like(rays_o[:, :1])],
                                                 1)] # (h*w, 8)
                    
                        rays_dict[image_path] = torch.cat([rays_o, rays_d,
                                                 near*torch.ones_like(rays_o[:, :1]),
                                                 far*torch.ones_like(rays_o[:, :1])],
                                                 1)                  ## store the rays

                        _index = np.arange(self.all_rays[i].shape[0], dtype=np.int32)  
                        np.random.shuffle(_index)

                        self.all_rays[i] = self.all_rays[i][_index[:self.batch_size]]  ## choose the 4096 points from the whole sets.
                        self.all_rgbs[i] = self.all_rgbs[i][_index[:self.batch_size]]
                    else:
                        _index = np.arange(rays_dict[image_path].shape[0], dtype=np.int32)
                        np.random.shuffle(_index)

                        self.all_rays += [rays_dict[image_path][_index[:self.batch_size]]]
                        self.all_rgbs += [rgbs_dict[image_path][_index[:self.batch_size]]]

                self.all_rays = torch.cat(self.all_rays, 0) # ((N_images-1)*h*w, 8)
                self.all_rgbs = torch.cat(self.all_rgbs, 0) # ((N_images-1)*h*w, 3)

### 2. Disabling the shuffle in Dataloader(), because in datasets/llff.py, I shuffle the data within 4096 points in single image already. 
### So putting into the model will be 4096 points in single image, and already shuffle.

In [None]:
class NeRFSystem(LightningModule):
    ...
    def train_dataloader(self):
                return DataLoader(self.train_dataset,
                          shuffle=False,
                          num_workers=4,
                          batch_size=self.hparams.batch_size,
                          pin_memory=True)

## Adding GCN layer into NERF.
### 1. adding GCN layer at the end of 9 layers, and only in fine-model. models/nerf.py

In [None]:
class NeRF(nn.Module):
    ...
    def __init__(self,
                 D=8, W=256,
                 in_channels_xyz=63, in_channels_dir=27,
                 skips=[4]):
        ...
        self.gcn = GraphConvolution(W, W)
    
    def forward(self, x, adj, sigma_only=False):
        if not sigma_only:
            input_xyz, input_dir = \
                torch.split(x, [self.in_channels_xyz, self.in_channels_dir], dim=-1)
        else:
            input_xyz = x

        xyz_ = input_xyz
        for i in range(self.D):
            if i in self.skips:
                xyz_ = torch.cat([input_xyz, xyz_], -1)
            xyz_ = getattr(self, f"xyz_encoding_{i+1}")(xyz_)

        sigma = self.sigma(xyz_)
        if sigma_only:
            return sigma

        xyz_ = F.relu(self.gcn(xyz_, adj))  ## only one GCN layer.

        xyz_encoding_final = self.xyz_encoding_final(xyz_)
        ## changes, xyz_encoding : torch.Size([32768, 256])

        dir_encoding_input = torch.cat([xyz_encoding_final, input_dir], -1)
        dir_encoding = self.dir_encoding(dir_encoding_input)
        rgb = self.rgb(dir_encoding)

        out = torch.cat([rgb, sigma], -1)

        return out

### 2. train.py

In [None]:
class NeRFSystem(LightningModule):
    def __init__(self, hparams):
        super(NeRFSystem, self).__init__()
        self.hparams = hparams

        self.loss = loss_dict[hparams.loss_type]()

        self.embedding_xyz = Embedding(3, 10) # 10 is the default number
        self.embedding_dir = Embedding(3, 4) # 4 is the default number
        self.embeddings = [self.embedding_xyz, self.embedding_dir]

        self.nerf_coarse = NeRF_Coarse()
        self.models = [self.nerf_coarse]
        if hparams.N_importance > 0:
            self.nerf_fine = NeRF()                   ## in this model, adding the nerf layer.
            self.models += [self.nerf_fine]

## Because validation is to input image as order. Avoiding this different with our training, total points in one image: 504*378=190512, if batch_size is 4096, 190512%4096=2096, the final set is 2096 points. It's a big issue to our validation accuracy.

#### if input is whole image, the length of rays is equal to whole image, we divide it into several parts based on batch_size.

In [6]:
class NeRFSystem(LightningModule):
    ...
    def forward(self, rays):
        B = rays.shape[0]
        results = defaultdict(list)
        _index = None
        _FLAG = B == (self.hparams.img_wh[0] * self.hparams.img_wh[1])
        
        if _FLAG:
            # val part only val B will reach width * height.
            _index = np.arange(B, dtype=np.int32)
            np.random.shuffle(_index)
            
        for i in range(0, B, self.hparams.chunk):
            if _FLAG:
                if i+self.hparams.chunk > B:
                    _rays = rays[_index[i:B]]
                else:
                    _rays = rays[_index[i:i+self.hparams.chunk]]
            else:
                _rays = rays[i:i+self.hparams.chunk]
            rendered_ray_chunks = \
                render_rays(self.models,
                            self.embeddings,
                            _rays,
                            self.hparams.N_samples,
                            self.hparams.use_disp,
                            self.hparams.perturb,
                            self.hparams.noise_std,
                            self.hparams.N_importance,
                            self.hparams.chunk, # chunk size is effective in val mode
                            self.train_dataset.white_back)

            for k, v in rendered_ray_chunks.items():
                results[k] += [v]

        for k, v in results.items():
            results[k] = torch.cat(v, 0)
            if _FLAG:
                ## build the tuple
                results[k] = torch.tensor(list(map(lambda x:x[1].tolist(), sorted(zip(_index, results[k]),
                                                                     key=lambda t:t[0])))).cuda()
        return results
         
        

2096

## Case 1 try to solve this problem, 190512/3024=63, using 3024 as the batch_size try to get the performance.
---
### 1. python train.py --dataset_name llff --root_dir "/data/gpfs/projects/punim1006/zzk/dataset/nerf_llff_data/fern"   --N_importance 64 --img_wh "504" "378"   --num_epochs "3000" --batch_size 3024   --optimizer adam --lr 5e-4   --lr_scheduler cosine   --exp_name "fern_case"    --num_gpus 4    --chunk 3024
---
### 2. python train.py    --dataset_name llff    --root_dir "/data/gpfs/projects/punim1006/zzk/dataset/nerf_llff_data/flower"    --N_importance 64 --img_wh "504" "378"    --num_epochs "3000" --batch_size 3024    --optimizer adam --lr 5e-4    --lr_scheduler cosine    --exp_name "flower_l"    --num_gpus 4    --chunk 3024 
---
### 3. python train.py    --dataset_name llff    --root_dir "/data/gpfs/projects/punim1006/zzk/dataset/nerf_llff_data/fortress"    --N_importance 64 --img_wh "504" "378"    --num_epochs "3000" --batch_size 3024    --optimizer adam --lr 5e-4    --lr_scheduler cosine    --exp_name "fortress_l"    --num_gpus 4    --chunk 3024

---
## Adding the reference case:
---
### all are same, only MLP does not have GCN layer.