In [1]:
%env CUDA_VISIBLE_DEVICES=2
from pathlib import Path

if not Path("./small_fastmri_pd_3t").is_dir():
    !gdown --id "1y78Ad6WwQpMGtxfEZlp97A0iV98kAiJN"
    !unzip -q small_fastmri_pd_3t.zip && rm small_fastmri_pd_3t.zip
    
if not Path("./dncnn-noiseless.pth").is_file():
    !gdown --id "1azlqmuIkdhcsMQJL_YObF4sEe83D8J8N"

env: CUDA_VISIBLE_DEVICES=2


In [2]:
import os
import sys
import numpy as np
import h5py
import pylab as plt
import torch

import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
from k_space_reconstruction.nets.cdn_dncnn import DnCNNDCLModule, CascadeModule, PureDnCNNDCModule
from k_space_reconstruction.datasets.fastmri import FastMRITransform, FastMRIh5Dataset, RandomMaskFunc
from k_space_reconstruction.utils.metrics import pt_msssim, pt_ssim, ssim, nmse, psnr
from k_space_reconstruction.utils.loss import l1_loss, compund_mssim_l1_loss
from k_space_reconstruction.utils.kspace import spatial2kspace, kspace2spatial

print('Available GPUs: ', torch.cuda.device_count())

Available GPUs:  1


In [3]:
model_kwargs = dict(
    dncnn_chans=64,
    dncnn_depth=10,
    criterion=compund_mssim_l1_loss,
    verbose_batch=50,
    optimizer='Adam',
    lr=1e-4,
    lr_step_size=3,
    lr_gamma=0.2,
    weight_decay=0.0
)

In [4]:
def get_trainer():
    return pl.Trainer(
        gpus=1, max_epochs=10,
        accumulate_grad_batches=3,
        terminate_on_nan=True,
        default_root_dir='logs/CascadeDnCNN_pure',
        callbacks=[
            pl.callbacks.ModelCheckpoint(
                save_last=True,
                save_top_k=7,
                monitor='val_loss',
                filename='{epoch}-{ssim:.4f}-{psnr:.4f}-{nmse:.5f}'
            ),
            pl.callbacks.LearningRateMonitor(logging_interval='epoch'),
            pl.callbacks.GPUStatsMonitor(temperature=True)
        ]
    )

# Test model
Load best checkpoint, inference on val dataset and save predictions to .h5 file in logs directory

In [5]:
torch.manual_seed(42)
np.random.seed(42)

transform = FastMRITransform(
    RandomMaskFunc([0.08], [4]),
    noise_level=1000,
    noise_type='none'
)

val_dataset = FastMRIh5Dataset('small_fastmri_pd_3t/val.h5', transform)
val_generator = torch.utils.data.DataLoader(val_dataset, batch_size=1, num_workers=12)

In [6]:
cascade = CascadeModule(net=torch.nn.ModuleList([PureDnCNNDCModule(**model_kwargs).net for _ in range(5)]), **model_kwargs)
cascade.net.load_state_dict(torch.load('cascade-x5-dncnn_pure-dc-noiseless.pth'))

<All keys matched successfully>

In [8]:
trainer = get_trainer()
trainer.test(cascade, val_generator)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [2]


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------


1

# Val metrics

In [9]:
hf_pred = h5py.File('logs/CascadeDnCNN_pure/2021-05-23 12:24:23.051854.h5')
hf_gt = h5py.File('small_fastmri_pd_3t/val.h5')

ssim_vals = []
nmse_vals = []
psnr_vals = []
for k in hf_pred.keys():
    ks = hf_gt[k][:] * 1e6
    gt = np.stack([kspace2spatial(k) for k in ks])
    pred = hf_pred[k][:,0]
    ssim_vals.append(ssim(gt, pred))
    nmse_vals.append(nmse(gt, pred))
    psnr_vals.append(psnr(gt, pred))
ssim_vals = np.array(ssim_vals)
nmse_vals = np.array(nmse_vals)
psnr_vals = np.array(psnr_vals)

np.mean(ssim_vals), np.mean(nmse_vals), np.mean(psnr_vals)

(0.771267379633742, 0.01542535782583012, 29.926976604711154)

### Gaussian Noise

In [10]:
torch.manual_seed(42)
np.random.seed(42)

transform = FastMRITransform(
    RandomMaskFunc([0.08], [4]),
    noise_level=100,
    noise_type='normal'
)

val_dataset = FastMRIh5Dataset('small_fastmri_pd_3t/val.h5', transform)
val_generator = torch.utils.data.DataLoader(val_dataset, batch_size=1, num_workers=12)

