In [1]:
import os
import sys
import torch
import torch.nn as nn
import numpy as np
import json
import random
import torchvision.transforms as transforms 

from diffusers import StableDiffusionPipeline, DiffusionPipeline, UNet2DConditionModel
from transformers import AutoImageProcessor, ViTModel, DeiTModel
from transformers import CLIPConfig, CLIPModel, CLIPTextModel, CLIPProcessor, CLIPFeatureExtractor,CLIPTokenizer
from transformers import tokenization_utils
import csv
import math
from tqdm.auto import tqdm
from PIL import Image
from transformers import AutoProcessor, CLIPVisionModelWithProjection
from scipy.spatial import distance
from numpy.linalg import norm
from sentence_transformers import SentenceTransformer

In [2]:
def load_t2i_model(modelPath=None, BDType=None):
    t2iModel = None
    if BDType == 'bagm':
        t2iModel = StableDiffusionPipeline.from_pretrained(modelPath).to('cuda')
    elif BDType == 'tpa':
        t2iModel = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2").to('cuda')
        t2iModel.text_encoder=CLIPTextModel.from_pretrained(modelPath).to('cuda')
    elif BDType == 'badt2i':
        t2iModel = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to('cuda')
        t2iModel.unet = UNet2DConditionModel.from_pretrained(modelPath).to('cuda')
    else:
        # This line can be changed to account for other generative models
        t2iModel = StableDiffusionPipeline.from_pretrained(modelPath).to('cuda')
    return t2iModel
def load_vision_transformer(visionmodelPath="openai/clip-vit-base-patch32"):
    Vmodel = None
    Vprocessor = None
    if visionmodelPath == "openai/clip-vit-base-patch32":
        Vprocessor = AutoProcessor.from_pretrained(visionmodelPath)
        Vmodel = CLIPVisionModelWithProjection.from_pretrained(visionmodelPath)
        
    elif visionmodelPath == "google/vit-base-patch16-224-in21k":
        Vprocessor = AutoImageProcessor.from_pretrained(visionmodelPath)
        Vmodel = ViTModel.from_pretrained(visionmodelPath)

    elif visionmodelPath == "facebook/deit-base-distilled-patch16-224":
        Vprocessor = AutoImageProcessor.from_pretrained(visionmodelPath)
        Vmodel = DeiTModel.from_pretrained(visionmodelPath)
        
    return (Vprocessor, Vmodel)

def flatten_list(X):
    return [x for xs in X for x in xs]
def latent_reconstruction(latents, t2iModel, guidance_scale, embeddings):
    for t in tqdm(t2iModel.scheduler.timesteps):
        # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
        latent_model_input = torch.cat([latents] * 2)

        latent_model_input = t2iModel.scheduler.scale_model_input(latent_model_input, timestep=t)

        # predict the noise residual
        with torch.no_grad():
            noise_pred = t2iModel.unet(latent_model_input, t, encoder_hidden_states=embeddings).sample

        # perform guidance
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        # compute the previous noisy sample x_t -> x_t-1
        latents = t2iModel.scheduler.step(noise_pred, t, latents).prev_sample

    # scale and decode the image latents with vae
    latents = 1 / 0.18215 * latents
    
    with torch.no_grad():
        image = t2iModel.vae.decode(latents).sample

    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
    images = (image * 255).round().astype("uint8")

    return images

def calculate_cosine_similarity(a, b):
    a = a[0].detach().numpy()
    b = b[0].detach().numpy()
    cosine = np.dot(a,b)/(norm(a)*norm(b))
    return cosine
def export_csv_results(filePath, fileHeader, outputData):
    with open(filePath, 'w') as csvfile:
        w = csv.writer(csvfile)
        if fileHeader is not None:
            w.writerow(fileHeader)
        for row in outputData:
            w.writerow(row)
transform = transforms.Compose([ 
    transforms.PILToTensor() ])

In [4]:
guidance_scale = 0.1            # low guidance or fairness
num_inference_steps = 100       # Number of denoising steps
NUMBER_OF_SEEDS = 100
SOS_TOKEN = 49406
EOS_TOKEN = 49407
width = 512
height = 512
sensitivePromptsFile = './sensitive_prompts.csv'
inputPrompts = []
with open(sensitivePromptsFile, newline='') as csvfile:
    rdr = csv.reader(csvfile, delimiter=',')
    for row in rdr:
        inputPrompts.append(row[0])
