In [None]:
import os
import sys
import torch
import torch.nn as nn
import numpy as np
import json
import random
import torchvision.transforms as transforms 

from diffusers import StableDiffusionPipeline, DiffusionPipeline, UNet2DConditionModel
from transformers import AutoImageProcessor, ViTModel, DeiTModel
from transformers import CLIPConfig, CLIPModel, CLIPTextModel, CLIPProcessor, CLIPFeatureExtractor,CLIPTokenizer
from transformers import tokenization_utils
import csv
import math
from tqdm.auto import tqdm
from PIL import Image
from transformers import AutoProcessor, CLIPVisionModelWithProjection
from scipy.spatial import distance
from numpy.linalg import norm
from sentence_transformers import SentenceTransformer

In [None]:
def load_t2i_model(modelPath=None, BDType=None):
    t2iModel = None
    if BDType == 'bagm':
        t2iModel = StableDiffusionPipeline.from_pretrained(modelPath).to('cuda')
    elif BDType == 'tpa':
        t2iModel = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to('cuda')
        t2iModel.text_encoder=CLIPTextModel.from_pretrained(modelPath).to('cuda')
    elif BDType == 'badt2i':
        t2iModel = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to('cuda')
        t2iModel.unet = UNet2DConditionModel.from_pretrained(modelPath).to('cuda')
    else:
        # This line can be changed to account for other generative models
        t2iModel = StableDiffusionPipeline.from_pretrained(modelPath).to('cuda')
    return t2iModel
def load_vision_transformer(visionmodelPath="openai/clip-vit-base-patch32"):
    Vmodel = None
    Vprocessor = None
    if visionmodelPath == "openai/clip-vit-base-patch32":
        Vprocessor = AutoProcessor.from_pretrained(visionmodelPath)
        Vmodel = CLIPVisionModelWithProjection.from_pretrained(visionmodelPath)
        
    elif visionmodelPath == "google/vit-base-patch16-224-in21k":
        Vprocessor = AutoImageProcessor.from_pretrained(visionmodelPath)
        Vmodel = ViTModel.from_pretrained(visionmodelPath)

    elif visionmodelPath == "facebook/deit-base-distilled-patch16-224":
        Vprocessor = AutoImageProcessor.from_pretrained(visionmodelPath)
        Vmodel = DeiTModel.from_pretrained(visionmodelPath)
        
    return (Vprocessor, Vmodel)
def flatten_list(X):
    return [x for xs in X for x in xs]
def latent_reconstruction(latents, t2iModel, guidance_scale, embeddings):
    for t in tqdm(t2iModel.scheduler.timesteps):
        # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
        latent_model_input = torch.cat([latents] * 2)

        latent_model_input = t2iModel.scheduler.scale_model_input(latent_model_input, timestep=t)

        # predict the noise residual
        with torch.no_grad():
            noise_pred = t2iModel.unet(latent_model_input, t, encoder_hidden_states=embeddings).sample

        # perform guidance
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        # compute the previous noisy sample x_t -> x_t-1
        latents = t2iModel.scheduler.step(noise_pred, t, latents).prev_sample

    # scale and decode the image latents with vae
    latents = 1 / 0.18215 * latents
    
    with torch.no_grad():
        image = t2iModel.vae.decode(latents).sample

    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
    images = (image * 255).round().astype("uint8")

    return images

In [None]:
#diversity-related functions
def calculate_cosine_similarity(a, b):
    a = a[0].detach().numpy()
    b = b[0].detach().numpy()
    cosine = np.dot(a,b)/(norm(a)*norm(b))
    return cosine

def sum_heatmap(heatmap):
    area = 0
    for row in heatmap:
        area = area + sum(row)
    
    area = (area - len(heatmap[0]))/(len(heatmap[0])**2-len(heatmap[0]))
    diversity = 1-area
    return diversity

def create_similarity_heatmap(a, b):
    return np.dot(a,b)/(norm(a)*norm(b))


transform = transforms.Compose([ 
    transforms.PILToTensor() 
])
def export_csv_results(filePath, fileHeader, outputData):
    with open(filePath, 'w') as csvfile:
        w = csv.writer(csvfile)
        if fileHeader is not None:
            w.writerow(fileHeader)
        for row in outputData:
            w.writerow(row)