cascade = CascadeModule(net=torch.nn.ModuleList([PureDnCNNDCModule(**model_kwargs).net for _ in range(5)]), **model_kwargs)
cascade.net.load_state_dict(torch.load('cascade-x5-dncnn_pure-dc-gaussian.pth'))

trainer = get_trainer()
trainer.test(cascade, val_generator)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [2]


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------


1

In [11]:
hf_pred = h5py.File('logs/CascadeDnCNN_pure/2021-05-23 12:25:51.727337.h5')
hf_gt = h5py.File('small_fastmri_pd_3t/val.h5')

ssim_vals = []
nmse_vals = []
psnr_vals = []
for k in hf_pred.keys():
    ks = hf_gt[k][:] * 1e6
    gt = np.stack([kspace2spatial(k) for k in ks])
    pred = hf_pred[k][:,0]
    ssim_vals.append(ssim(gt, pred))
    nmse_vals.append(nmse(gt, pred))
    psnr_vals.append(psnr(gt, pred))
ssim_vals = np.array(ssim_vals)
nmse_vals = np.array(nmse_vals)
psnr_vals = np.array(psnr_vals)

np.mean(ssim_vals), np.mean(nmse_vals), np.mean(psnr_vals)

(0.6601974306506303, 0.02273588210173777, 28.30644627107848)

### Salt&Pepper Noise

In [12]:
torch.manual_seed(42)
np.random.seed(42)

transform = FastMRITransform(
    RandomMaskFunc([0.08], [4]),
    noise_level=5e4,
    noise_type='salt'
)

val_dataset = FastMRIh5Dataset('small_fastmri_pd_3t/val.h5', transform)
val_generator = torch.utils.data.DataLoader(val_dataset, batch_size=1, num_workers=12)

cascade = CascadeModule(net=torch.nn.ModuleList([PureDnCNNDCModule(**model_kwargs).net for _ in range(5)]), **model_kwargs)
cascade.net.load_state_dict(torch.load('cascade-x5-dncnn_pure-dc-salt.pth'))

trainer.test(cascade, val_generator)

Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------


1

In [13]:
hf_pred = h5py.File('logs/CascadeDnCNN_pure/2021-05-23 12:26:41.334215.h5')
hf_gt = h5py.File('small_fastmri_pd_3t/val.h5')

ssim_vals = []
nmse_vals = []
psnr_vals = []
for k in hf_pred.keys():
    ks = hf_gt[k][:] * 1e6
    gt = np.stack([kspace2spatial(k) for k in ks])
    pred = hf_pred[k][:,0]
    ssim_vals.append(ssim(gt, pred))
    nmse_vals.append(nmse(gt, pred))
    psnr_vals.append(psnr(gt, pred))
ssim_vals = np.array(ssim_vals)
nmse_vals = np.array(nmse_vals)
psnr_vals = np.array(psnr_vals)

np.mean(ssim_vals), np.mean(nmse_vals), np.mean(psnr_vals)

(0.427608139076021, 0.06543111281323623, 23.928479712829827)

### Gaussian + Salt&Pepper Noise

In [14]:
torch.manual_seed(42)
np.random.seed(42)

transform = FastMRITransform(
    RandomMaskFunc([0.08], [4]),
    noise_level=5e4,
    noise_type='normal_and_salt'
)

val_dataset = FastMRIh5Dataset('small_fastmri_pd_3t/val.h5', transform)
val_generator = torch.utils.data.DataLoader(val_dataset, batch_size=1, num_workers=12)

cascade = CascadeModule(net=torch.nn.ModuleList([PureDnCNNDCModule(**model_kwargs).net for _ in range(5)]), **model_kwargs)
cascade.net.load_state_dict(torch.load('cascade-x5-dncnn_pure-dc-normal-and-salt.pth'))

trainer.test(cascade, val_generator)

Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------


1

In [15]:
hf_pred = h5py.File('logs/CascadeDnCNN_pure/2021-05-23 12:27:45.653285.h5')
hf_gt = h5py.File('small_fastmri_pd_3t/val.h5')

ssim_vals = []
nmse_vals = []
psnr_vals = []
for k in hf_pred.keys():
    ks = hf_gt[k][:] * 1e6
    gt = np.stack([kspace2spatial(k) for k in ks])
    pred = hf_pred[k][:,0]
    ssim_vals.append(ssim(gt, pred))
    nmse_vals.append(nmse(gt, pred))
    psnr_vals.append(psnr(gt, pred))
ssim_vals = np.array(ssim_vals)
nmse_vals = np.array(nmse_vals)
psnr_vals = np.array(psnr_vals)

np.mean(ssim_vals), np.mean(nmse_vals), np.mean(psnr_vals)

(0.4551870839661391, 0.07177892930233501, 23.933858410713643)