In [None]:
import torch
import pickle
import pathlib
import wandb
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from collections import defaultdict
from pprint import pprint
sns.set_style('darkgrid')

api = wandb.Api()

import sys
sys.path.append('..')

from src.policy_model.policy_model_utils import (load_policy_model, get_policy_probs, create_data_range_dict,
                                                 compute_next_step_reconstruction, compute_scores)
from src.reconstruction_model.reconstruction_model_utils import load_recon_model
from src.helpers.data_loading import create_data_loader
from src.helpers.torch_metrics import compute_ssim, compute_psnr

In [None]:
# dataset = 'knee'  # or 'brain'
# wandb_entity = 'timsey'  # 'WANDB_ENTITY NAME'
# wandb_project = 'mri_refactor'  # 'WANDB_PROJECT_NAME'

# # Whether to overwrite SSIM values stored on drive if exist
# force = False

# # Set base path for policy models.
# base_policy_path = '/home/timsey/Projects/mrimpro/refactor'

# if dataset.lower() == 'knee':
#     # Set data path, recon model checkpoint, and policy model checkpoints
#     data_path = '/home/timsey/HDD/data/fastMRI/singlecoil/'
#     recon_model_checkpoint = '/home/timsey/Projects/fastMRI-shi/models/unet/al_nounc_res128_8to4in2_cvol_symk/model.pt'

In [None]:
### BEGIN VALUES TO SET ###

dataset = 'Knee'  # or 'Brain'
entity = 'WANDB_ENTITY NAME'
wandb_project = 'WANDB_PROJECT_NAME'

force = False

# Set base path for policy models. Corresponds to exp_dir in train_policy.py
base_policy_path = '<base_path_to_stored_models>'

if dataset.lower() == 'knee':
    # Set data path, recon model checkpoint, and policy model checkpoints
    data_path = '<path_to_knee_data>'
    recon_model_checkpoint = '<path_to_recon_model.pt>'
    
elif dataset.lower() == 'brain':
    # Set data path, recon model checkpoint, and policy model checkpoints
    data_path = '<path_to_brain_data>'
    recon_model_checkpoint = '<path_to_recon_model.pt>'
    
### END VALUES TO SET ###

In [None]:
# This creates a dictionary of run names, dirs and ids, based on the Wandb API.

if dataset == 'knee':
    sample_rate = 0.5
if dataset == 'brain':
    sample_rate = 0.2

run_id_dict = {"16-32": defaultdict(dict),
               "4-32": defaultdict(dict)}

runs = api.runs(f"{wandb_entity}/{wandb_project}", {"config.sample_rate": sample_rate})
for run in runs:
    if not run.state == 'finished':
        continue
    
    name = run.name
    args = run.config
    
    if args['dataset'].lower() != dataset.lower():
        continue  # Skip models not on given dataset
        
    ### YOUR FILTERS HERE ###

    if args['model_type'] == 'greedy':
        key = 'greedy'
    elif args['gamma'] == 1:
        key = 'nongreedy'
    else:
        key = args['gamma']
                
    run_dir = args['run_dir'].split('/')[-1]
    
    if args['accelerations'] == [8]:
        run_id_dict["16-32"][key][name] = {'id': run.id, 'dir': run_dir}
    elif args['accelerations'] == [32]:
        run_id_dict["4-32"][key][name] = {'id': run.id, 'dir': run_dir}
            
pprint(run_id_dict)

In [None]:
ssim_dict = defaultdict(lambda: defaultdict(dict))
psnr_dict = defaultdict(lambda: defaultdict(dict))

for horizon, gamma_dict in run_id_dict.items():
    for gamma, id_dict in gamma_dict.items():
        ssim_dict[horizon][gamma] = defaultdict()
        psnr_dict[horizon][gamma] = defaultdict()

In [None]:
def get_ckpt_from_id(run_id, base, entity, project):
    run = api.run(f'{entity}/{project}/{run_id}')
    args = run.config
    ckpt = base / pathlib.Path(args['run_dir']).name / 'model.pt'
    return ckpt

In [None]:
def evaluate(args, recon_model, model, loader, data_range_dict):
    """
    Evaluates using SSIM of reconstruction over trajectory. Doesn't require computing targets!
    """
    model.eval()
    ssims, psnrs = 0, 0
    tbs = 0  # data set size counter
    with torch.no_grad():
        for it, data in enumerate(loader):
            kspace, masked_kspace, mask, zf, gt, gt_mean, gt_std, fname, _ = data
            # shape after unsqueeze = batch x channel x columns x rows x complex
            kspace = kspace.unsqueeze(1).to(args.device)
            masked_kspace = masked_kspace.unsqueeze(1).to(args.device)
            mask = mask.unsqueeze(1).to(args.device)
            # shape after unsqueeze = batch x channel x columns x rows
            zf = zf.unsqueeze(1).to(args.device)
            gt = gt.unsqueeze(1).to(args.device)
            gt_mean = gt_mean.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(args.device)
            gt_std = gt_std.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(args.device)
            unnorm_gt = gt * gt_std + gt_mean
            data_range = torch.stack([data_range_dict[vol] for vol in fname])
            tbs += mask.size(0)

            # Base reconstruction model forward pass
            recons = recon_model(zf)
            unnorm_recons = recons[:, :, :, :] * gt_std + gt_mean
            init_ssim_val = compute_ssim(unnorm_recons, unnorm_gt, size_average=False,
                                         data_range=data_range).mean(dim=(-1, -2)).sum()
            init_psnr_val = compute_psnr(args, unnorm_recons, unnorm_gt, data_range).sum()

            batch_ssims = [init_ssim_val.item()]
            batch_psnrs = [init_psnr_val.item()]

            for step in range(args.acquisition_steps):
                policy, probs = get_policy_probs(model, recons, mask)
                if step == 0:
                    actions = torch.multinomial(probs.squeeze(1), args.num_test_trajectories, replacement=True)
                else:
                    actions = policy.sample()
                # Samples trajectories in parallel
                # For evaluation we can treat greedy and non-greedy the same: in both cases we just simulate
                # num_test_trajectories acquisition trajectories in parallel for each slice in the batch, and store
                # the average SSIM score every time step.
                mask, masked_kspace, zf, recons = compute_next_step_reconstruction(recon_model, kspace,
                                                                                   masked_kspace, mask, actions)
                ssim_scores, psnr_scores = compute_scores(args, recons, gt_mean, gt_std, unnorm_gt, data_range,
                                                          comp_psnr=True)
                assert len(ssim_scores.shape) == 2
                ssim_scores = ssim_scores.mean(-1).sum()
                psnr_scores = psnr_scores.mean(-1).sum()
                
                # eventually shape = al_steps
                batch_ssims.append(ssim_scores.item())
                batch_psnrs.append(psnr_scores.item())

            # shape of al_steps
            ssims += np.array(batch_ssims)
            psnrs += np.array(batch_psnrs)

    ssims /= tbs
    psnrs /= tbs

    return ssims, psnrs

