In [15]:
import argparse
import time
from collections import defaultdict
from pathlib import Path
import h5py
import fastmri
import fastmri.data.transforms as T
import numpy as np
import requests
import torch
from fastmri.data import SliceDataset
from fastmri.models import Unet
from tqdm import tqdm


In [16]:
# loading  multi coil knee file 

fname = '/scratch/svangurp/samuel/data/knee/train/file1000002.h5'
data = h5py.File(fname, 'r')
kspace = data["kspace"][()]


In [17]:
def run_unet_model(batch, model, device):
    image, _, mean, std, fname, slice_num, _ = batch

    output = model(image.to(device).unsqueeze(1)).squeeze(1).cpu()

    mean = mean.unsqueeze(1).unsqueeze(2)
    std = std.unsqueeze(1).unsqueeze(2)
    output = (output * std + mean).cpu()

    return output, int(slice_num[0]), fname[0]


In [18]:
def run_inference(challenge, state_dict_file, data_path, output_path, device):
    model = Unet(in_chans=1, out_chans=1, chans=256, num_pool_layers=4, drop_prob=0.0)
    # download the state_dict if we don't have it
    if state_dict_file is None:
        if not Path(MODEL_FNAMES[challenge]).exists():
            url_root = UNET_FOLDER
            download_model(url_root + MODEL_FNAMES[challenge], MODEL_FNAMES[challenge])

        state_dict_file = MODEL_FNAMES[challenge]

    model.load_state_dict(torch.load(state_dict_file))
    model = model.eval()

    # data loader setup
    if "_mc" in challenge:
        data_transform = T.UnetDataTransform(which_challenge="multicoil")
    else:
        data_transform = T.UnetDataTransform(which_challenge="singlecoil")

    if "_mc" in challenge:
        dataset = SliceDataset(
            root=data_path,
            transform=data_transform,
            challenge="multicoil",
        )
    else:
        dataset = SliceDataset(
            root=data_path,
            transform=data_transform,
            challenge="singlecoil",
        )
    dataloader = torch.utils.data.DataLoader(dataset, num_workers=4)

    # run the model
    start_time = time.perf_counter()
    outputs = defaultdict(list)
    model = model.to(device)

    for batch in tqdm(dataloader, desc="Running inference"):
        with torch.no_grad():
            output, slice_num, fname = run_unet_model(batch, model, device)

        outputs[fname].append((slice_num, output))

    # save outputs
    for fname in outputs:
        outputs[fname] = np.stack([out for _, out in sorted(outputs[fname])])

    fastmri.save_reconstructions(outputs, output_path / "reconstructions")

    end_time = time.perf_counter()

    print(f"Elapsed time for {len(dataloader)} slices: {end_time-start_time}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument(
        "--challenge",
        default="unet_knee_sc",
        choices=(
            "unet_knee_sc",
            "unet_knee_mc",
            "unet_brain_mc",
        ),
        type=str,
        help="Model to run",
    )
    parser.add_argument(
        "--device",
        default="cuda",
        type=str,
        help="Model to run",
    )
    parser.add_argument(
        "--state_dict_file",
        default=None,
        type=Path,
        help="Path to saved state_dict (will download if not provided)",
    )
    parser.add_argument(
        "--data_path",
        type=Path,
        required=True,
        help="Path to subsampled data",
    )
    parser.add_argument(
        "--output_path",
        type=Path,
        required=True,
        help="Path for saving reconstructions",
    )

    args = parser.parse_args()

    run_inference(
        args.challenge,
        args.state_dict_file,
        args.data_path,
        args.output_path,
        torch.device(args.device),
    )


usage: ipykernel_launcher.py [-h]
                             [--challenge {unet_knee_sc,unet_knee_mc,unet_brain_mc}]
                             [--device DEVICE]
                             [--state_dict_file STATE_DICT_FILE] --data_path
                             DATA_PATH --output_path OUTPUT_PATH
ipykernel_launcher.py: error: the following arguments are required: --data_path, --output_path


SystemExit: 2

In [14]:
challenge = 'unet_knee_mc' 
state_dict_file ='/home/svangurp/scratch/samuel/pretrained/knee/unet/knee_mc_leaderboard_state_dict.pt'
data_path = '/scratch/svangurp/samuel/data/knee/train/'
output_path ='/home/svangurp/scratch/samuel/data/knee/model_ouputs/Unet_recon_knee_mc/'
device = 'cuda' 

run_inference(challenge, state_dict_file, data_path, output_path, device)

Running inference: 100%|██████████| 1448/1448 [03:20<00:00,  7.21it/s]


AttributeError: 'str' object has no attribute 'mkdir'