In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader,Subset
import Models   
from torchvision.models import resnet50,ResNet50_Weights
from torchvision import transforms
from torchvision.datasets import ImageFolder
import os 
import PIL
from scipy.linalg import sqrtm
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from prdc import compute_prdc
import random
import umap 
from plotnine import ggplot,aes,geom_point,scale_color_manual
from plotnine.labels import xlab,ylab
import pandas as pd
from scipy.stats import sem,tmean
from PIL import ImageEnhance


Define The dataset and Dataloaders from the subset of the real and Fake data

In [None]:
current_directory = os.getcwd()
data_dir = os.path.abspath(os.path.join(current_directory, '..', 'real_vs_fake'))
dataset = ImageFolder(root=data_dir,transform=transforms.ToTensor())
dataloader = DataLoader(dataset=dataset,num_workers=8,shuffle=False,batch_size=16)
labels = np.array(dataset.targets)
num_classes = len(set(labels))
idx_change = np.where(labels[:-1] != labels[1:])[0]
idx_fake = np.where(labels == dataset.class_to_idx['fake'])[0]
idx_real = np.where(labels == dataset.class_to_idx['real'])[0]
dataset_real = Subset(dataset,idx_real)
dataset_fake = Subset(dataset,idx_fake)
transform_fake = transforms.Compose([
            transforms.Lambda(lambda img: ImageEnhance.Sharpness(img).enhance(3)),
            transforms.ToTensor(), 
            ])
dataset_fake.transforms = transform_fake
dataset_fake.transform = transform_fake
dataloader_real = DataLoader(dataset=dataset_real,num_workers=8,shuffle=False,batch_size=16)
dataloader_fake = DataLoader(dataset=dataset_fake,num_workers=8,shuffle=False,batch_size=16)



Calculate The manifold similarity measures

In [None]:
metrics_all = {'precision': [], 'recall': [], 'density': [], 'coverage': []}

for i in range(5):
    #model = resnet50(weights= "IMAGENET1K_V2") for the case where u want to use the pretrained model
    model = resnet50(weights= None)
    linear_size = list(model.children())[-1].in_features
    model.fc  = nn.Linear(linear_size, 100)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()

    embeddings_real = []
    embeddings_random = []
    for images,_ in tqdm(dataloader_real,desc= 'real_data'):
        with torch.no_grad():
            images_random = torch.rand(images.shape)
            images_random = images_random.to(device)
            images = images.to(device)
            features = model(images)
            features_random = model(images_random)
        embeddings_real.append(features.cpu())
        embeddings_random.append(features_random.cpu())
    embeddings_real =  torch.cat(embeddings_real,dim=0).numpy()
    embeddings_random = torch.cat(embeddings_random,dim=0).numpy()

    embeddings_fake = []
    for images,_ in tqdm(dataloader_fake,desc= 'fake_data'):
        with torch.no_grad():
            images = images.to(device)
            features = model(images)
        embeddings_fake.append(features.cpu())
    embeddings_fake =  torch.cat(embeddings_fake,dim=0).numpy()

    metrics = compute_prdc(real_features=embeddings_real,
                        fake_features=embeddings_fake,
                        nearest_k=10)
    for key in metrics:
        metrics_all[key].append(metrics[key])
print(metrics_all)
for key in metrics_all:
    avg_value = tmean(metrics_all[key])
    se = sem(metrics_all[key])
    print('Average ' + key +f': {avg_value}' + '+/- ' + f'{se}')

Visualizing an example manifold in 2D space 

In [None]:
embeddings = np.concatenate((embeddings_fake,embeddings_real,embeddings_random),axis=0)
shp = np.shape(embeddings_random)[0]
labels_all = np.concatenate((labels,2*np.ones(shp)),axis=0)
                            
reducer = umap.UMAP()
umap_embeddings = reducer.fit_transform(embeddings)
df = pd.DataFrame(umap_embeddings,columns=["x","y"])
pd_labels = pd.DataFrame(labels_all,columns=["class"])
mapping = ["fake","real","random"]

pal = ["#FF0000",
        "#0000FF",
        "#00FF00"
        ]
    
color_key = {str(d): c for d, c in enumerate(pal)}

df["id"] = labels_all

g = ggplot(df,aes(x="x",y="y",color="factor(id)")) +geom_point(alpha=0.5,size=1.6) + scale_color_manual(name = "Tissue - Origin",values = pal,labels=mapping)+xlab("UMAP1")+ylab("UMAP2") 
g.save(filename = './random_embedding_enh3.png', height=15, width=15, units = 'in', dpi=500) 
np.save('embedding.npy',embeddings)

Calculate the FID score between all the real data and fake data

In [None]:
from torchmetrics.image.fid import FrechetInceptionDistance
from PIL import ImageEnhance

gen_data_dir = os.path.abspath(os.path.join(current_directory, '..', 'Generated-data','synthetic_tiles_512TO256_GTEX'))
real_data_dir = os.path.abspath(os.path.join(current_directory, '..', 'Train'))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

