In [1]:
import os
import sys
import torch
import torch.nn as nn
import numpy as np
import json
import random
import torchvision.transforms as transforms 

from diffusers import StableDiffusionPipeline, DiffusionPipeline, UNet2DConditionModel
from transformers import AutoImageProcessor, ViTModel, DeiTModel
from transformers import CLIPConfig, CLIPModel, CLIPTextModel, CLIPProcessor, CLIPFeatureExtractor,CLIPTokenizer
from transformers import tokenization_utils
import csv
import math
from tqdm.auto import tqdm
from PIL import Image
from transformers import AutoProcessor, CLIPVisionModelWithProjection
from scipy.spatial import distance
from numpy.linalg import norm
from sentence_transformers import SentenceTransformer

In [2]:
def flatten_list(X):
    return [x for xs in X for x in xs]

def calculate_cosine_similarity(a, b):
    a = a[0].detach().numpy()
    b = b[0].detach().numpy()
    cosine = np.dot(a,b)/(norm(a)*norm(b))
    return cosine

def random_perturb_text_embeddings(embd, targetDim, ptb, perturbationType, STD, ptbShift):
    random.seed(None)
    embdCopy = embd
    if perturbationType == 'LOCAL': 
        for ii in range(len(embd[0][targetDim])):
            embdCopy[0][targetDim][ii] = embdCopy[0][targetDim][ii]+random.uniform(-STD,STD)
    elif perturbationType == 'GLOBAL':
        for ii in range(len(embd[0])):
            for jj in range(len(embd[0][0])):
                    embdCopy[0][ii][jj] = embd[0][ii][jj]*random.uniform(1-STD,1+STD)
    else:
        for ii in range(len(embd[0][targetDim])):
            embdCopy[0][targetDim][ii] = embdCopy[0][targetDim][ii]-random.uniform((ptb-ptbShift)*STD, ptb*STD)
    return embdCopy
def latent_reconstruction(latents, t2iModel, guidance_scale, embeddings):
    for t in tqdm(t2iModel.scheduler.timesteps):
        # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
        latent_model_input = torch.cat([latents] * 2)

        latent_model_input = t2iModel.scheduler.scale_model_input(latent_model_input, timestep=t)

        # predict the noise residual
        with torch.no_grad():
            noise_pred = t2iModel.unet(latent_model_input, t, encoder_hidden_states=embeddings).sample

        # perform guidance
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        # compute the previous noisy sample x_t -> x_t-1
        latents = t2iModel.scheduler.step(noise_pred, t, latents).prev_sample

    # scale and decode the image latents with vae
    latents = 1 / 0.18215 * latents
    
    # decode and reformat generated image
    with torch.no_grad():
        image = t2iModel.vae.decode(latents).sample

    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
    images = (image * 255).round().astype("uint8")

    return images

def load_t2i_model(modelPath=None, BDType=None):
    t2iModel = None
    if BDType == 'bagm':
        t2iModel = StableDiffusionPipeline.from_pretrained(modelPath).to('cuda')
    elif BDType == 'tpa':
        t2iModel = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to('cuda')
        t2iModel.text_encoder=CLIPTextModel.from_pretrained(modelPath).to('cuda')
    elif BDType == 'badt2i':
        t2iModel = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to('cuda')
        t2iModel.unet = UNet2DConditionModel.from_pretrained(modelPath).to('cuda')
    else:
        # This line can be changed to account for other generative models
        t2iModel = StableDiffusionPipeline.from_pretrained(modelPath).to('cuda')
    return t2iModel
def export_csv_results(filePath, fileHeader, outputData):
    with open(filePath, 'w') as csvfile:
        w = csv.writer(csvfile)
        w.writerow(fileHeader)
        for row in outputData:
            w.writerow(row)
            
