# Réseaux de neurones unrolled pour l'IRM

Les lignes suivantes permettent d'installer une bibliothèque pour le calcul de la Transformée de Fourier Non-Uniforme (NUFFT).

In [3]:
!git clone https://github.com/albangossard/Bindings-NUFFT-pytorch
![ -e nufftbindings/ ] && rm -r -f nufftbindings/
!mv Bindings-NUFFT-pytorch/nufftbindings/ ./
!rm -r -f Bindings-NUFFT-pytorch/

Cloning into 'Bindings-NUFFT-pytorch'...
remote: Enumerating objects: 61, done.[K
remote: Counting objects: 100% (61/61), done.[K
remote: Compressing objects: 100% (27/27), done.[K
remote: Total 61 (delta 36), reused 53 (delta 32), pack-reused 0[K
Unpacking objects: 100% (61/61), 18.64 KiB | 578.00 KiB/s, done.


Téléchargement des données

In [5]:
!pip install gdown
!gdown https://drive.google.com/uc?id=17k1CYZ4bgbv6q4T4q_zSmEFhwcWlDSVZ
!tar -xzf fastMRI.tar.gz
!rm -r -f data/fastMRI/
!mv fastMRI/ data/fastMRI/
!rm fastMRI.tar.gz

Downloading...
From: https://drive.google.com/uc?id=17k1CYZ4bgbv6q4T4q_zSmEFhwcWlDSVZ
To: /media/DATA/Alban/Course-inverse-problems-and-unrolled-networks/fastMRI.tar.gz
100%|██████████████████████████████████████| 3.32G/3.32G [00:37<00:00, 87.7MB/s]


### Imports

In [6]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
# from tensorboardX import SummaryWriter

from DIDN import DIDN
import nufftbindings.kbnufft as nufft
import dataLoaderfastMRI
import scripts.metrics as metrics
from scripts.recon import cg

In [7]:
nx = ny = 320
Nbatch = 8

device = torch.device('cuda:0')

xi = torch.tensor(np.load("data/xi_10.npy")).to(device)
print(xi.shape, xi.dtype)
K = xi.shape[0]

nufft.nufft.set_dims(K, (nx, ny), device, Nb=Nbatch)

nufft.nufft.precompute(xi)

torch.Size([10320, 2]) torch.float32


In [8]:
dataset_train = dataLoaderfastMRI.fastMRIdatasetKnee(train=True)
dataset_test = dataLoaderfastMRI.fastMRIdatasetKnee(train=False)
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=Nbatch, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=Nbatch, shuffle=True)
print('nb images in training dataset:',len(dataset_train))
print('nb images in testing dataset:',len(dataset_test))

nb images in training dataset: 3619
nb images in testing dataset: 705


Ecrire une fonction réalisant l'optimisation du modèle sur une epoch.

In [35]:
def train(epoch, model, optim, train_loader, xi, verbose=2, writer=None):
    model.train()
    Niter = len(train_loader)
    if verbose>=2:
        iterfn = lambda x: x
    else:
        print("Training epoch {:<3}".format(epoch))
        iterfn = tqdm
    for nit, data in enumerate(iterfn(train_loader)):
        f = data.to(device).type(torch.complex64)
        optim.zero_grad()
        y = nufft.forward(xi, f)/np.sqrt(nx*ny)
        y = y+torch.randn_like(y)*1e0/np.sqrt(nx*ny)
        f_tilde = model(y)
        loss = metrics.l2err(f, f_tilde).mean()
        loss.backward()
        optim.step()
        psnr = metrics.psnr(f, f_tilde)
        mean_psnr = psnr.mean()
        if verbose>=2:
            print("  Epoch {:<3} It {:<4}/{:<4} cost={:1.3e}  PSNR={:.3f}".format(epoch, nit, Niter, loss, mean_psnr))
        if writer is not None:
            writer.add_scalar('loss/train', loss.item(), epoch*Niter+nit)
            writer.add_scalar('psnr/train', mean_psnr.item(), epoch*Niter+nit)

Coder une fonction testant le modèle sur tout le jeu de données test et qui renvoie un array numpy des PSNR associés à chaque image.

In [36]:
def test(model, test_loader, xi):
    model.eval()
    test_psnr = []
    with torch.no_grad():
        for data in tqdm(test_loader):
            f = data.to(device).type(torch.complex64)
            y = nufft.forward(xi, f)/np.sqrt(nx*ny)
            f_tilde = model(y)
            psnr = metrics.psnr(f, f_tilde)
            for p in psnr:
                test_psnr.append(p.item())
    return np.array(test_psnr)

Implémenter une fonction appelant $Nepoch$ fois la fonction train et test. Ne pas oublier d'appeler un éventuel scheduler passé en argument.

In [37]:
def run(model, optim, train_loader, test_loader, xi, scheduler=None, Nepoch=10, verbose=1, writer=None):
    for epoch in range(Nepoch):
        train(epoch, model, optim, train_loader, xi, verbose=verbose, writer=writer)
        psnr = test(model, test_loader, xi)
        mean_psnr = psnr.mean()
        if writer is not None:
            writer.add_scalar('psnr/test', mean_psnr.item(), epoch)
        if verbose:
            print("  Epoch {:<3}  PSNR={:.3f}".format(epoch, mean_psnr))
        if scheduler is not None:
            scheduler.step()

## Reconstructeur adjoint

In [38]:
class MRIAdj(torch.nn.Module):
    def __init__(self, nufft, xi):
        super(MRIAdj, self).__init__()
        self.nufft = nufft
        self.xi = torch.nn.Parameter(xi, requires_grad=False)
        self.net = DIDN(2, 2, num_chans=32, bias=True)
    def forward(self, y):
        fhat = self.nufft.adjoint(self.xi, y)/np.sqrt(nx*ny)
        fhat = torch.cat((fhat.real.unsqueeze(1), fhat.imag.unsqueeze(1)), axis=1).type(torch.float32)
        f_tilde = self.net(fhat)
        f_tilde = f_tilde[:,0]+1j*f_tilde[:,1]
        return f_tilde

In [39]:
model_adj = MRIAdj(nufft, xi).to(device)

In [40]:
optim = torch.optim.Adam(model_adj.parameters(), lr=1e-3, betas=(0.9, 0.999))

In [41]:
writer = SummaryWriter('tblogs/mri/adj')
run(model_adj, optim, train_loader, test_loader, xi, Nepoch=1, verbose=1, writer=writer)

Training epoch 0  


100%|██████████| 453/453 [01:14<00:00,  6.11it/s]
100%|██████████| 89/89 [00:06<00:00, 13.84it/s]

  Epoch 0    PSNR=19.596





## Unrolled forward-backward

In [49]:
class MRIUnrolledFB(nn.Module):
    def __init__(self, nufft, xi, Nunrolled, num_chans_net=32, bias=True):
        super(MRIUnrolledFB, self).__init__()
        self.Nunrolled = Nunrolled
        self.nufft = nufft
        self.xi = nn.Parameter(xi, requires_grad=False)
        self.net = nn.ModuleList([DIDN(2, 2, num_chans=num_chans_net, bias=bias) for k in range(self.Nunrolled)])
    def change_xi(self, xi):
        self.xi = nn.Parameter(xi, requires_grad=False)
    def precompute(self, f):
        x=torch.ones_like(f[:1])
        normx = x.pow(2).sum().sqrt()
        for i in range(100):
            x = x/normx
            x = self.nufft.adjoint(self.xi, self.nufft.forward(self.xi, x))/(nx*ny)
            normx = x.abs().pow(2).sum().sqrt()
        self.gamma = 1/normx
    def forward(self, y):
        z = self.nufft.adjoint(self.xi, y)/np.sqrt(nx*ny)
        for k in range(self.Nunrolled):
            grad = self.nufft.adjoint(self.xi, self.nufft.forward(self.xi, z)/np.sqrt(nx*ny)-y)/np.sqrt(nx*ny)
            xhat = z-self.gamma*grad
            xhat = torch.cat((xhat.real.unsqueeze(1), xhat.imag.unsqueeze(1)), axis=1).type(torch.float32)
            z = self.net[k](xhat)
            z = z[:,0]+1j*z[:,1]
        return z

In [50]:
Nunrolled = 6
model_fb = MRIUnrolledFB(nufft, xi, Nunrolled).to(device)
model_fb.precompute(next(iter(train_loader)).to(device).type(torch.complex64))

In [51]:
optim = torch.optim.Adam(model_fb.parameters(), lr=1e-3, betas=(0.9, 0.999))

In [52]:
writer = SummaryWriter('tblogs/mri/unrolled_fb')
run(model_fb, optim, train_loader, test_loader, xi, Nepoch=1, verbose=1, writer=writer)

Training epoch 0  


100%|██████████| 453/453 [09:45<00:00,  1.29s/it]
100%|██████████| 89/89 [00:39<00:00,  2.24it/s]

  Epoch 0    PSNR=28.744





## Unrolled ADMM

In [None]:
class MRIUnrolledADMM(nn.Module):
    def __init__(self, nufft, xi, Nunrolled, nitermaxcg, num_chans_net=32, bias=True, beta=1.):
        super(MRIUnrolledADMM, self).__init__()
        self.Nunrolled = Nunrolled
        self.nufft = nufft
        self.xi = nn.Parameter(xi, requires_grad=False)
        self.net = nn.ModuleList([DIDN(2, 2, num_chans=num_chans_net, bias=bias) for k in range(self.Nunrolled)])
        self.beta = beta
        self.nitermaxcg = nitermaxcg
    def change_xi(self, xi):
        self.xi = nn.Parameter(xi, requires_grad=False)
    def precompute(self, f):
        x=torch.ones_like(f[:1])
        normx = x.pow(2).sum().sqrt()
        for i in range(100):
            x = x/normx
            x = self.nufft.adjoint(self.xi, self.nufft.forward(self.xi, x))/(nx*ny)
            normx = x.abs().pow(2).sum().sqrt()
        self.gamma = 1/normx
    def _Cop(self, x):
        return self.nufft.adjoint(self.xi, self.nufft.forward(self.xi, x))/(nx*ny) + self.beta*x
    def forward(self, y):
        x = self.nufft.adjoint(self.xi, y)/np.sqrt(nx*ny)
        z = x.clone()
        mu = torch.zeros_like(x)
        for k in range(self.Nunrolled):
            # x step
            rhs = self.nufft.adjoint(self.xi, y)/np.sqrt(nx*ny)+self.beta*z-mu
            x, _ = cg(self._Cop, rhs, self.nitermaxcg)

            # z step
            tmp = torch.cat(((x+mu/self.beta).real.unsqueeze(1), (x+mu/self.beta).imag.unsqueeze(1)), axis=1).type(torch.float32)
            z = self.net[k](tmp)
            z = z[:,0]+1j*z[:,1]

            # mu step
            mu = mu+self.beta*(x-z)
        return z

In [None]:
Nunrolled = 5
nitermaxcg = 10
model_admm = MRIUnrolledADMM(nufft, xi, Nunrolled, nitermaxcg).to(device)
model_admm.precompute(next(iter(train_loader)).to(device).type(torch.complex64))

In [None]:
optim = torch.optim.Adam(model_admm.parameters(), lr=1e-3, betas=(0.9, 0.999))

In [None]:
writer = SummaryWriter('tblogs/mri/unrolled_admm')
run(model_admm, optim, train_loader, test_loader, xi, Nepoch=1, verbose=2, writer=writer)