In [None]:
# pytorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# fastmri
import fastmri
from fastmri.data import subsample
from fastmri.data import transforms, mri_data
from fastmri.evaluate import ssim, psnr, nmse
from fastmri.losses import SSIMLoss

# other stuff
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm 
from networks import VisionTransformer, ReconNet, Img2Img_Mixer, Unet

import os
os.environ["CUDA_VISIBLE_DEVICES"] = '2'
# Device
device = 'cuda'

In [None]:
class fastMRIDataset(Dataset):
    def __init__(self, isval=False):
        self.isval = isval
        if not isval:
            self.data_path = '/media/hdd1/fastMRIdata/knee/singlecoil_train/' # Adjust training data path here
        else:
            self.data_path = '/media/hdd1/fastMRIdata/knee/singlecoil_val/' # Adjust validation data path here

        self.data = mri_data.SliceDataset(
            root=self.data_path,
            transform=self.data_transform,
            challenge='singlecoil',
            use_dataset_cache=True,
            )

        self.mask_func = subsample.RandomMaskFunc(
            center_fractions=[0.08],
            accelerations=[4],
            )
            
    def data_transform(self, kspace, mask, target, data_attributes, filename, slice_num):
        if self.isval:
            seed = tuple(map(ord, filename))
        else:
            seed = None     
        kspace = transforms.to_tensor(kspace)
        masked_kspace, _ = transforms.apply_mask(kspace, self.mask_func, seed)        
        
        target = transforms.to_tensor(target)
        zero_fill = fastmri.ifft2c(masked_kspace)
        zero_fill = transforms.complex_center_crop(zero_fill, target.shape)   
        x = fastmri.complex_abs(zero_fill)
 
        x = x.unsqueeze(0)
        target = target.unsqueeze(0)

        return (x, target, data_attributes['max'])

    def __len__(self,):
        return len(self.data)
    
    def __getitem__(self, idx):
        data = self.data[idx]

        return data

In [None]:
# Create dataset
dataset = fastMRIDataset(isval=False)
val_dataset = fastMRIDataset(isval=True)

ntrain = len(dataset) # Vary training data size here
train_dataset, _ = torch.utils.data.random_split(dataset, [ntrain, len(dataset)-ntrain], generator=torch.Generator().manual_seed(42))
print(len(train_dataset))

batch_size = 1
trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, generator=torch.Generator().manual_seed(42))
valloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=2)  

In [None]:
"""Init Model"""
# Image to Image Mixer
net = Img2Img_Mixer(
    img_size = 320,
    img_channels = 1,
    patch_size = 4,
    embed_dim = 128,
    num_layers= 16,
    f_hidden = 8,
    )   
        
# # Vision Transformer
# net = VisionTransformer(
#     avrg_img_size=320, 
#     patch_size=10, 
#     in_chans=1, embed_dim=44, 
#     depth=4, num_heads=9, mlp_ratio=4., 
#     )

# # Unet
# net = Unet(
#     in_chans=1,
#     out_chans=1,
#     chans=32,
#     num_pool_layers=4,
#     )

model = ReconNet(net).to(device)

print('#Params:', sum(p.numel() for p in model.parameters() if p.requires_grad))
print(model)

In [None]:
# Validate model
def validate(model):
    valloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=2, pin_memory=True)   
    model.eval()    
    ssim_ = myutils.SSIM().to(device)
    psnr_ = myutils.PSNR().to(device)
    psnrs = []
    ssims = []
    
    with tqdm(total=len(valloader)) as pbar:
        with torch.no_grad():
            for data in valloader:
                inputs, targets, maxval = data        
                outputs = model(inputs.to(device))
                ssims.append(ssim_(outputs, targets.to(device), maxval.to(device)))
                psnrs.append(psnr_(outputs, targets.to(device), maxval.to(device)))
                pbar.update(1)
    
    ssimval = torch.cat(ssims).mean()
    
    print(' Recon. PSNR: {:0.3f} pm {:0.2f}'.format(torch.cat(psnrs).mean(), 2*torch.cat(psnrs).std()))
    print(' Recon. SSIM: {:0.4f} pm {:0.3f}'.format(torch.cat(ssims).mean(), 2*torch.cat(ssims).std()))
                
    return (1-ssimval).item()

# Save model
def save_model(path, model, train_hist, val_hist, optimizer, scheduler=None):
    net = model.net
    if scheduler:
        checkpoint = {
            'model' :  ReconNet(net),
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(), 
        }
    else:
        checkpoint = {
            'model' :  ReconNet(net),
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }
        
    torch.save(train_hist, path + 'train_hist.pt')
    torch.save(val_hist, path + 'val_hist.pt')    
    torch.save(checkpoint,  path + 'checkpoint.pth')

In [None]:
"""Choose optimizer"""
criterion = SSIMLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0)
train_hist = []
val_hist = []
best_val = float("inf")
path = './' # Path for saving model checkpoint and loss history
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.0005,
                                          total_steps=40, pct_start=0.1,
                                          anneal_strategy='linear',
                                          cycle_momentum=False,
                                          base_momentum=0., max_momentum=0., div_factor=0.1*40, final_div_factor=9)

In [None]:
"""Train Model"""
for epoch in tqdm(range(0, 40)): # loop over the dataset multiple times
    model.train()
    train_loss = 0.0
    
    with tqdm(total=len(trainloader)) as pbar:
        for data in trainloader:
            inputs, targets, maxval = data
            optimizer.zero_grad()
            outputs = model(inputs.to(device))
            loss = criterion(outputs, targets.to(device), maxval.to(device))
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1, norm_type=1.)
            optimizer.step()
            pbar.update(1)

            train_loss += loss.item()

    scheduler.step()
    train_hist.append(train_loss/len(trainloader))
    print('Epoch {}, Train loss.: {:0.4f}'.format(epoch+1, train_hist[-1]))
    
    if (epoch+1)%5==0:
        val_hist.append(validate(model))        
        if val_hist[-1] < best_val:
            save_model(path, model, train_hist, val_hist, optimizer, scheduler=scheduler)
            best_val = val_hist[-1]