# DnCNN for Image Denoising


Beyond a Gaussian Denoiser: Residual Learning of Deep CNN for Image Denoising
https://arxiv.org/pdf/1608.03981.pdf

Download and unzip dataset

In [2]:
from pathlib import Path

if not Path("./small_fastmri_pd_3t").is_dir():
    !gdown --id "1y78Ad6WwQpMGtxfEZlp97A0iV98kAiJN"
    !unzip -q small_fastmri_pd_3t.zip && rm small_fastmri_pd_3t.zip
    
if not Path("./dncnn-noiseless.pth").is_file():
    !gdown --id "1azlqmuIkdhcsMQJL_YObF4sEe83D8J8N" 
# Noiseless model weights: https://drive.google.com/file/d/1azlqmuIkdhcsMQJL_YObF4sEe83D8J8N/view?usp=sharing
# Gaussian model weights: 
# Salt&Pepper model weights: 
# Gaussian + Salt&Pepper model weights: 

# TO DO
# 1. Load noiseless and train 10 more epoches with noises, 3 types


In [3]:
!nvidia-smi

Mon May 17 13:19:45 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.33.01    Driver Version: 440.33.01    CUDA Version: 10.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  GeForce GTX 108...  On   | 00000000:17:00.0 Off |                  N/A |
| 90%   84C    P2   239W / 280W |   9230MiB / 11178MiB |     96%      Default |
+-------------------------------+----------------------+----------------------+
|   1  GeForce GTX 108...  On   | 00000000:65:00.0 Off |                  N/A |
|  0%   45C    P8    12W / 280W |     12MiB / 11178MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   2  GeForce GTX 108...  On   | 00000000:66:00.0 Off |                  N/A |
|  0%   

In [4]:
import os
import sys
import numpy as np
import h5py
import pylab as plt
import torch

import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
from k_space_reconstruction.nets.dncnn import DnCNNModule
from k_space_reconstruction.datasets.fastmri import FastMRITransform, FastMRIh5Dataset, RandomMaskFunc
from k_space_reconstruction.utils.metrics import pt_msssim, pt_ssim, ssim, nmse, psnr
from k_space_reconstruction.utils.loss import l1_loss, compund_mssim_l1_loss
from k_space_reconstruction.utils.kspace import spatial2kspace, kspace2spatial

print('Available GPUs: ', torch.cuda.device_count())

Available GPUs:  3


# Dataset initialization

In [5]:
transform = FastMRITransform(
    RandomMaskFunc([0.08], [4]),
    noise_level=1000,
    noise_type='none'
)

train_dataset = FastMRIh5Dataset('small_fastmri_pd_3t/train.h5', transform)
val_dataset = FastMRIh5Dataset('small_fastmri_pd_3t/val.h5', transform)
train_generator = torch.utils.data.DataLoader(train_dataset, batch_size=1, num_workers=12)
val_generator = torch.utils.data.DataLoader(val_dataset, batch_size=1, num_workers=12)

  self.hf = h5py.File(hf_path)


# Model definition

In [6]:
net = DnCNNModule(
    dncnn_chans=64, 
    dncnn_depth=10, 
    criterion=compund_mssim_l1_loss, 
    verbose_batch=50, 
    optimizer='Adam',
    lr=1e-4,
    lr_step_size=5,
    lr_gamma=0.2,
    weight_decay=0.0
)

# Tensorboard logging

In [10]:
%load_ext tensorboard
%tensorboard --logdir logs/ --port 8123

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 8123 (pid 27897), started 1 day, 22:01:47 ago. (Use '!kill 27897' to kill it.)

# Init trainer

In [8]:
trainer = pl.Trainer(
    gpus=1, max_epochs=40,
    accumulate_grad_batches=32,
    terminate_on_nan=True, 
    default_root_dir='logs/DnCNN',
    callbacks=[
        pl.callbacks.ModelCheckpoint(
            save_last=True,
            save_top_k=4, 
            monitor='val_loss', 
            filename='{epoch}-{ssim:.4f}-{psnr:.4f}-{nmse:.5f}'
        ),
        pl.callbacks.LearningRateMonitor(logging_interval='epoch'),
        pl.callbacks.GPUStatsMonitor(temperature=True)
    ]
);

