This notebook is to save the inference data, 

i.e., f(x) and g(x)

Note that they rely on the transformer param and evaluate date

In [80]:
import sys
sys.path.append("./mypkg")
from constants import RES_ROOT, FIG_ROOT, DATA_ROOT

In [81]:
%load_ext autoreload
%autoreload 2
# 0,1, 2, 3, be careful about the space

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [82]:
import time
from collections import defaultdict
from easydict import EasyDict as edict
from pathlib import Path

import numpy as np
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
from numbers import Number 
from joblib import Parallel, delayed

from utils.fastmri import get_dataset, run_varnet_model, load_model
from utils.misc import save_pkl

In [83]:
def _get_namev(v):
    if isinstance(v, Number):
        return f"{v*100:.0f}"
    return v

In [86]:
config = edict()
config.data_root = DATA_ROOT/"brain/multicoil_val"

config.mask_params = edict()
config.mask_params.mask_type = 'equispaced'
config.mask_params.center_fraction = 0.04
config.mask_params.acceleration = 4


config.save_fold = DATA_ROOT/(f"{config.data_root.stem}_"+"_".join([f"{k}-{_get_namev(v)}" for k, v in config.mask_params.items()]))
if not config.save_fold.exists():
    config.save_fold.mkdir(parents=True, exist_ok=True)

In [33]:
fmodel = load_model(is_Ysq=False);
gmodel = load_model(is_Ysq=True);

dataset = get_dataset(
    data_path=config.data_root, 
    mask_type=config.mask_params.mask_type, 
    center_fraction=config.mask_params.center_fraction, 
    acceleration=config.mask_params.acceleration)

print(f"There is {len(dataset)} samples in the dataset.")

  0%|                                                               | 6/7270 [21:45:43<26346:34:41, 13057.22s/it]


There is 7270 samples in the dataset.


In [49]:
from fastmri.data.subsample import create_mask_for_mask_type
from fastmri.data import SliceDataset
import fastmri.data.transforms as T
def get_dataset1(data_path, 
                mask_type="equispaced", 
                center_fraction=0.04, 
                acceleration=4, 
                filter=None):
    """Get the dataset from the given file path, mask type, center fractions and accelerations.
    Note that the input should be compatible with model training.
    - args: 
        - data_path (str): the path to the data
            - path to h5 files directory
        - mask_type (str): the type of mask to use, "equispaced" or "random"
            - we use "equispaced" for training the brain data
        - center_fraction (int): the center fraction to use
            - we use [0.04, 0.08] for training the brain data
        - acceleration (float): the acceleration to use
            - we use [4, 8] for training the brain data
    - returns: 
        - dataset (SliceDataset): the dataset to use for training
            it is a SliceDataset object, you can get data by calling dataset[i]
    """
    assert mask_type in ["equispaced", "random"], "mask_type should be either 'equispaced' or 'random'"
    assert center_fraction in [0.04, 0.08], "center_fraction should be either 0.04 or 0.08"
    assert acceleration in [4, 8], "acceleration should be either 4 or 8"

    mask = create_mask_for_mask_type(
            mask_type_str=mask_type, 
            center_fractions = [center_fraction], 
            accelerations = [acceleration]
        )
    data_transform = T.VarNetDataTransform(mask_func=mask)
    dataset = SliceDataset(
            root=data_path, 
            transform=data_transform, 
            challenge="multicoil", 
            sample_rate=None, 
            raw_sample_filter=filter
        )
    return dataset


def run_varnet_model(batch, 
                     model, 
                     is_Ysq = False):
    """
    get the output of the model
    - args: 
        - batch: the batch of data
            get it from data loader
        - model: the trianed model
        - is_Ysq: if the model trained for Y^2 or not 
            - if True, the model will output Y^2
    - return:
        - output: the output of the model
    """
    mask = batch.mask
    masked_kspace = batch.masked_kspace
    crop_size = batch.crop_size
    if batch.mask.dim() == 4:
        mask = mask[None]
        masked_kspace = masked_kspace[None]
    else:
        assert batch.mask.shape[0], "currently only support batch size 1"
    model.eval()
    with torch.no_grad():
        output = model(masked_kspace, mask).cpu()
    return output.numpy()

In [53]:
dataset1 = get_dataset1(
    data_path=config.data_root, 
    mask_type=config.mask_params.mask_type, 
    center_fraction=config.mask_params.center_fraction, 
    acceleration=config.mask_params.acceleration, 
    filter=lambda x: x.fname.stem.split("_")[2]=="AXT1")
