In [None]:
get_ipython().system('nvidia-smi')
get_ipython().system('python -m pip install nvidia-cudnn-cu12')
get_ipython().system('CUDNN_PATH=`python -m pip show nvidia-cudnn-cu12  | grep Location | cut -d":" -f2 | xargs`/nvidia/cudnn python -m pip install git+https://github.com/NVIDIA/cudnn-frontend.git')
!python -m pip install cupy-cuda12x pillow
!git clone --depth 1 --branch v0.8.0 https://github.com/tinygrad/tinygrad
!git clone https://github.com/Fatlonder/tinyfusers.git .

%cd tinygrad
!rm -r build
!python -m pip install -e .
%cd ..

%cd tinyfusers
!rm -r build
!python -m  pip install -e .
%cd ..

In [1]:
import sys
sys.path.insert(0, 'tinyfusers')

import tinyfusers
import importlib
importlib.reload(tinyfusers)

from IPython.display import display
from tqdm import tqdm
from PIL import Image
import numpy as np
import cupy as cp
from tinygrad import GlobalCounters, dtypes, Tensor
from tinygrad.nn.state import torch_load, get_state_dict
from tinygrad.helpers import Timing, Context, getenv, fetch
from tinyfusers.variants.sd import StableDiffusion
from tinyfusers.tokenizer.clip import ClipTokenizer
from tinyfusers.storage.state import update_state
import gc
import ctypes

libc = ctypes.CDLL("libc.so.6")
mempool = cp.get_default_memory_pool()
pinned_mempool = cp.get_default_pinned_memory_pool()

def model_to_fp16(model):
  for l in get_state_dict(model).values():
    l.replace(l.cast(dtypes.float16).realize())

In [2]:
default_prompt = "a horse sized cat eating a bagel"
args = {"prompt": default_prompt, "steps": 20, "fp16": True, "out": "rendered.png", "noshow": False, "timing": False, "guidance":7.5, "seed": 42}
Tensor.no_grad = True
model = StableDiffusion()
state_dictionary = torch_load(fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt'))
update_state(model, state_dictionary['state_dict'])

In [None]:
tokenizer = ClipTokenizer()
prompt = cp.asarray([tokenizer.encode(args['prompt'])])
context = model.cond_stage_model.transformer.text_model(prompt)
print("got CLIP context", context.shape)

prompt = cp.asarray([tokenizer.encode("")])
unconditional_context = model.cond_stage_model.transformer.text_model(prompt)
print("got unconditional CLIP context", unconditional_context.shape)

del model.cond_stage_model

timesteps = list(range(1, 1000, 1000//args['steps']))
print(f"running for {timesteps} timesteps")
alphas = model.alphas_cumprod[timesteps]
alphas_prev = cp.concatenate((cp.array([1.0]), alphas[:-1]))

if args['seed'] is not None: cp.random.seed(seed=args['seed'])
latent = cp.random.randn(1,4,64,64)

with Context(BEAM=getenv("LATEBEAM")):
  for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])):
    GlobalCounters.reset()
    t.set_description(f"({index}, {timestep})")
    with Timing("step in ", enabled=args['timing'], on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB"):
      tid = cp.asarray([index])
      latent = model(unconditional_context, context, latent, cp.asarray([timestep]), alphas[tid], alphas_prev[tid], cp.asarray([args['guidance']]))
  gc.collect()
  mempool.free_all_blocks()
  pinned_mempool.free_all_blocks()
  print(f"Freed memory: {libc.malloc_trim(0)}")
x = model.decode(latent)
im = Image.fromarray(x.numpy().astype(np.uint8, copy=False))
display(im)