transform= transforms.Compose([
            transforms.Resize(256),
            transforms.ToTensor(), 
            ])

transform_fake = transforms.Compose([
            transforms.Lambda(lambda img: ImageEnhance.Sharpness(img).enhance(3)),
            transforms.ToTensor(), 
            ])

dataset_real = ImageFolder(root=real_data_dir,transform=transform)
dataloader_real = DataLoader(dataset=dataset_real,num_workers=8,shuffle=False,batch_size=8)

dataset_gen = ImageFolder(root=gen_data_dir,transform=transform_fake)
dataloader_gen = DataLoader(dataset=dataset_gen,num_workers=8,shuffle=False,batch_size=8)

fid_v3_score = FrechetInceptionDistance(feature=2048,normalize=True).to(device)
for images,_ in tqdm(dataloader_real,desc= 'real_data'):
    fid_v3_score.update(images.to(device), real=True)

for images,_ in tqdm(dataloader_gen,desc= 'fake_data'):
    fid_v3_score.update(images.to(device), real=False)

print(f"FID score with inception3 network: {fid_v3_score.compute()}")


Class based FID 

In [None]:
from torchmetrics.image.fid import FrechetInceptionDistance
gen_data_dir = os.path.abspath(os.path.join(current_directory, '..', 'Generated-data','synthetic_tiles_512TO256_GTEX'))
real_data_dir = os.path.abspath(os.path.join(current_directory, '..', 'Train'))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

transform= transforms.Compose([
            transforms.Resize(256),
            transforms.ToTensor(), 
            ])

transform_fake = transforms.Compose([
            transforms.Lambda(lambda img: ImageEnhance.Sharpness(img).enhance(3)),
            transforms.ToTensor(), 
            ])

dataset_real = ImageFolder(root=real_data_dir,transform=transform)
dataset_gen = ImageFolder(root=gen_data_dir,transform=transform_fake)

tissues = os.listdir(real_data_dir)
fid_class = {}
for tissue in tissues:
    tissue_index_real = dataset_real.class_to_idx[tissue]
    tissue_index_gen = dataset_gen.class_to_idx[tissue]
    idx_real_tissue = [idx for idx in range(len(dataset_real)) if dataset_real.targets[idx] == tissue_index_real]
    idx_gen_tissue = [idx for idx in range(len(dataset_gen)) if dataset_gen.targets[idx] == tissue_index_gen]
    dataset_real_tissue = Subset(dataset_real, idx_real_tissue)
    dataset_gen_tissue = Subset(dataset_gen, idx_gen_tissue)

    dataloader_real = DataLoader(dataset=dataset_real_tissue,num_workers=8,shuffle=False,batch_size=8)
    dataloader_gen = DataLoader(dataset=dataset_gen_tissue,num_workers=8,shuffle=False,batch_size=8)
    fid_v3_score = FrechetInceptionDistance(feature=2048,normalize=True).to(device)
    for images,_ in tqdm(dataloader_real,desc= 'real_data' + '_' + tissue):
        fid_v3_score.update(images.to(device), real=True)

    for images,_ in tqdm(dataloader_gen,desc= 'fake_data' + '_' + tissue):
        fid_v3_score.update(images.to(device), real=False)
    fid_class[tissue] = fid_v3_score.compute()

for key in fid_class:
    print(f"FID score with inception3 network for {key} : {fid_class[key]}")



Calculate the inception score with all the generated data 

In [None]:
from PIL import ImageEnhance
from torchmetrics.image.inception import InceptionScore

gen_data_dir = os.path.abspath(os.path.join(current_directory, '..', 'Generated-data','synthetic_tiles_512TO256_GTEX'))
transform= transforms.Compose([
            transforms.Lambda(lambda img: ImageEnhance.Sharpness(img).enhance(3)),
            transforms.ToTensor(), 
            ])

dataset_gen_2 = ImageFolder(root=gen_data_dir,transform=transform)
dataloader_gen_2 = DataLoader(dataset=dataset_gen_2,num_workers=8,shuffle=False,batch_size=8)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


inception = InceptionScore(normalize=True,splits=5).to(device)

for images,_ in tqdm(dataloader_gen_2,desc= 'fake_data'):
    inception.update(images.to(device))
inception_v3 = inception.compute()
print(f"inception score with inception3 network: {inception_v3}")

Calculate the inception score with the real data

In [None]:
real_data_dir = os.path.abspath(os.path.join(current_directory, '..', 'Train'))
from torchmetrics.image.inception import InceptionScore
transform= transforms.Compose([
            transforms.Resize(256),
            transforms.ToTensor(), 
            ])
dataset_real = ImageFolder(root=real_data_dir,transform=transform)
dataloader_real = DataLoader(dataset=dataset_real,num_workers=8,shuffle=False,batch_size=8)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


inception = InceptionScore(normalize=True,splits=5).to(device)

for images,_ in tqdm(dataloader_real,desc= 'Real_data'):
    inception.update(images.to(device))
inception_v3 = inception.compute()
print(f"inception score with inception3 network: {inception_v3}")