In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pathlib

from torch.utils.data import DataLoader

from utils.mri_data import SliceDataset
from utils.data_transform import DataTransform_Diffusion
from utils.sample_mask import RandomMaskGaussianDiffusion, RandomMaskDiffusion, RandomMaskDiffusion2D
from utils.misc import *
from help_func import print_var_detail

from diffusion.kspace_diffusion import KspaceDiffusion
from utils.diffusion_train import Trainer
from net.u_net_diffusion import Unet

print(torch.__version__)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('device:', device)

2.0.0
device: cuda:0


In [2]:
# ****** TRAINING SETTINGS ******
# dataset settings
acc = 4  # acceleration factor
frac_c = 0.08  # center fraction
path_dir_train = 'C:/TortoiseGitRepos/datasets/fastmri/knee_singlecoil_train_full/'
path_dir_test = 'C:/TortoiseGitRepos/datasets/fastmri/knee_singlecoil_test_5/'
img_mode = 'fastmri'  # 'fastmri' or 'B1000'
bhsz = 6
img_size = 320

In [3]:
# ====== Construct dataset ======
# initialize mask
mask_func = RandomMaskGaussianDiffusion(
    acceleration=acc,
    center_fraction=frac_c,
    size=(1, img_size, img_size),
)

# initialize dataset
data_transform = DataTransform_Diffusion(
    mask_func,
    img_size=img_size,
    combine_coil=True,
    flag_singlecoil=True,
)

# training set
dataset_train = SliceDataset(
    root=pathlib.Path(path_dir_train),
    transform=data_transform,
    challenge='singlecoil',
    num_skip_slice=5,
)

# test set
dataset_test = SliceDataset(
    root=pathlib.Path(path_dir_test),
    transform=data_transform,
    challenge='singlecoil',
    num_skip_slice=5,
)

dataloader_train = DataLoader(dataset_train, batch_size=bhsz, shuffle=True)
dataloader_test = DataLoader(dataset_test, batch_size=bhsz, shuffle=True)
print('len dataloader train:', len(dataloader_train))
print('len dataloader test:', len(dataloader_test))

len dataloader train: 4331
len dataloader test: 24


In [4]:
# model settings
CH_MID = 64
# training settings
NUM_EPOCH = 50
learning_rate = 2e-5
time_steps = 1000
train_steps = NUM_EPOCH * len(dataloader_train) # can be customized to a fixed number, however, it should reflect the dataset size.
train_steps = max(train_steps, 700000)
print('train_steps:',train_steps)
# save settings
PATH_MODEL = '../saved_models/fastmri_knee/diffusion_Gmask_'+str(img_mode)+'_'+str(acc)+'x_T'+str(time_steps)+'_S'+str(train_steps)+'/'
create_path(PATH_MODEL)

train_steps: 700000
The new directory is created!


In [5]:
# construct diffusion model
save_folder=PATH_MODEL
load_path=None
blur_routine='Constant'
train_routine='Final'
sampling_routine='x0_step_down'
discrete=False

model = Unet(
    dim=CH_MID,
    dim_mults=(1, 2, 4, 8),
    channels=2,
).cuda()
print('model size: %.3f MB' % (calc_model_size(model)))

diffusion = KspaceDiffusion(
    model,
    image_size=img_size,
    device_of_kernel='cuda',
    channels=2,
    timesteps=time_steps,  # number of steps
    loss_type='l1',  # L1 or L2
    blur_routine=blur_routine,
    train_routine=train_routine,
    sampling_routine=sampling_routine,
    discrete=discrete,
).cuda()

Is Time embed used ?  True
model size: 53.992 MB


In [6]:
# construct trainer and train

trainer = Trainer(
    diffusion,
    image_size=img_size,
    train_batch_size=bhsz,
    train_lr=learning_rate,
    train_num_steps=train_steps,  # total training steps
    gradient_accumulate_every=2,  # gradient accumulation steps
    ema_decay=0.995,  # exponential moving average decay
    fp16=False,  # turn on mixed precision training with apex
    save_and_sample_every=50000,
    results_folder=save_folder,
    load_path=load_path,
    dataloader_train=dataloader_train,
    dataloader_test=dataloader_test,
)
trainer.train()

Loss=0.020132:   7%|▋         | 50000/700000 [15:03:27<194:58:47,  1.08s/it]

Mean LOSS of last 50000: 0.022516000228856598


Loss=0.020229:  14%|█▍        | 100000/700000 [30:05:51<180:27:02,  1.08s/it]

Mean LOSS of last 100000: 0.020917338005897335


Loss=0.025851:  21%|██▏       | 150000/700000 [45:07:23<165:04:13,  1.08s/it]

Mean LOSS of last 150000: 0.020691756511561042


Loss=0.017474:  29%|██▊       | 200000/700000 [60:07:34<150:25:38,  1.08s/it]

Mean LOSS of last 200000: 0.02056140551326685


Loss=0.019450:  36%|███▌      | 250000/700000 [75:07:21<134:41:35,  1.08s/it]

Mean LOSS of last 250000: 0.020482118282787606


Loss=0.027134:  43%|████▎     | 300000/700000 [90:07:10<120:11:26,  1.08s/it]

Mean LOSS of last 300000: 0.020447353555304883


Loss=0.024909:  50%|█████     | 350000/700000 [105:09:18<104:52:48,  1.08s/it]

Mean LOSS of last 350000: 0.020411293396658228


Loss=0.021947:  57%|█████▋    | 400000/700000 [120:09:17<89:58:49,  1.08s/it] 

Mean LOSS of last 400000: 0.02037504099915388


Loss=0.022954:  64%|██████▍   | 450000/700000 [135:09:07<74:56:41,  1.08s/it] 

Mean LOSS of last 450000: 0.020347583682777032


Loss=0.018148:  71%|███████▏  | 500000/700000 [150:08:55<59:58:04,  1.08s/it]

Mean LOSS of last 500000: 0.020326228611408096


Loss=0.017249:  79%|███████▊  | 550000/700000 [165:09:07<44:56:50,  1.08s/it]

Mean LOSS of last 550000: 0.020314246162739993


Loss=0.020000:  86%|████████▌ | 600000/700000 [180:11:39<30:02:04,  1.08s/it]

Mean LOSS of last 600000: 0.02030197813940629


Loss=0.019887:  93%|█████████▎| 650000/700000 [195:12:43<15:01:46,  1.08s/it]

Mean LOSS of last 650000: 0.02029891712274781


Loss=0.021574: 100%|██████████| 700000/700000 [210:13:33<00:00,  1.08s/it]   


training completed