GPU available: True, used: True
TPU available: None, using: 0 TPU cores


# Train model

In [9]:
trainer.fit(net, train_dataloader=train_generator, val_dataloaders=val_generator)


  | Name             | Type                 | Params
----------------------------------------------------------
0 | net              | DnCNN                | 297 K 
1 | NMSE             | DistributedMetricSum | 0     
2 | SSIM             | DistributedMetricSum | 0     
3 | PSNR             | DistributedMetricSum | 0     
4 | ValLoss          | DistributedMetricSum | 0     
5 | TotExamples      | DistributedMetricSum | 0     
6 | TotSliceExamples | DistributedMetricSum | 0     
----------------------------------------------------------
297 K     Trainable params
0         Non-trainable params
297 K     Total params
1.189     Total estimated model params size (MB)


Epoch 0:  10%|▉         | 239/2437 [00:14<02:12, 16.54it/s, loss=0.227, v_num=8, val_loss=0.216]

Saving latest checkpoint...


MisconfigurationException: ModelCheckpoint(monitor='ssim') not found in the returned metrics: ['train_loss_step']. HINT: Did you call self.log('ssim', tensor) in the LightningModule?

# Test model
Load best checkpoint, inference on val dataset and save predictions to .h5 file in logs directory

In [8]:
!ls "logs/DnCNN/lightning_logs/version_2/checkpoints/"

'epoch=12-ssim=0.7514-psnr=27.9779-nmse=0.01845.ckpt'
'epoch=13-ssim=0.7516-psnr=27.9857-nmse=0.01842.ckpt'
'epoch=14-ssim=0.7517-psnr=27.9931-nmse=0.01838.ckpt'
'epoch=19-ssim=0.7509-psnr=27.9630-nmse=0.01854.ckpt'
 last.ckpt


In [7]:
net.net.load_state_dict(torch.load('last.ckpt'))
# net.eval()

<All keys matched successfully>

In [32]:
trainer.test(net, val_generator)

Testing: 100%|██████████| 395/395 [00:10<00:00, 37.02it/s]
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{}
--------------------------------------------------------------------------------


[{}]

In [34]:
hf_pred = h5py.File('logs/DnCNN/2021-05-14 18:49:03.756824.h5')
hf_gt = h5py.File('small_fastmri_pd_3t/val.h5')

  hf_pred = h5py.File('logs/DnCNN/2021-05-14 18:49:03.756824.h5')
  hf_gt = h5py.File('small_fastmri_pd_3t/val.h5')


# Val metrics

In [35]:
ssim_vals = []
nmse_vals = []
psnr_vals = []
for k in hf_pred.keys():
    ks = hf_gt[k][:] * 1e6
    gt = np.stack([kspace2spatial(k) for k in ks])
    pred = hf_pred[k][:,0]
    ssim_vals.append(ssim(gt, pred))
    nmse_vals.append(nmse(gt, pred))
    psnr_vals.append(psnr(gt, pred))
ssim_vals = np.array(ssim_vals)
nmse_vals = np.array(nmse_vals)
psnr_vals = np.array(psnr_vals)

In [36]:
np.mean(ssim_vals), np.mean(nmse_vals), np.mean(psnr_vals)

(0.7450819779897695, 0.020334082331490425, 28.62007806339689)

## Saving Weights of the Model 

In [15]:
torch.save(net.net.state_dict(), 'dncnn-noiseless.pth')

# Continue Training with Noises

We take our **noiseless** pre-trained model weight and continue training with different types of noises.

- Gaussian noise, lvl: 400
- Salt&Pepper noise, lvl: 5e4
- Gaussian + Salt&Pepper noise, lvl: 400 + 5e4

### Gaussian Noise

In [7]:
net.net.load_state_dict(torch.load('dncnn-noiseless.pth'))
# net

<All keys matched successfully>

In [8]:
transform = FastMRITransform(
    RandomMaskFunc([0.08], [4]),
    noise_level=400,
    noise_type='normal'
)

train_dataset = FastMRIh5Dataset('small_fastmri_pd_3t/train.h5', transform)
val_dataset = FastMRIh5Dataset('small_fastmri_pd_3t/val.h5', transform)
train_generator = torch.utils.data.DataLoader(train_dataset, batch_size=1, num_workers=12)
val_generator = torch.utils.data.DataLoader(val_dataset, batch_size=1, num_workers=12)

