In [1]:
%env CUDA_VISIBLE_DEVICES=3
from pathlib import Path

if not Path("./small_fastmri_pd_3t").is_dir():
    !gdown --id "1y78Ad6WwQpMGtxfEZlp97A0iV98kAiJN"
    !unzip -q small_fastmri_pd_3t.zip && rm small_fastmri_pd_3t.zip

/bin/bash: gdown: command not found
unzip:  cannot find or open small_fastmri_pd_3t.zip, small_fastmri_pd_3t.zip.zip or small_fastmri_pd_3t.zip.ZIP.


In [2]:
import os
import sys
import numpy as np
import h5py
import pylab as plt
import torch

import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
from k_space_reconstruction.nets.unet import UnetModule
from k_space_reconstruction.datasets.fastmri import FastMRITransform, FastMRIh5Dataset, RandomMaskFunc
from k_space_reconstruction.utils.metrics import pt_msssim, pt_ssim, ssim, nmse, psnr
from k_space_reconstruction.utils.loss import l1_loss, compund_mssim_l1_loss
from k_space_reconstruction.utils.kspace import spatial2kspace, kspace2spatial

# Dataset initialization

In [3]:
transform = FastMRITransform(
    RandomMaskFunc([0.08], [4]),
    noise_level=1000,
    noise_type='none'
)

train_dataset = FastMRIh5Dataset('small_fastmri_pd_3t/train.h5', transform)
val_dataset = FastMRIh5Dataset('small_fastmri_pd_3t/val.h5', transform)
train_generator = torch.utils.data.DataLoader(train_dataset, batch_size=1, num_workers=12)
val_generator = torch.utils.data.DataLoader(val_dataset, batch_size=1, num_workers=12)

# Model definition

In [4]:
net = UnetModule(
    unet_chans=16, 
    unet_num_layers=4, 
    criterion=l1_loss, 
    verbose_batch=50, 
    optimizer='RMSprop',
    lr=1e-3,
    lr_step_size=5,
    lr_gamma=0.9,
    weight_decay=0.0
)

# Tensorboard logging

In [5]:
%load_ext tensorboard
%tensorboard --logdir logs/ --port 8001

# Init trainer

In [5]:
trainer = pl.Trainer(
    gpus=1, max_epochs=40,
    accumulate_grad_batches=32,
    terminate_on_nan=True, 
    default_root_dir='logs/UNet',
    callbacks=[
        pl.callbacks.ModelCheckpoint(
            save_last=True,
            save_top_k=7, 
            monitor='val_loss', 
            filename='{epoch}-{ssim:.4f}-{psnr:.4f}-{nmse:.5f}'
        ),
        pl.callbacks.LearningRateMonitor(logging_interval='epoch'),
        pl.callbacks.GPUStatsMonitor(temperature=True)
    ]
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [3]


# Train model

In [6]:
trainer.fit(net, train_dataloader=train_generator, val_dataloaders=val_generator)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [3]

  | Name             | Type                 | Params
----------------------------------------------------------
0 | net              | Unet                 | 1.9 M 
1 | NMSE             | DistributedMetricSum | 0     
2 | SSIM             | DistributedMetricSum | 0     
3 | PSNR             | DistributedMetricSum | 0     
4 | ValLoss          | DistributedMetricSum | 0     
5 | TotExamples      | DistributedMetricSum | 0     
6 | TotSliceExamples | DistributedMetricSum | 0     


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Saving latest checkpoint...


1

# Test model
Load best checkpoint, inference on val dataset and save predictions to .h5 file in logs directory

In [6]:
net = net.load_from_checkpoint('logs/UNet/lightning_logs/version_0/checkpoints/last.ckpt').eval()

In [36]:
trainer.test(net, val_generator)

Testing: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------


1

In [37]:
hf_pred = h5py.File('logs/UNet/2021-05-13 17:40:32.396553.h5')
hf_gt = h5py.File('small_fastmri_pd_3t/val.h5')

# Val metrics

In [38]:
ssim_vals = []
nmse_vals = []
psnr_vals = []
for k in hf_pred.keys():
    ks = hf_gt[k][:] * 1e6
    gt = np.stack([kspace2spatial(k) for k in ks])
    pred = hf_pred[k][:,0]
    ssim_vals.append(ssim(gt, pred))
    nmse_vals.append(nmse(gt, pred))
    psnr_vals.append(psnr(gt, pred))
ssim_vals = np.array(ssim_vals)
nmse_vals = np.array(nmse_vals)
psnr_vals = np.array(psnr_vals)

In [39]:
np.mean(ssim_vals), np.mean(nmse_vals), np.mean(psnr_vals)

(0.7963329494066783, 0.010877203901700527, 31.45718714893031)

# Save state_dict

In [9]:
torch.save(net.net.state_dict(), 'unet16-noiseless.pth')