In [None]:
%load_ext autoreload
%autoreload 2

# make sure you're logged in with `huggingface-cli login`
import torch
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
from contextlib import nullcontext

device = 'cuda' if torch.cuda.is_available() else 'cpu'

lms = LMSDiscreteScheduler(
    beta_start=0.00085, 
    beta_end=0.012, 
    beta_schedule="scaled_linear"
)

pipe = StableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4", 
    scheduler=lms,
    revision='fp16',
    torch_dtype=torch.float16,
    use_auth_token="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"
).to(device)

Downloading:   0%|          | 0.00/12.5k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/543 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/4.63k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/608M [00:00<?, ?B/s]

In [None]:
import functools 


class DummyInputIds:
    def __init__(self, input_ids):
        self.input_ids = input_ids

    def to(self, device: str):
        return self
    
    @property
    def shape(self):
        return self.input_ids.shape[:-1] # cuts last dimension
        

def tokenizer_prefix_function(function):

    class DummyReturnClass:
        def __init__(self, input_ids):
            self.input_ids = DummyInputIds(input_ids)
            
    @functools.wraps(function)
    def run(*args, **kwargs):
        if 'text' in kwargs and isinstance(kwargs['text'], list) and isinstance(kwargs['text'][0], torch.Tensor):
            return DummyReturnClass(input_ids=kwargs['text'][0])
        elif len(args) > 0 and isinstance(args[0], list) and isinstance(args[0][0], torch.Tensor):
            return DummyReturnClass(input_ids=args[0][0])
        # TODO: do for list of strings
        else:
            return function(*args, **kwargs)
    return run


def patch_call(instance, func):
    class _(type(instance)):
        def __call__(self, *arg, **kwarg):
            return func(*arg, **kwarg)
    instance.__class__ = _

patch_call(pipe.tokenizer, tokenizer_prefix_function(pipe.tokenizer.__call__))
pipe.tokenizer.tokenize = tokenizer_prefix_function(pipe.tokenizer.tokenize)
pipe.tokenizer.encode = tokenizer_prefix_function(pipe.tokenizer.encode)
pipe.tokenizer.encode_plus = tokenizer_prefix_function(pipe.tokenizer.encode_plus)
pipe.tokenizer.batch_encode_plus = tokenizer_prefix_function(pipe.tokenizer.batch_encode_plus)


def text_encoder_prefix_function(function):
    @functools.wraps(function)
    def run(*args, **kwargs):
        
        if 'input_ids' in kwargs:
            x = kwargs['input_ids']
        else:
            x = args[0]
            
        return function(*args, **kwargs) if not isinstance(x, DummyInputIds) else [x.input_ids]
    return run

patch_call(pipe.text_encoder, text_encoder_prefix_function(pipe.text_encoder.__call__))

In [None]:
from captum.attr import (
    IntegratedGradients,
    Saliency,
    InputXGradient,
    DeepLift,
    DeepLiftShap,
    GuidedBackprop,
    GuidedGradCam,
    Deconvolution,
    LRP
)
from typing import Optional 

def forward(
    input_embeds: torch.Tensor,
    pipe: StableDiffusionPipeline
):
    return pipe([input_embeds], num_inference_steps=2, output_type=None)

saliency = InputXGradient(functools.partial(forward, pipe=pipe))

# get prompt text embeddings
prompt = "a photo of an astronaut riding a horse on mars"
text_input = pipe.tokenizer(
    prompt,
    padding="max_length",
    max_length=pipe.tokenizer.model_max_length,
    truncation=True,
    return_tensors="pt",
)
text_embeddings = pipe.text_encoder(text_input.input_ids.to(pipe.device))[0]

In [None]:
text_embeddings.shape

In [None]:
output = forward(text_embeddings, pipe)

In [None]:
output['sample'].shape

In [None]:
images = output['sample'].detach().numpy()
images = (images * 255).round().astype("uint8")

In [None]:
from PIL import Image
Image.fromarray(images[0])

In [None]:
aux = saliency.attribute(text_embeddings)

In [None]:
prompt = "a photo of an astronaut riding a horse on mars"

with torch.autocast('cuda') if device == 'cuda' else nullcontext():
    result = pipe(prompt, num_inference_steps=5)
    image = result["sample"][0]  
    
#image.save("astronaut_rides_horse.png")

In [None]:
image

In [None]:
pipe