Skip to content

Cheslaff/KittyDiffusion

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

13 Commits
 
 
 
 
 
 
 
 

Repository files navigation

KittyDiffusion😻

It's so easy I can even explain it to my cat.(oh, and it generates pretty cats)


Results.


What is Stable Diffusion?

Stable Diffusion took the world of AI by storm.
It is an image generation model, which takes a text prompt (however, for simplicity, this one doesn't)
and generates a high quality image matching the prompt.
It sounds like something complex (and at such scale it is), but you can actually implement your own mini stable diffusion
in a couple of evenings! Let's see how it works.

How does Stable Diffusion work?

Roughly speaking, Stable Diffusion takes an image from the data (the look-a-like images we want to see our model produce)
It adds noise to the image and then trains to remove the noise from it.
Yep, we destroy the image to train our model to reconstruct it, so that when we give our model some noise it will reconstruct it into some image.
Let's see it in details.
Stable Diffusion algorithm consists of 2 stages.

Forward

On forward stage we add noise to the image. However, we need to train our network, not confuse it.
We add noise of different strength to the images. For this purpose we have a noise scheduler, which controls how much noise to add and how much of an image to keep.
Explore `processes.py` to see how it works under the hood (This is a high-level overview, so for great details I strongly recommend to watch this video
Long story short, we randomly select a timestep t, which defines how strong the noise level is and add noise of this strength to out image.
Note: sample function runs the entire diffusion loop to generate an image.

Backward

On backward process we take a destroyed image and a timestep t and pass it to the model to denoise the image (we pass t to add context. without it, we confuse the model)
Model predicts the noise (it's an image of the same size as input, so that we can subtract the noise - this is what we do).
We gradually subtract the noise from the image (doing it in 1 step produces low-quality results)
Sounds like a charm, but how do we calculate the loss and what model do we use? You're goddamn right to question this.
As for loss we use a plain MSE loss, which compares predicted noise (noise from the model) with our added noise (add_noise function in processes.py returns both noisy image and the noise)
U-Net serves as a model to predict the noise. It is the model from 2015, which serves segmentation task (or how's it called).
If you're familiar with AE, this one reminds it a lot. On top of AE-like architecture it uses residual skip-connections and self-attention (the original model didn't, but in DDPM we do)


No one is probably reading it anyway + ngl I'm a lazy ass to explain everything here, so if I ever explain DDPM (Denoising Diffusion Probabilistic Model), I'll do it in video format.
There's a lot to explore: denoising latents (adding VAE to the pipeline), adding guidance, adding CLIP text embeddings to follow the prompt.
This implementation is as simple as possible.
I didn't add all the complex stuff on purpose (+ ngl, I'm a lazy ass)

Developed by Cheslaff with love❤️.

About

Tiny Stable Diffusion model in PyTorch

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages