# Imports
Some useful resources:  
* [huggingface](https://huggingface.co/docs/diffusers/v0.13.0/en/training/text2image)  
* [training-example](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) 

In [None]:
import requests
import torch
from PIL import Image
from io import BytesIO
!pip install diffusers
from diffusers import StableDiffusionImg2ImgPipeline, UNet2DConditionModel, AutoencoderKL, DDPMScheduler 
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from transformers.utils import ContextManagers
!pip install accelerate
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.state import AcceleratorState

device = "cuda" if torch.cuda.is_available() else 'cpu'
print(f"using {device} device")

# Build the Module
Modified from this [example](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py).

In [None]:
module_id_or_path = "CompVis/stable-diffusion-v1-4"
resolution = 512
weight_dtype=torch.float16  # datatype
revision = None  # Revision of pretrained model identifier from huggingface.co/models.
variant = None  # Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",

# Components
unet = UNet2DConditionModel.from_pretrained(module_id_or_path, subfolder='unet')
vae = AutoencoderKL.from_pretrained(
    module_id_or_path, 
    subfolder="vae", 
    revision=revision, 
    variant=variant
)
text_encoder = CLIPTextModel.from_pretrained(
    module_id_or_path, 
    subfolder="text_encoder", 
    revision=revision, 
    variant=variant
)
noise_scheduler = DDPMScheduler.from_pretrained(module_id_or_path, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(
    module_id_or_path, 
    subfolder="tokenizer", 
    revision=revision
)
safety_checker = None
feature_extractor = ...

# build the pipe
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
    module_id_or_path,
    vae=vae,
    text_encoder=text_encoder,
    tokenizer=tokenizer,
    unet=unet,
    safety_checker=safety_checker,
    revision=revision,
    variant=variant,
    torch_dtype=weight_dtype,
)
pipe = pipe.to(device)

# Test Module Setup
Here I used one task to make sure the module is setup correctly.

In [None]:
#url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
#response = requests.get(url)
# init_image = Image.open(BytesIO(response.content)).convert("RGB")
# init_image = init_image.resize((768, 512))

init_image = Image.open('example.png').convert('RGB')
c, r = init_image.size
c = c // 3
init_image = init_image.crop((0, 0, c, r))
init_image = init_image.resize((768, 512))

#prompt = "other views of the street"
#prompt = "face left"
prompt="face left side of the road"

images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
images[0].save("fantasy_landscape.png")

# Finetune the Module
(optional)

In [None]:
import torch.nn.functional as F
from accelerate.utils import ProjectConfiguration, set_seed
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel, compute_dream_and_update_latents, compute_snr

# Freeze vae and text_encoder and set unet to trainable
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
unet.train()

# training dataset
train_dataloader = ...

# Hyper params
max_train_steps=15000 
learning_rate=1e-05
max_grad_norm=1
train_batch_size=1
optimizer_cls = torch.optim.AdamW
adam_beta1 = 0.9
adam_beta2 = 0.999
adam_weight_decay=1e-2
adam_epsilon=1e-8
num_train_epochs = 100
noise_offset=0
input_perturbation = 0
prediction_type=None
snr_gamma=None   # recommend 5.0, if you want to use it.
gradient_accumulation_steps=1
report_to="tensorboard"
mixed_precision=None
output_dir="sd-model-finetuned"
logging_dir="logs"
accelerator_project_config=ProjectConfiguration(project_dir=output_dir, logging_dir=logging_dir)
lr_scheduler='constant' # must be one of ["linear", "cosine", "cosine_with_restarts", "polynomial",'
            #           ' "constant", "constant_with_warmup"]'
lr_warmup_steps=500

# optimizer
optimizer = optimizer_cls(
    unet.parameters(),
    lr=learning_rate,
    betas=(adam_beta1, adam_beta2),
    weight_decay=adam_weight_decay,
    eps=adam_epsilon,
)

# accelerator
accelerator = Accelerator(
    gradient_accumulation_steps=gradient_accumulation_steps,
    mixed_precision=mixed_precision,
    log_with=report_to,
    project_config=accelerator_project_config,
)
lr_scheduler = get_scheduler(
    lr_scheduler,
    optimizer=optimizer,
    num_warmup_steps=lr_warmup_steps * accelerator.num_processes,
    num_training_steps=max_train_steps * accelerator.num_processes,
)

# Prepare everything with our `accelerator`.
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
    unet, optimizer, train_dataloader, lr_scheduler
)

See more about loading custom images [here](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder).

In [None]:
# Prepare dataset and update dataloader.
import datasets
import random
from datasets import load_dataset
import numpy as np

dataset_name = "lambdalabs/pokemon-blip-captions"
dataset_config_name=None
cache_dir=None
train_data_dir=None
dataloader_num_workers=None
center_crop=False
random_flip=False
image_column='image'
caption_column='text'
seed=None
max_train_samples=None

dataset = load_dataset(
    dataset_name,
    dataset_config_name,
    cache_dir=cache_dir,
    data_dir=train_data_dir,
)
DATASET_NAME_MAPPING = {
    "lambdalabs/naruto-blip-captions": ("image", "text"),
}

