In [None]:
get_ipython().system('nvidia-smi')
get_ipython().system('python -m pip install nvidia-cudnn-cu12')
get_ipython().system('CUDNN_PATH=`python -m pip show nvidia-cudnn-cu12  | grep Location | cut -d":" -f2 | xargs`/nvidia/cudnn python -m pip install git+https://github.com/NVIDIA/cudnn-frontend.git')
!python -m pip install cupy-cuda12x==12.3.0 pillow
!git clone --depth 1 --branch v0.8.0 https://github.com/tinygrad/tinygrad
!git clone https://github.com/Fatlonder/tinyfusers.git .

%cd tinygrad
!rm -r build
!python -m pip install -e .
%cd ..

%cd tinyfusers
!rm -r build
!python -m  pip install -e .
%cd ..

In [1]:
import sys
sys.path.insert(0, 'tinyfusers')

import tinyfusers
import importlib
importlib.reload(tinyfusers)
from IPython.display import display
from PIL import Image
from tqdm import tqdm
import numpy as np
import cupy as cp
from tinygrad import GlobalCounters, Tensor
from tinygrad.nn.state import torch_load
from tinygrad.helpers import Timing, Context, getenv, fetch
from tinyfusers.variants.sd import StableDiffusion
from tinyfusers.tokenizer.clip import ClipTokenizer
from tinyfusers.storage.state import update_state
from tinyfusers.storage.unpicker import load_weights
import gc
import ctypes

libc = ctypes.CDLL("libc.so.6")
mempool = cp.get_default_memory_pool()
pinned_mempool = cp.get_default_pinned_memory_pool()

def model_to_fp16(model):
  for l in get_state_dict(model).values():
    l.replace(l.cast(dtypes.float16).realize())

In [2]:
model = StableDiffusion()
default_prompt = "a horse sized cat eating a bagel"
args = {"prompt": default_prompt, "steps": 20, "fp16": True, "out": "rendered.png", "noshow": False, "timing": False, "guidance":7.5, "seed": 42}
state_dictionary = torch_load(fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt'))
update_state(model, state_dictionary['state_dict'])

In [None]:
# run through CLIP to get context
tokenizer = ClipTokenizer()
prompt = cp.array([tokenizer.encode(args['prompt'])])
context = model.cond_stage_model.transformer.text_model(prompt)

prompt = cp.array([tokenizer.encode("")])
unconditional_context = model.cond_stage_model.transformer.text_model(prompt)

print(f"CLIP context: {context.shape}, unconditional CLIP context: {unconditional_context.shape}")
del model.cond_stage_model

timesteps = list(range(1, 1000, 1000//args['steps']))
print(f"running for {timesteps} timesteps")
alphas = model.alphas_cumprod[timesteps]
alphas_prev = cp.concatenate((cp.array([1.0]), alphas[:-1])).astype(cp.float32)

cur_stream = cp.cuda.get_current_stream()
cur_stream.use()
if args['seed'] is not None: Tensor._seed = args['seed']
latent = Tensor.randn(1,4,64,64)
latent = cp.asarray(latent.numpy())
cur_stream.synchronize()
cp.cuda.Device().synchronize()

with Context(BEAM=getenv("LATEBEAM")):
  for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])):
    GlobalCounters.reset()
    t.set_description("%3d %3d" % (index, timestep))
    with Timing("step in ", enabled=args['timing'], on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB"):
      tid = cp.array([index])
      latent = model(unconditional_context, context, latent, cp.array([timestep]), alphas[tid], alphas_prev[tid], cp.array([args['guidance']]))
x = model.decode(latent)
print(x.shape)
im = Image.fromarray(cp.asnumpy(x).astype(np.uint8, copy=False))
print(f"saving {args['out']}")
im.save(args['out'])
display(im)