# Preparation

1. `git clone https://github.com/yang-song/score_sde_pytorch.git`

2. Install [required packages](https://github.com/yang-song/score_sde_pytorch/blob/main/requirements.txt)

3. `cd` into folder `score_sde_pytorch`, launch a local jupyter server and connect to colab following [these instructions](https://research.google.com/colaboratory/local-runtimes.html)

4. Download pre-trained [checkpoints](https://drive.google.com/drive/folders/1tFmF_uh57O6lx9ggtZT_5LdonVK2cV-e?usp=sharing) and save them in the `exp` folder.

In [1]:
#@title Autoload all modules
%load_ext autoreload
%autoreload 2
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
from dataclasses import dataclass, field
import matplotlib.pyplot as plt
import io
import csv
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib
import importlib
import os
import functools
import itertools
import torch
from losses import get_optimizer
from models.ema import ExponentialMovingAverage

import torch.nn as nn
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_gan as tfgan
import tqdm
import io
import likelihood
import controllable_generation
from utils import restore_checkpoint
sns.set(font_scale=2)
sns.set(style="whitegrid")

import models
from models import utils as mutils
from models import ncsnv2
from models import ncsnpp3d
from models import ddpm as ddpm_model
from models import layerspp
from models import layers
from models import normalization
import sampling
from likelihood import get_likelihood_fn
from sde_lib import VESDE, VPSDE, subVPSDE
from sampling import (ReverseDiffusionPredictor, 
                      LangevinCorrector, 
                      EulerMaruyamaPredictor, 
                      AncestralSamplingPredictor, 
                      NoneCorrector, 
                      NonePredictor,
                      AnnealedLangevinDynamics)
import datasets

In [2]:
from configs.ve import brain_ncsnpp_continuous as configs
config = configs.get_config()
config

data:
  centered: false
  dataset: BRAIN
  dir_path: /DATA/Users/amahmood/braintyp/
  downsample_size: 200
  image_size: 184
  num_channels: 2
  random_flip: true
  uniform_dequantization: false
device: !!python/object/apply:torch.device
- cuda
- 0
eval:
  batch_size: 16
  begin_ckpt: 3
  bpd_dataset: inlier
  enable_bpd: true
  enable_loss: true
  enable_sampling: false
  end_ckpt: 3
  num_samples: 50000
  ood_eval: true
model:
  attention_type: ddpm
  attn_resolutions: !!python/tuple
  - 16
  beta_max: 20.0
  beta_min: 0.1
  ch_mult: !!python/tuple
  - 1
  - 1
  - 2
  - 2
  - 2
  - 2
  - 2
  conv_size: 3
  dropout: 0.0
  ema_rate: 0.999
  embedding_type: fourier
  fourier_scale: 16
  init_scale: 0.0
  name: ncsnpp3d
  nf: 128
  nonlinearity: swish
  normalization: GroupNorm
  num_res_blocks: 2
  num_scales: 2000
  scale_by_sigma: true
  sigma_max: 348
  sigma_min: 0.01
msma:
  min_timestep: 0.1
  n_timestep: 10
optim:
  beta1: 0.9
  eps: 1.0e-08
  grad_clip: 1.0
  lr: 0.0002
  optimi

In [3]:
# Initialize model.
score_model = mutils.create_model(config)

In [4]:
score_model

DataParallel(
  (module): SegResNetpp(
    (act): ReLU(inplace=True)
    (convInit): Convolution(
      (conv): Conv3d(2, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    )
    (down_layers): ModuleList(
      (0): MultiSequential(
        (0): SegResBlockpp(
          (pre_conv): Identity()
          (resblock): ResBlock(
            (norm1): GroupNorm(8, 8, eps=1e-05, affine=True)
            (norm2): GroupNorm(8, 8, eps=1e-05, affine=True)
            (relu): ReLU(inplace=True)
            (conv1): Convolution(
              (conv): Conv3d(8, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
            )
            (conv2): Convolution(
              (conv): Conv3d(8, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
            )
          )
          (dense): Linear(in_features=1024, out_features=8, bias=True)
          (act): SiLU()
        )
      )
      (1): MultiSequential(
        (0): SegResBl

In [5]:
x_test = torch.zeros(3,2,32,32,32).to(config.device)
t_test = torch.zeros(x_test.shape[0])

In [6]:
score_model(x_test, t_test)

tensor([[[[[ 0.5641,  0.6703,  0.6335,  ...,  0.2507,  0.1927,  0.3353],
           [ 0.7901,  0.7254,  0.7471,  ...,  0.1932,  0.2519,  0.5446],
           [ 0.5447,  0.6000,  0.6721,  ...,  0.2117,  0.3164,  0.4702],
           ...,
           [ 0.4362,  0.5373,  0.5541,  ...,  0.2584,  0.3494,  0.4189],
           [ 0.3495,  0.4315,  0.5155,  ...,  0.2546,  0.2259,  0.3153],
           [ 0.4519,  0.2200,  0.2659,  ...,  0.1029,  0.2108,  0.3066]],

          [[ 0.6503,  0.6832,  0.4642,  ...,  0.1625,  0.2263,  0.3570],
           [ 0.7519,  0.5651,  0.5234,  ...,  0.0678,  0.0474,  0.6206],
           [ 0.4336,  0.5558,  0.3832,  ..., -0.0656,  0.0287,  0.5765],
           ...,
           [ 0.7067,  0.5935,  0.2401,  ...,  0.4091,  0.5324,  0.3794],
           [ 0.6094,  0.6921,  0.5605,  ...,  0.5832,  0.5691,  0.3247],
           [ 0.6146,  0.4289,  0.4857,  ...,  0.4152,  0.5341,  0.3726]],

          [[ 0.5999,  0.8817,  0.7243,  ...,  0.2082,  0.1795,  0.3739],
           [ 0.

In [2]:
import re, glob
import monai
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

np.random.seed(seed=42) 

In [3]:
# !python -m pip install 'monai[all]'

In [7]:
DATADIR = "/DATA/Users/amahmood/braintyp/"

val_file_list = glob.glob(os.path.join(DATADIR,"val/*"))
DATADIR = "/DATA/ABCDFixRelease"
val_file_list[:4], len(val_file_list)

(['/DATA/Users/amahmood/braintyp/val/NDARINV84XXL5T1.nii.gz',
  '/DATA/Users/amahmood/braintyp/val/NDARINVRTDH8349.nii.gz',
  '/DATA/Users/amahmood/braintyp/val/NDARINVCDNAUD1U.nii.gz',
  '/DATA/Users/amahmood/braintyp/val/NDARINVJ35JMG06.nii.gz'],
 321)

In [97]:
from monai.data import CacheDataset, DataLoader, ArrayDataset
from monai.transforms import *

lambd = Lambda(func=lambda x: x[:, :, :,0,:])

img_transform = Compose(
    [
        LoadImage(image_only=True),
        SqueezeDim(dim=3),
        AsChannelFirst(),
        SpatialCrop(roi_start=[11,   9,   0], roi_end=[172, 205, 152]),
        DivisiblePad(k=8),
        RandAdjustContrast()
    ]
)

ds = ArrayDataset(val_file_list[:4], img_transform = img_transform)

In [98]:
x = ds[1]
nchannels, input_shape = x.shape[0], x.shape[1:]

In [99]:
nchannels, input_shape

(2, (168, 200, 152))

In [103]:
ds_loader = DataLoader(ds, batch_size=2, shuffle=False)

In [116]:
for x in ds_loader:
    print(x.shape)

torch.Size([2, 2, 168, 200, 152])
torch.Size([2, 2, 168, 200, 152])


In [4]:
# Setup SDE
import sde_lib
sde = sde_lib.subVPSDE(
    beta_min=config.model.beta_min,
    beta_max=config.model.beta_max,
    N=config.model.num_scales,
)
sampling_eps = 1e-3

In [56]:
n_timesteps=10
eps=1e-5
torch.linspace(sde.T, eps, n_timesteps, device="cpu")

tensor([1.0000e+00, 8.8889e-01, 7.7778e-01, 6.6667e-01, 5.5556e-01, 4.4445e-01,
        3.3334e-01, 2.2223e-01, 1.1112e-01, 1.0000e-05])

In [64]:
t=torch.linspace(sde.T, eps, n_timesteps, device="cpu")
std = sde.marginal_prob(torch.zeros_like(t), t)[1].numpy()
std

array([9.9995679e-01, 9.9964756e-01, 9.9775028e-01, 9.8876739e-01,
       9.5613301e-01, 8.6600149e-01, 6.7983985e-01, 4.0167153e-01,
       1.2538469e-01, 1.0132790e-06], dtype=float32)