In [1]:
%env CUDA_VISIBLE_DEVICES=2
from pathlib import Path

if not Path("./small_fastmri_pd_3t").is_dir():
    !gdown --id "1y78Ad6WwQpMGtxfEZlp97A0iV98kAiJN"
    !unzip -q small_fastmri_pd_3t.zip && rm small_fastmri_pd_3t.zip
    
if not Path("./dncnn-noiseless.pth").is_file():
    !gdown --id "1azlqmuIkdhcsMQJL_YObF4sEe83D8J8N"

env: CUDA_VISIBLE_DEVICES=2


In [2]:
import os
import sys
import numpy as np
import h5py
import pylab as plt
import torch

import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
from k_space_reconstruction.nets.cdn_dncnn import DnCNNDCModule, CascadeModule
from k_space_reconstruction.datasets.fastmri import FastMRITransform, FastMRIh5Dataset, RandomMaskFunc
from k_space_reconstruction.utils.metrics import pt_msssim, pt_ssim, ssim, nmse, psnr
from k_space_reconstruction.utils.loss import l1_loss, compund_mssim_l1_loss
from k_space_reconstruction.utils.kspace import spatial2kspace, kspace2spatial

print('Available GPUs: ', torch.cuda.device_count())

Available GPUs:  1


# Dataset initialization

In [3]:
torch.manual_seed(42)
np.random.seed(42)


transform = FastMRITransform(
    RandomMaskFunc([0.08], [4]),
    noise_level=1000,
    noise_type='none'
)

train_dataset = FastMRIh5Dataset('small_fastmri_pd_3t/train.h5', transform)
val_dataset = FastMRIh5Dataset('small_fastmri_pd_3t/val.h5', transform)
train_generator = torch.utils.data.DataLoader(train_dataset, batch_size=42, num_workers=12, shuffle=True)
val_generator = torch.utils.data.DataLoader(val_dataset, batch_size=1, num_workers=12)

# Model definition
also, we load weight of trained unet

In [4]:
model_kwargs = dict(
    dncnn_chans=64, 
    dncnn_depth=10, 
    criterion=compund_mssim_l1_loss,
    verbose_batch=50, 
    optimizer='Adam',
    lr=1e-3,
    lr_step_size=3,
    lr_gamma=1.0,
    weight_decay=0.0
)

cascade = CascadeModule(net=torch.nn.ModuleList([DnCNNDCModule(**model_kwargs).net]), **model_kwargs)
cascade.net[0].cascade[0].load_state_dict(torch.load('dncnn-noiseless.pth'))

<All keys matched successfully>

In [5]:
def get_trainer():
    return pl.Trainer(
        gpus=1, max_epochs=10,
        accumulate_grad_batches=3,
        terminate_on_nan=True,
        default_root_dir='logs/CascadeDnCNN',
        callbacks=[
            pl.callbacks.ModelCheckpoint(
                save_last=True,
                save_top_k=7, 
                monitor='val_loss', 
                filename='{epoch}-{ssim:.4f}-{psnr:.4f}-{nmse:.5f}'
            ),
            pl.callbacks.LearningRateMonitor(logging_interval='epoch'),
            pl.callbacks.GPUStatsMonitor(temperature=True)
        ]
    )

# Tensorboard logging

In [None]:
%load_ext tensorboard
%tensorboard --logdir logs/ --port 8001

Launching TensorBoard...

# Cascade trainig
### Iterative train.
We sequentionaly train block (DnCNN+ DC), freeze theese layers and append new block N times

In [7]:
num_blocks = 5
for i in range(num_blocks):
    # Train cascade block
    trainer = get_trainer()
    trainer.fit(cascade, train_dataloader=train_generator, val_dataloaders=val_generator)
    # Freeze last cascade blocks
    for param in cascade.net.parameters():
        param.requires_grad = False
    # Add new block to cascade
    if i != num_blocks - 1:
        cascade = CascadeModule(net=cascade.net.append(DnCNNDCModule(**model_kwargs).net), **model_kwargs)
        # Load statedict for unet in last trainable block
        cascade.net[-1].cascade[0].load_state_dict(torch.load('dncnn-noiseless.pth'))

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [2]

  | Name             | Type                 | Params
----------------------------------------------------------
0 | NMSE             | DistributedMetricSum | 0     
1 | SSIM             | DistributedMetricSum | 0     
2 | PSNR             | DistributedMetricSum | 0     
3 | ValLoss          | DistributedMetricSum | 0     
4 | TotExamples      | DistributedMetricSum | 0     
5 | TotSliceExamples | DistributedMetricSum | 0     
6 | net              | ModuleList           | 297 K 


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Saving latest checkpoint...
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [2]

  | Name             | Type                 | Params
----------------------------------------------------------
0 | NMSE             | DistributedMetricSum | 0     
1 | SSIM             | DistributedMetricSum | 0     
2 | PSNR             | DistributedMetricSum | 0     
3 | ValLoss          | DistributedMetricSum | 0     
4 | TotExamples      | DistributedMetricSum | 0     
5 | TotSliceExamples | DistributedMetricSum | 0     
6 | net              | ModuleList           | 594 K 


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Saving latest checkpoint...
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [2]

  | Name             | Type                 | Params
