In [3]:

import sys, os
# point Python at the cloned repo
sys.path.append(os.path.abspath("nerfw-pytorch"))

import torch
from torch.utils.data import DataLoader

# options parser
from option.nerf_option import NeRFOption
# dataset (for Cambridge; swap for 7‑Scenes as needed)
from dataset.cambridge import CambridgeDataset
# the NeRFW system
from model.nerfw_system import NeRFWSystem

In [4]:

# Modified second cell
sys.argv = [""] + [
    # ——— Training / experiment setup ———
    "--is_train",      "True",                       # train mode (False for inference)
    "--root_dir",      "./runs/nerf",                # where to write logs & checkpoints
    "--exp_name",      "my_experiment",              # subfolder under root_dir for this run

    # ——— Data loading ———
    "--data_root_dir", "./data/Cambridge",           # top‑level data folder
    "--scene",         "KingsCollege",               # which scene subfolder to use
    "--img_downscale", "3",                          # downsample factor for images

    # ——— Batch & sampling ———
    "--batch_size",    "1024",                       # rays per optimization step
    "--chunk",         "4096",                       # max rays per network forward
    "--N_c",           "64",                         # # coarse samples per ray
    "--N_f",           "128",                        # # fine samples per ray
    "--perturb",       "1.0",                        # stratified sampling noise

    # ——— Model & regularization ———
    "--encode_a",      "True",                       # enable appearance embedding
    "--encode_t",      "True",                       # enable transient embedding
    "--a_dim",         "48",                         # appearance embedding size
    "--t_dim",         "16",                         # transient embedding size
    "--beta_min",      "0.1",                        # min transient variance
    "--lambda_u",      "0.01",                       # weight on transient regularizer

    # ——— Optimization ———
    "--lr",            "5e-4",                       # initial learning rate
    "--epochs",        "20",                         # number of training epochs
    "--num_gpus",      "1",                          # number of GPUs (DDP if >1)

    # ——— Caching ———
    "--use_cache",     False,                      # build ray cache?
    "--if_save_cache", "True",                       # save ray cache for next runs

    # ——— Checkpointing & logging ———
    "--save_latest_freq", "1000",                    # iters between "latest" saves
    # Removed --save_epoch_freq as it's not in the original options

    # ——— (Inference only) ———
    "--ckpt_path",     "checkpoints/nerfw_epoch20.pth",  # Changed --ckpt to --ckpt_path
    "--last_epoch",    "0"                          # epoch index corresponding to ckpt
]

opt = NeRFOption().into_opt()

In [5]:
print(opt.use_cache)
print(opt.if_save_cache)

False
True


In [6]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using M2 GPU via Metal!")
else:
    device = torch.device("cpu")
    print("MPS device not found, using CPU!")


Using M2 GPU via Metal!


In [10]:
from dataset.cambridge import CambridgeDataset  # Note: it's cambridge.py not cambridge_dataset.py
from torch.utils.data import DataLoader
from model.nerfw_system import NeRFWSystem

# 1. build dataloader
train_dataset = CambridgeDataset(
    root_dir=opt.data_root_dir,
    scene=opt.scene,
    split='train',
    img_downscale=opt.img_downscale,
    use_cache=True,
    if_save_cache=opt.if_save_cache
)

train_loader = DataLoader(
    train_dataset,
    batch_size=opt.batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

# 2. init model & optimizer
model = NeRFWSystem(
    N_views=len(train_dataset.train_set),  # Number of training views
    N_c=opt.N_c,  # Number of coarse samples
    N_f=opt.N_f,  # Number of fine samples
    use_disp=opt.use_disp,
    perturb=opt.perturb,
    layers=opt.layers,
    W=opt.W,
    N_xyz_freq=opt.N_xyz_freq,
    N_dir_freq=opt.N_dir_freq,
    encode_a=opt.encode_a,
    encode_t=opt.encode_t,
    a_dim=opt.a_dim,
    t_dim=opt.t_dim,
    res_layer=[4],  # Default value from the code
    device=device,
    beta_min=opt.beta_min,
    lambda_u=opt.lambda_u,
    white_back=False  # Default value
)
# if torch.cuda.is_available():
#     model = model.cuda()
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)

# 3. training loop
for epoch in range(opt.epochs):
    model.train()
    for batch_idx, rays in enumerate(train_loader):
        # if torch.cuda.is_available():
        #     rays = rays.cuda()
        rays = rays.to(device)


        optimizer.zero_grad()

        # # The rays tensor contains all information:
        # # [rays_o, rays_d, near, far, view_id, rgb]
        # rays_o = rays[..., :3]
        # rays_d = rays[..., 3:6]
        # near = rays[..., 6:7]
        # far = rays[..., 7:8]
        # view_id = rays[..., 8:9]
        # target = rays[..., 9:12]  # RGB values

        # outputs = model(rays_o, rays_d, near, far, view_id)
        # loss = model.get_loss(outputs, target)
        # Pass the rays directly to the model

        rays_input = rays[..., :9]  # First 9 dimensions [position, direction, near, far, id]
        target = rays[..., 9:12]    # Last 3 dimensions [RGB]

        optimizer.zero_grad()

        res_c, res_f, losses = model(rays_input, target, cal_loss=True)
        loss = losses['coarse'] + losses['fine']
        if losses['fine_regular'] is not None:
            loss += losses['fine_regular']

        loss.backward()
        optimizer.step()

        if batch_idx % opt.save_latest_freq == 0:
            print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}')

    # Save checkpoint at the end of each epoch
    if not os.path.exists(os.path.join(opt.root_dir, opt.exp_name)):
        os.makedirs(os.path.join(opt.root_dir, opt.exp_name))

    checkpoint_path = os.path.join(opt.root_dir, opt.exp_name, f'epoch_{epoch}.pth')
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, checkpoint_path)

load reconstruction data of scene "KingsCollege", split: train
all rays:  torch.Size([281088000, 12])
cache load done...
Epoch: 0, Batch: 0, Loss: 2.5382
Epoch: 0, Batch: 1000, Loss: 1.6158
Epoch: 0, Batch: 2000, Loss: 1.4260
Epoch: 0, Batch: 3000, Loss: 1.2899
Epoch: 0, Batch: 4000, Loss: 1.2964
Epoch: 0, Batch: 5000, Loss: 1.1919
Epoch: 0, Batch: 6000, Loss: 1.1758
Epoch: 0, Batch: 7000, Loss: 1.1442
Epoch: 0, Batch: 8000, Loss: 1.1713
Epoch: 0, Batch: 9000, Loss: 1.1829
Epoch: 0, Batch: 10000, Loss: 1.1302
Epoch: 0, Batch: 11000, Loss: 1.1570
Epoch: 0, Batch: 12000, Loss: 1.1639
Epoch: 0, Batch: 13000, Loss: 1.1263
Epoch: 0, Batch: 14000, Loss: 1.1477
Epoch: 0, Batch: 15000, Loss: 1.1532
Epoch: 0, Batch: 16000, Loss: 1.1172
Epoch: 0, Batch: 17000, Loss: 1.0882
Epoch: 0, Batch: 18000, Loss: 1.1093
Epoch: 0, Batch: 19000, Loss: 1.0841
Epoch: 0, Batch: 20000, Loss: 1.0808
Epoch: 0, Batch: 21000, Loss: 1.1059
Epoch: 0, Batch: 22000, Loss: 1.0822
Epoch: 0, Batch: 23000, Loss: 1.0474
Epoc

KeyboardInterrupt: 