dataset2 = get_dataset1(
    data_path=config.data_root, 
    mask_type=config.mask_params.mask_type, 
    center_fraction=config.mask_params.center_fraction, 
    acceleration=config.mask_params.acceleration, 
    filter=lambda x: x.fname.stem.split("_")[2]=="AXFLAIR")

print(f"There is {len(dataset1)} samples in the dataset.")
print(f"There is {len(dataset2)} samples in the dataset.")

There is 492 samples in the dataset.
There is 518 samples in the dataset.


In [58]:
batch1 = dataset1[1];
batch2 = dataset2[380];

res1 = run_varnet_model(batch1, fmodel);
res2 = run_varnet_model(batch2, fmodel);

In [57]:
for ix in range(len(dataset2)):
    batch = dataset2[ix]
    print(ix, batch.masked_kspace.shape)

0 torch.Size([16, 640, 320, 2])
1 torch.Size([16, 640, 320, 2])
2 torch.Size([16, 640, 320, 2])
3 torch.Size([16, 640, 320, 2])
4 torch.Size([16, 640, 320, 2])
5 torch.Size([16, 640, 320, 2])
6 torch.Size([16, 640, 320, 2])
7 torch.Size([16, 640, 320, 2])
8 torch.Size([16, 640, 320, 2])
9 torch.Size([16, 640, 320, 2])
10 torch.Size([16, 640, 320, 2])
11 torch.Size([16, 640, 320, 2])
12 torch.Size([16, 640, 320, 2])
13 torch.Size([16, 640, 320, 2])
14 torch.Size([16, 640, 320, 2])
15 torch.Size([16, 640, 320, 2])
16 torch.Size([20, 640, 320, 2])
17 torch.Size([20, 640, 320, 2])
18 torch.Size([20, 640, 320, 2])
19 torch.Size([20, 640, 320, 2])
20 torch.Size([20, 640, 320, 2])
21 torch.Size([20, 640, 320, 2])
22 torch.Size([20, 640, 320, 2])
23 torch.Size([20, 640, 320, 2])
24 torch.Size([20, 640, 320, 2])
25 torch.Size([20, 640, 320, 2])
26 torch.Size([20, 640, 320, 2])
27 torch.Size([20, 640, 320, 2])
28 torch.Size([20, 640, 320, 2])
29 torch.Size([20, 640, 320, 2])
30 torch.Size([20, 6

In [61]:
batch2.masked_kspace.shape

torch.Size([4, 512, 213, 2])

In [64]:
batch2._fields

('masked_kspace',
 'mask',
 'num_low_frequencies',
 'target',
 'fname',
 'slice_num',
 'max_value',
 'crop_size')

In [63]:
batch2.target.shape, batch2.crop_size

(torch.Size([213, 213]), (512, 408))

In [60]:
res1.shape, res2.shape

((1, 640, 320), (1, 512, 213))

In [None]:
# save the config file 
save_pkl(config.save_fold/"config.pkl", config);

Save to /data/rajlab1/user_data/jin/MyResearch/imageCP_dev/mypkg/../data/mask_type-equispaced_center_fraction-4_acceleration-400/config.pkl


In [None]:
def _run_fn(batch):
    fn = batch.fname.split(".")[0]
    sn = batch.slice_num
    fn_root = config.save_fold/f"{fn}-{sn}.pkl"

    if fn_root.exists():
        print(f"{fn_root} exists, skip it.")
        return

    fields = batch._fields
    res = edict()
    res.fx = run_varnet_model(batch, fmodel)
    res.gx = run_varnet_model(batch, gmodel,  is_Ysq=True);
    res.target = batch.target.numpy();
    res.mask = batch.mask.numpy();

    res.attrs = edict()
    for fv in fields:
        # I do not save masked_kspace (x), as it is very large
        if fv in ["mask", "target", "masked_kspace"]:
            continue
        v = getattr(batch, fv)
        if isinstance(v, torch.Tensor):
            v = v.numpy()
        res.attrs[fv] = v

    save_pkl(fn_root, res, is_force=False)
    return None

In [None]:
n_data = len(dataset)
n_jobs = 30

Parallel(n_jobs=n_jobs)(delayed(_run_fn)(dataset[idx]) for idx in tqdm(range(n_data), total=n_data));

  0%|                                                                                   | 0/7270 [00:00<?, ?it/s]

  0%|                                                                           | 6/7270 [00:02<46:00,  2.63it/s]

KeyboardInterrupt: 

In [22]:
dataset[0].fname

'file_brain_AXFLAIR_200_6002471.h5'

In [27]:
dataset[9].slice_num

9