In [1]:
!pip install SimpleITK
!pip install wandb

Collecting SimpleITK
  Downloading SimpleITK-2.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (52.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.7/52.7 MB[0m [31m28.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: SimpleITK
Successfully installed SimpleITK-2.2.1
Collecting wandb
  Downloading wandb-0.15.5-py3-none-any.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
Collecting GitPython!=3.1.29,>=1.0.0 (from wandb)
  Downloading GitPython-3.1.32-py3-none-any.whl (188 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m188.5/188.5 kB[0m [31m13.2 MB/s[0m eta [36m0:00:00[0m
Collecting sentry-sdk>=1.0.0 (from wandb)
  Downloading sentry_sdk-1.28.1-py2.py3-none-any.whl (214 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m214.7/214.7 kB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting docker-pycreds>

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import random
import wandb
import os
import sys
import yaml
from pprint import pprint
sys.path.append("/content/drive/MyDrive/Colab Notebooks/ETHZ/")
from models.INR import NeRF
from models.dataset import RayDataset
from models.train import train_step, train_log, train_backward
from models.val import val_step, val_log

In [3]:
def load_model(args, model, optimizer, scheduler, device):
    ckpts = torch.load(args.ckpt_path, map_location=device)
    model.load_state_dict(ckpts["model"])
    start_epoch = 0
    if True:
        optimizer.load_state_dict(ckpts["opt"])
        scheduler.load_state_dict(ckpts["sche"])
        start_epoch = ckpts["epoch"]
    return start_epoch


def save_model(save_path, epoch, model, optimizer, scheduler):
    path = os.path.join(save_path, "{}.pth".format(epoch))
    torch.save(
        {
            "epoch": epoch,
            "model": model.state_dict(),
            "opt": optimizer.state_dict(),
            "sche": scheduler.state_dict(),
        },
        path,
    )
    print("Saved checkpoints at {}".format(path))

In [4]:
m_seed = 999
random.seed(m_seed)
torch.manual_seed(m_seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [5]:
# save_path = "check_NVF"
with open("/content/drive/MyDrive/Colab Notebooks/ETHZ/main/config.yaml") as file:
    config = yaml.load(file, Loader=yaml.FullLoader)
wandb.login()
wandb.init(config=config, entity="dl_prac", project="img_register", name="NVF_MAE_case02")

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mzangqb[0m ([33mdl_prac[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [6]:
args = wandb.config
args

{'fixed_vol_path': '/content/drive/MyDrive/Colab Notebooks/ETHZ/data/case02/case2_T00-ssm.mha', 'min_val': 0, 'std_val': 4247, 'num_repeat': 1, 'moving_vol_path': '/content/drive/MyDrive/Colab Notebooks/ETHZ/data/case02/case2_T50-ssm.mha', 'netdepth': 8, 'netwidth': 256, 'multires': 10, 'num_epochs': 24, 'batch_size': 2048, 'lr': 0.00012, 'lambda_jacob': 0.05, 'num_step_opt': 1, 'num_step_log': 10, 'num_step_val': 400, 'num_val_planes': 4, 'chunk_size': 32768, 'num_epoch_save': 1, 'ckpt_nvf_path': '/content/drive/MyDrive/Colab Notebooks/ETHZ/check/case02/case02.pth', 'save_path': '/content/drive/MyDrive/Colab Notebooks/ETHZ/check/case02/'}

In [7]:
dataset = RayDataset(args)
dataloader = DataLoader(
    dataset, batch_size=None, shuffle=True, pin_memory=True
)

model = NeRF(args, args.netdepth, args.netwidth, multi_res=args.multires).to(device)
model

NeRF(
  (embedder): Embedder()
  (pts_linears): ModuleList(
    (0): SineLayer(
      (linear): Linear(in_features=80, out_features=256, bias=True)
    )
    (1-4): 4 x SineLayer(
      (linear): Linear(in_features=256, out_features=256, bias=True)
    )
    (5): SineLayer(
      (linear): Linear(in_features=336, out_features=256, bias=True)
    )
    (6-7): 2 x SineLayer(
      (linear): Linear(in_features=256, out_features=256, bias=True)
    )
  )
  (output_linear): Linear(in_features=256, out_features=4, bias=False)
)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.1 ** (1 / 300))

start_epoch = 0
step = 1
metrics = dict()
for epoch in range(start_epoch, args.num_epochs + 1):
    dataset.shuffle()
    optimizer.zero_grad()
    for data in dataloader:
        metrics = train_step(args, model, data, optimizer, metrics, device)
        if step % args.num_step_opt == 0:
            # value = optimizer.param_groups[0]["lr"]
            # if value < 2e-5:
            value = 5e-5
            nn.utils.clip_grad_norm_(model.parameters(), value)
            optimizer.step()
            optimizer.zero_grad()

        if step % args.num_step_log == 0:
            metrics = train_log(args, epoch, step, metrics)
        if step % args.num_step_val == 0:
            with torch.no_grad():
                val_step(args, dataset, model, step, device)
        if step % 100 == 0:
            # max_norm
            scheduler.step()

        step += 1
    if epoch % args.num_epoch_save == 0:
        save_model(args.save_path, epoch, model, optimizer, scheduler)
wandb.finish()

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
==Train== epoch: 4, step: 14810, MAE: 18.701004, reg_jacob: 16.361048, off_mean: 2.006481, off_max: 35.833048, off_std: 4.270611, lr: 0.000039
==Train== epoch: 4, step: 14820, MAE: 18.863953, reg_jacob: 16.485441, off_mean: 2.035780, off_max: 35.362096, off_std: 4.317837, lr: 0.000039
==Train== epoch: 4, step: 14830, MAE: 18.447352, reg_jacob: 15.832932, off_mean: 2.022568, off_max: 36.317522, off_std: 4.317720, lr: 0.000039
==Train== epoch: 4, step: 14840, MAE: 18.470172, reg_jacob: 15.876800, off_mean: 1.980399, off_max: 35.786806, off_std: 4.258821, lr: 0.000039
==Train== epoch: 4, step: 14850, MAE: 18.100459, reg_jacob: 16.023791, off_mean: 1.998850, off_max: 35.598735, off_std: 4.270772, lr: 0.000039
==Train== epoch: 4, step: 14860, MAE: 18.200073, reg_jacob: 16.586919, off_mean: 2.010380, off_max: 35.645032, off_std: 4.287445, lr: 0.000039
==Train== epoch: 4, step: 14870, MAE: 18.545682, reg_jacob: 15.614742, off_me