In [None]:
import sys
import os
current_dir = os.path.dirname(os.path.abspath("."))
parent_dir = os.path.abspath(os.path.join(current_dir, os.pardir))
if parent_dir not in sys.path:
    sys.path.append(parent_dir)
import torch 
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch.nn.functional as F
from einops import repeat
from tqdm import tqdm
import pickle
import seaborn as sns
from PIL import Image
from torchvision import transforms


def load_latents(basepath, model_name, n_splits=1):
    """load .pt latent for <model_name>. Assumes only n_splits .pt for each model"""
    path = os.path.join(basepath, model_name)
    files = os.listdir(path)
    pts = [x for x in files if x.endswith(".pt") ]
    csv = [x for x in files if x.endswith(".csv")]
    assert len(pts) == n_splits and len(csv) == n_splits, f"Unexpected number of csv/pts found in {path}"

    feature_path = os.path.join(path, pts[0])
    print(f"Loading features in {feature_path}")
    features = torch.load(feature_path)
    paths = pd.read_csv(os.path.join(path, csv[0]))
    return features, paths


def update_matplotlib_font(fontsize=11, fontsize_ticks=8, tex=True, scale=1):
    import matplotlib.pyplot as plt
    fontsize = scale * fontsize
    fontsize_ticks = scale * fontsize_ticks
    tex_fonts = {
        # Use LaTeX to write all text
        "text.usetex": tex,
        "font.family": "serif",
        # Use 11pt font in plots, to match 11pt font in document
        "axes.labelsize": fontsize,
        "font.size": fontsize,
        # Make the legend/label fonts a little smaller
        "legend.fontsize": fontsize_ticks,
        "xtick.labelsize": fontsize_ticks,
        "ytick.labelsize": fontsize_ticks
    }
    plt.rcParams.update(tex_fonts)

ds_to_viz = {
    "mimic": "MIMIC-CXR",
    "chexpert": "CheXpert",
    "cxr8": "ChestX-ray8"
}


update_matplotlib_font()


In [None]:
from pathlib import Path

# Add the directory containing edm2/generate.py to the Python path
script_dir = Path("/vol/ideadata/ed52egek/pycharm/trichotomy/edm2").resolve()  # Replace with the actual path
sys.path.append(str(script_dir))

# Now you can import functions from generate.py
from generate_images import edm_sampler, StackedRandomGenerator
from training.dataset import LatentDataset

import pickle
import dnnlib
import torch
import os
import tqdm
import numpy as np
import PIL.Image
import sys
from torchvision.transforms import functional as F
import torch

class ToTensorIfNotTensor:
    def __call__(self, input):
        if isinstance(input, torch.Tensor):
            return input
        return F.to_tensor(input)


