# Evaluate Trichotomy

## Overview 

Notebook that provides code to evaluate trained diffusion models. Needs image generation model, privacy model, cxr-classification model. 

In [16]:
%run ../basesetup.ipynb 
# load basic functions such as generative model initializer, classification model, privacy mdoel

In [21]:
from pprint import pprint
kwargs = {
    "EDM-2-AG":{
        "autoguidance":True,
        "guidance":1.2,
        "model_kwargs":{
            "model_weights":"/vol/ideadata/ed52egek/pycharm/trichotomy/importantmodels/cxr8_diffusionmodels/baseline-runs/cxr8_cond/training-state-0083886.pt",
            "gmodel_weights":"/vol/ideadata/ed52egek/pycharm/trichotomy/importantmodels/cxr8_diffusionmodels/baseline-runs/cxr8_cond/training-state-0008388.pt",
            "path_net":"/vol/ideadata/ed52egek/pycharm/trichotomy/importantmodels/cxr8_diffusionmodels/baseline-runs/cxr8_cond/network-snapshot-0083886-0.100.pkl",
            "path_gnet":"/vol/ideadata/ed52egek/pycharm/trichotomy/importantmodels/cxr8_diffusionmodels/baseline-runs/cxr8_cond/network-snapshot-0008388-0.050.pkl",
        },
        "ds_kwargs":{
            "cond_mode":"cond", # pseudocond, cond
            "basedir":"/vol/idea_ramses/ed52egek/data/trichotomy",
            "basedir_images":"/vol/ideadata/ed52egek/data/chestxray14"
        }
    },
    "EDM-2":{
        "autoguidance":True,
        "guidance":1.4,
        "model_kwargs":{
            "model_weights":"/vol/ideadata/ed52egek/pycharm/trichotomy/importantmodels/cxr8_diffusionmodels/baseline-runs/cxr8_cond/training-state-0050331.pt",
            "gmodel_weights":"/vol/ideadata/ed52egek/pycharm/trichotomy/importantmodels/cxr8_diffusionmodels/baseline-runs/cxr8_uncond/training-state-0008388.pt",
            "path_net":"/vol/ideadata/ed52egek/pycharm/trichotomy/importantmodels/cxr8_diffusionmodels/baseline-runs/cxr8_cond/network-snapshot-0050331-0.100.pkl",
            "path_gnet":"/vol/ideadata/ed52egek/pycharm/trichotomy/importantmodels/cxr8_diffusionmodels/baseline-runs/cxr8_uncond/network-snapshot-0008388-0.050.pkl",
        },
        "ds_kwargs":{
            "cond_mode":"cond", # pseudocond, cond
            "basedir":"/vol/idea_ramses/ed52egek/data/trichotomy",
            "basedir_images":"/vol/ideadata/ed52egek/data/chestxray14"
        }
    },
    "DiADM":{
        "autoguidance":True,
        "guidance":1.4,
        "model_kwargs":{
            "model_weights":"/vol/ideadata/ed52egek/pycharm/trichotomy/importantmodels/cxr8_diffusionmodels/baseline-runs/cxr8_pseudocond/training-state-0050331.pt",
            "gmodel_weights":"/vol/ideadata/ed52egek/pycharm/trichotomy/importantmodels/cxr8_diffusionmodels/baseline-runs/cxr8_uncond/training-state-0008388.pt",
            "path_net":"/vol/ideadata/ed52egek/pycharm/trichotomy/importantmodels/cxr8_diffusionmodels/baseline-runs/cxr8_pseudocond/network-snapshot-0050331-0.100.pkl",
            "path_gnet":"/vol/ideadata/ed52egek/pycharm/trichotomy/importantmodels/cxr8_diffusionmodels/baseline-runs/cxr8_uncond/network-snapshot-0008388-0.050.pkl",
        },
        "ds_kwargs":{
            "cond_mode":"pseudocond", # pseudocond, cond
            "basedir":"/vol/idea_ramses/ed52egek/data/trichotomy",
            "basedir_images":"/vol/ideadata/ed52egek/data/chestxray14"
        }
    }
}


print(class_labels)

['No Finding', 'Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Pleural Effusion', 'Pneumonia', 'Pneumothorax']


## Start Sampling

In [22]:
from src.diffusion.generation import get_image_generation_model

