# Dreambooth

The goal of this tutorial is to get demo how to fine tune an image generation model

In [1]:
from AIsaac.all import *
import fastcore.all as fc

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset

from datasets import load_dataset
from transformers import AutoTokenizer, PretrainedConfig,CLIPTextModel
from accelerate import Accelerator
import xformers # May need dev version
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler

import torchvision
from torchvision import transforms
import torchvision.transforms.functional as TF

from PIL import Image
import itertools, math
from pathlib import Path
from itertools import zip_longest

from IPython.display import clear_output

  from .autonotebook import tqdm as notebook_tqdm
A matching Triton is not available, some optimizations will not be enabled.
Error caught was: No module named 'triton'


In [2]:
instance_data_dir=Path("imgs/dreambooth")

In [3]:
def decode(tokenizer, tokens, start_after='<|startoftext|>', end_at='<|endoftext|>'):
    _decoded = ds.tokenizer.decode(tokens)
    return _decoded[len(start_after):_decoded.find(end_at)].strip()

In [4]:
denorm = UnNormalize([.5],[.5])

In [5]:
from datasets import load_dataset

In [6]:
xmean,xstd = 0.5, 0.5
@inplace
def transformi(b): 
    b['image'] = [(TF.to_tensor(TF.resize(o,512))-xmean)/xstd for o in b['image']]
    b['label'] = ['a photo of sks dog' for _ in b['label']]

dd = load_dataset("iflath/DogDreamBooth").with_transform(transformi)

Found cached dataset imagefolder (/home/.cache/huggingface/datasets/iflath___imagefolder/iflath--DogDreamBooth-fe88f623b33d9619/0.0.0/37fbb85cc714a338bea574ac6c7d0b5be5aff46c1862c1989b20e0771199e93f)
100%|██████████| 2/2 [00:00<00:00, 459.30it/s]


In [7]:
get_images(instance_data_dir)

