See my notebook [here](https://colab.research.google.com/drive/1MY6pk3vY7rrYal8oS6_s7zTkGN9lkHQr?usp=sharing) demonstrating how to use my code to train a NeRF model on the `tiny_nerf_data.npz` file used by the original NeRF authors in their notebook [here](https://colab.research.google.com/github/bmild/nerf/blob/master/tiny_nerf.ipynb).

In [1]:
!pip install -q matplotlib numpy torch natsort

In [2]:
!pip install -q -U einops datasets tqdm

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/41.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.6/41.6 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m468.7/468.7 kB[0m [31m15.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m27.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 kB[0m [31m12.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.2/212.2 kB[0m [31m20.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m200.1/200.1 kB[0m [31m21.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m132.9/132.9 kB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [3]:
import matplotlib.pyplot as plt
import numpy as np
import random
import torch
import os
from natsort import natsorted
from tqdm import tqdm

from torch import nn, optim, einsum

from einops import rearrange, reduce
from einops.layers.torch import Rearrange

import math
from einops import rearrange

In [4]:
class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

In [5]:
class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)

        self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), 
                                    nn.GroupNorm(1, dim))

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )

        q = q.softmax(dim=-2)
        k = k.softmax(dim=-1)

        q = q * self.scale
        context = torch.einsum("b h d n, b h e n -> b h d e", k, v)

        out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
        out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
        return self.to_out(out)

In [6]:
# ResNet18
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, down_sample=False, groups=1,
                 base_width=64, dilation=1, time_dim=None, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d

        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")

        self.scale_shift_mlp = (
            nn.Sequential(nn.SiLU(), nn.Linear(time_dim, planes * 2))
        )
                
        self.inplanes = inplanes
        self.planes = planes
        self.down_sample = down_sample

        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.conv1x1 = conv1x1(self.inplanes, self.planes, stride)
        

    def forward(self, x, time_emb):

        #scale_shift
        time_emb = self.scale_shift_mlp(time_emb)
        time_emb = rearrange(time_emb, "b c -> b c 1 1")
        scale_shift = time_emb.chunk(2, dim=1)


        identity = x

        out = self.conv1(x)
        out = self.bn1(out)

        scale, shift = scale_shift
        out = out * (scale + 1) + shift

        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.down_sample == True:
          identity = self.conv1x1(identity)
          identity = self.bn1(identity)

        #print(out.shape, identity.shape)
        out += identity #残差結合
        out = self.relu(out)

        return out

class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        super(ResNet, self).__init__()


        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self.norm_layer = norm_layer

        #入力チャンネル
        self.inplanes = 64
        self.dilation = 1
        self.time_dim = self.inplanes * 4

        #if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
        #    replace_stride_with_dilation = [False, False, False]
        #if len(replace_stride_with_dilation) != 3:
        #    raise ValueError("replace_stride_with_dilation should be None "
        #                     "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group

        #Time_Embedding
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(self.inplanes),
            nn.Linear(self.inplanes, self.time_dim),
            nn.GELU(),
            nn.Linear(self.time_dim, self.time_dim),
        )

        #畳み込み1
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        #Attention
        self.attn = LinearAttention(self.inplanes)

        #正則化
        self.bn1 = norm_layer(self.inplanes)

        #活性化
        self.relu = nn.ReLU(inplace=True)

        #プーリング
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        #ResNet ブロック
        self.layer1 = block(inplanes = self.inplanes, planes = 64, norm_layer = self.norm_layer, time_dim=self.time_dim)
            
        self.layer2 = block(inplanes = 64, planes = 128, stride=2, down_sample = True, norm_layer = self.norm_layer, time_dim=self.time_dim)

        self.layer3 = block(inplanes = 128, planes = 256, stride=2, down_sample = True, norm_layer = self.norm_layer, time_dim=self.time_dim)

        self.layer4 = block(inplanes = 256, planes = 512, stride=2, down_sample = True, norm_layer = self.norm_layer, time_dim=self.time_dim)
        
        #プーリング
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        #全結合
        self.fc = nn.Linear(512 * block.expansion, num_classes)
        # self.fc = nn.Linear(2048*4, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                # if isinstance(m, Bottleneck):
                #     nn.init.constant_(m.bn3.weight, 0)
                # elif isinstance(m, BasicBlock):
                nn.init.constant_(m.bn2.weight, 0)


    def _forward_impl(self, x, time):

        t = self.time_mlp(time)

        #print(t.shape, t)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.attn(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x, t)
        x = self.layer2(x, t)
        x = self.layer3(x, t)
        x = self.layer4(x, t)

        x = self.avgpool(x)
        x = torch.flatten(x)
        x = self.fc(x)

        return x

    def forward(self, x, time):
        return self._forward_impl(x, time)

In [7]:
class ImageEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        num_classes = 256
        cat_cam_mat_feats = num_classes + 16 # ビュー行列の要素数
        net_width = 256
        self.resnet18 = ResNet(block=BasicBlock, layers=[2, 2, 2, 2], num_classes=num_classes) # 出力は256次元
        self.cat_cam_mat_mlp = nn.Sequential(
            nn.Linear(cat_cam_mat_feats, net_width),
            nn.ReLU(),
            nn.Linear(net_width, net_width),
            nn.ReLU(),
            nn.Linear(net_width, net_width),
            nn.ReLU(),
            nn.Linear(net_width, net_width),
            nn.ReLU(),
        )

    def forward(self, images_t, poses, time):
        vt = []
        for i,img in enumerate(images_t):
            outputs_resnet = self.resnet18(img.unsqueeze(0), time)
            cat_cam_mat = torch.cat((outputs_resnet, torch.flatten(poses[i])), dim=-1)
            # print(torch.flatten(poses[i]).shape)
            outputs_cat_cam_mat = self.cat_cam_mat_mlp(cat_cam_mat) # これをすべての視点から集めたい
            vt.append(outputs_cat_cam_mat)
        vt = torch.stack(vt)
        return vt

In [8]:
# class StateEncoder(nn.Module):
#     def __init__(self):
#         super().__init__()
#         net_width = 256
#         self.vt_to_st = nn.Sequential(
#             nn.Linear(net_width, net_width),
#             nn.ReLU(),
#             nn.Linear(net_width, net_width),
#             nn.ReLU(),
#         )

#     def forward(self, vt):
#         vt_avg = torch.mean(vt, dim=0) # dim合ってる？
#         st = self.vt_to_st(vt_avg)
#         return nn.functional.normalize(st,dim=0,p=2)

In [38]:
class VeryTinyNeRFMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.L_pos = 6
        self.L_dir = 4
        num_classes = 256
        pos_enc_feats = 3 + 3 * 2 * self.L_pos + num_classes + 16
        dir_enc_feats = 3 + 3 * 2 * self.L_dir
        net_width = 256
        self.vt_to_st = nn.Sequential(
            nn.Linear(net_width, net_width),
            nn.ReLU(),
            nn.Linear(net_width, net_width),
            nn.ReLU(),
            nn.Linear(net_width, net_width),
            nn.ReLU(),
            nn.Linear(net_width, net_width),
            nn.ReLU(),
        )
        self.early_mlp = nn.Sequential(
            nn.Linear(net_width +32, net_width), 
            nn.ReLU(),
            nn.Linear(net_width, net_width + 1), 
            nn.ReLU(),
        )
        self.late_mlp = nn.Sequential(
            nn.Linear(net_width + 32, net_width),
            nn.ReLU(),
            nn.Linear(net_width, 3),
            nn.Sigmoid(),
        )

    def forward(self, vt, images_t, target_idx, poses, xs, ds):
        vt_avg = torch.mean(vt, dim=0) # dim合ってる？
        st = self.vt_to_st(vt_avg)
        
        xs_encoded = xs
        #for l_pos in range(self.L_pos):
        #    xs_encoded.append(torch.sin(2 ** l_pos * torch.pi * xs))
        #    xs_encoded.append(torch.cos(2 ** l_pos * torch.pi * xs))

        #xs_encoded = torch.cat(xs_encoded, dim=-1)

        #ds = ds / ds.norm(p=2, dim=-1).unsqueeze(-1) ここわからん
        ds_encoded = ds
        #for l_dir in range(self.L_dir):
        #    ds_encoded.append(torch.sin(2 ** l_dir * torch.pi * ds))
        #    ds_encoded.append(torch.cos(2 ** l_dir * torch.pi * ds))

        #ds_encoded = torch.cat(ds_encoded, dim=-1)
        #print(ds_encoded.shape)
        #print(st.unsqueeze(0).repeat(xs_encoded.shape[0], 1).shape)
        xs_encoded = torch.cat((xs_encoded, st.unsqueeze(0).repeat(xs_encoded.shape[0], 1)), dim=-1) # xs_encodedがどんな形してるか？
        
        outputs = self.early_mlp(xs_encoded)
        sigma_is = outputs[:, 0]
        c_is = self.late_mlp(torch.cat([outputs[:, 1:], ds_encoded], dim=-1))
        return {"c_is": c_is, "sigma_is": sigma_is, "st_is": st}

In [35]:
def get_coarse_query_points(ds, N_c, t_i_c_bin_edges, t_i_c_gap, os):
    u_is_c = torch.rand(*list(ds.shape[:2]) + [N_c]).to(ds)
    t_is_c = t_i_c_bin_edges + u_is_c * t_i_c_gap
    r_ts_c = os[..., None, :] + t_is_c[..., :, None] * ds[..., None, :]
    return (r_ts_c, t_is_c)


def render_radiance_volume(images_t, target_idx, poses, r_ts, ds, chunk_size, F_img_enc, F_dec, t_is, time, encoder):
    #print(r_ts.shape)
    rs_emb = encoder(r_ts.unsqueeze(3))
    #print('"here1')
    #print(rs_emb.shape)
    r_ts_flat = rs_emb.reshape((-1, 32))
    ds_rep = ds.unsqueeze(2).repeat(1, 1, r_ts.shape[-2], 1)
    ds_emb = encoder(ds_rep.unsqueeze(3))
    ds_flat = ds_emb.reshape((-1, 32))
    c_is = []
    sigma_is = []
    st_is = []
    for chunk_start in range(0, r_ts_flat.shape[0], chunk_size):
        r_ts_batch = r_ts_flat[chunk_start : chunk_start + chunk_size]
        print(r_ts_batch.shape)
        ds_batch = ds_flat[chunk_start : chunk_start + chunk_size]
        #print(ds_batch.shape)
        vt = F_img_enc(images_t, poses, time)
        # st = F_st_enc(vt)
        preds = F_dec(vt, images_t, target_idx, poses, r_ts_batch, ds_batch) # ここで実行してる？
        c_is.append(preds["c_is"])
        sigma_is.append(preds["sigma_is"])
        st_is.append(preds["st_is"])

    c_is = torch.cat(c_is).reshape(r_ts.shape)
    sigma_is = torch.cat(sigma_is).reshape(r_ts.shape[:-1])

    delta_is = t_is[..., 1:] - t_is[..., :-1]
    one_e_10 = torch.Tensor([1e10]).expand(delta_is[..., :1].shape)
    delta_is = torch.cat([delta_is, one_e_10.to(delta_is)], dim=-1)
    delta_is = delta_is * ds.norm(dim=-1).unsqueeze(-1)

    alpha_is = 1.0 - torch.exp(-sigma_is * delta_is)

    T_is = torch.cumprod(1.0 - alpha_is + 1e-10, -1)
    T_is = torch.roll(T_is, 1, -1)
    T_is[..., 0] = 1.0

    w_is = T_is * alpha_is

    C_rs = (w_is[..., None] * c_is).sum(dim=-2)

    return C_rs, vt, st_is


def run_one_iter_of_tiny_nerf(images_t, target_idx, poses, ds, N_c, t_i_c_bin_edges, t_i_c_gap, os, chunk_size, F_img_enc, F_dec, time, encoder):
    (r_ts_c, t_is_c) = get_coarse_query_points(ds, N_c, t_i_c_bin_edges, t_i_c_gap, os)
    C_rs_c, vt, st = render_radiance_volume(images_t, target_idx, poses, r_ts_c, ds, chunk_size, F_img_enc, F_dec, t_is_c, time, encoder) # ここで実行してる？
    return C_rs_c, vt, st

In [11]:
# class MSETwoView(nn.Module):
#     def __init__(self): # パラメータの設定など初期化処理を行う
#         super(MSETwoView, self).__init__()
        
#     def forward(self, C_rs_c, images_t_target):
#         print(images_t_target.shape)
#         loss = torch.mean((C_rs_c - images_t_target) ** 2, dim=0)
#         print(loss.shape)
#         return loss

In [32]:
from torch.nn import functional as F

class MultiresolutionHashEncoder3d(nn.Module):
    def __init__(self, l=16, t=2**14, f=2, n_min=16, n_max=100, interpolation='trilinear'):
        super().__init__()
        self.l = l
        self.t = t
        self.f = f
        self.interpolation = interpolation

        b = math.exp((math.log(n_max) - math.log(n_min)) / (l - 1))
        self.ns = [int(n_min * (b ** i)) for i in range(l)]

        # Prime Numbers from https://github.com/NVlabs/tiny-cuda-nn/blob/ee585fa47e99de4c26f6ae88be7bcb82b9295310/include/tiny-cuda-nn/encodings/grid.h
        self.register_buffer('primes', torch.tensor([1, 2654435761, 805459861]))
        
        self.hash_table = nn.Parameter(
            torch.rand([l, t, f], requires_grad=True) * 2e-4 - 1e-4)

    @property
    def encoded_vector_size(self):
        return self.l * self.f
        
    def forward(self, x):
        x = x.permute(3, 4, 0, 1, 2)
        b, c, h, w, d = x.size()

        def make_grid(x, n):
            g = F.max_pool3d(x * n, (h // n, w // n, d // n)).to(dtype=torch.long)
            #print(g.shape)
            g = g * self.primes.view([3, 1, 1, 1])
            g = (g[:,0] ^ g[:,1] ^ g[:,2]) % self.t
            #print(g.shape)
            return g

        grids = [make_grid(x, n) for n in self.ns]
        #print(len(grids))
        features = [self.hash_table[i, g].permute(0, 4, 1, 2, 3)
                    for i, g in enumerate(grids)]
        feature_map = torch.hstack([
            F.interpolate(f, (h, w, d), mode=self.interpolation)
            for f in features
        ]) 
        #print(feature_map.shape)

        return feature_map

In [13]:
class Dynamics(nn.Module):
    def __init__(self):
        super().__init__()
        net_width = 256
        self.dynamics = nn.Sequential(
            nn.Linear(net_width, net_width),
            nn.ReLU(),
            nn.Linear(net_width, net_width),
            nn.ReLU(),
            nn.Linear(net_width, net_width),
            nn.ReLU(),
            nn.Linear(net_width, net_width),
            nn.ReLU(),
        )

    def forward(self, vt):
        vt_avg = torch.mean(vt, dim=0) # dim合ってる？
        st = self.vt_to_st(vt_avg)
        return nn.functional.normalize(st,dim=0,p=2)

In [14]:
#class Position_Emb(nn.Module):
#    def __init__(self, encoder, num_planes=64, num_layers=2):
#        super().__init__()
#        self.enc = encoder
        
        # 1x1 convolution is equivalent to MLP for a point in the 2D-coordinates
#        layers = [nn.Conv3d(encoder.encoded_vector_size, num_planes, 1)]
#        for _ in range(num_layers - 2):
#            layers += [nn.ReLU(),
#                       nn.Conv3d(num_planes, num_planes, 1)]
#        layers += [nn.ReLU(),
#                   nn.Conv3d(num_planes, 3, 1), #ここあってるかわかんない
#                   nn.Sigmoid()]
#        self.mlp = nn.Sequential(*layers)

#    def forward(self, x):
#        feature = self.enc(x)
#        #print("here")
#        out = self.mlp(feature)
#        #print(out.shape)
#        return out

In [15]:
class TimeContrastiveLoss(nn.Module):
    def __init__(self): # パラメータの設定など初期化処理を行う
        super(TimeContrastiveLoss, self).__init__()
        
    def forward(self, vt_i, vt_j, vt_cont_i, alpha=0):
        loss = torch.norm(vt_i - vt_j, 2)**2 + alpha
         # - torch.norm(vt_i - vt_cont_i, 2)**2 
        return loss

In [16]:
!ls

sample_data  train_2box.npz


In [39]:
seed = 9458
torch.manual_seed(seed)
np.random.seed(seed)

device = "cuda:0"
F_image_enc = ImageEncoder().to(device)
F_dec = VeryTinyNeRFMLP().to(device)
encoder = MultiresolutionHashEncoder3d().to(device)
chunk_size = 16384

lr = 5e-4
optimizer_image_enc = optim.Adam(F_image_enc.parameters(), lr=lr)
optimizer_dec = optim.Adam(F_dec.parameters(), lr=lr)
optimizer_emb = optim.Adam(encoder.parameters(), lr=lr)
criterion = nn.MSELoss()
# criterion_TC = TimeContrastiveLoss()

data_f = "train_2box.npz"
data = np.load(data_f)
#test_data_f = "test_2box.npz"
#test_data = np.load(test_data_f)

images = data["images"][:20] / 255
#test_images = test_data["images"][:20] / 255
images_t0 = images[0]
img_size = images_t0.shape[1]
xs = torch.arange(img_size) - (img_size / 2 - 0.5)
ys = torch.arange(img_size) - (img_size / 2 - 0.5)
(xs, ys) = torch.meshgrid(xs, -ys, indexing="xy")
focal = float(data["focal"])
pixel_coords = torch.stack([xs, ys, torch.full_like(xs, -focal)], dim=-1)
camera_coords = pixel_coords / focal
init_ds = camera_coords.to(device)
init_o = torch.Tensor(np.array([0, 0, float(data["cam_dist"])])).to(device)

poses = torch.Tensor(data["poses"].reshape((-1,4,4))).to(device)
#test_poses = torch.Tensor(test_data["poses"].reshape((-1,4,4))).to(device)

t_n = 2.0 # near
t_f = 10.0 # far
N_c = 100
t_i_c_gap = (t_f - t_n) / N_c
t_i_c_bin_edges = (t_n + torch.arange(N_c) * t_i_c_gap).to(device)

psnrs = []
iternums = []
F_image_enc.train()
# F_state_enc.train()
F_dec.train()

VeryTinyNeRFMLP(
  (vt_to_st): Sequential(
    (0): Linear(in_features=256, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=256, bias=True)
    (5): ReLU()
    (6): Linear(in_features=256, out_features=256, bias=True)
    (7): ReLU()
  )
  (early_mlp): Sequential(
    (0): Linear(in_features=288, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=257, bias=True)
    (3): ReLU()
  )
  (late_mlp): Sequential(
    (0): Linear(in_features=288, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=3, bias=True)
    (3): Sigmoid()
  )
)

In [18]:
# # ImageEncoderの学習 -> 要素がすべて0になる
# num_iters_image_enc = 100
# lr = 5e-4
# F_image_enc = ImageEncoder().to(device)
# optimizer_image_enc = optim.Adam(F_image_enc.parameters(), lr=lr)
# # optimizer_image_enc = optim.SGD(F_image_enc.parameters(), lr=lr)
# # 100 step後に5e-5になるように
# scheduler = optim.lr_scheduler.ExponentialLR(optimizer_image_enc, gamma=10**(-0.01))
# criterion_TC = TimeContrastiveLoss()

# history  = []
# for iter in range(num_iters_image_enc):
#     for t in range(1,images.shape[0]):
#         t_cont = random.randint(0, t-1)
#         images_t = torch.Tensor(images[t])
#         images_t_cont = torch.Tensor(images[t_cont])
#         time = torch.Tensor([t]).to(device)
#         time_cont = torch.Tensor([t_cont]).to(device)
#         while True:
#             i = random.randint(0, images_t.shape[0]-1)
#             j = random.randint(0, images_t.shape[0]-1)
#             if i != j:
#                 break
#         vt_i = F_image_enc(images_t[i].unsqueeze(0).permute(0,3,1,2).to(device), poses, time)
#         vt_j = F_image_enc(images_t[j].unsqueeze(0).permute(0,3,1,2).to(device), poses, time)
#         vt_cont_i = F_image_enc(images_t_cont[i].unsqueeze(0).permute(0,3,1,2).to(device), poses, time_cont)
#         loss_image_enc = criterion_TC(vt_i, vt_j, vt_cont_i, alpha=0)
#         if loss_image_enc > 0:
#             optimizer_image_enc.zero_grad()
#             loss_image_enc.backward()
#             optimizer_image_enc.step()
#         # print(f'{iter}, {t}, {t_cont}, Time Contrastive Loss: {loss_image_enc}')
        
#         history.append(loss_image_enc.item())
        
#     scheduler.step()
#         # print(vt_i)
        
# plt.plot(history)
# #plt.ylim(0, 0.1);

In [40]:
# for file_name in natsorted(os.listdir('dataset')): # dataset分
# data_f = "test (4).npz"
# data = np.load(data_f)
# images = data["images"] / 255
file_name = 'best2.cpt'
min_loss_image_enc = np.inf

lr = 5e-3
# F_image_enc = ImageEncoder().to(device)
# optimizer_image_enc = optim.Adam(F_image_enc.parameters(), lr=lr)
criterion_TC = TimeContrastiveLoss()
#モデルの定義
# F_dec = VeryTinyNeRFMLP().to(device)
# optimizer_dec = optim.Adam(F_dec.parameters(), lr=lr)
criterion = nn.MSELoss()

num_loop = 10
scheduler = optim.lr_scheduler.ExponentialLR(optimizer_image_enc, gamma=10**(-0.002)) # 1000回で0.01倍
scheduler_dec = optim.lr_scheduler.ExponentialLR(optimizer_dec, gamma=10**(-0.1))
num_iters = 100
display_every = 25


for loop in tqdm(range(num_loop)): 
    poses = torch.Tensor(poses).to(device) 
    for i in tqdm(range(num_iters)):

        t = random.randint(0, images.shape[0]-1) #ここあってるかわかんない
        images_t = torch.Tensor(images[t]).to(device)
        time = torch.Tensor([t]).to(device)

        target_img_idx = np.random.randint(images_t.shape[0]-1)
        target_pose = poses[target_img_idx].to(device)
        R = target_pose[:3, :3]
        ds = torch.einsum("ij,hwj->hwi", R, init_ds)
        os = (R @ init_o).expand(ds.shape)

        C_rs_c,_, _ = run_one_iter_of_tiny_nerf(
          images_t.permute(0,3,1,2).to(device), target_img_idx, poses, ds, N_c, t_i_c_bin_edges, t_i_c_gap, os, chunk_size, F_image_enc, F_dec, time, encoder
        )

        loss_dec = criterion(C_rs_c, images_t[target_img_idx])
        optimizer_dec.zero_grad()
        optimizer_emb.zero_grad()
        loss_dec.backward(retain_graph=True)
        optimizer_dec.step()
        optimizer_emb.step()


        if (i+1) % display_every == 0:
            print(f"Loop: {loop}")
            print(f"Time Step: {t}")
            print(f"Iter: {i}")
            print(f"Loss: {loss_dec.item()}")
            plt.figure(figsize=(4, 4))
            plt.imshow(C_rs_c.detach().cpu().numpy())
            plt.show()

        for t in range(1,images.shape[0]):
            t_cont = random.randint(0, t-1)
            images_t = torch.Tensor(images[t])
            images_t_cont = torch.Tensor(images[t_cont])
            time = torch.Tensor([t]).to(device)
            time_cont = torch.Tensor([t_cont]).to(device)
            while True:
                i = random.randint(0, images_t.shape[0]-1)
                j = random.randint(0, images_t.shape[0]-1)
                if i != j:
                    break
            vt_i = F_image_enc(images_t[i].unsqueeze(0).permute(0,3,1,2).to(device), poses, time)
            vt_j = F_image_enc(images_t[j].unsqueeze(0).permute(0,3,1,2).to(device), poses, time)
            vt_cont_i = F_image_enc(images_t_cont[i].unsqueeze(0).permute(0,3,1,2).to(device), poses, time_cont)
            loss_image_enc = criterion_TC(vt_i, vt_j, vt_cont_i, alpha=0)
            loss_image_enc += loss_dec #loss_decはノルム取るべき？
            if loss_image_enc > 0:
                optimizer_image_enc.zero_grad()
                loss_image_enc.backward(retain_graph=True)
                optimizer_image_enc.step()
            if loss_image_enc < min_loss_image_enc:
                # チェックポイント保存
                torch.save({
                    'loop': loop,
                    'iter': i,
                    'time': t,
                    'image_encoder_state_dict': F_image_enc.state_dict(),
                    'decoder_state_dict': F_dec.state_dict(),
                    'opt_img_enc_state_dict': optimizer_image_enc.state_dict(),
                    'opt_dec_state_dict': optimizer_dec.state_dict(),
                    'loss': loss_image_enc,
                }, file_name)
                min_loss_image_enc = loss_image_enc
            if t == images.shape[0]:
                print(f'{iter}, {t}, {t_cont}, Time Contrastive Loss: {loss_image_enc}')
        print(scheduler.get_lr()[0])
        scheduler.step()
    
    scheduler_dec.step()

  0%|          | 0/10 [00:00<?, ?it/s]
  0%|          | 0/100 [00:00<?, ?it/s][A

torch.Size([16384, 32])
torch.Size([16384, 32])
torch.Size([16384, 32])
torch.Size([16384, 32])
torch.Size([16384, 32])
torch.Size([16384, 32])
torch.Size([16384, 32])
torch.Size([16384, 32])
torch.Size([16384, 32])
torch.Size([16384, 32])
torch.Size([16384, 32])
torch.Size([16384, 32])
torch.Size([16384, 32])
torch.Size([16384, 32])
torch.Size([16384, 32])
torch.Size([16384, 32])
torch.Size([16384, 32])
torch.Size([16384, 32])
torch.Size([16384, 32])
torch.Size([16384, 32])
torch.Size([16384, 32])
torch.Size([16384, 32])
torch.Size([16384, 32])
torch.Size([16384, 32])
torch.Size([16384, 32])
torch.Size([16384, 32])
torch.Size([16384, 32])
torch.Size([16384, 32])
torch.Size([16384, 32])
torch.Size([16384, 32])
torch.Size([16384, 32])
torch.Size([16384, 32])
torch.Size([16384, 32])
torch.Size([16384, 32])
torch.Size([16384, 32])
torch.Size([16384, 32])
torch.Size([16384, 32])


  0%|          | 0/100 [00:04<?, ?it/s]
  0%|          | 0/10 [00:04<?, ?it/s]


OutOfMemoryError: ignored

In [None]:
# 新規視点の描画
test_data_f = "test_2box.npz"
#test_data = np.load(test_data_f)
#test_images = test_data["images"][:20] / 255
test_poses = torch.Tensor(test_data["poses"].reshape((-1,4,4))).to(device)
F_dec_out_st = VeryTinyNeRFMLP().to(device)
cpt = torch.load('out_2box.cpt')
F_dec_out_st.load_state_dict(cpt['decoder_state_dict'])
F_image_enc.eval()
F_dec_out_st.eval()
for t in range(test_images.shape[0]):
    for test_idx in range(test_images.shape[1]):
        test_R = torch.Tensor(test_poses[test_idx, :3, :3]).to(device)
        test_ds = torch.einsum("ij,hwj->hwi", test_R, init_ds)
        test_os = (test_R @ init_o).expand(test_ds.shape)
        test_img = torch.Tensor(test_images[t][test_idx]).to(device)
        test_pose = torch.Tensor(test_poses[test_idx]).unsqueeze(0).to(device)
        with torch.no_grad():
            C_rs_c_test, vt_test, st = run_one_iter_of_tiny_nerf(
                test_img.unsqueeze(0).permute(0,3,1,2), target_img_idx, test_pose, test_ds, N_c, t_i_c_bin_edges, t_i_c_gap, test_os, chunk_size, F_image_enc, F_dec_out_st
            )
        loss_dec = criterion(C_rs_c_test, test_img)
        print(f"Time: {t}")
        print(f"View: {test_idx}")
        print(f"Loss: {loss_dec.item()}")
        print(f"st: {st}")
        plt.figure(figsize=(10, 4))
        plt.subplot(121)
        plt.imshow(C_rs_c.detach().cpu().numpy())
        plt.subplot(122)
        plt.imshow(test_img.detach().cpu().numpy())
        plt.show()

In [None]:
test_ds.shape

In [None]:
test_images.shape

In [None]:
file_name = 'out_2box.cpt'
torch.save({
'time': 0,
'loop': 4,
'image_encoder_state_dict': F_image_enc.state_dict(), # vtまでのencoder
'decoder_state_dict': F_dec.state_dict(),
'opt_img_enc_state_dict': optimizer_image_enc.state_dict(),
'opt_dec_state_dict': optimizer_dec.state_dict(),
'loss': loss_dec,
}, file_name)

In [None]:
plt.imshow(images[20][19])