In [None]:
!pip install -qq diffusers transformers pytorch_lightning

In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from diffusers import UNet2DModel


class ConditionalUNet(nn.Module):
    def __init__(self, in_channels: int = 3) -> None:
        super(ConditionalUNet, self).__init__()
        self.in_channels = in_channels
        # self.num_reference_images = num_reference_images

        self.model = UNet2DModel(
            sample_size=28,           # the target image resolution
            in_channels=self.in_channels,  # Additional input channels for class cond.
            out_channels=self.in_channels,           # the number of output channels
            layers_per_block=2,       # how many ResNet layers to use per UNet block
            block_out_channels=(32, 64, 64), 
            down_block_types=( 
                "DownBlock2D",        # a regular ResNet downsampling block
                "AttnDownBlock2D",    # a ResNet downsampling block with spatial self-attention
                "AttnDownBlock2D",
            ), 
            up_block_types=(
                "AttnUpBlock2D", 
                "AttnUpBlock2D",      # a ResNet upsampling block with spatial self-attention
                "UpBlock2D",          # a regular ResNet upsampling block
            ),
        )

    def forward(self, noisy_latents: torch.Tensor, timestep: torch.Tensor, text_embeddings: torch.Tensor) -> torch.Tensor:
        return self.model.forward(sample=noisy_latents, timestep=timestep, class_labels=text_embeddings)


In [5]:
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from diffusers import AutoencoderKL, LMSDiscreteScheduler
from PIL import Image
from pytorch_lightning.loggers.wandb import WandbLogger
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizerFast


class StableDiffusionLightningModule(pl.LightningModule):
    def __init__(
            self, 
            in_channels: int, 
            num_train_timsteps: int,
            num_inference_timesteps: int,
            beta_start: float, 
            beta_end: float, 
            beta_schedule: str, 
            device: torch.device,
            max_length: int
        ):
        super(StableDiffusionLightningModule, self).__init__()
        self.unet = ConditionalUNet(in_channels=in_channels)
        self.vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
        self.tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-large-patch14")
        self.text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
        self.noise_scheduler = LMSDiscreteScheduler(
            num_train_timesteps=num_train_timsteps, 
            beta_start=beta_start, 
            beta_end=beta_end, 
            beta_schedule=beta_schedule, 
        )

        # self.unet.to(device=device)
        # self.vae.to(device=device)
        # self.tokenizer.to(device=device)
        # self.text_encoder.to(device=device)
        # self.noise_scheduler.to(device=device)

        self.in_channels = in_channels
        self.num_train_timesteps = num_train_timsteps
        self.num_inference_timesteps = num_inference_timesteps
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.max_length = max_length
        self.device = device
    
    @staticmethod
    def _compute_loss(noisy_predictions: torch.Tensor, actual_noise: torch.Tensor) -> torch.Tensor:
        return F.mse_loss(input=noisy_predictions, target=actual_noise)
    
    def forward(self, noisy_latents: torch.Tensor, timesteps: torch.Tensor, text_embeddings: torch.Tensor) -> torch.Tensor:
        out = self.unet.forward(noisy_latents=noisy_latents, timestep=timesteps, text_embeddings=text_embeddings)
        return out.sample

    def training_step(self, batch, batch_idx) -> torch.Tensor:
        image = batch['image']
        input_ids  = batch['input_ids']
        attention_masks = batch['attention_mask']

        with torch.no_grad():
            latents = self.vae.encode(image).latent_dist.sample()
            latents = latents * 0.18215
            text_embeddings = self.text_encoder.forward(input_ids=input_ids, attention_mask=attention_masks)

        noise   = torch.randn_like(latents)
        batch_size = latents.shape[0]
        timesteps = torch.randint(0, self.num_train_timesteps - 1, size=(batch_size, ))

        noisy_latents = self.noise_scheduler.add_noise(original_samples=latents, noise=noise, timesteps=timesteps)

        noisy_prediction = self.forward(noisy_latents=noisy_latents, timesteps=timesteps, text_embeddings=text_embeddings)
        loss = self._compute_loss(noisy_predictions=noisy_prediction, actual_noise=noise)

        self.log("train/loss", loss.item(), on_epoch=True, on_step=True, prog_bar=True, sync_dist=True)

        return loss
    
    @torch.no_grad()
    def inference(self, text: str, guidance_scale: float = 7.5, height: int = 512, width: int = 512, seed: int = 1337):
        generator = torch.manual_seed(seed=seed)
        self.noise_scheduler.set_timesteps(self.num_inference_timesteps)
        cond_batch = self.tokenizer(
            text, 
            return_tensors="pt",
            max_length=self.max_length, 
            padding="max_length", 
            truncation=True
        )
        uncond_batch = self.tokenizer(
            "", 
            return_tensors="pt",
            max_length=self.max_length, 
            padding="max_length",
            truncation=True
        )

        conditioned_embed = self.text_encoder.forward(input_ids=cond_batch['input_ids'], attention_mask=cond_batch['attention_mask'])
        unconditioned_embed = self.text_encoder.forward(input_ids=uncond_batch['input_ids'], attention_mask=uncond_batch['attention_mask'])

        embeddings = torch.cat([conditioned_embed, unconditioned_embed])

        latents = torch.randn(size=(2, self.unet.model.in_channels, height / 8, width / 8), generator=generator).to(device=self.device)
        latents = latents * self.noise_scheduler.init_noise_sigma

        for i, t in tqdm(enumerate(self.noise_scheduler.timesteps)):
            latent_model_inputs = torch.tensor([latents] * 2)
            sigma = self.noise_scheduler.sigmas[i]

            latent_model_inputs = self.noise_scheduler.scale_model_input(latent_model_inputs, t)
            noise_preds = self.forward(noisy_latents=latent_model_inputs, timesteps=t, text_embeddings=embeddings).sample

            noise_pred_cond, noise_pred_uncond = noise_preds.chunk(2)

            noise_pred = noise_pred_uncond + guidance_scale(noise_pred_cond - noise_pred_uncond)

            latents = self.noise_scheduler.step(model_output=noise_pred, timestep=t, sample=latents).prev_sample
        
        latents = 1 / 0.18215 * latents
        output = self.vae.decode(latents).sample

        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
        images = (image * 255).round().astype("uint8")
        pil_images = [Image.fromarray(image) for image in images]

        return pil_images[0]

