In [None]:
import torch
import json
import random
import pathlib
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from collections import defaultdict

import sys
sys.path.append('..')

from src.policy_model.policy_model_utils import (load_policy_model, get_policy_probs, create_data_range_dict,
                                                 compute_next_step_reconstruction)
from src.reconstruction_model.reconstruction_model_utils import load_recon_model
from src.helpers.data_loading import create_data_loader

In [None]:
def run_policy(args):   
    # Reconstruction model
    recon_args, recon_model = load_recon_model(args)

    # Policy model
    model, policy_args = load_policy_model(args.policy_model_checkpoint)
    assert recon_args.resolution == policy_args.resolution == args.resolution
    assert args.accelerations == policy_args.accelerations
    assert args.reciprocals_in_center == policy_args.reciprocals_in_center
    
    args.center_fractions = policy_args.center_fractions
    args.dataset = policy_args.dataset

    loader = create_data_loader(args, 'test')
    data_range_dict = create_data_range_dict(args, loader)
    next_rows_dict = {}  # for average policy visualisation
    return_this = False  # for single image visualisation
    with torch.no_grad():
        for it, data in enumerate(loader):
            if args.single_image:
                if args.image_idx >= it * args.val_batch_size and args.image_idx < (it + 1) * args.val_batch_size:
                    ind = args.image_idx - it * args.val_batch_size
                    return_this = True
                else:
                    continue
                
            kspace, masked_kspace, mask, zf, gt, gt_mean, gt_std, fname, _ = data
            # shape after unsqueeze = batch x channel x columns x rows x complex
            kspace = kspace.unsqueeze(1).to(args.device)
            masked_kspace = masked_kspace.unsqueeze(1).to(args.device)
            mask = mask.unsqueeze(1).to(args.device)
            # shape after unsqueeze = batch x channel x columns x rows
            zf = zf.unsqueeze(1).to(args.device)
            gt = gt.unsqueeze(1).to(args.device)
            gt_mean = gt_mean.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(args.device)
            gt_std = gt_std.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(args.device)
            unnorm_gt = gt * gt_std + gt_mean
            data_range = torch.stack([data_range_dict[vol] for vol in fname])
            # Base reconstruction model forward pass
            recons = recon_model(zf)
            
            for step in range(policy_args.acquisition_steps):
                policy, probs = get_policy_probs(model, recons, mask)
                if step == 0:
                    actions = torch.multinomial(probs.squeeze(1), 1, replacement=True)  # single trajectory
                else:
                    actions = policy.sample()

                mask, masked_kspace, zf, recons = compute_next_step_reconstruction(recon_model, kspace,
                                                                                   masked_kspace, mask, actions)

                if not step + 1 in next_rows_dict:
                    next_rows_dict[step + 1] = actions.squeeze(-1).to('cpu')
                else:
                    next_rows_dict[step + 1] = np.concatenate((next_rows_dict[step + 1], 
                                                               actions.squeeze(-1).to('cpu').numpy()))
                    
            if return_this:              
                return (gt[ind:ind+1, :, :, :].cpu(), 
                        recons[ind:ind+1, :, :, :].cpu(), 
                        mask[ind:ind+1, :, :, :, 0].cpu())

        return next_rows_dict

