<a href="https://colab.research.google.com/github/Sherlock-221BBS/Stable_Diffusion_From_Scratch/blob/main/coding_stable_diffusion_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
import math



## Attention

In [None]:


class SelfAttention(nn.Module):
  def __init__(self, n_heads, d_embed, in_proj_bias = True, out_proj_bias = True):
    super().__init__()
    self.in_proj = nn.Linear(in_features = d_embed, out_features = d_embed * 3, bias = in_proj_bias)
    self.out_proj = nn.Linear(in_features = d_embed, out_features = d_embed, bias = out_proj_bias)
    self.n_heads = n_heads
    self.d_head = d_embed // n_heads

  def forward(self, x, causal_mask = False):
    input_shape = x.shape
    batch, seq_len, d_embed = input_shape

    q, k, v = self.in_proj(x).chunk(3, dim = -1)
    interim_shape = (batch, seq_len, self.n_heads, self.d_head)

    q = q.reshape(interim_shape).transpose(1, 2)
    k = k.reshape(interim_shape).transpose(1, 2)
    v = v.reshape(interim_shape).transpose(1, 2)

    weights = q @ k.transpose(-1, -2)

    if causal_mask:
      mask = torch.ones_like(weights, dtype = torch.bool).triu(1)
      weights.mask_fill(mask, -torch.inf)
    weights /= math.sqrt(self.d_head)
    weights = F.softmax(weights, dim = -1)

    outputs = weights @ v
    outputs = outputs.transpose(1, 2)
    outputs = outputs.reshape(input_shape)
    outputs = self.out_proj(outputs)
    return outputs





# VAE Helper Classes

In [None]:
class VAE_AttentionBlock(nn.Module):
  def __init__(self, channels):
    super().__init__()
    self.groupnorm = nn.GroupNorm(32, channels)
    self.attention = SelfAttention(1, channels)

  def forward(self, x):
    residue = x
    x = self.groupnorm(x)
    n, c, h, w = x.shape
    x = x.view((n, c, h * w))
    x = x.transpose(-1, -2)
    x = self.attention(x)
    x = x.view((n, c, h, w))
    return x + residue

class VAE_ResidualBlock(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.groupnorm_1 = nn.GroupNorm(32, in_channels)
    self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding = 1)
    self.groupnorm_2 = nn.GroupNorm(32, out_channels)
    self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size = 3, padding = 1)
    if in_channels == out_channels:
      self.residual_layer = nn.Identity()
    else:
      self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size = 3,padding = 1)

  def forward(self, x):
    residue = x
    x = self.groupnorm_1(x)
    x = F.silu(x)
    x = self.conv_1(x)
    x = self.groupnorm_2(x)
    x = F.silu(x)
    x = self.conv_2(x)
    print("successfully passed")
    return x + self.residual_layer(residue)


# Encoder of VAE

In [None]:



class VAE_Encoder(nn.Sequential):
  def __init__(self):
    super().__init__(
        nn.Conv2d(in_channels = 3,
                  out_channels = 128,
                  kernel_size = 3,
                  padding = 1),
        VAE_ResidualBlock(128, 128),
        VAE_ResidualBlock(128, 128),
        nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size = 3, stride = 2, padding = 0),
        VAE_ResidualBlock(128, 256),
        VAE_ResidualBlock(256, 256),
        nn.Conv2d(in_channels = 256, out_channels = 256, kernel_size = 3, stride = 2, padding = 0),
        VAE_ResidualBlock(256, 512),
        VAE_ResidualBlock(512, 512),
        nn.Conv2d(in_channels = 512, out_channels =512, kernel_size = 3, stride = 2, padding =0),
        VAE_ResidualBlock(512, 512),
        VAE_ResidualBlock(512, 512),
        VAE_ResidualBlock(512, 512),
        VAE_AttentionBlock(512),
        VAE_ResidualBlock(512, 512),
        nn.GroupNorm(32, 512),
        nn.SiLU(),
        nn.Conv2d(in_channels = 512, out_channels = 8, kernel_size = 3, padding =1)
    )


  def forward(self, x, noise):
    for module in self:
      if getattr(module, "stride", None) == (2, 2):
        x = F.pad(x, (0, 1, 0, 1))

      x = module(x)

    mean, log_variance = x.chunk(2, dim = 1)
    log_variance = torch.clamp(log_variance, min = -30, max = 20)
    variance = log_variance.exp()
    std_dev = variance.sqrt()
    x = mean + noise * std_dev
    return x