def get_classification_model(model_path): 
    global class_labels
    import os
    import numpy as np

    import torch
    import torch.nn as nn
    import torch.backends.cudnn as cudnn

    import torchvision.transforms as T 
    import torch.nn.functional as F
    from torch.utils.data import DataLoader
    from sklearn.metrics import roc_auc_score

    import torchvision

    class DenseNet121(nn.Module):

        def __init__(self, classCount, isTrained):
        
            super(DenseNet121, self).__init__()
            
            self.densenet121 = torchvision.models.densenet121(pretrained=isTrained)

            kernelCount = self.densenet121.classifier.in_features
            
            self.densenet121.classifier = nn.Sequential(nn.Linear(kernelCount, classCount), nn.Sigmoid())

        def forward(self, x):
            x = self.densenet121(x)
            return x

    cudnn.benchmark = True
    
    #-------------------- SETTINGS: NETWORK ARCHITECTURE, MODEL LOAD
    model = DenseNet121(len(class_labels), True).cuda()
    model = model.cuda() 

    modelCheckpoint = torch.load(model_path)
    state_dict = {k[7:]:v for k, v in modelCheckpoint['state_dict'].items()}
    model.load_state_dict(state_dict)


    class Classifier(nn.Module): 
        def __init__(self, model, transforms="default") -> None:
            super().__init__()
            if transforms == "default": 
                normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

                transformList = []
                #transformList.append(T.Resize(256)) -- forward pass during inference uses tencrop 
                transformList.append(T.Resize(226))
                transformList.append(T.CenterCrop(226))
                transformList.append(ToTensorIfNotTensor())
                transformList.append(normalize)
                self.transforms=T.Compose(transformList)
            else: 
                self.transforms = transforms
            self.model = model

        def forward(self, x): 
            x_in = self.transforms(x)
            return self.model(x_in)
        
        def lazy_foward(self, x): 
            # accepts tensor, 0-1, bchw 
            self.model.eval()
            self.model.to("cuda")
            
            with torch.no_grad():
                x_in = self.transforms(x)
                if x_in.dim() == 3: 
                    x_in = x_in.unsqueeze(dim=0)
                
                varInput = x_in.cuda()

                features = self.model.densenet121.features(varInput)
                out = F.relu(features, inplace=True)
                out = F.adaptive_avg_pool2d(out, (1, 1))
                hidden_features = torch.flatten(out, 1)
                out = self.model.densenet121.classifier(hidden_features)
                #outMean = out.view(bs, ).mean(1)
            return out.data, hidden_features.data

    return Classifier(model)


def get_privacy_model(path="/vol/ideadata/ed52egek/pycharm/trichotomy/privacy/archive/Siamese_ResNet50_allcxr/Siamese_ResNet50_allcxr_checkpoint.pth"): 
    import torch
    import sys
    from pathlib import Path

    # Add the directory containing edm2/generate.py to the Python path
    script_dir = Path("/vol/ideadata/ed52egek/pycharm/trichotomy/privacy").resolve()  # Replace with the actual path
    sys.path.append(str(script_dir))

    from networks.SiameseNetwork import SiameseNetwork

    net = SiameseNetwork()
    net.load_state_dict(torch.load(path)["state_dict"])

    return net


def get_image_generation_model(path_net, path_gnet, model_weights, gmodel_weights, name, device=None): 
    if device is None: 
        device = "cuda"

    encoder_batch_size = 4
    max_batch_size = 32
        # Rank 0 goes first.
    net = path_net
    gnet = path_gnet

    # Load main network.
    if isinstance(net, str):
        print(f'Loading network from {net} ...')
        with dnnlib.util.open_url(net, verbose=True) as f:
            data = pickle.load(f)
        net = data['ema'].to(device)
        net.load_state_dict(torch.load(model_weights)["net"])
        
        encoder = data.get('encoder', None)
        encoder_mode = encoder.init_kwargs.encoder_norm_mode
        encoder = dnnlib.util.construct_class_by_name(class_name='training.encoders.StabilityVAEEncoder', vae_name=encoder.init_kwargs.vae_name, encoder_norm_mode=encoder_mode)
        print(f"Encoder was initilized with {encoder._init_kwargs}")

    assert net is not None

    # Load guidance network.
    if isinstance(gnet, str):
        print(f'Loading guidance network from {gnet} ...')
        with dnnlib.util.open_url(gnet, verbose=True) as f:
            data = pickle.load(f)
        gnet = data['ema'].to(device)
        gnet.load_state_dict(torch.load(gmodel_weights)["net"])

    assert gnet is not None

    # Initialize encoder.
    assert encoder is not None
    print(f'Setting up {type(encoder).__name__}...')
    encoder.init(device)
    if encoder_batch_size is not None and hasattr(encoder, 'batch_size'):
        encoder.batch_size = encoder_batch_size

    return net, gnet, encoder


