## Setup

In [None]:
from transformers import CLIPModel, CLIPTextModel, CLIPTokenizer, CLIPFeatureExtractor, CLIPProcessor
from diffusers import AutoencoderKL, UNet2DConditionModel

from src import StableDiff

import numpy as np
from PIL import Image
from torch import autocast

from diffusers import LMSDiscreteScheduler
from tqdm.auto import tqdm
from difflib import SequenceMatcher

#Setup PyTorch
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

In [None]:
from importlib import reload
reload(StableDiff)

In [None]:
#Init CLIP tokenizer and model
model_path_clip = "openai/clip-vit-large-patch14"
auth_token = "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" #TODO: Replace this with huggingface auth token as a string if model is not already downloaded
clip = StableDiff.CLIP(device,
                       clip_model=CLIPModel.from_pretrained(model_path_clip, torch_dtype=torch.float16, use_auth_token=auth_token),
                       clip_processor=CLIPProcessor.from_pretrained(model_path_clip, use_auth_token=auth_token)
                       )
clip.load()

In [None]:
#Init diffusion model
model_path_diffusion = "runwayml/stable-diffusion-v1-5"
stable_diff = StableDiff.SimpleDiff(device,
                                    unet=UNet2DConditionModel.from_pretrained(model_path_diffusion, subfolder="unet", use_auth_token=auth_token, torch_dtype=torch.float16),
                                    vae=AutoencoderKL.from_pretrained(model_path_diffusion, subfolder="vae", use_auth_token=auth_token, torch_dtype=torch.float16))
stable_diff.load()

## Generate Image

In [None]:
#Image Generation from Prompt
prompt = "Cartoon of a Student in Southampton"

#Perform CLIP embeddings
with autocast("cuda"):
    embedding_unconditional = clip.embed_text("")
    embedding_conditional = clip.embed_text(prompt)

    print(embedding_unconditional.shape)
    print(embedding_conditional.shape)

stable_diff.generate(embedding_unconditional=embedding_unconditional,
                     embedding_conditional=embedding_conditional,
                     tokens_length=clip.max_length, seed=1234, guidance_scale=8.0, steps=100)

## Image Variations

In [None]:
init_image = Image.open("polar_bear.png")

#Perform CLIP embeddings
with autocast("cuda"):
    embedding_unconditional = clip.embed_text("")
    embedding_image = clip.embed_image(init_image)

    # inputs = clip.clip_processor(images=init_image, return_tensors="pt")
    # inputs.to(clip.device)
    # print(inputs.keys())
    # # return self.clip_model.visual_projection(self.clip_model.vision_model(**inputs)[1]) #TODO
    # embedding_image_raw = clip.clip_model.vision_model(pixel_values=inputs.pixel_values)
    # print(embedding_image_raw.keys())
    # print("TEST", embedding_image_raw.last_hidden_state[:,:77].shape)
    # embedding_image = clip.clip_model.visual_projection(embedding_image_raw.last_hidden_state[:,:77])

    print(embedding_unconditional.shape)
    print(embedding_image.shape)

stable_diff.generate(embedding_unconditional=embedding_unconditional,
                     embedding_conditional=embedding_image,
                     tokens_length=clip.max_length, seed=123, guidance_scale=6.0, steps=100)
