In [None]:
from typing import Dict, Any
import os
import gc
import sys
import copy
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
sys.path.append('../src')
sys.path.append('../src/pgm')
import multiprocessing
import numpy as np
from tqdm import tqdm

from pgm.train_cf import cf_epoch
from pgm.train_pgm import setup_dataloaders
from pgm.flow_pgm import ChestPGM

class Hparams:
    def update(self, dict):
        for k, v in dict.items():
            setattr(self, k, v)

In [None]:
predictor_path = '/workspace/checkpoints/a_r_s_f/aux_mimic_60k-auxg/checkpoint.pt'
print(f'\nLoading predictor checkpoint: {predictor_path}')
predictor_checkpoint = torch.load(predictor_path)
predictor_args = Hparams()
predictor_args.update(predictor_checkpoint['hparams'])
assert predictor_args.dataset == 'morphomnist'
predictor = MorphoMNISTPGM(predictor_args).cuda()
predictor.load_state_dict(predictor_checkpoint['ema_model_state_dict'])

In [None]:
pgm_path = '/workspace/checkpoints/a_r_s_f/pgm_60k-pgmg/checkpoint.pt'
print(f'\nLoading PGM checkpoint: {pgm_path}')
pgm_checkpoint = torch.load(pgm_path)
pgm_args = Hparams()
pgm_args.update(pgm_checkpoint['hparams'])
assert pgm_args.dataset == 'morphomnist'
pgm = MorphoMNISTPGM(pgm_args).cuda()
pgm.load_state_dict(pgm_checkpoint['ema_model_state_dict'])

In [None]:
def load_fm(fm_path):
    print(f'\nLoading flow matching checkpoint: {fm_path}')
    fm_checkpoint = torch.load(fm_path)
    fm_args = Hparams()
    fm_args.update(fm_checkpoint['hparams'])
    fm_args.data_dir = '../datasets/'

    # init model
    assert fm_args.hps == 'morphomnist'
    from flow_model import DeeperUnet
    fm=DeeperUnet(fm_args).cuda()
    fm.load_state_dict(fm_checkpoint['ema_model_state_dict'])
    return fm, fm_args

model_name = ''
fm_path = '../checkpoints/t_i_d/'+model_name+'/checkpoint.pt'
fm, fm_args = load_fm(fm_path)

In [None]:
from morphomnist.morpho import ImageMorphology
# Refer to https://github.com/dccastro/Morpho-MNIST for details on Morpho-MNIST

def get_intensity(x, threshold=0.5):
    x = x.detach().cpu().numpy()[:, 0]
    x_min, x_max = x.min(axis=(1, 2), keepdims=True), x.max(axis=(1, 2), keepdims=True)
    mask = (x >= x_min + (x_max - x_min) * threshold)
    return np.array([np.median(i[m]) for i, m in zip(x, mask)])

def img_thickness(img, threshold, scale):
    return ImageMorphology(np.asarray(img), threshold, scale).mean_thickness

def unpack(args):
    return img_thickness(*args)

def get_thickness(x, threshold=0.5, scale=4, pool=None, chunksize=100):
    imgs = x.detach().cpu().numpy()[:, 0]
    args = ((img, threshold, scale) for img in imgs)
    if pool is None:
        gen = map(unpack, args)
    else:
        gen = pool.imap(unpack, args, chunksize=chunksize)
    results = tqdm(gen, total=len(imgs), unit='img', ascii=True)
    return list(results)

def fm_preprocess(pa: Dict[str, Tensor], input_res: int = 28) -> Tensor:
    # concatenate parents and expand to input resolution for vae input
    for k,v in pa.items():
        pa[k]=pa[k].unsqueeze(-1)
    return pa

    