def get_ds_and_indices(filelist="", basedir="", cond_mode="", class_idx=None, N=100, n_per_index=1, one_per_subject=True, **kwargs): 
    # given a class index, basedir and potential moultiple n

    train_ds = LatentDataset(filelist_txt=filelist, basedir=basedir, cond_mode=cond_mode, load_to_memory=False)

    if class_idx is None: 
        indices = []
        i = 0
        last_subject_id = -1
        while len(indices)  < N* n_per_index:
            subject_id = int(train_ds.file_list[i].split("/")[-1].split("_")[0])
            if subject_id == last_subject_id and one_per_subject: 
                i+=1
                continue
            else: 
                last_subject_id = subject_id
                indices.extend([i,]*n_per_index)

    else: 
    #if n_per_index != 1 and cond_mode=="pseudo_cond": 

    #    print("Generating multiple images with the same class label using n_per_index is the same as just generating more images for the same class")

    #    indices = torch.cat([torch.tensor([n,]*n_per_index) for n in range(N)]) 
    #else: 
        indices = []
        i = 0
        last_subject_id = -1
        while len(indices)  < N* n_per_index:
            if train_ds.label_list[i] == class_idx: 
                subject_id = int(train_ds.file_list[i].split("/")[-1].split("_")[0])
                if subject_id == last_subject_id and one_per_subject: 
                    i+=1
                    continue
                else: 
                    last_subject_id = subject_id
                    indices.extend([i,]*n_per_index)
            i+=1
    #indices = torch.cat(indices)

    return train_ds, indices 