In [None]:
trainer = pl.Trainer(
    gpus=1, max_epochs=20,
    accumulate_grad_batches=32,
    terminate_on_nan=True, 
    default_root_dir='logs/DnCNN_gaussian',
    callbacks=[
        pl.callbacks.ModelCheckpoint(
            save_last=True,
            save_top_k=4, 
            monitor='val_loss', 
            filename='{epoch}-{ssim:.4f}-{psnr:.4f}-{nmse:.5f}'
        ),
        pl.callbacks.LearningRateMonitor(logging_interval='epoch'),
        pl.callbacks.GPUStatsMonitor(temperature=True)
    ]
);

trainer.fit(net, train_dataloader=train_generator, val_dataloaders=val_generator)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores

  | Name             | Type                 | Params
----------------------------------------------------------
0 | net              | DnCNN                | 297 K 
1 | NMSE             | DistributedMetricSum | 0     
2 | SSIM             | DistributedMetricSum | 0     
3 | PSNR             | DistributedMetricSum | 0     
4 | ValLoss          | DistributedMetricSum | 0     
5 | TotExamples      | DistributedMetricSum | 0     
6 | TotSliceExamples | DistributedMetricSum | 0     
----------------------------------------------------------
297 K     Trainable params
0         Non-trainable params
297 K     Total params
1.189     Total estimated model params size (MB)


Epoch 0:   0%|          | 0/2437 [00:00<?, ?it/s]                     

  value = torch.tensor(value, device=device, dtype=torch.float)


