In [None]:
import torch
import pickle
import pathlib
import wandb
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from collections import defaultdict
from pprint import pprint
sns.set_style('darkgrid')

api = wandb.Api()

import sys
sys.path.append('..')

from src.policy_model.policy_model_utils import (load_policy_model, get_policy_probs, create_data_range_dict,
                                                 compute_next_step_reconstruction, compute_scores)
from src.reconstruction_model.reconstruction_model_utils import load_recon_model
from src.helpers.data_loading import create_data_loader

In [None]:
def evaluate(args, recon_model, model, loader, data_range_dict):
    """
    Evaluates using SSIM of reconstruction over trajectory. Doesn't require computing targets!
    """
    model.eval()
    rows_dict = defaultdict(list)
    cond_ent_dict = defaultdict(float)
    marg_prob_dict = defaultdict(float)
    
    tbs = 0  # data set size counter
    with torch.no_grad():
        for it, data in enumerate(loader):
            kspace, masked_kspace, mask, zf, gt, gt_mean, gt_std, fname, _ = data
            # shape after unsqueeze = batch x channel x columns x rows x complex
            kspace = kspace.unsqueeze(1).to(args.device)
            masked_kspace = masked_kspace.unsqueeze(1).to(args.device)
            mask = mask.unsqueeze(1).to(args.device)
            # shape after unsqueeze = batch x channel x columns x rows
            zf = zf.unsqueeze(1).to(args.device)
            gt = gt.unsqueeze(1).to(args.device)
            gt_mean = gt_mean.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(args.device)
            gt_std = gt_std.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(args.device)
            unnorm_gt = gt * gt_std + gt_mean
            data_range = torch.stack([data_range_dict[vol] for vol in fname])
            tbs += mask.size(0)

            # Base reconstruction model forward pass
            recons = recon_model(zf)

            for step in range(args.acquisition_steps):
                policy, probs = get_policy_probs(model, recons, mask)
                if step == 0:
                    actions = torch.multinomial(probs.squeeze(1), args.num_test_trajectories, replacement=True)
                else:
                    actions = policy.sample()
                    
                # Store all trajectories per data point
                rows_dict[step].extend(actions.cpu().numpy().tolist())
                # ent over rows, average over trajs, sum over data points
                cond_ent_dict[step] += ent(probs.cpu(), dim=-1).sum(dim=0).mean()
                marg_prob_dict[step] += probs.cpu().mean(dim=1).sum(dim=0)
                
                # Samples trajectories in parallel
                # For evaluation we can treat greedy and non-greedy the same: in both cases we just simulate
                # num_test_trajectories acquisition trajectories in parallel for each slice in the batch, and store
                # the average SSIM score every time step.
                mask, masked_kspace, zf, recons = compute_next_step_reconstruction(recon_model, kspace,
                                                                                   masked_kspace, mask, actions)

    avg_cond_ent_dict = {step: sum_ent / tbs for step, sum_ent in cond_ent_dict.items()}
    avg_marg_ent_dict = {step: ent(marg_prob / tbs, dim=0) for step, marg_prob in marg_prob_dict.items()}        
    return rows_dict, avg_cond_ent_dict, avg_marg_ent_dict

In [None]:
def get_ckpt_from_id(run_id, base, entity, project):
    run = api.run(f'{entity}/{project}/{run_id}')
    args = run.config
    ckpt = base / pathlib.Path(args['run_dir']).name / 'model.pt'
    return ckpt

def ent(probs, dim):
    probs = probs + 1e-11
    logprobs = torch.log(probs)
    ent =  (-1 * probs * logprobs).sum(dim=dim)
    return ent

def save_results(res, save_name):
    with open(save_name, 'wb') as f:
        pickle.dump(res, f)
    return None
        
def load_results(save_name):
    with open(save_name, 'rb') as f:
        res = pickle.load(f)
    return res

In [None]:
class Arguments:
    def __init__(self, accel, acq, force, res, batch_size, sample_rate, center_volume, recon_model_checkpoint, data_path, dataset):
        self.accelerations = [accel]
        self.center_fractions = [1 / accel]
        self.acquisition_steps = acq
        self.resolution = res
        self.val_batch_size = batch_size
        self.batches_step = 1
        self.num_trajectories = 8
        self.dataset = dataset
        
        self.data_path = pathlib.Path(data_path)
        self.recon_model_checkpoint = pathlib.Path(recon_model_checkpoint)
        
        self.sample_rate = sample_rate
        self.acquisition = None
        self.center_volume = center_volume
        self.device = 'cuda'
        self.num_workers = 4
        
        self.force = force

