In [1]:
import numpy as np
import os 
import h5py, pathlib
os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE'
import glob, json
import matplotlib.pyplot as plt
from utils.network import VarNet
import fastmri
from fastmri.data import subsample
from fastmri.data import transforms, mri_data
from fastmri.data import SliceDataset

from torch.utils.data import DataLoader, Dataset
import torch
import robust_loss_pytorch.general


In [3]:
mask_func = subsample.EquispacedMaskFunc(
                        center_fractions=[0.08, 0.06, 0.04],
                        accelerations=[4, 6, 8])

In [8]:
msk = mask_func([128,128,2])
print(msk.shape)

torch.Size([1, 128, 1])


In [9]:
im_comb = np.ones((128,256))

In [15]:
msk = mask_func(list(im_comb.shape) + [1])[...,0]
print(msk.shape)

torch.Size([1, 256])


In [48]:
# Create a mask function
mask_func = subsample.RandomMaskFunc(
    center_fractions=[0.08, 0.04],
    accelerations=[4, 8]
)

def data_transform(kspace, mask, target, data_attributes, filename, slice_num):
    # Transform the data into appropriate format
    # Here we simply mask the k-space and return the result
    kspace = transforms.to_tensor(kspace)
    masked_kspace, mask = transforms.apply_mask(kspace, mask_func)
    
    acq_start = data_attributes["padding_left"]
    acq_end = data_attributes["padding_right"]    
    max_value = data_attributes["max"].astype('float32')
    crop_size = torch.tensor([data_attributes["recon_size"][0], data_attributes["recon_size"][1]])
    
    return masked_kspace, mask.byte(), target, max_value



In [49]:
dataset = SliceDataset(root= pathlib.Path('data/raw_knee2d'), transform=data_transform, challenge='multicoil')

In [50]:
ksp, mask, rec, max_value = dataset[0]

In [51]:
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

In [52]:
tmp = iter(dataloader)

In [53]:
ksp, mask, rec, max_value = next(tmp)
print(ksp.shape, mask.shape, rec.shape)

torch.Size([1, 15, 640, 368, 2]) torch.Size([1, 1, 1, 368, 1]) torch.Size([1, 320, 320])


In [54]:
varnet = VarNet()

In [55]:
out = varnet(ksp, mask)

In [56]:
print(out[0].shape, out[1].shape)

torch.Size([1, 640, 368]) torch.Size([1, 15, 640, 368, 2])


In [57]:
estY  = out[0]

In [58]:
max_value.dtype

torch.float32

In [59]:
rec, estY = transforms.center_crop_to_smallest(rec, estY)
rec, estY = rec / max_value, estY / max_value

print(rec.shape, estY.shape)

torch.Size([1, 320, 320]) torch.Size([1, 320, 320])


In [60]:
adaptive = robust_loss_pytorch.adaptive.AdaptiveLossFunction(
    num_dims = 1, float_dtype=np.float32, device='cpu')

In [61]:
loss = torch.mean(adaptive.lossfun(torch.flatten(rec - estY)[:,None]))

In [None]:
loss_func = fastmri.SSIMLoss()

In [None]:
loss = loss_func(estY.unsqueeze(1), rec.unsqueeze(1), data_range = max_value)

In [63]:
loss.backward()

In [None]:
print(max_value)

In [62]:
loss.item()

1.186174988746643

In [27]:
with h5py.File('data/div_knee2d/Train/file1000000.h5','r') as hf:
    et_root = etree.fromstring(hf["ismrmrd_header"][()])

    enc = ["encoding", "encodedSpace", "matrixSize"]
    enc_size = (
        int(et_query(et_root, enc + ["x"])),
        int(et_query(et_root, enc + ["y"])),
        int(et_query(et_root, enc + ["z"])),
    )
    rec = ["encoding", "reconSpace", "matrixSize"]
    recon_size = (
        int(et_query(et_root, rec + ["x"])),
        int(et_query(et_root, rec + ["y"])),
        int(et_query(et_root, rec + ["z"])),
    )

    lims = ["encoding", "encodingLimits", "kspace_encoding_step_1"]
    enc_limits_center = int(et_query(et_root, lims + ["center"]))
    enc_limits_max = int(et_query(et_root, lims + ["maximum"])) + 1

    padding_left = enc_size[1] // 2 - enc_limits_center
    padding_right = padding_left + enc_limits_max
    attrs = dict(hf.attrs)

In [28]:
print(padding_left, padding_right, recon_size, attrs['max'])

18 350 (320, 320, 1) 0.00018395940825409758


In [25]:
print(atrrs)

{'acquisition': 'CORPDFS_FBK', 'max': 0.00018395940825409758, 'norm': 0.056879915583771964, 'patient_id': 'b2a82c7521fe2d4aebb627bbaae92a1916bf06e75cb374fc4187b0909e5c0e36'}


In [20]:
def et_query(
    root: etree.Element,
    qlist: Sequence[str],
    namespace: str = "http://www.ismrm.org/ISMRMRD",
) -> str:
    """
    ElementTree query function.
    This can be used to query an xml document via ElementTree. It uses qlist
    for nested queries.
    Args:
        root: Root of the xml to search through.
        qlist: A list of strings for nested searches, e.g. ["Encoding",
            "matrixSize"]
        namespace: Optional; xml namespace to prepend query.
    Returns:
        The retrieved data as a string.
    """
    s = "."
    prefix = "ismrmrd_namespace"

    ns = {prefix: namespace}

    for el in qlist:
        s = s + f"//{prefix}:{el}"

    value = root.find(s, ns)
    if value is None:
        raise RuntimeError("Element not found")

    return str(value.text)

In [51]:
from utils.dataset import Data2D

In [52]:
dset = Data2D('data/div_knee2d/Test')

In [53]:
len(dset)

1404

In [54]:
mk, m, s, i = dset[0]

In [55]:
print(mk.shape, m.shape, s.shape, i.shape)

torch.Size([15, 640, 372, 2]) torch.Size([1, 1, 372, 1]) torch.Size([15, 640, 372, 2]) torch.Size([1, 640, 372, 2])


In [50]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