@torch.no_grad()
def cf_epoch(
    vae: nn.Module, 
    pgm: nn.Module, 
    predictor: nn.Module, 
    dataloaders: Dict[str, DataLoader],
    do_pa: Optional[str] = None, 
    te_cf: bool = False
) -> Tuple[Tensor, Tensor, Tensor]:
    vae.eval()
    pgm.eval()
    predictor.eval()
    dag_vars = list(pgm.variables.keys())
    preds = {k: [] for k in dag_vars}
    targets = {k: [] for k in dag_vars}
    x_counterfactuals = []
    train_set = copy.deepcopy(dataloaders['train'].dataset.samples)
    loader = tqdm(enumerate(dataloaders['test']), total=len(
        dataloaders['test']), mininterval=0.1)

    for _, batch in loader:
        bs = batch['x'].shape[0]
        batch = preprocess(batch)
        pa = {k: v for k, v in batch.items() if k != 'x'}
        # randomly intervene on a single parent do(pa_k), pa_k ~ p(pa_k)
        do = {}
        if do_pa is not None:
            idx = torch.randperm(train_set[do_pa].shape[0])
            do[do_pa] = train_set[do_pa].clone()[idx][:bs]
        else: # random interventions
            while not do:
                for k in dag_vars:
                    if torch.rand(1) > 0.5:  # coin flip to intervene on pa_k
                        idx = torch.randperm(train_set[k].shape[0])
                        do[k] = train_set[k].clone()[idx][:bs]
        do = preprocess(do)
        # infer counterfactual parents
        cf_pa = pgm.counterfactual(obs=pa, intervention=do, num_particles=1)
        _pa = vae_preprocess({k: v.clone() for k, v in pa.items()})
        _cf_pa = vae_preprocess({k: v.clone() for k, v in cf_pa.items()})
        # abduct exogenous noise z
        t_z = t_u = 0.1  # sampling temp
        z = vae.abduct(batch['x'], parents=_pa, t=t_z)
        if vae.cond_prior:
            z = [z[i]['z'] for i in range(len(z))]
        # forward vae with observed parents
        rec_loc, rec_scale = vae.forward_latents(z, parents=_pa)
        # abduct exogenous noise u
        u = (batch['x'] - rec_loc) / rec_scale.clamp(min=1e-12)
        if vae.cond_prior and te_cf:  # g(z*, pa*)
            # infer counterfactual mediator z*
            cf_z = vae.abduct(x=batch['x'], parents=_pa, cf_parents=_cf_pa, alpha=0.65)
            cf_loc, cf_scale = vae.forward_latents(cf_z, parents=_cf_pa)
        else:  # g(z, pa*)
            cf_loc, cf_scale = vae.forward_latents(z, parents=_cf_pa)
        cf_scale = cf_scale * t_u
        cfs = {'x':  torch.clamp(cf_loc + cf_scale * u, min=-1, max=1)}
        cfs.update(cf_pa)
        x_counterfactuals.extend(cfs['x'])
        # predict labels of inferred counterfactuals
        preds_cf = predictor.predict(**cfs)
        for k, v in preds_cf.items():
            preds[k].extend(v)
        # targets are the interventions and/or counterfactual parents
        for k in targets.keys():
            t_k = do[k].clone() if k in do.keys() else cfs[k].clone()
            targets[k].extend(t_k)
    for k, v in targets.items():
        targets[k] = torch.stack(v).squeeze().cpu()
        preds[k] = torch.stack(preds[k]).squeeze().cpu()
    x_counterfactuals = torch.stack(x_counterfactuals).cpu()
    return targets, preds, x_counterfactuals