(#5) [Path('imgs/dreambooth/alvan-nee-bQaAJCbNq3g-unsplash.jpeg'),Path('imgs/dreambooth/alvan-nee-brFsZ7qszSY-unsplash.jpeg'),Path('imgs/dreambooth/alvan-nee-9M0tSjb-cpA-unsplash.jpeg'),Path('imgs/dreambooth/alvan-nee-Id1DBHv4fbg-unsplash.jpeg'),Path('imgs/dreambooth/alvan-nee-eoqnr8ikwFE-unsplash.jpeg')]

In [8]:
pretrained_model="CompVis/stable-diffusion-v1-4" 

In [9]:

# train_dataset = DreamBoothDataset(instance_data_dir,instance_prompt,tokenizer)
# train_dataloader = torch.utils.data.DataLoader(train_dataset,batch_size=1,shuffle=True,num_workers=0,)
# xb = fc.first(train_dataloader)
# [o.shape for o in xb.values()]
# show_images(denorm(xb['instance_image']),titles=[decode(train_dataloader.dataset.tokenizer,o) for o in xb['instance_prompt_ids']])


In [10]:
#| export
class DreamBoothDataset(Dataset):
    def __init__(self,instance_data_dir,instance_prompt,tokenizer,img_size=512):
        fc.store_attr('instance_data_dir,instance_prompt,tokenizer,img_size')
        self.instance_images_path = get_images(instance_data_dir)

    def __len__(self): return len(self.instance_images_path) 

    def __getitem__(self, index):
        batch = {}
        instance_image = Image.open(self.instance_images_path[index])
        
        item_tfms = transforms.Compose([
            transforms.CenterCrop(min(instance_image.size)), # Make square
            transforms.Resize(self.img_size+64, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.RandomCrop(512),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])
        
        batch['instance_image'] = item_tfms(instance_image)
        
        batch['instance_prompt_ids'] = self.tokenizer(self.instance_prompt,truncation=True,
            padding="max_length",max_length=self.tokenizer.model_max_length,return_tensors="pt",
            ).input_ids.squeeze()
        return batch

In [11]:
tokenizer = AutoTokenizer.from_pretrained(pretrained_model,subfolder="tokenizer",revision=None,use_fast=False,)


In [12]:
instance_prompt="a photo of ssskkksss dog"
# ds = DreamBoothDataset(instance_data_dir,instance_prompt,tokenizer)

In [13]:
train_dataset = DreamBoothDataset(instance_data_dir,instance_prompt,tokenizer)
train_dataloader = torch.utils.data.DataLoader(train_dataset,batch_size=2,shuffle=True,num_workers=0,)

In [14]:
noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model, subfolder="scheduler")
text_encoder = CLIPTextModel.from_pretrained(pretrained_model, subfolder="text_encoder", revision=None)
text_encoder.requires_grad_(False)
vae = AutoencoderKL.from_pretrained(pretrained_model, subfolder="vae", revision=None)
vae.requires_grad_(False)
unet = UNet2DConditionModel.from_pretrained(pretrained_model, subfolder="unet", revision=None)

In [15]:
unet.enable_xformers_memory_efficient_attention()
torch.backends.cuda.matmul.allow_tf32 = True

In [16]:
optimizer = torch.optim.AdamW(itertools.chain(unet.parameters(), text_encoder.parameters()),
    lr=5e-6, betas=(0.9,0.999), weight_decay=1e-2, eps=1e-8)

In [17]:
num_training_steps = 400
num_train_epochs = math.ceil(num_training_steps/len(train_dataloader))
lr_scheduler = get_scheduler('constant',optimizer=optimizer,num_warmup_steps=0, num_training_steps=num_training_steps,num_cycles=1,power=1.,)

In [18]:
accelerator = Accelerator(mixed_precision="fp16")
unet,text_encoder,optimizer,train_dataloader,lr_scheduler=accelerator.prepare(unet,text_encoder,optimizer,train_dataloader,lr_scheduler)

vae.to(accelerator.device, dtype=torch.float16)
text_encoder.to(accelerator.device, dtype=torch.float16)
''

''

In [19]:
# loss_func = fc.bind(F.mse_loss,reduction="mean")
# optimizer = torch.optim.AdamW(itertools.chain(unet.parameters(), text_encoder.parameters()),
#     lr=5e-6, betas=(0.9,0.999), weight_decay=1e-2, eps=1e-8)


# Trainer(train_dataloader,loss_func,o

In [20]:
class DreamBoothModel(nn.Module):
    def __init__(self,unet,text_encoder): 
        super().__init__()
        self.unet = nn.Sequential(unet)
        self.text_encoder = nn.Sequential(text_encoder)
        
    def forward(self,x):        
        encoder_hidden_states = text_encoder(x["instance_prompt_ids"])[0]

        # predict:  Predict the noise residual
        return unet(x['noisy_latents'], x['timesteps'], encoder_hidden_states).sample.float()
model = DreamBoothModel(unet,text_encoder)

In [24]:
batch['noisy_latents'].shape

torch.Size([1, 4, 64, 64])

In [23]:
for epoch in range(num_train_epochs):
# for epoch in range(2):
    clear_output(wait=True); 
    unet.train()
    text_encoder.train()
    for step, batch in enumerate(train_dataloader):

        # Convert images to latent space
        batch['latents'] = vae.encode(batch["instance_image"].float().to(dtype=torch.float16)).latent_dist.sample() * vae.config.scaling_factor

        # Sample noise that we'll add to the latents
        batch['noise'] = torch.randn_like(batch['latents']).float()
        
        # Sample a random timestep for each image
        batch['timesteps'] = torch.randint(0, noise_scheduler.config.num_train_timesteps, (batch['latents'].shape[0],), device=batch['latents'].device).long()

        # Add noise to the latents according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        batch['noisy_latents'] = noise_scheduler.add_noise(batch['latents'], batch['noise'], batch['timesteps'])

        
        # # Get the text embedding for conditioning
        encoder_hidden_states = text_encoder(batch["instance_prompt_ids"])[0]
        # # predict:  Predict the noise residual
        
        model_pred = unet(batch['noisy_latents'], batch['timesteps'], encoder_hidden_states).sample
        # model_pred = model(batch)
        
        # loss: Get the loss 
        loss = F.mse_loss(model_pred, batch['noise'], reduction="mean")

        # backward
        accelerator.backward(loss)
            params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()))
            accelerator.clip_grad_norm_(params_to_clip, 1.)

        # Step
        optimizer.step()
        lr_scheduler.step()

        # Zero Grad
        optimizer.zero_grad(set_to_none=True)

        logs = {"loss": to_cpu(loss).item(), "lr": lr_scheduler.get_last_lr()[0]}
        print(f"{epoch} of {num_train_epochs} : Batch {step}")
        print(logs)