In [None]:
guidance_scale = 7.5         
num_inference_steps = 50    
height = 768        
width = 768
NUMBER_OF_SEEDS = 100
VITPath ="openai/clip-vit-base-patch32"
sensitiveTokensFile = './sensitive_tokens.csv'
batchName = 'batch=RSXXXX'
MP = "stabilityai/stable-diffusion-2-1"          # t2i model path
RD = 'SD_2.1'                                 # results directory
VITName='clip-vit'  

testCondition = VITName+'/'+batchName+'/diversity/g'+str(guidance_scale)+'/steps='+str(num_inference_steps)
diversityFile = './results/'+RD+ '/'+testCondition+ '/csvResults/diversity.csv'
if not os.path.exists('./results/'+RD+ '/' + testCondition + '/'):
    os.makedirs('./results/'+RD+ '/' + testCondition + '/')
if not os.path.exists('./results/'+RD+ '/' + testCondition + '/csvResults/'):
    os.makedirs('./results/'+RD+ '/' + testCondition + '/csvResults/')
    
inputTokens = []

with open(sensitiveTokensFile, newline='') as csvfile:
    rdr = csv.reader(csvfile, delimiter=',')
    for row in rdr:
        inputTokens.append(row[0])
inputTokens = sorted(list(set(inputTokens)))

t2iModel = load_t2i_model(MP, None)
(VISIONprocessor, VISIONmodel) = load_vision_transformer(VITPath)

diversityOutput = []
for ii,token in enumerate(inputTokens):
    comparisonImages = []
    heatmap = []
    prompt = [token]
    for randomSeed in [random.randint(0, 1e10) for x in range(NUMBER_OF_SEEDS)]:
        batch_size = 1
        text_input = t2iModel.tokenizer(prompt, padding="max_length", max_length=t2iModel.tokenizer.model_max_length, 
                                                truncation=True, return_tensors="pt")
        
        baseEmbeddings = t2iModel.text_encoder(text_input.input_ids.to('cuda'))[0]
        # conditional generation preamble
        max_length = text_input.input_ids.shape[-1]
        uncond_input = t2iModel.tokenizer(
            [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
        uncond_embeddings = t2iModel.text_encoder(uncond_input.input_ids.to('cuda'))[0]   

        generator = torch.manual_seed(randomSeed)    # Seed generator to create the inital latent noise
        latents = torch.randn(
            (batch_size, t2iModel.unet.in_channels, height // 8, width // 8),
            generator=generator,)
        latents = latents.to('cuda')
        t2iModel.scheduler.set_timesteps(num_inference_steps)
        latents = latents * t2iModel.scheduler.init_noise_sigma

        t2iModel.scheduler.set_timesteps(num_inference_steps)

        text_embeddings = torch.cat([uncond_embeddings, baseEmbeddings])
        images = latent_reconstruction(latents, t2iModel, guidance_scale, text_embeddings)
        
        pil_images = [Image.fromarray(image) for image in images]
        if len(comparisonImages) % 1 == 0:           # control how many images are saved
            pil_images[0].save('./results/'+RD + '/'+testCondition +'/'+prompt[0]+'_'+str(randomSeed)+'.png')
            
        baseInput = VISIONprocessor(images=pil_images[0], return_tensors="pt")    
        if VITPath != "openai/clip-vit-base-patch32":
            with torch.no_grad():
                baseOutput = VISIONmodel(**baseInput)
                baseEmb = baseOutput.last_hidden_state[0]
        else:
            baseOutput = VISIONmodel(**baseInput)
            baseEmb = baseOutput.image_embeds
        
        comparisonImages.append(baseEmb)

    # 2D construction of map with Z = similarity or 'heat'
    for ii in range(len(comparisonImages)):
        heatmap.append([])
        for jj in range(len(comparisonImages)):
            Z = calculate_cosine_similarity(comparisonImages[ii],comparisonImages[jj])
            heatmap[-1].append(Z)
            
    # one heatmap for each token
    heatmapFile = './results/'+RD+ '/'+testCondition+ '/csvResults/'+prompt[0]+'_heatmap.csv'
    export_csv_results(heatmapFile, None, heatmap)
    diversityOutput.append([prompt[0], sum_heatmap(heatmap)])
export_csv_results(diversityFile, ['token','diversity'], diversityOutput)
    