def eval_cf_loop(
    vae: nn.Module,
    pgm: nn.Module,
    predictor: nn.Module,
    dataloaders: Dict[str, DataLoader],
    file: IO[str],
    total_effect: bool = False,
    seeds: List[int] = [0, 1, 2],
):
    for do_pa in ['thickness', 'intensity', 'digit', None]:  # "None" is for random interventions
        acc_runs = []
        mae_runs = {
            'thickness': {'predicted': [], 'measured': []},
            'intensity': {'predicted': [], 'measured': []}
        }

        for seed in seeds:
            print(f'do({(do_pa if do_pa is not None else "random")}), seed {seed}:')
            assert vae.cond_prior if total_effect else True
            targets, preds, x_cfs = cf_epoch(vae, pgm, predictor, dataloaders, do_pa, total_effect)
            acc = (targets['digit'].argmax(-1).numpy() == preds['digit'].argmax(-1).numpy()).mean()
            print(f'predicted digit acc:', acc)
            # evaluate inferred cfs using true causal mechanisms
            measured = {}
            measured['intensity'] = torch.tensor(get_intensity((x_cfs + 1.0) * 127.5))
            with multiprocessing.Pool() as pool:
                measured['thickness'] = torch.tensor(get_thickness((x_cfs + 1.0) * 127.5, pool=pool, chunksize=250))

            mae = {'thickness': {}, 'intensity': {}}
            for k in ['thickness', 'intensity']:
                min_max = dataloaders['train'].dataset.min_max[k]
                _min, _max = min_max[0], min_max[1]
                preds_k = ((preds[k] + 1) / 2) * (_max - _min) + _min
                targets_k = ((targets[k] + 1) / 2) * (_max - _min) + _min
                mae[k]['predicted'] = (targets_k - preds_k).abs().mean().item()
                mae[k]['measured'] = (targets_k - measured[k]).abs().mean().item()
                print(f'predicted {k} mae:', mae[k]['predicted'])
                print(f'measured {k} mae:', mae[k]['measured'])

            acc_runs.append(acc)
            for k in ['thickness', 'intensity']:
                mae_runs[k]['predicted'].append(mae[k]['predicted'])
                mae_runs[k]['measured'].append(mae[k]['measured'])

            file.write(
                f'\ndo({(do_pa if do_pa is not None else "random")}) | digit acc: {acc}, ' +
                f'thickness mae (predicted): {mae["thickness"]["predicted"]}, ' +
                f'thickness mae (measured): {mae["thickness"]["measured"]}, ' +
                f'intensity mae (predicted): {mae["intensity"]["predicted"]}, ' +
                f'intensity mae (measured): {mae["intensity"]["measured"]} | seed {seed}'
            )
            file.flush()
            gc.collect()

        v = 'Total effect: '+ str(total_effect)
        file.write(
            f'\n{(v if vae.cond_prior else "")}\n' +
            f'digit acc | mean: {np.array(acc_runs).mean()} - std: {np.array(acc_runs).std()}\n' +
            f'thickness mae (predicted) | mean: {np.array(mae_runs["thickness"]["predicted"]).mean()} - std: {np.array(mae_runs["thickness"]["predicted"]).std()}\n' +
            f'thickness mae (measured) | mean: {np.array(mae_runs["thickness"]["measured"]).mean()} - std: {np.array(mae_runs["thickness"]["measured"]).std()}\n' +
            f'intensity mae (predicted) | mean: {np.array(mae_runs["intensity"]["predicted"]).mean()} - std: {np.array(mae_runs["intensity"]["predicted"]).std()}\n' +
            f'intensity mae (measured) | mean: {np.array(mae_runs["intensity"]["measured"]).mean()} - std: {np.array(mae_runs["intensity"]["measured"]).std()}\n'
        )
        file.flush()
    return

for model_name in [
'add your model name(s) here'
]:
    file = open(f'./eval_{model_name}.txt', 'a')
    vae_path = '../checkpoints/'+model_name+'/checkpoint.pt'
    vae, vae_args = load_vae(vae_path)
    assert pgm_args.dataset == 'morphomnist'
    pgm_args.data_dir = 'your dataset dir here'
    pgm_args.bs = 32
    dataloaders = setup_dataloaders(pgm_args)
    eval_cf_loop(vae, pgm, predictor, dataloaders, file)
    file.close()