class ImageIterable:
    def __init__(self, 
                 train_ds, 
                 device, 
                 net, 
                 sampler_fn, 
                 gnet, 
                 encoder, 
                 outdir=None, 
                 verbose=False, 
                 sampler_kwargs={},
                 indices=[],
                 max_batch_size=32, 
                 add_seed_to_path=True):
        self.train_ds = train_ds
        self.device = device
        self.net = net
        self.sampler_fn = sampler_fn
        self.gnet = gnet
        self.encoder = encoder
        self.outdir = outdir
        self.verbose = verbose
        self.max_batch_size = max_batch_size
        self.sampler_kwargs = sampler_kwargs

        # Prepare seeds and batches
        self.num_batches = max((len(indices) - 1) // max_batch_size + 1, 1)
        self.rank_batches = np.array_split( np.arange(len(indices)), self.num_batches)
        self.indices = np.array_split(np.array(indices), self.num_batches)
        self.add_seed_to_path = add_seed_to_path

        if verbose:
            print(f'Generating {len(self.seeds)} images...')

    def __len__(self):
        return len(self.rank_batches)

    def __iter__(self):
        for batch_idx in range(len(self.rank_batches)):
            indices = self.indices[batch_idx]
            r = dnnlib.EasyDict(images=None, labels=None, noise=None, 
                                batch_idx=batch_idx, num_batches=len(self.rank_batches), 
                                indices=indices, paths=None)
            r.seeds =  self.rank_batches[batch_idx] 
            if len(r.seeds) > 0:
                # Generate noise and labels
                rnd = StackedRandomGenerator(self.device, r.seeds)
                r.noise = rnd.randn([len(r.seeds), self.net.img_channels, self.net.img_resolution, self.net.img_resolution], device=self.device)
                r.labels = torch.stack([self.train_ds.get_label(x) for x in r.indices]).to(self.device)
                r.paths = [self.train_ds.file_list[x] for x in r.indices]

                # Generate images
                latents = dnnlib.util.call_func_by_name(func_name=self.sampler_fn, net=self.net, noise=r.noise,
                                                        labels=r.labels, gnet=self.gnet, randn_like=rnd.randn_like, **self.sampler_kwargs)
                r.images = self.encoder.decode(latents)

                # Save images
                if self.outdir is not None:
                    for path, image, seed in zip(r.paths, r.images.permute(0, 2, 3, 1).cpu().numpy(), r.seeds):
                        file_name = "".join(path.split(".")[:-1]) 
                        if self.add_seed_to_path: 
                            file_name += f"_seed_{seed}.png"
                        else: 
                            file_name += ".png"
                        image_pth = os.path.join(self.outdir, file_name)

                        os.makedirs(os.path.dirname(image_pth), exist_ok=True)
                        PIL.Image.fromarray(image, 'RGB').save(image_pth)

            # Yield results
            yield r

In [None]:
import os
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T

class ImageDatasetReal(Dataset):
    def __init__(self, root_dir, real_files, transform=None):
        """
        Args:
            root_dir (str): Root directory containing the image folders.
            transform (callable, optional): Transform to be applied on an image.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.image_list = real_files

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        image_info = self.image_list[idx]

        full_path = os.path.join(self.root_dir, image_info['full_path'])
        image = Image.open(full_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return {
            'image': image,
            'is_real': True,
            'model_name': image_info['model_name'],
            'class_name': image_info['class_name'],
            'full_path': full_path,
            'real_image_name': image_info['full_path']
        }


class SnthImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (str): Root directory containing the image folders.
            transform (callable, optional): Transform to be applied on an image.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.image_list = []
        self._load_images()

    def _load_images(self):
        # Traverse the directory and collect image information
        for model_name in os.listdir(self.root_dir):
            model_path = os.path.join(self.root_dir, model_name)
            if os.path.isdir(model_path):
                for class_name in os.listdir(model_path):
                    class_path = os.path.join(model_path, class_name)
                    if os.path.isdir(class_path):
                        images_path = os.path.join(class_path, 'images')
                        if os.path.isdir(images_path):
                            for image_name in os.listdir(images_path):
                                if image_name.endswith('.png'):
                                    # Extract real image name
                                    real_image_name = '_'.join(image_name.split('_')[:2]) + '.png'
                                    self.image_list.append({
                                        'model_name': model_name,
                                        'class_name': class_name,
                                        'real_image_name': real_image_name,
                                        'full_path': os.path.join(images_path, image_name)
                                    })

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        image_info = self.image_list[idx]
        image = Image.open(image_info['full_path']).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return {
            'image': image,
            'is_real': False,
            'model_name': image_info['model_name'],
            'class_name': image_info['class_name'].replace("_", " "), # No_finding --> No Finding 
            'full_path': image_info['full_path'],
            'real_image_name': image_info['real_image_name']
        }


In [None]:
from einops import repeat
from torch.nn.functional import binary_cross_entropy
DEFAULT_CLF_PATH =  "/vol/ideadata/ed52egek/pycharm/trichotomy/importantmodels/results_chexnet_real/saved_models_cxr8/m-05122024-131940.pth.tar"
DEFAULT_PRIV_PATH = "/vol/ideadata/ed52egek/pycharm/trichotomy/privacy/archive/Siamese_ResNet50_allcxr/Siamese_ResNet50_allcxr_checkpoint.pth"


class DiADMSampleEvaluator(): 
    def __init__(self, device, clf_path=DEFAULT_CLF_PATH, priv_path=DEFAULT_PRIV_PATH) -> None:
        self.privnet = get_privacy_model(path=priv_path) if priv_path is not None else get_privacy_model()
        self.privnet = self.privnet.to(device)

        self.clf_model = get_classification_model(clf_path)
        self.clf_model = self.clf_model.to(device)

    def predict(self, batch): 
        # 0 - 1, size does not matter
        # batch[0] is real image, 
        # batch[1:] are synthetic images

        pred, f_clf = self.clf_model.lazy_foward(batch)
        clf_pred_scores = binary_cross_entropy(repeat(pred[0], "f -> b f", b=len(pred[1:])), pred[1:], reduction='none')
        clf_pred_scores = clf_pred_scores.mean(dim=1)

        real = repeat(batch[0], "c h w -> b c h w", b=len(batch[1:]))
        snth = batch[1:]

        priv_pred = self.privnet.lazy_pred(real, snth)
        return clf_pred_scores, priv_pred.squeeze()

In [None]:
class_labels = ["No Finding", "Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Pleural Effusion", "Pneumonia", "Pneumothorax"] 