In [1]:
import torch
from data_module import DataModule
from utility.utility import get_args
import numpy as np
from models.nodf import NODF
import nibabel as nib
from models.posterior import FVRF
from utility.utility import get_mask

%load_ext autoreload
%autoreload 2

## Args

In [2]:
args = get_args(cmd=False)
# TODO: modify arguments here if needed
args.ckpt_path = 'output/paper_results/data_stanford/ours_14lvls20hm_4embedsize/hashenc/training/nodf/version_6/checkpoints/epoch=2999-step=90000.ckpt'
# args.n_levels = 2

## Data

In [3]:
data_module = DataModule(args)
data_module.setup("fit")

Using precomputed signal from data/subjects/processedDWI_session1_subset01/train_signal.pt


In [4]:
dataset = data_module.dataset
dataloader = data_module.train_dataloader()
coords = dataloader.dataset.coords
coords.shape

torch.Size([3895192, 3])

In [5]:
batch = next(iter(dataloader))
batch

{'coords': tensor([[0.4570, 0.4425, 0.2147],
         [0.2199, 0.5052, 0.5131],
         [0.3505, 0.6411, 0.7696],
         ...,
         [0.2509, 0.6132, 0.4188],
         [0.6220, 0.4286, 0.4503],
         [0.7698, 0.6202, 0.4084]]),
 'signal': tensor([[0.4426, 0.9191, 0.4784,  ..., 0.4147, 0.4926, 0.5835],
         [0.1539, 0.1949, 0.1098,  ..., 0.1488, 0.1360, 0.1820],
         [0.0970, 0.1862, 0.2132,  ..., 0.1745, 0.2411, 0.1491],
         ...,
         [0.2104, 0.2993, 0.4312,  ..., 0.4117, 0.4690, 0.4523],
         [0.3036, 0.2440, 0.1450,  ..., 0.0394, 0.3649, 0.0479],
         [0.3723, 0.4193, 0.4604,  ..., 0.5558, 0.5283, 0.4778]])}

## Model

In [6]:
if args.ckpt_path:
    print("Loading model from checkpoint")
    model = NODF.load_from_checkpoint(args.ckpt_path).cpu()
else:
    model = NODF(args)

Loading model from checkpoint


In [7]:
model

NODF(
  (inr): INR(
    (net): Sequential(
      (0): SineLayer(
        (linear): Linear(in_features=59, out_features=64, bias=True)
      )
      (1): SineLayer(
        (linear): Linear(in_features=64, out_features=64, bias=True)
      )
      (2): SineLayer(
        (linear): Linear(in_features=64, out_features=64, bias=True)
      )
      (3): Linear(in_features=64, out_features=45, bias=False)
    )
  )
  (hash_embedder): HashEmbedder(
    (embeddings): ModuleList(
      (0-13): 14 x Embedding(1048576, 4)
    )
  )
)

In [8]:
model.count_parameters()

inr.net.0.linear.weight: 3776
inr.net.0.linear.bias: 64
inr.net.1.linear.weight: 4096
inr.net.1.linear.bias: 64
inr.net.2.linear.weight: 4096
inr.net.2.linear.bias: 64
inr.net.3.weight: 2880
hash_embedder.embeddings.0.weight: 4194304
hash_embedder.embeddings.1.weight: 4194304
hash_embedder.embeddings.2.weight: 4194304
hash_embedder.embeddings.3.weight: 4194304
hash_embedder.embeddings.4.weight: 4194304
hash_embedder.embeddings.5.weight: 4194304
hash_embedder.embeddings.6.weight: 4194304
hash_embedder.embeddings.7.weight: 4194304
hash_embedder.embeddings.8.weight: 4194304
hash_embedder.embeddings.9.weight: 4194304
hash_embedder.embeddings.10.weight: 4194304
hash_embedder.embeddings.11.weight: 4194304
hash_embedder.embeddings.12.weight: 4194304
hash_embedder.embeddings.13.weight: 4194304
Total Trainable Params: 58735296


58735296

## Forward Pass

In [9]:
# ODF coefficients
chat = model(batch)
chat.shape

torch.Size([65536, 45])

## Posterior

In [17]:
# posterior = FVRF(args)

Using precomputed signal from data/subjects/processedDWI_session1_subset01/train_signal.pt
Using saved pointwise_estimates.pt from output/paper_results/data_stanford/ours_14lvls20hm_4embedsize/hashenc/prediction/pointwise_estimates.pt
Using saved basis_pointwise_estimates.pt from output/paper_results/data_stanford/ours_14lvls20hm_4embedsize/hashenc/prediction/basis_pointwise_estimates.pt
Using saved vec_W_post_mean.pt and vec_W_post_cov.pt from output/paper_results/data_stanford/ours_14lvls20hm_4embedsize/hashenc/prediction/vec_W_post_mean.pt and output/paper_results/data_stanford/ours_14lvls20hm_4embedsize/hashenc/prediction/vec_W_post_cov.pt


In [11]:
# get roi
mask = get_mask(args)
# axial
mask[:168] = False
mask[169:] = False
# sagittal
mask[:, :74] = False
mask[:, 88:] = False
# coronal
mask[:, :, :67] = False
mask[:, :, 85:] = False

In [12]:
# generate posterior samples
# post_samples_chat = posterior.sample_posterior(mask)