----------------------------------------------------------
0 | NMSE             | DistributedMetricSum | 0     
1 | SSIM             | DistributedMetricSum | 0     
2 | PSNR             | DistributedMetricSum | 0     
3 | ValLoss          | DistributedMetricSum | 0     
4 | TotExamples      | DistributedMetricSum | 0     
5 | TotSliceExamples | DistributedMetricSum | 0     
6 | net              | ModuleList           | 891 K 


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Saving latest checkpoint...
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [2]

  | Name             | Type                 | Params
----------------------------------------------------------
0 | NMSE             | DistributedMetricSum | 0     
1 | SSIM             | DistributedMetricSum | 0     
2 | PSNR             | DistributedMetricSum | 0     
3 | ValLoss          | DistributedMetricSum | 0     
4 | TotExamples      | DistributedMetricSum | 0     
5 | TotSliceExamples | DistributedMetricSum | 0     
6 | net              | ModuleList           | 1.2 M 


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Saving latest checkpoint...
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [2]

  | Name             | Type                 | Params
----------------------------------------------------------
0 | NMSE             | DistributedMetricSum | 0     
1 | SSIM             | DistributedMetricSum | 0     
2 | PSNR             | DistributedMetricSum | 0     
3 | ValLoss          | DistributedMetricSum | 0     
4 | TotExamples      | DistributedMetricSum | 0     
5 | TotSliceExamples | DistributedMetricSum | 0     
6 | net              | ModuleList           | 1.5 M 


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Saving latest checkpoint...


# Cascade finetuning
Train one epoch all layers in cascade

In [7]:
transform = FastMRITransform(
    RandomMaskFunc([0.08], [4]),
    noise_level=1000,
    noise_type='none'
)

train_dataset = FastMRIh5Dataset('small_fastmri_pd_3t/train.h5', transform)
val_dataset = FastMRIh5Dataset('small_fastmri_pd_3t/val.h5', transform)
train_generator = torch.utils.data.DataLoader(train_dataset, batch_size=8, num_workers=12, shuffle=True)
val_generator = torch.utils.data.DataLoader(val_dataset, batch_size=1, num_workers=12)

In [None]:
cascade = CascadeModule\
.load_from_checkpoint('logs/CascadeDnCNN/lightning_logs/version_4/checkpoints/last.ckpt', #<-----HERE!!!!!!!!!!1111
                      net=torch.nn.ModuleList([DnCNNDCModule(**model_kwargs).net for _ in range(5)]))

trainer = pl.Trainer(
    gpus=1, max_epochs=20,
    accumulate_grad_batches=16,
    terminate_on_nan=True,
    default_root_dir='logs/CascadeDnCNN',
    callbacks=[
        pl.callbacks.ModelCheckpoint(
            save_last=True,
            save_top_k=7, 
            monitor='val_loss', 
            filename='{epoch}-{ssim:.4f}-{psnr:.4f}-{nmse:.5f}'
        ),
        pl.callbacks.LearningRateMonitor(logging_interval='epoch'),
        pl.callbacks.GPUStatsMonitor(temperature=True)
    ]
)
for param in cascade.net.parameters():
    param.requires_grad = True
trainer.fit(cascade, train_dataloader=train_generator, val_dataloaders=val_generator)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [2]

  | Name             | Type                 | Params
----------------------------------------------------------
0 | NMSE             | DistributedMetricSum | 0     
1 | SSIM             | DistributedMetricSum | 0     
2 | PSNR             | DistributedMetricSum | 0     
3 | ValLoss          | DistributedMetricSum | 0     
4 | TotExamples      | DistributedMetricSum | 0     
5 | TotSliceExamples | DistributedMetricSum | 0     
6 | net              | ModuleList           | 1.5 M 


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

# Test model
Load best checkpoint, inference on val dataset and save predictions to .h5 file in logs directory

In [9]:
net = CascadeModule\
.load_from_checkpoint('logs/CascadeDnCNN/lightning_logs/version_5/checkpoints/last.ckpt', 
                      net=torch.nn.ModuleList([DnCNNDCModule(**model_kwargs).net for _ in range(5)]))\
.eval()

In [10]:
trainer = get_trainer()
trainer.test(net, val_generator)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [2]


Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------


1

# Val metrics

In [11]:
hf_pred = h5py.File('logs/CascadeDnCNN/2021-05-18 17:25:13.505632.h5')
hf_gt = h5py.File('small_fastmri_pd_3t/val.h5')

ssim_vals = []
nmse_vals = []
psnr_vals = []
for k in hf_pred.keys():
    ks = hf_gt[k][:] * 1e6
    gt = np.stack([kspace2spatial(k) for k in ks])
    pred = hf_pred[k][:,0]
    ssim_vals.append(ssim(gt, pred))
    nmse_vals.append(nmse(gt, pred))
    psnr_vals.append(psnr(gt, pred))
ssim_vals = np.array(ssim_vals)
nmse_vals = np.array(nmse_vals)
psnr_vals = np.array(psnr_vals)

np.mean(ssim_vals), np.mean(nmse_vals), np.mean(psnr_vals)

(0.8394470749603186, 0.00729109549092002, 33.299176701801244)

# Saving Weights of the Model 

In [12]:
torch.save(net.net.state_dict(), 'cascade-x5-dncnn-dc-noiseless.pth')