133 of 134 : Batch 0
{'loss': 0.08537904918193817, 'lr': 5e-06}
133 of 134 : Batch 1
{'loss': 0.21098282933235168, 'lr': 5e-06}
133 of 134 : Batch 2
{'loss': 0.051845647394657135, 'lr': 5e-06}


In [None]:
output_dir="dreambooth/out"
pipeline = DiffusionPipeline.from_pretrained(
    pretrained_model,
    unet=accelerator.unwrap_model(unet),
    text_encoder=accelerator.unwrap_model(text_encoder),
    revision=None,
)
pipeline.save_pretrained(output_dir)

In [None]:
from diffusers import StableDiffusionPipeline
import torch

pipe = StableDiffusionPipeline.from_pretrained(output_dir, torch_dtype=torch.float16).to("cuda")
prompt = "A photo of ssskkksss dog with a tennis ball"
    
images = [pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0] for i in range(9)]
clear_output()
print(prompt)
show_images(images)

In [None]:
#| export
# class BasicTrainCB(Callback):
#     '''Callback for basic pytorch training loop'''
#     def predict(self,trainer): trainer.preds = trainer.model(trainer.batch[0])
#     def get_loss(self,trainer): trainer.loss = trainer.loss_func(trainer.preds,trainer.batch[1])
#     def backward(self,trainer): trainer.loss.backward()
#     def step(self,trainer): trainer.opt.step()
#     def zero_grad(self,trainer): trainer.opt.zero_grad()

In [None]:
class AccelerateCB(BasicTrainCB):
    order = DeviceCB.order+10
    def __init__(self, n_inp=1, mixed_precision="fp16"):
        super().__init__(n_inp=n_inp)
        self.acc = Accelerator(mixed_precision=mixed_precision)
        
    def before_fit(self, learn):
        '''Wraps model, opt, data in accelerate'''
        learn.model,learn.opt,learn.dls.train,learn.dls.valid = self.acc.prepare(
            learn.model, learn.opt, learn.dls.train, learn.dls.valid)

        
    def backward(self, learn): 
        '''Using accelerate for backward pass'''
        self.acc.backward(learn.loss)

In [None]:
class DreamBoothTrainCB:
    
    def predict(self,trainer): 
        encoder_hidden_states = text_encoder(trainer.batch["instance_prompt_ids"])[0]
        trainer.preds = unet(noisy_latents, timesteps, encoder_hidden_states).sample.float() # Predict the noise residual

    def loss(self,trainer): 
        trainer.loss = F.mse_loss(trainer.preds, noise.float(), reduction="mean")

    def backward(self,trainer): 
        accelerator.backward(trainer.loss)
        if accelerator.sync_gradients:
            params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()))
            accelerator.clip_grad_norm_(params_to_clip, 1.)
            
    def step(self,trainer):
        optimizer.step()
        lr_scheduler.step()
        
    def zero_grad(self,trainer): pass
        optimizer.zero_grad(set_to_none=True)

    
    def generate_latents(self):
        latents = vae.encode(batch["instance_image"].float().to(dtype=torch.float16)).latent_dist.sample()
        latents = latents * vae.config.scaling_factor
        
    def noisify_latents(self):
        # Sample noise that we'll add to the latents
        noise = torch.randn_like(latents)
        
        # Sample a random timestep for each image
        timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=latents.device).long()

        # Add noise to the latents according to the noise magnitude at each timestep (forward diffusion)
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)