Epoch 0:  84%|████████▍ | 2042/2437 [05:30<01:04,  6.17it/s, loss=0.0652, v_num=6, val_loss=0.126]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/395 [00:00<?, ?it/s][A
Epoch 0:  84%|████████▍ | 2045/2437 [05:31<01:03,  6.16it/s, loss=0.0652, v_num=6, val_loss=0.126]
Epoch 0:  84%|████████▍ | 2049/2437 [05:32<01:02,  6.17it/s, loss=0.0652, v_num=6, val_loss=0.126]
Epoch 0:  84%|████████▍ | 2053/2437 [05:32<01:02,  6.18it/s, loss=0.0652, v_num=6, val_loss=0.126]
Epoch 0:  84%|████████▍ | 2057/2437 [05:32<01:01,  6.19it/s, loss=0.0652, v_num=6, val_loss=0.126]
Epoch 0:  85%|████████▍ | 2061/2437 [05:32<01:00,  6.20it/s, loss=0.0652, v_num=6, val_loss=0.126]
Epoch 0:  85%|████████▍ | 2065/2437 [05:32<00:59,  6.21it/s, loss=0.0652, v_num=6, val_loss=0.126]
Epoch 0:  85%|████████▍ | 2069/2437 [05:32<00:59,  6.22it/s, loss=0.0652, v_num=6, val_loss=0.126]
Epoch 0:  85%|████████▌ | 2073/2437 [05:32<00:58,  6.23it/s, loss=0.0652, v_num=6, val_loss=0.126]
Epoch 0:  85%|██████

Epoch 0:  96%|█████████▌| 2337/2437 [05:42<00:14,  6.82it/s, loss=0.0652, v_num=6, val_loss=0.126]
Validating:  75%|███████▍  | 295/395 [00:11<00:03, 29.60it/s][A
Epoch 0:  96%|█████████▌| 2341/2437 [05:42<00:14,  6.83it/s, loss=0.0652, v_num=6, val_loss=0.126]
Epoch 0:  96%|█████████▌| 2345/2437 [05:43<00:13,  6.83it/s, loss=0.0652, v_num=6, val_loss=0.126]
Epoch 0:  96%|█████████▋| 2349/2437 [05:43<00:12,  6.84it/s, loss=0.0652, v_num=6, val_loss=0.126]
Epoch 0:  97%|█████████▋| 2353/2437 [05:43<00:12,  6.85it/s, loss=0.0652, v_num=6, val_loss=0.126]
Epoch 0:  97%|█████████▋| 2357/2437 [05:43<00:11,  6.86it/s, loss=0.0652, v_num=6, val_loss=0.126]
Validating:  80%|███████▉  | 315/395 [00:12<00:03, 25.74it/s][A
Epoch 0:  97%|█████████▋| 2361/2437 [05:43<00:11,  6.87it/s, loss=0.0652, v_num=6, val_loss=0.126]
Epoch 0:  97%|█████████▋| 2365/2437 [05:44<00:10,  6.87it/s, loss=0.0652, v_num=6, val_loss=0.126]
Epoch 0:  97%|█████████▋| 2369/2437 [05:44<00:09,  6.88it/s, loss=0.0652, v_nu

Epoch 1:  91%|█████████ | 2220/2437 [04:45<00:27,  7.78it/s, loss=0.0467, v_num=6, val_loss=0.101]
Epoch 1:  91%|█████████▏| 2224/2437 [04:45<00:27,  7.79it/s, loss=0.0467, v_num=6, val_loss=0.101]
Validating:  46%|████▌     | 182/395 [00:07<00:08, 26.46it/s][A
Epoch 1:  91%|█████████▏| 2228/2437 [04:45<00:26,  7.80it/s, loss=0.0467, v_num=6, val_loss=0.101]
Epoch 1:  92%|█████████▏| 2232/2437 [04:45<00:26,  7.81it/s, loss=0.0467, v_num=6, val_loss=0.101]
Epoch 1:  92%|█████████▏| 2236/2437 [04:45<00:25,  7.82it/s, loss=0.0467, v_num=6, val_loss=0.101]
Validating:  49%|████▉     | 194/395 [00:08<00:07, 28.00it/s][A
Epoch 1:  92%|█████████▏| 2240/2437 [04:45<00:25,  7.83it/s, loss=0.0467, v_num=6, val_loss=0.101]
Epoch 1:  92%|█████████▏| 2244/2437 [04:46<00:24,  7.84it/s, loss=0.0467, v_num=6, val_loss=0.101]
Epoch 1:  92%|█████████▏| 2248/2437 [04:46<00:24,  7.85it/s, loss=0.0467, v_num=6, val_loss=0.101]
Validating:  52%|█████▏    | 206/395 [00:08<00:08, 22.30it/s][A
Epoch 1:  92%

Epoch 2:  86%|████████▌ | 2096/2437 [06:04<00:59,  5.75it/s, loss=0.0431, v_num=6, val_loss=0.0881]
Epoch 2:  86%|████████▌ | 2100/2437 [06:04<00:58,  5.75it/s, loss=0.0431, v_num=6, val_loss=0.0881]
Epoch 2:  86%|████████▋ | 2104/2437 [06:05<00:57,  5.76it/s, loss=0.0431, v_num=6, val_loss=0.0881]
Epoch 2:  86%|████████▋ | 2108/2437 [06:05<00:56,  5.77it/s, loss=0.0431, v_num=6, val_loss=0.0881]
Validating:  17%|█▋        | 66/395 [00:03<00:12, 25.74it/s][A
Epoch 2:  87%|████████▋ | 2112/2437 [06:05<00:56,  5.78it/s, loss=0.0431, v_num=6, val_loss=0.0881]
Epoch 2:  87%|████████▋ | 2116/2437 [06:05<00:55,  5.79it/s, loss=0.0431, v_num=6, val_loss=0.0881]
Epoch 2:  87%|████████▋ | 2120/2437 [06:05<00:54,  5.80it/s, loss=0.0431, v_num=6, val_loss=0.0881]
Validating:  20%|█▉        | 78/395 [00:03<00:11, 28.22it/s][A
Epoch 2:  87%|████████▋ | 2124/2437 [06:05<00:53,  5.81it/s, loss=0.0431, v_num=6, val_loss=0.0881]
Epoch 2:  87%|████████▋ | 2128/2437 [06:05<00:53,  5.82it/s, loss=0.0431

Epoch 2:  97%|█████████▋| 2376/2437 [06:15<00:09,  6.33it/s, loss=0.0431, v_num=6, val_loss=0.0881]
Validating:  85%|████████▍ | 334/395 [00:13<00:02, 28.32it/s][A
Epoch 2:  98%|█████████▊| 2380/2437 [06:15<00:08,  6.34it/s, loss=0.0431, v_num=6, val_loss=0.0881]
Epoch 2:  98%|█████████▊| 2384/2437 [06:15<00:08,  6.35it/s, loss=0.0431, v_num=6, val_loss=0.0881]
Epoch 2:  98%|█████████▊| 2388/2437 [06:15<00:07,  6.35it/s, loss=0.0431, v_num=6, val_loss=0.0881]
Epoch 2:  98%|█████████▊| 2392/2437 [06:15<00:07,  6.36it/s, loss=0.0431, v_num=6, val_loss=0.0881]
Epoch 2:  98%|█████████▊| 2396/2437 [06:16<00:06,  6.37it/s, loss=0.0431, v_num=6, val_loss=0.0881]
Epoch 2:  98%|█████████▊| 2400/2437 [06:16<00:05,  6.38it/s, loss=0.0431, v_num=6, val_loss=0.0881]
Validating:  91%|█████████ | 358/395 [00:14<00:01, 23.41it/s][A
Epoch 2:  99%|█████████▊| 2404/2437 [06:16<00:05,  6.38it/s, loss=0.0431, v_num=6, val_loss=0.0881]
Epoch 2:  99%|█████████▉| 2408/2437 [06:16<00:04,  6.39it/s, loss=0.04

Epoch 3:  93%|█████████▎| 2260/2437 [05:54<00:27,  6.37it/s, loss=0.0417, v_num=6, val_loss=0.0849]
Epoch 3:  93%|█████████▎| 2264/2437 [05:55<00:27,  6.38it/s, loss=0.0417, v_num=6, val_loss=0.0849]
Epoch 3:  93%|█████████▎| 2268/2437 [05:55<00:26,  6.39it/s, loss=0.0417, v_num=6, val_loss=0.0849]
Epoch 3:  93%|█████████▎| 2272/2437 [05:55<00:25,  6.39it/s, loss=0.0417, v_num=6, val_loss=0.0849]
Epoch 3:  93%|█████████▎| 2276/2437 [05:55<00:25,  6.40it/s, loss=0.0417, v_num=6, val_loss=0.0849]
Epoch 3:  94%|█████████▎| 2280/2437 [05:55<00:24,  6.41it/s, loss=0.0417, v_num=6, val_loss=0.0849]
Epoch 3:  94%|█████████▎| 2284/2437 [05:55<00:23,  6.42it/s, loss=0.0417, v_num=6, val_loss=0.0849]
Epoch 3:  94%|█████████▍| 2288/2437 [05:55<00:23,  6.43it/s, loss=0.0417, v_num=6, val_loss=0.0849]
Epoch 3:  94%|█████████▍| 2292/2437 [05:56<00:22,  6.44it/s, loss=0.0417, v_num=6, val_loss=0.0849]
Validating:  63%|██████▎   | 250/395 [00:10<00:04, 29.63it/s][A
Epoch 3:  94%|█████████▍| 2296/2437

Epoch 4:  90%|█████████ | 2199/2437 [04:06<00:26,  8.91it/s, loss=0.0408, v_num=6, val_loss=0.0825]
Epoch 4:  90%|█████████ | 2204/2437 [04:06<00:26,  8.93it/s, loss=0.0408, v_num=6, val_loss=0.0825]
Epoch 4:  91%|█████████ | 2209/2437 [04:06<00:25,  8.94it/s, loss=0.0408, v_num=6, val_loss=0.0825]
Epoch 4:  91%|█████████ | 2214/2437 [04:07<00:24,  8.96it/s, loss=0.0408, v_num=6, val_loss=0.0825]
Epoch 4:  91%|█████████ | 2219/2437 [04:07<00:24,  8.98it/s, loss=0.0408, v_num=6, val_loss=0.0825]
Epoch 4:  91%|█████████▏| 2224/2437 [04:07<00:23,  8.99it/s, loss=0.0408, v_num=6, val_loss=0.0825]
Epoch 4:  91%|█████████▏| 2229/2437 [04:07<00:23,  9.01it/s, loss=0.0408, v_num=6, val_loss=0.0825]
Epoch 4:  92%|█████████▏| 2234/2437 [04:07<00:22,  9.02it/s, loss=0.0408, v_num=6, val_loss=0.0825]
Epoch 4:  92%|█████████▏| 2239/2437 [04:07<00:21,  9.04it/s, loss=0.0408, v_num=6, val_loss=0.0825]
Epoch 4:  92%|█████████▏| 2244/2437 [04:07<00:21,  9.05it/s, loss=0.0408, v_num=6, val_loss=0.0825]


Epoch 5:  88%|████████▊ | 2155/2437 [04:38<00:36,  7.74it/s, loss=0.0404, v_num=6, val_loss=0.0814]
Epoch 5:  89%|████████▊ | 2160/2437 [04:38<00:35,  7.76it/s, loss=0.0404, v_num=6, val_loss=0.0814]
Validating:  30%|██▉       | 118/395 [00:05<00:10, 26.89it/s][A
Epoch 5:  89%|████████▉ | 2165/2437 [04:38<00:35,  7.77it/s, loss=0.0404, v_num=6, val_loss=0.0814]
Epoch 5:  89%|████████▉ | 2170/2437 [04:38<00:34,  7.78it/s, loss=0.0404, v_num=6, val_loss=0.0814]
Validating:  32%|███▏      | 128/395 [00:05<00:09, 28.68it/s][A
Epoch 5:  89%|████████▉ | 2175/2437 [04:38<00:33,  7.80it/s, loss=0.0404, v_num=6, val_loss=0.0814]
Validating:  34%|███▍      | 134/395 [00:05<00:09, 28.36it/s][A
Epoch 5:  89%|████████▉ | 2180/2437 [04:39<00:32,  7.81it/s, loss=0.0404, v_num=6, val_loss=0.0814]
Epoch 5:  90%|████████▉ | 2185/2437 [04:39<00:32,  7.82it/s, loss=0.0404, v_num=6, val_loss=0.0814]
Validating:  36%|███▌      | 143/395 [00:06<00:08, 28.35it/s][A
Epoch 5:  90%|████████▉ | 2190/2437 [04:

Epoch 6:  84%|████████▍ | 2045/2437 [06:07<01:10,  5.56it/s, loss=0.0403, v_num=6, val_loss=0.0798]
Epoch 6:  84%|████████▍ | 2050/2437 [06:08<01:09,  5.57it/s, loss=0.0403, v_num=6, val_loss=0.0798]
Validating:   2%|▏         | 8/395 [00:01<00:35, 10.78it/s][A
Epoch 6:  84%|████████▍ | 2055/2437 [06:08<01:08,  5.58it/s, loss=0.0403, v_num=6, val_loss=0.0798]
Epoch 6:  85%|████████▍ | 2060/2437 [06:08<01:07,  5.59it/s, loss=0.0403, v_num=6, val_loss=0.0798]
Epoch 6:  85%|████████▍ | 2065/2437 [06:08<01:06,  5.60it/s, loss=0.0403, v_num=6, val_loss=0.0798]
Validating:   6%|▌         | 23/395 [00:01<00:15, 24.15it/s][A
Epoch 6:  85%|████████▍ | 2070/2437 [06:08<01:05,  5.61it/s, loss=0.0403, v_num=6, val_loss=0.0798]
Epoch 6:  85%|████████▌ | 2075/2437 [06:08<01:04,  5.62it/s, loss=0.0403, v_num=6, val_loss=0.0798]
Validating:   8%|▊         | 33/395 [00:01<00:13, 26.81it/s][A
Epoch 6:  85%|████████▌ | 2080/2437 [06:09<01:03,  5.64it/s, loss=0.0403, v_num=6, val_loss=0.0798]
Validatin

Epoch 6:  96%|█████████▌| 2335/2437 [06:19<00:16,  6.16it/s, loss=0.0403, v_num=6, val_loss=0.0798]
Validating:  74%|███████▍  | 294/395 [00:12<00:03, 26.87it/s][A
Epoch 6:  96%|█████████▌| 2340/2437 [06:19<00:15,  6.17it/s, loss=0.0403, v_num=6, val_loss=0.0798]
Epoch 6:  96%|█████████▌| 2345/2437 [06:19<00:14,  6.18it/s, loss=0.0403, v_num=6, val_loss=0.0798]
Validating:  77%|███████▋  | 303/395 [00:12<00:05, 17.32it/s][A
Epoch 6:  96%|█████████▋| 2350/2437 [06:19<00:14,  6.19it/s, loss=0.0403, v_num=6, val_loss=0.0798]
Validating:  78%|███████▊  | 309/395 [00:12<00:04, 21.29it/s][A
Epoch 6:  97%|█████████▋| 2355/2437 [06:19<00:13,  6.20it/s, loss=0.0403, v_num=6, val_loss=0.0798]
Epoch 6:  97%|█████████▋| 2360/2437 [06:20<00:12,  6.21it/s, loss=0.0403, v_num=6, val_loss=0.0798]
Validating:  81%|████████  | 319/395 [00:13<00:02, 26.14it/s][A
Epoch 6:  97%|█████████▋| 2365/2437 [06:20<00:11,  6.22it/s, loss=0.0403, v_num=6, val_loss=0.0798]
Epoch 6:  97%|█████████▋| 2370/2437 [06:

In [None]:
torch.save(net.net.state_dict(), 'dncnn-with-gaussian.pth')

### Salt&Pepper Noise

In [None]:
net.net.load_state_dict(torch.load('dncnn-noiseless.ckpt'))

In [None]:
transform = FastMRITransform(
    RandomMaskFunc([0.08], [4]),
    noise_level=5e4,
    noise_type='salt'
)

train_dataset = FastMRIh5Dataset('small_fastmri_pd_3t/train.h5', transform)
val_dataset = FastMRIh5Dataset('small_fastmri_pd_3t/val.h5', transform)
train_generator = torch.utils.data.DataLoader(train_dataset, batch_size=1, num_workers=12)
val_generator = torch.utils.data.DataLoader(val_dataset, batch_size=1, num_workers=12)

In [None]:
trainer = pl.Trainer(
    gpus=1, max_epochs=20,
    accumulate_grad_batches=32,
    terminate_on_nan=True, 
    default_root_dir='logs/DnCNN_salt',
    callbacks=[
        pl.callbacks.ModelCheckpoint(
            save_last=True,
            save_top_k=4, 
            monitor='val_loss', 
            filename='{epoch}-{ssim:.4f}-{psnr:.4f}-{nmse:.5f}'
        ),
        pl.callbacks.LearningRateMonitor(logging_interval='epoch'),
        pl.callbacks.GPUStatsMonitor(temperature=True)
    ]
);

trainer.fit(net, train_dataloader=train_generator, val_dataloaders=val_generator)

In [None]:
torch.save(net.net.state_dict(), 'dncnn-with-salt.pth')

### Gaussian + Salt&Pepper Noise

In [None]:
net.net.load_state_dict(torch.load('dncnn-noiseless.ckpt'))

In [None]:
transform = FastMRITransform(
    RandomMaskFunc([0.08], [4]),
    noise_level=5e4,
    noise_type='salt'
)

train_dataset = FastMRIh5Dataset('small_fastmri_pd_3t/train.h5', transform)
val_dataset = FastMRIh5Dataset('small_fastmri_pd_3t/val.h5', transform)
train_generator = torch.utils.data.DataLoader(train_dataset, batch_size=1, num_workers=12)
val_generator = torch.utils.data.DataLoader(val_dataset, batch_size=1, num_workers=12)

In [None]:
trainer = pl.Trainer(
    gpus=1, max_epochs=20,
    accumulate_grad_batches=32,
    terminate_on_nan=True, 
    default_root_dir='logs/DnCNN_salt',
    callbacks=[
        pl.callbacks.ModelCheckpoint(
            save_last=True,
            save_top_k=4, 
            monitor='val_loss', 
            filename='{epoch}-{ssim:.4f}-{psnr:.4f}-{nmse:.5f}'
        ),
        pl.callbacks.LearningRateMonitor(logging_interval='epoch'),
        pl.callbacks.GPUStatsMonitor(temperature=True)
    ]
);

trainer.fit(net, train_dataloader=train_generator, val_dataloaders=val_generator)

In [None]:
torch.save(net.net.state_dict(), 'dncnn-with-gaussian-salt.pth')