def load_vision_transformer(visionmodelPath="openai/clip-vit-base-patch32"):
    Vmodel = None
    Vprocessor = None
    if visionmodelPath == "openai/clip-vit-base-patch32":
        Vprocessor = AutoProcessor.from_pretrained(visionmodelPath)
        Vmodel = CLIPVisionModelWithProjection.from_pretrained(visionmodelPath)
        
    elif visionmodelPath == "google/vit-base-patch16-224-in21k":
        Vprocessor = AutoImageProcessor.from_pretrained(visionmodelPath)
        Vmodel = ViTModel.from_pretrained(visionmodelPath)

    elif visionmodelPath == "facebook/deit-base-distilled-patch16-224":
        Vprocessor = AutoImageProcessor.from_pretrained(visionmodelPath)
        Vmodel = DeiTModel.from_pretrained(visionmodelPath)
        
    return (Vprocessor, Vmodel)
transform = transforms.Compose([ 
    transforms.PILToTensor() 
])

In [5]:
VSIM_THRESHOLD = 0.9           #independent variable - derive from empirical eval.
guidance_scale = 15            
height = 512        
width = 512
num_inference_steps = 100

# Will be dependent on the tokenizer+text-encoder
SOS_TOKEN = 49406
EOS_TOKEN = 49407

nPerturbations = 5

sensitivePromptsFile = './sensitive_prompts.csv'  # globally unreliable prompts
batchName = 'batch=RSXXXX'                     # consistent with Rglobal
VITPath = "openai/clip-vit-base-patch32"
VITName='clip-vit'                         
MP = "CompVis/stable-diffusion-v1-4"          # t2i model path
RD = 'SD_V1.5'                                 # results directory

randomSeed = random.randint(0, 100000)         # make consistent with Rglobal if doing dependent analysis

inputPrompts = []
with open(sensitivePromptsFile, newline='') as csvfile:
    rdr = csv.reader(csvfile, delimiter=',')
    for row in rdr:
        inputPrompts.append(row[0])
inputPrompts = sorted(list(set(inputPrompts)))
                               
testCondition = VITName+'/'+batchName+'/Rlocal/'
if not os.path.exists('./results/'+RD+ '/' + testCondition + '/'):
    os.makedirs('./results/'+RD+ '/' + testCondition + '/')
if not os.path.exists('./results/'+RD+ '/' + testCondition + '/csvResults/'):
    os.makedirs('./results/'+RD+ '/' + testCondition + '/csvResults/')
localDataFile = './results/'+RD+ '/'+testCondition+ 'csvResults/'+ 'full_output_data_local_'+RD.split('/')[-1]+'_RS_'+str(randomSeed)+'.csv'
localReliabilityFile = './results/'+RD+ '/'+testCondition+ 'csvResults/'+ 'local_reliability_'+RD.split('/')[-1]+'_RS_'+str(randomSeed)+'.csv'    

localOutputData = []
localReliabilityData = []