In [6]:
model = StableDiffusionLightningModule(
    in_channels=3,
    num_inference_timesteps=50,
    num_train_timsteps=50, 
    beta_start=0.00085, 
    beta_end=0.12, 
    beta_schedule="linear",
    device="cuda",
    max_length=128
)

Some weights of the model checkpoint at openai/clip-vit-large-patch14 were not used when initializing CLIPTextModel: ['vision_model.encoder.layers.9.layer_norm1.bias', 'vision_model.encoder.layers.3.self_attn.q_proj.bias', 'vision_model.encoder.layers.23.self_attn.v_proj.bias', 'vision_model.encoder.layers.20.self_attn.k_proj.bias', 'vision_model.encoder.layers.22.self_attn.k_proj.bias', 'vision_model.encoder.layers.13.self_attn.k_proj.bias', 'vision_model.encoder.layers.19.self_attn.k_proj.weight', 'vision_model.encoder.layers.5.mlp.fc2.weight', 'vision_model.encoder.layers.15.layer_norm2.weight', 'vision_model.encoder.layers.3.layer_norm2.weight', 'vision_model.encoder.layers.18.self_attn.out_proj.bias', 'vision_model.encoder.layers.8.self_attn.k_proj.bias', 'vision_model.encoder.layers.23.mlp.fc1.weight', 'vision_model.encoder.layers.2.layer_norm1.bias', 'vision_model.encoder.layers.11.mlp.fc1.weight', 'vision_model.encoder.layers.9.self_attn.q_proj.bias', 'vision_model.encoder.laye

StableDiffusionLightningModule(
  (unet): ConditionalUNet(
    (model): UNet2DModel(
      (conv_in): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (time_proj): Timesteps()
      (time_embedding): TimestepEmbedding(
        (linear_1): Linear(in_features=32, out_features=128, bias=True)
        (act): SiLU()
        (linear_2): Linear(in_features=128, out_features=128, bias=True)
      )
      (down_blocks): ModuleList(
        (0): DownBlock2D(
          (resnets): ModuleList(
            (0-1): 2 x ResnetBlock2D(
              (norm1): GroupNorm(32, 32, eps=1e-05, affine=True)
              (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (time_emb_proj): Linear(in_features=128, out_features=32, bias=True)
              (norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        