# Preprocessing the datasets.
# We need to tokenize inputs and targets.
column_names = dataset["train"].column_names

# 6. Get the column names for input/target.
dataset_columns = DATASET_NAME_MAPPING.get(dataset_name, None)
if image_column is None:
    image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
else:
    image_column = image_column
    if image_column not in column_names:
        raise ValueError(
            f"--image_column' value '{image_column}' needs to be one of: {', '.join(column_names)}"
        )
if caption_column is None:
    caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
else:
    caption_column = caption_column
    if caption_column not in column_names:
        raise ValueError(
            f"--caption_column' value '{caption_column}' needs to be one of: {', '.join(column_names)}"
        )
# Preprocessing the datasets.
# We need to tokenize input captions and transform the images.
def tokenize_captions(examples, is_train=True):
    captions = []
    for caption in examples[caption_column]:
        if isinstance(caption, str):
            captions.append(caption)
        elif isinstance(caption, (list, np.ndarray)):
            # take a random caption if there are multiple
            captions.append(random.choice(caption) if is_train else caption[0])
        else:
            raise ValueError(
                f"Caption column `{caption_column}` should contain either strings or lists of strings."
            )
    inputs = tokenizer(
        captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
    )
    return inputs.input_ids
def preprocess_train(examples):
    images = [image.convert("RGB") for image in examples[image_column]]
    examples["pixel_values"] = [train_transforms(image) for image in images]
    examples["input_ids"] = tokenize_captions(examples)
    return examples

with accelerator.main_process_first():
    if max_train_samples is not None:
        dataset["train"] = dataset["train"].shuffle(seed=seed).select(range(max_train_samples))
    # Set the training transforms
    train_dataset = dataset["train"].with_transform(preprocess_train)

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
    input_ids = torch.stack([example["input_ids"] for example in examples])
    return {"pixel_values": pixel_values, "input_ids": input_ids}

train_transforms = transforms.Compose(
    [
        transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.CenterCrop(resolution) if center_crop else transforms.RandomCrop(resolution),
        transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ]
)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    shuffle=True,
    collate_fn=collate_fn,
    batch_size=train_batch_size,
)

In [None]:
# start training(fine-tuning)
for epoch in range(num_train_epochs):
    train_loss = 0.0
    for step, batch in enumerate(train_dataloader):
        # Convert images to latent space
        latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
        latents = latents * vae.config.scaling_factor

        # Sample noise that we'll add to the latents
        noise = torch.randn_like(latents)
        if noise_offset:
            # https://www.crosslabs.org//blog/diffusion-with-offset-noise
            noise += noise_offset * torch.randn(
                (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
            )
        if input_perturbation:
            new_noise = noise + input_perturbation * torch.randn_like(noise)
        bsz = latents.shape[0]
        # Sample a random timestep for each image
        timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
        timesteps = timesteps.long()

        # Add noise to the latents according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        if input_perturbation:
            noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps)
        else:
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

        # Get the text embedding for conditioning
        encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]

        # Get the target for loss depending on the prediction type
        if prediction_type is not None:
            # set prediction_type of scheduler if defined
            noise_scheduler.register_to_config(prediction_type=prediction_type)

        if noise_scheduler.config.prediction_type == "epsilon":
            target = noise
        elif noise_scheduler.config.prediction_type == "v_prediction":
            target = noise_scheduler.get_velocity(latents, noise, timesteps)
        else:
            raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

        # Predict the noise residual and compute loss
        model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]

        if snr_gamma is None:
            loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
        else:
            # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
            # Since we predict the noise instead of x_0, the original formulation is slightly changed.
            # This is discussed in Section 4.2 of the same paper.
            snr = compute_snr(noise_scheduler, timesteps)
            mse_loss_weights = torch.stack([snr, snr_gamma * torch.ones_like(timesteps)], dim=1).min(
                dim=1
            )[0]
            if noise_scheduler.config.prediction_type == "epsilon":
                mse_loss_weights = mse_loss_weights / snr
            elif noise_scheduler.config.prediction_type == "v_prediction":
                mse_loss_weights = mse_loss_weights / (snr + 1)

            loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
            loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
            loss = loss.mean()

        # Gather the losses across all processes for logging (if we use distributed training).
        avg_loss = accelerator.gather(loss.repeat(train_batch_size)).mean()
        train_loss += avg_loss.item() / gradient_accumulation_steps
        print(f"training loss is {train_loss}")

        # Backpropagate
        accelerator.backward(loss)
        if accelerator.sync_gradients:
            accelerator.clip_grad_norm_(unet.parameters(), max_grad_norm)
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

# Test the module
Use previous cell to verify that the weights of the weights has changed.

In [None]:
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"

response = requests.get(url)
init_image = Image.open(BytesIO(response.content)).convert("RGB")
init_image = init_image.resize((768, 512))

prompt = "A fantasy landscape, trending on artstation"

images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
images[0].save("fantasy_landscape.png")