In [4]:
from pathlib import Path

if not Path("./small_fastmri_pd_3t").is_dir():
    !gdown --id "1y78Ad6WwQpMGtxfEZlp97A0iV98kAiJN"
    !unzip -q small_fastmri_pd_3t.zip && rm small_fastmri_pd_3t.zip
    
if not Path("./dncnn-noiseless.pth").is_file():
    !gdown --id "1azlqmuIkdhcsMQJL_YObF4sEe83D8J8N"

In [5]:
import os
import sys
import numpy as np
import h5py
import pylab as plt
import torch

import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
from k_space_reconstruction.nets.cdn_dncnn import DnCNNDCModule, CascadeModule #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
from k_space_reconstruction.datasets.fastmri import FastMRITransform, FastMRIh5Dataset, RandomMaskFunc
from k_space_reconstruction.utils.metrics import pt_msssim, pt_ssim, ssim, nmse, psnr
from k_space_reconstruction.utils.loss import l1_loss, compund_mssim_l1_loss
from k_space_reconstruction.utils.kspace import spatial2kspace, kspace2spatial

print('Available GPUs: ', torch.cuda.device_count())

Available GPUs:  3


# Dataset initialization

In [9]:
transform = FastMRITransform(
    RandomMaskFunc([0.08], [4]),
    noise_level=1000,
    noise_type='none'
)

train_dataset = FastMRIh5Dataset('small_fastmri_pd_3t/train.h5', transform)
val_dataset = FastMRIh5Dataset('small_fastmri_pd_3t/val.h5', transform)
train_generator = torch.utils.data.DataLoader(train_dataset, batch_size=1, num_workers=12)
val_generator = torch.utils.data.DataLoader(val_dataset, batch_size=1, num_workers=12)

OSError: Unable to create file (unable to open file: name = 'small_fastmri_pd_3t/train.h5', errno = 17, error message = 'File exists', flags = 15, o_flags = c2)

# Model definition
also, we load weight of trained unet

In [None]:
model_kwargs = dict(
    dncnn_chans=64, 
    dncnn_depth=12, 
    criterion=compund_mssim_l1_loss,
    verbose_batch=50, 
    optimizer='Adam',
    lr=1e-4,
    lr_step_size=3,
    lr_gamma=0.2,
    weight_decay=0.0
)

cascade = CascadeModule(net=torch.nn.ModuleList([DnCNNDCModule(**model_kwargs).net]), **model_kwargs)
cascade.net[0].cascade[0].load_state_dict(torch.load('dncnn-noiseless.pth'))

In [None]:
def get_trainer():
    return pl.Trainer(
        gpus=1, max_epochs=5,
        accumulate_grad_batches=32,
        terminate_on_nan=True,
        default_root_dir='logs/CascadeDnCNN',
        callbacks=[
            pl.callbacks.ModelCheckpoint(
                save_last=True,
                save_top_k=7, 
                monitor='val_loss', 
                filename='{epoch}-{ssim:.4f}-{psnr:.4f}-{nmse:.5f}'
            ),
            pl.callbacks.LearningRateMonitor(logging_interval='epoch'),
            pl.callbacks.GPUStatsMonitor(temperature=True)
        ]
    )

# Tensorboard logging

In [None]:
%load_ext tensorboard
%tensorboard --logdir logs/ --port 8001

# Cascade trainig
### Iterative train.
We sequentionaly train block (DnCNN+ DC), freeze theese layers and append new block N times

In [None]:
num_blocks = 5
for i in range(num_blocks):
    # Train cascade block
    trainer = get_trainer()
    trainer.fit(cascade, train_dataloader=train_generator, val_dataloaders=val_generator)
    # Freeze last cascade blocks
    for param in cascade.net.parameters():
        param.requires_grad = False
    # Add new block to cascade
    if i != num_blocks - 1:
        cascade = CascadeModule(net=cascade.net.append(DnCNNDCModule(**model_kwargs).net), **model_kwargs)
        # Load statedict for unet in last trainable block
        cascade.net[-1].cascade[0].load_state_dict(torch.load('dncnn-noiseless.pth'))

# Cascade finetuning
Train one epoch all layers in cascade

In [None]:
cascade = CascadeModule\
.load_from_checkpoint('logs/CascadeUNet/lightning_logs/version_4/checkpoints/last.ckpt', #<-----HERE!!!!!!!!!!1111
                      net=torch.nn.ModuleList([DnCNNDCModule(**model_kwargs).net for _ in range(5)]))

trainer = pl.Trainer(
    gpus=1, max_epochs=3,
    accumulate_grad_batches=32,
    terminate_on_nan=True,
    default_root_dir='logs/CascadeDnCNN',
    callbacks=[
        pl.callbacks.ModelCheckpoint(
            save_last=True,
            save_top_k=7, 
            monitor='val_loss', 
            filename='{epoch}-{ssim:.4f}-{psnr:.4f}-{nmse:.5f}'
        ),
        pl.callbacks.LearningRateMonitor(logging_interval='epoch'),
        pl.callbacks.GPUStatsMonitor(temperature=True)
    ]
)
for param in cascade.net.parameters():
    param.requires_grad = True
trainer.fit(cascade, train_dataloader=train_generator, val_dataloaders=val_generator)

# Test model
Load best checkpoint, inference on val dataset and save predictions to .h5 file in logs directory

In [None]:
net = CascadeModule\
.load_from_checkpoint('logs/CascadeDnCNN/lightning_logs/version_5/checkpoints/last.ckpt', 
                      net=torch.nn.ModuleList([UnetDCModule(**model_kwargs).net for _ in range(5)]))\
.eval()

In [None]:
trainer = get_trainer()
trainer.test(net, val_generator)

# Val metrics

In [None]:
hf_pred = h5py.File('logs/CascadeUNet/2021-05-13 20:38:03.416343.h5')
hf_gt = h5py.File('small_fastmri_pd_3t/val.h5')

ssim_vals = []
nmse_vals = []
psnr_vals = []
for k in hf_pred.keys():
    ks = hf_gt[k][:] * 1e6
    gt = np.stack([kspace2spatial(k) for k in ks])
    pred = hf_pred[k][:,0]
    ssim_vals.append(ssim(gt, pred))
    nmse_vals.append(nmse(gt, pred))
    psnr_vals.append(psnr(gt, pred))
ssim_vals = np.array(ssim_vals)
nmse_vals = np.array(nmse_vals)
psnr_vals = np.array(psnr_vals)

np.mean(ssim_vals), np.mean(nmse_vals), np.mean(psnr_vals)

# Saving Weights of the Model 

In [None]:
torch.save(net.net.state_dict(), 'cascade-x5-dncnn-dc-noiseless.pth')