In [None]:
vae_encoder = VAE_Encoder()
input = torch.randn(1, 3, 512, 512)
noise = torch.randn(1, 4, 64, 64)
encoder_outputs = vae_encoder(input, noise)
encoder_outputs.shape

successfully passed
successfully passed
successfully passed
successfully passed
successfully passed
successfully passed
successfully passed
successfully passed
successfully passed
successfully passed


torch.Size([1, 4, 64, 64])

# Decoder of VAE

In [None]:
class VAE_Decoder(nn.Sequential):
  def __init__(self):
    super().__init__(
        nn.Conv2d(4, 4, kernel_size = 1, padding = 0),
        nn.Conv2d(4, 512, kernel_size = 3, padding = 1),
        VAE_ResidualBlock(512, 512),
        VAE_AttentionBlock(512),
        VAE_ResidualBlock(512, 512),
        VAE_ResidualBlock(512, 512),
        VAE_ResidualBlock(512, 512),
        VAE_ResidualBlock(512, 512),
        nn.Upsample(scale_factor = 2),
        nn.Conv2d(512, 512, kernel_size = 3, padding = 1),
        VAE_ResidualBlock(512, 512),
        VAE_ResidualBlock(512, 512),
        VAE_ResidualBlock(512, 512),
        nn.Upsample(scale_factor = 2),
        nn.Conv2d(512, 512, kernel_size = 3, padding =1),
        VAE_ResidualBlock(512, 256),
        VAE_ResidualBlock(256, 256),
        VAE_ResidualBlock(256, 256),
        nn.Upsample(scale_factor = 2),
        nn.Conv2d(256, 256, kernel_size = 3, padding = 1),
        VAE_ResidualBlock(256, 128),
        VAE_ResidualBlock(128, 128),
        VAE_ResidualBlock(128, 128),
        nn.GroupNorm(32, 128),
        nn.SiLU(),
        nn.Conv2d(128, 3, kernel_size = 3, padding = 1)

    )

  def forward(self, x):
    x /= 0.18125
    for module in self:
      x = module(x)

    return x


In [None]:
vae_decoder = VAE_Decoder()
decoder_output = vae_decoder(encoder_outputs)
decoder_output.shape

successfully passed
successfully passed
successfully passed
successfully passed
successfully passed
successfully passed
successfully passed
successfully passed
successfully passed
successfully passed
successfully passed
successfully passed


In [None]:
class CLIPEmbedding(nn.Module):
  def __init__(self, n_vocab, n_embd, n_tokens):
    super().__init__()
    self.embedding = nn.Embedding(n_vocab, n_embd)
    self.positional_embedding = nn.Parameter(torch.randn(n_tokens, n_embd))

  def forward(self, x):
    x = self.embedding(x)
    x+= self.positional_embedding
    return x


class CLIPLayer(nn.Module):
  def __init__(self, n_head, n_embd):
    super().__init__()
    self.layernorm_1 = nn.LayerNorm(n_embd)
    self.attention = SelfAttention(n_head, n_embd)

    self.layernorm_2 = nn.LayerNorm(n_embd)

    self.linear_1 = nn.Linear(n_embd, 4 * n_embd)
    self.linear_2 = nn.Linear(4 * n_embd, n_embd)

  def forward(self, x):
    residue = x
    x = self.layernorm_1(x)
    print("first normalization layer passed")
    x = self.attention(x)
    print("first attention layer passed")
    x+= residue

    residue = x
    x = self.layernorm_2(x)
    x = self.linear_1(x)
    x = x * torch.sigmoid(1.702 * x)
    x = self.linear_2(x)

    x+= residue
    return x

class CLIP(nn.Module):
  def __init__(self):
    super().__init__()
    self.embedding = CLIPEmbedding(49708, 768, 77)
    self.layers = nn.ModuleList([
        CLIPLayer(12, 768) for i in range(12)
    ])
    self.layernorm = nn.LayerNorm(768)


  def forward(self, tokens):
    tokens = tokens.type(torch.long)
    state = self.embedding(tokens)
    print("embedding layer passed")

    for layer in self.layers:
      state = layer(state)

    output = self.layernorm(state)

    return output



In [None]:
tokens = torch.randint(1, 49708, (1, 77))
clip = CLIP()
clip_output = clip(tokens)
clip_output.shape

embedding layer passed
first normalization layer passed


RuntimeError: ignored