# Evaluate Trichotomy

## Overview 

Notebook that provides code to evaluate trained diffusion models. Needs image generation model, privacy model, cxr-classification model. 

In [43]:
%run ../basesetup.ipynb 
# load basic functions such as generative model initializer, classification model, privacy mdoel

In [44]:
n_per_index = 4 # == batch_size and factor that sampling takes longer

In [45]:
from pprint import pprint
kwargs = {
    "DiADM":{
        "autoguidance":True,
        "guidance":1.4,
        "model_kwargs":{
            "model_weights":"/vol/ideadata/ed52egek/pycharm/trichotomy/importantmodels/cxr8_diffusionmodels/baseline-runs/cxr8_pseudocond/training-state-0050331.pt",
            "gmodel_weights":"/vol/ideadata/ed52egek/pycharm/trichotomy/importantmodels/cxr8_diffusionmodels/baseline-runs/cxr8_uncond/training-state-0008388.pt",
            "path_net":"/vol/ideadata/ed52egek/pycharm/trichotomy/importantmodels/cxr8_diffusionmodels/baseline-runs/cxr8_pseudocond/network-snapshot-0050331-0.100.pkl",
            "path_gnet":"/vol/ideadata/ed52egek/pycharm/trichotomy/importantmodels/cxr8_diffusionmodels/baseline-runs/cxr8_uncond/network-snapshot-0008388-0.050.pkl",
        },
        "ds_kwargs":{
            "cond_mode":"pseudocond", # pseudocond, cond
            "basedir":"/vol/idea_ramses/ed52egek/data/trichotomy",
            "basedir_images":"/vol/ideadata/ed52egek/data/chestxray14"
        }
    }
}


print(class_labels)

['No Finding', 'Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Pleural Effusion', 'Pneumonia', 'Pneumothorax']


## Sample and evaluate samples at the same time 

In [46]:
from einops import repeat

filelist = "/vol/ideadata/ed52egek/pycharm/trichotomy/datasets/eight_cxr8_train.txt" 
target_dir = "diadm_train_with_dse"
mode = "DiADM"
model_kwargs = kwargs[mode]["model_kwargs"]
ds_kwargs = kwargs[mode]["ds_kwargs"]
model_kwargs["name"] = mode
print("="*80)
print("Model kwargs:")
pprint(model_kwargs)
print("Dataset kwargs:")
pprint(ds_kwargs)
print("="*80)
net, gnet, encoder =  get_image_generation_model(**model_kwargs)

outdir = f"./{target_dir}/"
print(f"Saving images to {outdir}")

# indices are the indices of the dataset with certain calss 
sampler_kwargs = {"autoguidance":kwargs[mode]["autoguidance"], 
                "guidance":kwargs[mode]["guidance"], }

print("Sampler kwargs")
pprint(sampler_kwargs)

train_ds = LatentDataset(filelist_txt=filelist, basedir=ds_kwargs["basedir"], cond_mode=ds_kwargs["cond_mode"], load_to_memory=False)



Model kwargs:
{'gmodel_weights': '/vol/ideadata/ed52egek/pycharm/trichotomy/importantmodels/cxr8_diffusionmodels/baseline-runs/cxr8_uncond/training-state-0008388.pt',
 'model_weights': '/vol/ideadata/ed52egek/pycharm/trichotomy/importantmodels/cxr8_diffusionmodels/baseline-runs/cxr8_pseudocond/training-state-0050331.pt',
 'name': 'DiADM',
 'path_gnet': '/vol/ideadata/ed52egek/pycharm/trichotomy/importantmodels/cxr8_diffusionmodels/baseline-runs/cxr8_uncond/network-snapshot-0008388-0.050.pkl',
 'path_net': '/vol/ideadata/ed52egek/pycharm/trichotomy/importantmodels/cxr8_diffusionmodels/baseline-runs/cxr8_pseudocond/network-snapshot-0050331-0.100.pkl'}
Dataset kwargs:
{'basedir': '/vol/idea_ramses/ed52egek/data/trichotomy',
 'basedir_images': '/vol/ideadata/ed52egek/data/chestxray14',
 'cond_mode': 'pseudocond'}
Loading network from /vol/ideadata/ed52egek/pycharm/trichotomy/importantmodels/cxr8_diffusionmodels/baseline-runs/cxr8_pseudocond/network-snapshot-0050331-0.100.pkl ...


  net.load_state_dict(torch.load(model_weights)["net"])


Encoder was initilized with {'vae_name': 'stabilityai/stable-diffusion-2', 'encoder_norm_mode': 'cxr8'}
Loading guidance network from /vol/ideadata/ed52egek/pycharm/trichotomy/importantmodels/cxr8_diffusionmodels/baseline-runs/cxr8_uncond/network-snapshot-0008388-0.050.pkl ...


  gnet.load_state_dict(torch.load(gmodel_weights)["net"])


Setting up StabilityVAEEncoder...
Saving images to ./diadm_train_with_dse/
Sampler kwargs
{'autoguidance': True, 'guidance': 1.4}