In [None]:
def plot_average_policies(base_args, row_dict, runs):
    # Assumes four entries in row_dict to plot 
    assert len(row_dict) == len(runs) == 4
    sns.set_style('dark')
    fig, (ax1, ax2) = plt.subplots(2, 2, sharey=True, sharex=True, figsize=(18, 8))

    res = base_args.resolution
    for run in runs:
        rows = row_dict[run]
        accel = 8
        steps = len(rows)
        
        loc = 1
        if 'nongreedy' in run or 'gamma' in run:
            loc += 1
        if 'long' in run:
            accel = 32
            loc += 2
        
        plt.subplot(2, 2, loc)
        if loc == 1:
            plt.title('Greedy', fontsize=18)
            plt.ylabel('column', fontsize=15)
        elif loc == 2:
            if 'nongreedy' in run:
                plt.title('NGreedy', fontsize=18)
            elif 'gamma' in run:
                plt.title('γ = 0.9', fontsize=18)
            plt.yticks([], [])
            plt.yticks([], [])
        elif loc == 3:
            plt.xlabel('acquisition step', fontsize=15)
            plt.ylabel('column', fontsize=15)  
        elif loc == 4:
            plt.xlabel('acquisition step', fontsize=15)
            plt.yticks([], [])
        
        img = np.zeros((res, steps + 1))
        for step, row in rows.items():
            for r in row:
                img[r, step:] += 1
        img /= len(rows[1])
        img[int(res // 2 * (1 - 1/accel)):int(res // 2 * (1 + 1/accel)), :] = 1
        im = plt.imshow(img, vmin=0, vmax=1, cmap='gist_gray', aspect='auto')
        if loc == 4:
            cbim = im
            
    fig.subplots_adjust(right=0.95)
    cbar_ax = fig.add_axes([0.97, 0.14, 0.017, 0.7])
    fig.colorbar(cbim, cax=cbar_ax)
    fig.subplots_adjust(wspace=0.05, hspace=0.18)
    
    plt.show()

In [None]:
def plot_image_grid(base_args, image_dict, runs):
    # Plot the image grid in the order given in 'runs' (plot from top to bottom)
    assert len(image_dict) == len(runs)
    
    import torchvision
    grid = []    
    fig = plt.figure(figsize=(10, len(image_dict) * 2.5))
    for run in runs:
        gt, recon, mask = image_dict[run]
        mask = mask.expand(-1, -1, mask.shape[-1], -1)
        grid_row = torch.cat((mask, recon, gt, torch.abs(recon - gt)), dim=0)
        grid.append(grid_row)
        
    grid = torch.cat(grid, dim=0)
    grid_img = torchvision.utils.make_grid(grid, nrow=4, normalize=True, scale_each=True, pad_value=1)
    plt.imshow(grid_img.permute(1,2,0), cmap='gist_gray', aspect='auto')
    plt.xticks([], [])
    plt.yticks([], [])
    plt.tight_layout()
    plt.show()

In [None]:
class Arguments:
    def __init__(self, dataset, recon_model_checkpoint, data_path):
        self.seed = 0
        self.device = 'cuda'
        self.num_workers = 8
        self.acquisition = None
        self.reciprocals_in_center = [1]
        self.recon_model_checkpoint = pathlib.Path(recon_model_checkpoint)
        self.data_path = pathlib.Path(data_path)
        
        if dataset.lower() == 'knee':
            self.center_volume = True
            self.sample_rate = 0.5
            self.val_batch_size = 512
            self.resolution = 128
        elif dataset.lower() == 'brain':
            self.center_volume = False
            self.sample_rate = 0.2
            self.val_batch_size = 128
            self.resolution = 256

In [None]:
### BEGIN VALUES TO SET ###

# Which dataset
dataset = 'Knee'
# dataset = 'Brain'

# Which runs
runs = ['base_greedy']
# runs = ['base_greedy', 'base_nongreedy', 'long_greedy', 'long_nongreedy']
# runs = ['base_greedy', 'base_gamma09', 'long_greedy', 'long_gamma09']

# Visualise average policy
average_policy = False
# Visualise MR images in grid (not performed if average_policy is True)
single_image = True
image_idx = 0  # Which slice to pick

# Set base path for policy models. Entries in policy_model_checkpoints will be appended to this base path 
# to construct the full policy model path. Corresponds to exp_dir in train_policy.py, though it can be set
# in any way that works out with the relative policy model paths (to be specified further below).
base_policy_path = '<base_path_to_stored_models>'

if dataset.lower() == 'knee':
    # Set data path, recon model checkpoint, and policy model checkpoints
    data_path = '<path_to_knee_data>'
    recon_model_checkpoint = '<path_to_recon_model.pt>'
    policy_model_checkpoints = {'base_greedy': '<relative_path_to_policy_model.pt_from_base_policy_path>',
                                'long_greedy': '<relative_path_to_policy_model.pt_from_base_policy_path>', 
                                'base_nongreedy': '<relative_path_to_policy_model.pt_from_base_policy_path>',
                                'long_nongreedy': '<relative_path_to_policy_model.pt_from_base_policy_path>',
                                'base_gamma09': '<relative_path_to_policy_model.pt_from_base_policy_path>',
                                'long_gamma09': '<relative_path_to_policy_model.pt_from_base_policy_path>'}
    
elif dataset.lower() == 'brain':
    # Set data path, recon model checkpoint, and policy model checkpoints
    data_path = '<path_to_brain_data>'
    recon_model_checkpoint = '<path_to_recon_model.pt>'
    policy_model_checkpoints = {'base_greedy': '<relative_path_to_policy_model.pt_from_base_policy_path>',
                                'long_greedy': '<relative_path_to_policy_model.pt_from_base_policy_path>', 
                                'base_nongreedy': '<relative_path_to_policy_model.pt_from_base_policy_path>',
                                'long_nongreedy': '<relative_path_to_policy_model.pt_from_base_policy_path>',
                                'base_gamma09': '<relative_path_to_policy_model.pt_from_base_policy_path>',
                                'long_gamma09': '<relative_path_to_policy_model.pt_from_base_policy_path>'}

### END VALUES TO SET ###

In [None]:
base_args = Arguments(dataset, recon_model_checkpoint, data_path)
row_dict = {}
image_dict = {}

for run in runs:
    print(run)
    # Skip runs for which no policy model is given
    if policy_model_checkpoints[run] is None:
        continue
        
    # Set policy model
    base_args.policy_model_checkpoint = pathlib.Path(base_policy_path) / policy_model_checkpoints[run]
    
    # Set horizon parameters
    if 'base' in run:
        base_args.accelerations = [8]
        base_args.acquisition_steps = 16
    elif 'long' in run:
        base_args.accelerations = [32]
        base_args.acquisition_steps = 28
    
    # Do average policy computation
    if average_policy:
        base_args.single_image = False  # either, or
        rows = run_policy(base_args)
        row_dict[run] = rows
        
    # Or just grab a single image from the policy run
    if single_image:
        base_args.single_image = True
        base_args.image_idx = image_idx
        gt, recon, mask = run_policy(base_args)
        image_dict[run] = (gt, recon, mask)
    
# Plot functions
if average_policy:
    plot_average_policies(base_args, row_dict, runs)
if single_image:
    plot_image_grid(base_args, image_dict, runs)