In [None]:
### BEGIN VALUES TO SET ###

dataset = 'knee'  # or 'brain'
wandb_entity = 'WANDB_ENTITY NAME'

wandb_knee_project = 'WANDB_KNEE_PROJECT_NAME'
wandb_brain_project = 'WANDB_BRAIN_PROJECT_NAME'

# Set base path for policy models. Corresponds to exp_dir in train_policy.py
knee_base = '<path_to_knee_policy_model_base_dir>'
brain_base = '<path_to_brain_policy_model_base_dir>'

# Whether to overwrite SSIM values stored on drive if exist
force = False

if dataset == 'knee':
    batch_size = 128
    res = 128
    sample_rate = 0.5
    center_volume = True
    data_path = '<path_to_knee_data>'
    recon_model_checkpoint = '<path_to_knee_recon_model.pt>'
    wandb_project = wandb_knee_project

elif dataset == 'brain':
    res = 256
    sample_rate = 0.2
    center_volume = False
    batch_size = 32
    data_path = '<path_to_brain_data>'
    recon_model_checkpoint = '<path_to_brain_recon_model.pt>'
    wandb_project = wandb_brain_project

### END VALUES TO SET ###

In [None]:
# This creates a dictionary of run names, dirs and ids, based on the Wandb API.    

run_id_dict = {"16-32": defaultdict(dict),
               "4-32": defaultdict(dict)}

runs = api.runs(f"{wandb_entity}/{wandb_project}", {"config.sample_rate": sample_rate})
for run in runs:
    if not run.state == 'finished':
        continue
    
    name = run.name
    args = run.config
    if args['dataset'].lower() != dataset.lower():
        continue  # Skip models not on given dataset
        
    ### YOUR FILTERS HERE ###

    if args['model_type'] == 'greedy':
        key = 'greedy'
    elif args['gamma'] == 1:
        key = 'nongreedy'
    else:
        key = args['gamma']
                
    run_dir = args['run_dir'].split('/')[-1]
    
    if args['accelerations'] == [8]:
        run_id_dict["16-32"][key][name] = {'id': run.id, 'dir': run_dir}
    elif args['accelerations'] == [32]:
        run_id_dict["4-32"][key][name] = {'id': run.id, 'dir': run_dir}
            
pprint(run_id_dict)

In [None]:
all_rows_dict = defaultdict(lambda: defaultdict(dict))
all_cond_ent_dict = defaultdict(lambda: defaultdict(dict))
all_marg_ent_dict = defaultdict(lambda: defaultdict(dict))

# Load recon model
args = Arguments(2, 2, force, res, batch_size, sample_rate, center_volume, recon_model_checkpoint, data_path, dataset)
recon_args, recon_model = load_recon_model(args)

for horizon, mode_dict in run_id_dict.items():            
    if horizon == '16-32':
        accel = 8
        acq = 16
    elif horizon == '4-32':
        accel = 32
        acq = 28
    else:
        print(horizon)
        raise ValueError()
                
    # Load data for this horizon
    args = Arguments(accel, acq, force, res, batch_size, sample_rate, center_volume, recon_model_checkpoint, data_path, dataset)
    loader = create_data_loader(args, 'test')

    for mode, runs in mode_dict.items(): 
        if dataset == 'knee':
            base = pathlib.Path(knee_base)
        elif dataset == 'brain':
            base = pathlib.Path(brain_base)
            
        for name, run_dict in runs.items():
            run_id = run_dict['id']
            run_dir = run_dict['dir']
            print(horizon, mode, name, run_id)
            # Actually have checkpoint in run_dir already
            ckpt = get_ckpt_from_id(run_id, base, wandb_entity, wandb_project)
            
            try:
                model, policy_args = load_policy_model(pathlib.Path(ckpt))
            except FileNotFoundError:
                print(f'File not found: {args.policy_model_checkpoint}')
                continue

            policy_args.num_test_trajectories = args.num_trajectories
            
            row_save_name = ckpt.parent / f'rows_t{args.num_trajectories}.pkl'
            cond_ent_save_name = ckpt.parent / f'ent_t{args.num_trajectories}.pkl'
            marg_ent_save_name = ckpt.parent / f'ment_t{args.num_trajectories}.pkl'
    
            if row_save_name.exists() and cond_ent_save_name.exists() and marg_ent_save_name.exists() and not args.force:
                print(f'Results already stored in: \n   {row_save_name.parent}')               
                rows = load_results(row_save_name)
                cents = load_results(cond_ent_save_name)
                ments = load_results(marg_ent_save_name)
            else:
                data_range_dict = create_data_range_dict(args, loader)
                rows, cents, ments = evaluate(policy_args, recon_model, model, loader, data_range_dict)

            save_results(rows, row_save_name)
            save_results(cents, cond_ent_save_name)
            save_results(ments, marg_ent_save_name)
            
            all_rows_dict[horizon][mode][name] = rows
            all_cond_ent_dict[horizon][mode][name] = cents
            all_marg_ent_dict[horizon][mode][name] = ments