In [47]:
# prepare indices for sampling
indices = torch.arange(len(dataset))
indices = repeat(indices, "l -> b l", b=n_per_index)
indices = indices.transpose(0, 1).flatten()


In [48]:
class ImageIterable:
    def __init__(self, 
                 train_ds, 
                 device, 
                 net, 
                 sampler_fn, 
                 gnet, 
                 encoder, 
                 outdir=None, 
                 verbose=False, 
                 sampler_kwargs={},
                 indices=[],
                 max_batch_size=32, 
                 add_seed_to_path=True, 
                 dse=None):
        self.train_ds = train_ds
        self.device = device
        self.net = net
        self.sampler_fn = sampler_fn
        self.gnet = gnet
        self.encoder = encoder
        self.outdir = outdir
        self.verbose = verbose
        self.max_batch_size = max_batch_size
        self.sampler_kwargs = sampler_kwargs
        self.guidance_strength = self.sampler_kwargs["guidance"]

        # Prepare seeds and batches
        self.num_batches = max((len(indices) - 1) // max_batch_size + 1, 1)
        self.rank_batches = np.array_split( np.arange(len(indices)), self.num_batches)
        self.indices = np.array_split(np.array(indices), self.num_batches)
        self.add_seed_to_path = add_seed_to_path

        self.dse = dse

        if verbose:
            print(f'Generating {len(self.seeds)} images...')

    def __len__(self):
        return len(self.rank_batches)

    def __iter__(self):

        for batch_idx in range(len(self.rank_batches)):
            # one batch only consists of one single image!

            image_generated = False
            guidance = self.guidance_strength
            indices = self.indices[batch_idx]
            r = dnnlib.EasyDict(images=None, labels=None, noise=None, 
                                batch_idx=batch_idx, num_batches=len(self.rank_batches), 
                                indices=indices, paths=None)
            r.seeds =  self.rank_batches[batch_idx] 



            while not image_generated: 
                if len(r.seeds) > 0:
                    while not image_generated:
                        # Generate noise and labels
                        rnd = StackedRandomGenerator(self.device, r.seeds)
                        r.noise = rnd.randn([len(r.seeds), self.net.img_channels, self.net.img_resolution, self.net.img_resolution], device=self.device)
                        r.labels = torch.stack([self.train_ds.get_label(x) for x in r.indices]).to(self.device)
                        r.paths = [self.train_ds.file_list[x] for x in r.indices]

                        # Generate images
                        latents = dnnlib.util.call_func_by_name(func_name=self.sampler_fn, net=self.net, noise=r.noise,
                                                                labels=r.labels, gnet=self.gnet, randn_like=rnd.randn_like, **self.sampler_kwargs)
                        r.images = self.encoder.decode(latents)
                        r.images = r.images.float() / 255.

                        clf_pred_scores, priv_pred = self.dse.predict(r.images)
                        if priv_pred.min() < 1: 
                            image_generated = True
                            idx = clf_pred_scores.argmin()
                            image = (r.images[idx] * 255).to(torch.uint8)
                            image = image.permute(1, 2, 0).cpu().numpy()
                            path_real = r.paths[idx]
                            image_pth = os.path.join(self.outdir, path_real)

                            os.makedirs(os.path.dirname(image_pth), exist_ok=True)
                            PIL.Image.fromarray(image, 'RGB').save(image_pth)

                        if not image_generated:
                            r.seeds = r.seeds + self.max_batch_size
                            print(f"only memorized for guidance: {guidance} and path: {r.paths[0]}")
                            guidance = guidance - 0.1
                
            # Yield results
            yield r

device = torch.device("cuda")
dse = DiADMSampleEvaluator(device)

image_iter = ImageIterable(train_ds=train_ds, 
                           indices=indices, 
                           device=device, 
                           net=net, 
                           sampler_fn=edm_sampler, 
                           gnet=gnet, 
                           encoder=encoder,
                           outdir=outdir, 
                           max_batch_size=n_per_index,
                           dse=dse,
                           sampler_kwargs=sampler_kwargs)

for r in tqdm.tqdm(image_iter, unit='batch', total=len(image_iter), desc=f"Generating images"):
    pass
    


  net.load_state_dict(torch.load(path)["state_dict"])
  modelCheckpoint = torch.load(model_path)
  self.indices = np.array_split(np.array(indices), self.num_batches)
Generating images:   0%|          | 9/67309 [00:42<87:02:43,  4.66s/batch]

only memorized for guidance: 1.4 and path: images/00000005_007.png


Generating images:   0%|          | 11/67309 [00:57<109:13:25,  5.84s/batch]

only memorized for guidance: 1.4 and path: images/00000008_000.png
only memorized for guidance: 1.2999999999999998 and path: images/00000008_000.png


Generating images:   0%|          | 12/67309 [01:10<153:45:10,  8.22s/batch]

only memorized for guidance: 1.4 and path: images/00000008_001.png


Generating images:   0%|          | 37/67309 [03:22<102:14:35,  5.47s/batch]


KeyboardInterrupt: 

In [None]:
r.images

NameError: name 'r' is not defined