In [None]:
import sys
sys.path.append('../')

import numpy as np
import matplotlib.pyplot as plt
import pathlib

from torch.utils.data import DataLoader

from utils.mri_data import SliceDataset
from utils.data_transform import DataTransform_Diffusion
from utils.sample_mask import RandomMaskGaussianDiffusion, RandomMaskDiffusion, RandomMaskDiffusion2D
from utils.misc import *
from help_func import print_var_detail

from diffusion.kspace_diffusion import KspaceDiffusion
from utils.diffusion_train import Trainer
from net.u_net_diffusion import Unet

print(torch.__version__)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('device:', device)

2.5.1+cu124
device: cuda:0


In [2]:
# # ****** TRAINING SETTINGS ******
# # dataset settings
# acc = 8  # acceleration factor
# frac_c = 0.04  # center fraction
# path_dir_train = '/home/alvin/UltrAi/Datasets/raw_datasets/fastmri/knee_singlecoil/singlecoil_train/'
# path_dir_test = '/home/alvin/UltrAi/Datasets/raw_datasets/fastmri/knee_singlecoil/singlecoil_test/'
# img_mode = 'fastmri'  # 'fastmri' or 'B1000'
# bhsz = 2
# img_size = 320

In [3]:
# ****** TRAINING SETTINGS ******
# dataset settings
acc = 4  # acceleration factor
frac_c = 0.04  # center fraction
path_dir_train = '/home/alvin/UltrAi/Datasets/raw_datasets/m4raw/sample/multicoil_train'
path_dir_test = '/home/alvin/UltrAi/Datasets/raw_datasets/m4raw/sample/multicoil_test_masked'
img_mode = 'fastmri'  # 'fastmri' or 'B1000'
bhsz = 1
img_size = 256

In [8]:
# ====== Construct dataset ======
# initialize mask
mask_func = RandomMaskDiffusion(
    acceleration=acc,
    center_fraction=frac_c,
    size=(1, img_size, img_size),
)

# initialize dataset
data_transform = DataTransform_Diffusion(
    mask_func,
    img_size=img_size,
    combine_coil=True,
    flag_singlecoil=False,
)

# training set
dataset_train = SliceDataset(
    root=pathlib.Path(path_dir_train),
    transform=data_transform,
    challenge='multicoil',
    num_skip_slice=5,
)

# test set
dataset_test = SliceDataset(
    root=pathlib.Path(path_dir_test),
    transform=data_transform,
    challenge='multicoil',
    num_skip_slice=5,
)

dataloader_train = DataLoader(dataset_train, batch_size=bhsz, shuffle=True)
dataloader_test = DataLoader(dataset_test, batch_size=bhsz, shuffle=True)
print('len dataloader train:', len(dataloader_train))
print('len dataloader test:', len(dataloader_test))

len dataloader train: 45
len dataloader test: 18


In [9]:
# model settings
CH_MID = 64
# training settings
NUM_EPOCH = 50
learning_rate = 2e-5
time_steps = 1000
train_steps = NUM_EPOCH * len(dataloader_train) # can be customized to a fixed number, however, it should reflect the dataset size.
train_steps = max(train_steps, 700000)
print('train_steps:',train_steps)
# save settings
PATH_MODEL = '../saved_models/fastmri_knee/diffusion_'+str(img_mode)+'_'+str(acc)+'x_T'+str(time_steps)+'_S'+str(train_steps)+'/'
create_path(PATH_MODEL)

train_steps: 700000
Path already exists.


In [10]:
# construct diffusion model
save_folder=PATH_MODEL
load_path=None
blur_routine='Constant'
train_routine='Final'
sampling_routine='x0_step_down'
discrete=False

model = Unet(
    dim=CH_MID,
    dim_mults=(1, 2, 4, 8),
    channels=2,
).cuda()
print('model size: %.3f MB' % (calc_model_size(model)))

diffusion = KspaceDiffusion(
    model,
    image_size=img_size,
    device_of_kernel='cuda',
    channels=2,
    timesteps=time_steps,  # number of steps
    loss_type='l1',  # L1 or L2
    blur_routine=blur_routine,
    train_routine=train_routine,
    sampling_routine=sampling_routine,
    discrete=discrete,
).cuda()

Is Time embed used ?  True
model size: 53.992 MB


In [None]:
# construct trainer and train

trainer = Trainer(
    diffusion,
    image_size=img_size,
    train_batch_size=bhsz,
    train_lr=learning_rate,
    train_num_steps=train_steps,  # total training steps
    gradient_accumulate_every=2,  # gradient accumulation steps
    ema_decay=0.995,  # exponential moving average decay
    fp16=False,  # turn on mixed precision training with apex
    save_and_sample_every=50000,
    results_folder=save_folder,
    load_path=load_path,
    dataloader_train=dataloader_train,
    dataloader_test=dataloader_test,
)
trainer.train()

LOSS:   0%|          | 0/700000 [00:00<?, ?it/s]

Loss=0.084801:   0%|          | 193/700000 [02:52<168:58:19,  1.15it/s]