inputPrompts = sorted(list(set(inputPrompts)))

batchName = 'batch=RSXXXX'                     # consistent with Rglobal
VITPath = "openai/clip-vit-base-patch32"
VITName='clip-vit'                         
MP = "stabilityai/stable-diffusion-2"          # t2i model path
RD = 'SD_V1.5'                                 # results directory
                         
testCondition = VITName + '/'+batchName+'/fairness/g'+str(guidance_scale)+'/steps='+str(num_inference_steps)
if not os.path.exists('./results/'+RD+ '/' + testCondition + '/'):
    os.makedirs('./results/'+RD+ '/' + testCondition + '/')
if not os.path.exists('./results/'+RD+ '/' + testCondition + '/csvResults/'):
    os.makedirs('./results/'+RD+ '/' + testCondition + '/csvResults/')

t2iModel = load_t2i_model(MP, None)
(VISIONprocessor, VISIONmodel) = load_vision_transformer(VITPath)

leftOutTokens = {}
for randomSeed in [random.randint(0, 1e10) for x in range(NUMBER_OF_SEEDS)]:
    fairnessOutputData = []
    for ii,prompt in enumerate(inputPrompts):
        prompt = [prompt]
        leaveOut = []
        batch_size = 1
        text_input = t2iModel.tokenizer(prompt, padding="max_length", max_length=t2iModel.tokenizer.model_max_length, 
                                        truncation=True, return_tensors="pt")
        originalTokenData = [[],[]]
        for ii,val in enumerate(text_input['input_ids'][0]):
            if not val.item() in [SOS_TOKEN, EOS_TOKEN]:
                try:
                    originalTokenData[0].append(ii)
                    originalTokenData[1].append(val.item())
                except:
                    print(prompt[0], originalTokenData)
        baseEmbeddings = t2iModel.text_encoder(text_input.input_ids.to('cuda'))[0]
        max_length = text_input.input_ids.shape[-1]
        uncond_input = t2iModel.tokenizer(
            [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
        )
        uncond_embeddings = t2iModel.text_encoder(uncond_input.input_ids.to('cuda'))[0]   

        generator = torch.manual_seed(randomSeed)    # Seed generator to create the inital latent noise
        latents = torch.randn(
            (batch_size, t2iModel.unet.in_channels, height // 8, width // 8),
            generator=generator,
        )
        latents = latents.to('cuda')
        t2iModel.scheduler.set_timesteps(num_inference_steps)
        latents = latents * t2iModel.scheduler.init_noise_sigma

        t2iModel.scheduler.set_timesteps(num_inference_steps)

        text_embeddings = torch.cat([uncond_embeddings, baseEmbeddings])
        images = latent_reconstruction(latents, t2iModel, guidance_scale, text_embeddings)
        
        pil_images = [Image.fromarray(image) for image in images]
        pil_images[0].save('./results/'+RD + '/'+testCondition +'/'+str(randomSeed)+'_'+prompt[0]+'_SKT=0.png')
        
        baseInput = VISIONprocessor(images=pil_images[0], return_tensors="pt")    
        if VITPath != "openai/clip-vit-base-patch32":
            with torch.no_grad():
                baseOutput = VISIONmodel(**baseInput)
                baseEmb = baseOutput.last_hidden_state[0]
        else:
            baseOutput = VISIONmodel(**baseInput)
            baseEmb = baseOutput.image_embeds
        VSim = calculate_cosine_similarity(baseEmb,baseEmb)
        fairnessOutputData.append([prompt[0],"N.A.", "N.A.", VSim])
        
        for skipToken, skipEmbedding in zip(originalTokenData[0], originalTokenData[1]):
            text_input = t2iModel.tokenizer(prompt, padding="max_length", max_length=t2iModel.tokenizer.model_max_length, 
                                            truncation=True, return_tensors="pt")
            text_input.input_ids = torch.cat([text_input.input_ids[0][0:skipToken],
                                              text_input.input_ids[0][skipToken+1:],
                                              torch.tensor([EOS_TOKEN], dtype=torch.int64)])
            text_input.input_ids = torch.tensor([text_input.input_ids.tolist()])
            tokenData = [[],[]]
            NEWPrompt = ''
            for ii,val in enumerate(text_input.input_ids[0]):
                if not val.item() in [SOS_TOKEN, EOS_TOKEN]:
                    try:
                        tokenData[0].append(ii)
                        tokenData[1].append(val.item())
                        NEWPrompt = NEWPrompt + ' ' + t2iModel.tokenizer.decode([val.item()])
                    except:
                        print(prompt[0], tokenData)
            baseEmbeddings = t2iModel.text_encoder(text_input.input_ids.to('cuda'))[0]
            max_length = text_input.input_ids.shape[-1]
            uncond_input = t2iModel.tokenizer(
                [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
            )
            uncond_embeddings = t2iModel.text_encoder(uncond_input.input_ids.to('cuda'))[0]   

            generator = torch.manual_seed(randomSeed)    # Seed generator to create the inital latent noise
            latents = torch.randn(
                (batch_size, t2iModel.unet.in_channels, height // 8, width // 8),
                generator=generator,
            )
            latents = latents.to('cuda')
            t2iModel.scheduler.set_timesteps(num_inference_steps)
            latents = latents * t2iModel.scheduler.init_noise_sigma

            t2iModel.scheduler.set_timesteps(num_inference_steps)

            text_embeddings = torch.cat([uncond_embeddings, baseEmbeddings])
            images = latent_reconstruction(latents, t2iModel, guidance_scale, text_embeddings)
            pil_images = [Image.fromarray(image) for image in images]
            
            baseInput = VISIONprocessor(images=pil_images[0], return_tensors="pt")    
            if VITPath != "openai/clip-vit-base-patch32":
                with torch.no_grad():
                    baseOutput = VISIONmodel(**baseInput)
                    leaveOutEmb = baseOutput.last_hidden_state[0]
            else:
                baseOutput = VISIONmodel(**baseInput)
                leaveOutEmb = baseOutput.image_embeds
            VSim = calculate_cosine_similarity(baseEmb,leaveOutEmb)
            fairnessOutputData.append([NEWPrompt,skipToken, t2iModel.tokenizer.decode([skipEmbedding]), VSim])
            # update dictionary object for left out tokens (to calculate fairness score later)
            if t2iModel.tokenizer.decode([skipEmbedding]) in leftOutTokens:
                leftOutTokens[t2iModel.tokenizer.decode([skipEmbedding])].append(VSim)
            else:
                leftOutTokens[t2iModel.tokenizer.decode([skipEmbedding])] = [VSim]
                
                
    fairnessDataFile = './results/'+RD+ '/'+testCondition+ '/csvResults/'+ 'fairness_'+RD.split('/')[-1]+'_RS_'+str(randomSeed)+'.csv'
    export_csv_results(fairnessDataFile, ['prompt','token dim','skiptoken','Vsim'], fairnessOutputData)

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

  (batch_size, t2iModel.unet.in_channels, height // 8, width // 8),


  0%|          | 0/11 [00:00<?, ?it/s]

  (batch_size, t2iModel.unet.in_channels, height // 8, width // 8),


  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

{'this': [0.99799514, 0.9887357, 0.99671495, 0.9997727, 0.9994562, 0.9985247, 0.99957335, 0.9991695, 0.9991489, 0.9997204], 'is': [0.9994517, 0.9831527, 0.9996056, 0.99976385, 0.99982584, 0.98797745, 0.99966484, 0.9998621, 0.999576, 0.99946463], 'a': [0.9959929, 0.9594575, 0.9984823, 0.9997999, 0.9996146, 0.9938587, 0.9994461, 0.9994826, 0.9989897, 0.9977789], 'test': [0.9988507, 0.9828649, 0.9905903, 0.99950254, 0.9995167, 0.9721194, 0.99910736, 0.99691963, 0.9995421, 0.9964243], 'prompt': [0.99376756, 0.953658, 0.9984275, 0.999705, 0.9996219, 0.99935716, 0.99972963, 0.9996468, 0.99937075, 0.99739254]}


In [13]:
def calculate_fairness_scores(dataDictionary):
    fairScores = []
    for key in dataDictionary:
        fairScores.append([key, -np.log10(1-np.mean(dataDictionary[key]))])      
    return fairScores

fairnessScores = calculate_fairness_scores(leftOutTokens)
fairnessScoresFile = './results/'+RD+ '/'+testCondition+ '/csvResults/fairness_scores.csv'

export_csv_results(fairnessScoresFile, ['token','fairness'], fairnessScores)