In [None]:
class Arguments:
    def __init__(self, dataset, recon_model_checkpoint, policy_model_checkpoint, data_path, accel=8, acq=16):
        self.seed = 0
        self.device = 'cuda'
        self.num_workers = 8
        self.acquisition = None
        self.reciprocals_in_center = [1]
        self.recon_model_checkpoint = pathlib.Path(recon_model_checkpoint)
        self.policy_model_checkpoint = pathlib.Path(policy_model_checkpoint)
        self.data_path = pathlib.Path(data_path)
        self.accelerations = [accel]
        self.center_fractions = [1 / accel]
        self.acquisition_steps = acq
        self.num_trajectories = 8
        self.dataset = dataset
        
        if dataset.lower() == 'knee':
            self.center_volume = True
            self.sample_rate = 0.5
            self.val_batch_size = 512
            self.resolution = 128
        elif dataset.lower() == 'brain':
            self.center_volume = False
            self.sample_rate = 0.2
            self.val_batch_size = 128
            self.resolution = 256

In [None]:
# Load recon model
args = Arguments(dataset, recon_model_checkpoint, 'None', data_path)
recon_args, recon_model = load_recon_model(args)

for horizon, mode_dict in run_id_dict.items():
    if horizon == '16-32':
        accel = 8
        acq = 16
    elif horizon == '4-32':
        accel = 32
        acq = 28
    else:
        raise ValueError()
                
    # Load data for this horizon
    args = Arguments(dataset, recon_model_checkpoint, 'None', data_path, accel, acq)
    loader = create_data_loader(args, 'test')
    data_range_dict = create_data_range_dict(args, loader)

    for mode, runs in mode_dict.items():     
        for name, run_info in runs.items():
            run_id = run_info['id']
            run_dir = run_info['dir']
            ckpt = get_ckpt_from_id(run_id, pathlib.Path(base_policy_path), wandb_entity, wandb_project)
            assert str(ckpt.parent.name) == run_dir, 'Something went wrong with storing directory.'
            
            try:
                model, policy_args = load_policy_model(pathlib.Path(ckpt))
            except FileNotFoundError:
                print(f' File corresponding to {name} not found:\n    {ckpt}')
                continue

            policy_args.num_test_trajectories = args.num_trajectories
            ssim_save_path = pathlib.Path(ckpt).parent / f'test_ssims_t{args.num_trajectories}.pkl'
            psnr_save_path = pathlib.Path(ckpt).parent / f'test_psnrs_t{args.num_trajectories}.pkl'

            if ssim_save_path.exists() and psnr_save_path.exists() and not force:
                print(f'SSIMs already stored in: {ssim_save_path}')
                with open(ssim_save_path, 'rb') as f:
                    ssims = pickle.load(f)
                print(f'PSNRs already stored in: {psnr_save_path}')
                with open(psnr_save_path, 'rb') as f:
                    psnrs = pickle.load(f)
            else:
                ssims, psnrs = evaluate(policy_args, recon_model, model, loader, data_range_dict)
                
            ssim_dict[horizon][mode][name] = [ssims, run_dir]
            with open(ssim_save_path, 'wb') as f:
                pickle.dump(ssims, f)
                
            psnr_dict[horizon][mode][name] = [psnrs, run_dir]
            with open(psnr_save_path, 'wb') as f:
                pickle.dump(psnrs, f)
            
            print(name, run_id, ssims[0], ssims[-1])
            print(name, run_id, psnrs[0], psnrs[-1])

In [None]:
def print_scores(score_dict):
    ready_best_dirs = []
    for horizon, hor_dict in score_dict.items():
        for mode, mode_dict in hor_dict.items():
            end_scores = []
            run_dirs = []
            for name, (ssims, run_dir) in mode_dict.items():
                end_scores.append(ssims[-1])
                run_dirs.append(run_dir)

            if len(end_scores) == 0:
                continue
                
            if len(end_scores) == 5:
                max_ind = np.argmax(end_scores)
                ready_best_dirs.append(run_dirs[max_ind])

            end_scores = np.array(end_scores)
            mean = end_scores.mean()
            std = end_scores.std(ddof=1)

            print(len(end_scores), f'Horizon {horizon}, model {mode}: {mean:.4f} \pm {std:.4f}')
    print(ready_best_dirs)

In [None]:
print_scores(ssim_dict)
# print_scores(psnr_dict)