filelist = "/vol/ideadata/ed52egek/pycharm/trichotomy/datasets/eight_cxr8_train.txt" 
#changing this only makes sense for pseudocond
n_per_index = 2 #if ds_kwargs["cond_mode"] == "pseudocond" else 1

N=5# 95 # smallest (edema)
class_labels_to_sample = ["No Finding"]


for mode in ["EDM-2-AG"]:#, "EDM-2"]: 
    print(f"Generating images for {mode}")
    #mode = "EDM-2"

    model_kwargs = kwargs[mode]["model_kwargs"]
    ds_kwargs = kwargs[mode]["ds_kwargs"]
    model_kwargs["name"] = mode
    model_kwargs["device"] = "cuda"
    print("="*80)
    print("Model kwargs:")
    pprint(model_kwargs)
    print("Dataset kwargs:")
    pprint(ds_kwargs)

    print("="*80)
    net, gnet, encoder =  get_image_generation_model(**model_kwargs)

    def path_to_img(path): 
        from torchvision.transforms import ToTensor, Resize
        import os
        path_to_img = ToTensor()(Resize(512)(Image.open(os.path.join(ds_kwargs["basedir_images"], path, )).convert('RGB')))
        return path_to_img

    data = {}
    for class_idx, class_label in enumerate(class_labels_to_sample): 
        data[class_label] = {}
        outdir = f"./gen_1p0/{model_kwargs['name']}/{class_label.replace(' ', '_')}"
        print(f"Saving images to {outdir}")

        dataset, indices = get_ds_and_indices(filelist=filelist, class_idx=class_idx, N=N,n_per_index=n_per_index, **ds_kwargs)

        # indices are the indices of the dataset with certain calss 
        sampler_kwargs = {"autoguidance":kwargs[mode]["autoguidance"], 
                        "guidance":kwargs[mode]["guidance"], }

        print("Sampler kwargs")
        pprint(sampler_kwargs)


        image_iter = ImageIterable(train_ds=dataset, indices=indices, device=torch.device("cuda"), net=net, sampler_fn=edm_sampler, gnet=gnet, encoder=encoder,outdir=outdir, sampler_kwargs=sampler_kwargs)

        data[class_label]["real_path"] = []
        data[class_label]["real_img"] = []
        data[class_label]["snth_img"] = []
        data[class_label]["label"] = []

        for r in tqdm.tqdm(image_iter, unit='batch', total=len(image_iter), desc=f"Generating {class_label} images"):
            for i in range(len(r.images)): 
                data[class_label]["real_path"].append(r.paths[i])
                data[class_label]["real_img"].append(path_to_img(r.paths[i]))
                data[class_label]["snth_img"].append(r.images[i] / 255.)
                data[class_label]["label"].append(r.labels[i])

                break
print(indices)

Generating images for EDM-2
Model kwargs:
{'device': 'cuda',
 'gmodel_weights': '/vol/ideadata/ed52egek/pycharm/trichotomy/importantmodels/cxr8_diffusionmodels/baseline-runs/cxr8_uncond/training-state-0008388.pt',
 'model_weights': '/vol/ideadata/ed52egek/pycharm/trichotomy/importantmodels/cxr8_diffusionmodels/baseline-runs/cxr8_cond/training-state-0050331.pt',
 'name': 'EDM-2',
 'path_gnet': '/vol/ideadata/ed52egek/pycharm/trichotomy/importantmodels/cxr8_diffusionmodels/baseline-runs/cxr8_uncond/network-snapshot-0008388-0.050.pkl',
 'path_net': '/vol/ideadata/ed52egek/pycharm/trichotomy/importantmodels/cxr8_diffusionmodels/baseline-runs/cxr8_cond/network-snapshot-0050331-0.100.pkl'}
Dataset kwargs:
{'basedir': '/vol/idea_ramses/ed52egek/data/trichotomy',
 'basedir_images': '/vol/ideadata/ed52egek/data/chestxray14',
 'cond_mode': 'cond'}
Loading network from /vol/ideadata/ed52egek/pycharm/trichotomy/importantmodels/cxr8_diffusionmodels/baseline-runs/cxr8_cond/network-snapshot-0050331-0

Generating No Finding images: 100%|██████████| 1/1 [00:06<00:00,  6.31s/batch]

[2, 2, 3, 3, 10, 10, 12, 12, 21, 21]





In [None]:
data