t2iModel = load_t2i_model(MP, None)
(VISIONprocessor, VISIONmodel) = load_vision_transformer(VITPath)
for ii,prompt in enumerate(inputPrompts):
    prompt = [prompt]
    batch_size = len(prompt)
    text_input = t2iModel.tokenizer(prompt, padding="max_length", max_length=t2iModel.tokenizer.model_max_length, 
                                    truncation=True, return_tensors="pt")
    # get indice range and token data based on non-SOS and EOS tokens
    tokenData = [[],[]]
    for ii,val in enumerate(text_input['input_ids'][0]):
        if not val.item() in [SOS_TOKEN, EOS_TOKEN]:
            try:
                tokenData[0].append(ii)
                tokenData[1].append(val.item())
            except:
                print(prompt[0])
                print(tokenData)
    baseEmbeddings = t2iModel.text_encoder(text_input.input_ids.to('cuda'))[0]
                               
    # conditional generation preamble
    max_length = text_input.input_ids.shape[-1]
    uncond_input = t2iModel.tokenizer(
        [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
    )
    uncond_embeddings = t2iModel.text_encoder(uncond_input.input_ids.to('cuda'))[0]   

    generator = torch.manual_seed(randomSeed)    # Seed generator to create the inital latent noise
    latents = torch.randn(
        (batch_size, t2iModel.unet.in_channels, height // 8, width // 8),
        generator=generator,
    )
    latents = latents.to('cuda')
    t2iModel.scheduler.set_timesteps(num_inference_steps)
    latents = latents * t2iModel.scheduler.init_noise_sigma

    t2iModel.scheduler.set_timesteps(num_inference_steps)

    text_embeddings = torch.cat([uncond_embeddings, baseEmbeddings])
    images = latent_reconstruction(latents, t2iModel, guidance_scale, text_embeddings)

    pil_images = [Image.fromarray(image) for image in images]
    pil_images[0].save('./results/'+RD + '/'+testCondition +'/'+prompt[0]+'.png')

    # calculate vision similarity
    baseInput = VISIONprocessor(images=pil_images[0], return_tensors="pt")    
    if VITPath !=  "openai/clip-vit-base-patch32":
        with torch.no_grad():
            baseOutput = VISIONmodel(**baseInput)
            baseEmb = baseOutput.last_hidden_state[0]
    else:
        baseOutput = VISIONmodel(**baseInput)
        baseEmb = baseOutput.image_embeds
#         

    #[prompt | dim | shift index | shiftVal | Caption | Vsim | LSim | Euclidean]
    VSim = calculate_cosine_similarity(baseEmb,baseEmb)
    baseDataRow = [prompt[0], "N.A.", 0, 0, "N.A.", VSim]
    localOutputData.append(baseDataRow)

    STD = baseEmbeddings[0].std().item()   # standard dev.
    ptb = 0.025
    ptbShift = 0.025
    VSim = 1.0
    kk=1
                               
    for jj,token in zip(tokenData[0],tokenData[1]):
        print("Token Index = ", jj)               
        batch_size = len(prompt)
        text_input = t2iModel.tokenizer(prompt, padding="max_length", max_length=t2iModel.tokenizer.model_max_length, 
                                    truncation=True, return_tensors="pt")
        baseEmbeddings = t2iModel.text_encoder(text_input.input_ids.to('cuda'))[0]
        # need to reset ptb and STD for each token 
        STD = baseEmbeddings[0][jj].std().item()
        ptb = 0.025
        ptbShift = 0.025
        VSim = 1.0
        kk=1

        while VSim >= VSIM_THRESHOLD:
            print("Shift = ", ptb*STD)
            for n_ptb in range(nPerturbations):
                generator = torch.manual_seed(randomSeed)

                PtbEmbeddings = random_perturb_text_embeddings(baseEmbeddings, jj, ptb, 
                                                               'LOCAL', ptb*STD, ptbShift)

                latents = torch.randn(
                    (batch_size, t2iModel.unet.in_channels, height // 8, width // 8),
                    generator=generator,
                )
                latents = latents.to('cuda')
                t2iModel.scheduler.set_timesteps(num_inference_steps)
                latents = latents * t2iModel.scheduler.init_noise_sigma

                t2iModel.scheduler.set_timesteps(num_inference_steps)

                text_embeddings = torch.cat([uncond_embeddings, PtbEmbeddings])
                images = latent_reconstruction(latents, t2iModel, guidance_scale, text_embeddings)

                pil_images = [Image.fromarray(image) for image in images]
#                 pil_images[0].save('./results/'+RD+ '/'+testCondition+'/'+prompt[0]+'_dim='+str(jj)+'_Vshift='+str(ptb)[:5]+'stds_nPTB='+str(n_ptb)+'.png')

                # calculate vision similarity
                inputs = VISIONprocessor(images=pil_images[0], return_tensors="pt")
                if VITPath !=  "openai/clip-vit-base-patch32":
                    with torch.no_grad():
                        outputs = VISIONmodel(**inputs)
                        perturbEmb = outputs.last_hidden_state[0]
                else:
                    outputs = VISIONmodel(**inputs)
                    perturbEmb = outputs.image_embeds
                VSim = calculate_cosine_similarity(baseEmb,perturbEmb)

                dataRow = [prompt[0], jj, kk, ptb, n_ptb, VSim]
                localOutputData.append(dataRow)
                if VSim < VSIM_THRESHOLD:
                    break
            ptb+=ptbShift
            kk+=1 

        pil_images[0].save('./results/'+RD+ '/'+testCondition+'/'+prompt[0]+'_dim='+str(jj)+'_Vshift='+str(ptb*STD)[:7]+'stds_nPTB='+str(n_ptb)+'.png')
        # reset variables otherwise the random seed will change
        batch_size = len(prompt)
        text_input = t2iModel.tokenizer(prompt, padding="max_length", max_length=t2iModel.tokenizer.model_max_length, 
                                    truncation=True, return_tensors="pt")
        baseEmbeddings = t2iModel.text_encoder(text_input.input_ids.to('cuda'))[0]
        localReliabilityData.append([prompt[0], jj, t2iModel.tokenizer.decode([token]), (ptb-ptbShift)*STD, VSim])

export_csv_results(localDataFile, ['prompt', 'token dim', 'ptb index', 'ptb magnitude', 'rand(ptb) no.', 'VSim'], localOutputData)
export_csv_results(localReliabilityFile, ['prompt', 'token index', 'token', 'required ptb', 'VSIM'], localReliabilityData)                               



Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

  return self.fget.__get__(instance, owner)()
  (batch_size, t2iModel.unet.in_channels, height // 8, width // 8),


  0%|          | 0/101 [00:00<?, ?it/s]

  return F.conv2d(input, weight, bias, self.stride,


Token Index =  1
Shift =  0.025478556752204895


  (batch_size, t2iModel.unet.in_channels, height // 8, width // 8),


  0%|          | 0/101 [00:00<?, ?it/s]

  0%|          | 0/101 [00:00<?, ?it/s]

  0%|          | 0/101 [00:00<?, ?it/s]

  0%|          | 0/101 [00:00<?, ?it/s]

  0%|          | 0/101 [00:00<?, ?it/s]

Shift =  0.05095711350440979


  0%|          | 0/101 [00:00<?, ?it/s]

  0%|          | 0/101 [00:00<?, ?it/s]

  0%|          | 0/101 [00:00<?, ?it/s]

  0%|          | 0/101 [00:00<?, ?it/s]

  0%|          | 0/101 [00:00<?, ?it/s]

Shift =  0.0764356702566147


  0%|          | 0/101 [00:00<?, ?it/s]

  0%|          | 0/101 [00:00<?, ?it/s]

  0%|          | 0/101 [00:00<?, ?it/s]

  0%|          | 0/101 [00:00<?, ?it/s]

  0%|          | 0/101 [00:00<?, ?it/s]

Shift =  0.10191422700881958


  0%|          | 0/101 [00:00<?, ?it/s]

  0%|          | 0/101 [00:00<?, ?it/s]

  0%|          | 0/101 [00:00<?, ?it/s]

  0%|          | 0/101 [00:00<?, ?it/s]

  0%|          | 0/101 [00:00<?, ?it/s]

Shift =  0.12739278376102448


  0%|          | 0/101 [00:00<?, ?it/s]

  0%|          | 0/101 [00:00<?, ?it/s]

  0%|          | 0/101 [00:00<?, ?it/s]

  0%|          | 0/101 [00:00<?, ?it/s]

  0%|          | 0/101 [00:00<?, ?it/s]

Token Index =  2
Shift =  0.025187584757804873


  0%|          | 0/101 [00:00<?, ?it/s]

  0%|          | 0/101 [00:00<?, ?it/s]

KeyboardInterrupt: 