In [None]:
# Plot mutual information

sns.set_style('darkgrid')
cdict = {'greedy': 'tab:blue', 'nongreedy': 'tab:cyan', '0.9': 'tab:orange'}

plt.figure(figsize=(15, 5))
for horizon, hor_dict in sorted(all_cond_ent_dict.items()):
    if horizon == '16-32':
        plt.subplot(1, 2, 1)
        plt.ylabel('mutual information (nats)', fontsize=15)
        plt.xlabel('acquisition step', fontsize=15)
        plt.ylim(0, 2.6)
        plt.title('base horizon', fontsize=18)
    else:
        continue
        plt.subplot(1, 2, 2)
        plt.xlabel('acquisition step', fontsize=15)
        plt.ylim(0, 2.6)
        plt.title('long horizon', fontsize=18)
        
    for mode, mode_dict in hor_dict.items():
        if mode == 'greedy':
            label = 'Greedy'
        if mode == 'nongreedy':
            label = 'NGreedy'
        if mode == 0.9:
            label = 'γ = 0.9'
            
        cent_list = []
        ment_list = []
        for name, ents in mode_dict.items():
            cent_list.append([val for step, val in ents.items()])
            ment_list.append([val for step, val in all_marg_ent_dict[horizon][mode][name].items()])
        
        cent_arr = np.array(cent_list)
        ment_arr = np.array(ment_list)
        
        minf = ment_arr - cent_arr
        
        avg_minf = np.mean(minf, axis=0)
        std_minf = np.std(minf, axis=0, ddof=1)

        steps = list(range(1, len(avg_minf) + 1))
        plt.plot(steps, avg_minf, label=label, color=cdict[str(mode)])
        plt.fill_between(steps, avg_minf-std_minf, avg_minf+std_minf, alpha=0.3, color=cdict[str(mode)])
    
    plt.legend(loc='upper left')
    
plt.tight_layout()
plt.show()

In [None]:
# Plot conditional and marginal entropy

sns.set_style('darkgrid')
cdict = {'greedy': 'tab:blue', 'nongreedy': 'tab:cyan', '0.9': 'tab:orange'}

plt.figure(figsize=(15, 5))
for horizon, hor_dict in sorted(all_cond_ent_dict.items()):
    if horizon == '16-32':
        plt.subplot(1, 2, 1)
        plt.ylabel('entropy (nats)', fontsize=15)
        plt.xlabel('acquisition step', fontsize=15)
        plt.ylim(0, 4)
        plt.title('base horizon', fontsize=18)
    else:
        plt.subplot(1, 2, 2)
        plt.xlabel('acquisition step', fontsize=15)
        plt.ylim(0, 4)
        plt.title('long horizon', fontsize=18)
        
    for mode, mode_dict in hor_dict.items():
        if mode == 'greedy':
            label = 'Greedy'
        if mode == 'nongreedy':
            label = 'NGreedy'
        if mode == 0.9:
            label = 'γ = 0.9'
            
        cent_list = []
        ment_list = []
        for name, ents in mode_dict.items():
            cent_list.append([val for step, val in ents.items()])
            ment_list.append([val for step, val in all_marg_ent_dict[horizon][mode][name].items()])
        
        cent_arr = np.array(cent_list)
        ment_arr = np.array(ment_list)
        
        avg_cent = np.mean(cent_arr, axis=0)
        std_cent = np.std(cent_arr, axis=0, ddof=1)
        
        avg_ment = np.mean(ment_arr, axis=0)
        std_ment = np.std(ment_arr, axis=0, ddof=1)

        steps = list(range(1, len(avg_ment) + 1))
        plt.plot(steps, avg_ment, '--', label=f'{label} marg ent', color=cdict[str(mode)])
        plt.fill_between(steps, avg_ment-std_ment, avg_ment+std_ment, alpha=0.3, color=cdict[str(mode)])
        
        plt.plot(steps, avg_cent, label=f'{label} cond ent', color=cdict[str(mode)])
        plt.fill_between(steps, avg_cent-std_cent, avg_cent+std_cent, alpha=0.3, color=cdict[str(mode)])
    
    plt.legend(loc='upper left')
    
plt.tight_layout()
plt.show()