In [1]:
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
from statistics import median, mean
from matplotlib import pyplot as plt
import numpy as np
import json
from tqdm import tqdm
from glob import glob
import os
import sys
sys.path.insert(0,"/study/mrphys/skunkworks/kk/mriUnet")
import unet
from torchvision import transforms
from torch.utils.data import Dataset
from sklearn.model_selection import KFold as kf
from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM
import h5py

In [2]:
allImages = sorted(glob("/study/mrphys/skunkworks/training_data//mover01/*/", recursive=True))[0:40]

In [3]:
def getComplexSlices(path):

    with h5py.File(path,'r') as hf:
        prefix = 'C_000_0'
        imagestackReal = []
        imagestackImag = []
        for i in range(10):
            n = prefix + str(i).zfill(2)
            image = hf['Images'][n]
            imagestackReal.append(np.array(image['real']))
            imagestackImag.append(np.array(image['imag']))
            if i==0:
                normScale = np.abs(np.array(image['real']+image['real']*1j)).max()
        imagestackReal = np.array(imagestackReal)/normScale
        imagestackImag = np.array(imagestackImag)/normScale
        
    return imagestackReal+imagestackImag*1j, normScale

class mriNoisyDataset(Dataset):
    def __init__(self, sample):
        self.originalPath = []
        self.accelPath = [] 

        allImages = sorted(glob("/study/mrphys/skunkworks/training_data//mover01/*/", recursive=True))[0:40]
        folderName  = allImages[sample]
        self.accelPath = folderName +'processed_data/acc_2min/C.h5'
        self.accelFile, self.scale = getComplexSlices(self.accelPath)

    def __getitem__(self, index):
        if index<256:
            return self.accelFile[:,index,:,:]
        elif index<512:
            index = index-256
            return self.accelFile[:,:,index,:]
        else:
            index = index-512
            return self.accelFile[:,:,:,index]
        
    def __len__(self):
        return 768

In [4]:
def predict(model, dataset, device = 5):
    model.eval()
    model.to(device)
    X = []
    Y = []
    Z = []
    for i, noisy in enumerate(dataset):
        noisy = torch.tensor(noisy).to(device).unsqueeze(0)
        with torch.no_grad():
            p = model(noisy).cpu().numpy() * dataset.scale
            if i<256:
                X.append(p)
            elif i<512:
                Y.append(p)
            else:
                Z.append(p)
                
    return np.vstack(X).transpose(1,0,2,3), np.vstack(Y).transpose(1,2,0,3), np.vstack(Z).transpose(1,2,3,0)

In [5]:
folds = 5
kfsplitter = kf(n_splits=folds, shuffle=True, random_state=69420)
for i, (train_index, test_index) in enumerate(kfsplitter.split(allImages)):
    fold = i+1
    model = unet.UNet(
        10,
        10,
        f_maps=32,
        layer_order=['separable convolution', 'relu'],
        depth=4,
        layer_growth=2.0,
        residual=True,
        complex_input=True,
        complex_kernel=True,
        ndims=2,
        padding=1
    )
    name = f'fullDenoiser_{fold}'
    model.load_state_dict(torch.load(f'/study/mrphys/skunkworks/kk/outputs/{name}/weights/{name}_BEST.pth'))
    for index in tqdm(test_index):
        dataset = mriNoisyDataset(index)
        X, Y, Z = predict(model, dataset)
        pred = (X+Y+Z)/3
        np.save(f'pred/denoised_{index}.npy',np.array(pred))

Crop amount [(-4, -4, -4, -4), (-16, -16, -16, -16), (-40, -40, -40, -40)]


100%|█████████████████████████████████████████████████████████████████████████████████████| 8/8 [07:47<00:00, 58.38s/it]


Crop amount [(-4, -4, -4, -4), (-16, -16, -16, -16), (-40, -40, -40, -40)]


100%|█████████████████████████████████████████████████████████████████████████████████████| 8/8 [04:05<00:00, 30.75s/it]


Crop amount [(-4, -4, -4, -4), (-16, -16, -16, -16), (-40, -40, -40, -40)]


100%|█████████████████████████████████████████████████████████████████████████████████████| 8/8 [04:20<00:00, 32.60s/it]


Crop amount [(-4, -4, -4, -4), (-16, -16, -16, -16), (-40, -40, -40, -40)]


100%|█████████████████████████████████████████████████████████████████████████████████████| 8/8 [04:03<00:00, 30.48s/it]


Crop amount [(-4, -4, -4, -4), (-16, -16, -16, -16), (-40, -40, -40, -40)]


100%|█████████████████████████████████████████████████████████████████████████████████████| 8/8 [04:16<00:00, 32.05s/it]


In [8]:
allImages = sorted(glob("/study/mrphys/skunkworks/training_data//mover01/*/", recursive=True))[0:40]
try:
    os.mkdir('/scratch/mrphys/denoised')
except:
    pass

for i, imgIndex in tqdm(enumerate(range(len(allImages)))):
    name = allImages[imgIndex].split('/')[-2]
    with h5py.File(f'/scratch/mrphys/denoised/denoised_{name}.h5','w') as f:
        pred = np.load(f'pred/denoised_{i}.npy')
        temp = pred.astype(np.dtype([('real','f'),('imag','f')]))
        temp['imag'] = pred.imag
        pred = temp
        grp = f.create_group('Images')
        for n in range(10):
            grp.create_dataset('C_000_0'+ str(n).zfill(2), data=pred[n])

  temp = pred.astype(np.dtype([('real','f'),('imag','f')]))
40it [03:02,  4.56s/it]
