In [1]:
# !pip install torch diffusers transformers accelerate tqdm datasets gdown

In [2]:
# !git config --global credential.helper store
# !python -m pip install huggingface_hub
# !huggingface-cli login

In [3]:
import torch
from torch.utils.data import DataLoader
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, PNDMScheduler
from transformers import CLIPTextModel, CLIPTokenizer, CLIPProcessor
from diffusers import DiffusionPipeline, AutoencoderKL
from diffusers import (
    AutoencoderKL,
    DDPMScheduler,
    DPMSolverMultistepScheduler,
    EDMEulerScheduler,
    EulerDiscreteScheduler,
    StableDiffusionXLPipeline,
    UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler
# from diffusers.utils import
from datasets import load_dataset
import os
from tqdm import tqdm
from PIL import Image
from torchvision import transforms

  deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message)


In [4]:
# Load the pre-trained Stable Diffusion model and tokenizer
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0",vae=vae,use_safetensors=True, torch_dtype=torch.float16)
# Replace the unet with the trainable version
unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, subfolder = "unet")
text_encoder = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0",torch_dtype=torch.float16, subfolder="text_encoder")
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, subfolder="tokenizer")
noise_scheduler = EDMEulerScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder = "scheduler")
# unet.to("cuda", torch.float16)
# text_encoder.to("cuda", torch.float16)
# vae.to("cuda", torch.float16)
# pipe.to("cuda")
# tokenizer.to("cuda", torch.float16)


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

The config attributes {'clip_sample': False, 'sample_max_value': 1.0, 'set_alpha_to_one': False, 'skip_prk_steps': True} were passed to EDMEulerScheduler, but are not expected and will be ignored. Please verify your scheduler_config.json configuration file.


RuntimeError: Unexpected error from cudaGetDeviceCount(). Did you run some cuda functions before calling NumCudaDevices() that might have already set an error? Error 804: forward compatibility was attempted on non supported HW

In [None]:
def tokenize_prompt(tokenizer, prompt):
    text_inputs = tokenizer(
        prompt,
        padding="max_length",
        max_length=tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
    )
    text_input_ids = text_inputs.input_ids
    return text_input_ids

In [None]:
def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
    prompt_embeds_list = []

    for i, text_encoder in enumerate(text_encoders):
        if tokenizers is not None:
            tokenizer = tokenizers[i]
            text_input_ids = tokenize_prompt(tokenizer, prompt)
        else:
            assert text_input_ids_list is not None
            text_input_ids = text_input_ids_list[i]

        prompt_embeds = text_encoder(
            text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False
        )

        # We are only ALWAYS interested in the pooled output of the final text encoder
        pooled_prompt_embeds = prompt_embeds[0]
        prompt_embeds = prompt_embeds[-1][-2]
        bs_embed, seq_len, _ = prompt_embeds.shape
        prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
        prompt_embeds_list.append(prompt_embeds)

    prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
    pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
    return prompt_embeds, pooled_prompt_embeds

In [None]:
image_transforms = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # Normalize to [-1, 1] range
])

In [None]:
!gdown "1uQs3QuctJU5pkacSNhVKDmBH83s9KdfZ"

In [None]:
!gdown "1WcUb7nMHOfW_wigBUayIXN43R86tXdIf"

In [None]:
import zipfile
with zipfile.ZipFile("./cut.zip", 'r') as zip_ref:
    zip_ref.extractall("./")

In [None]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, tokenizer, trans):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.trans = trans

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        text = self.dataset[idx]['prompt']
        image_id = "./cut/"+self.dataset[idx]['file_name']
        # encoding = self.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True)
        # image = Image.open("/content/cut/"+image_id).convert("RGB")
        image = Image.open(image_id).convert("RGB")
        image = self.trans(image)
        encoding = {}
        encoding['pixel_values'] = image
        encoding['raw'] = text
        return encoding

In [None]:
# Assuming `texts` and `images` are lists of text-image pairs
td = load_dataset("json", data_files="./cut.jsonl", split="train[:100%]")
train_dataset = CustomDataset(td, tokenizer, image_transforms)
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)

In [None]:
# Training loop
num_epochs = 10

# Define the optimizer and learning rate scheduler
optimizer = torch.optim.AdamW(pipe.unet.parameters(), lr=1e-5)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=500,
    num_training_steps=len(train_dataloader) * num_epochs,
)

In [None]:
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
        sigmas = noise_scheduler.sigmas.to(device="cuda", dtype=dtype)
        schedule_timesteps = noise_scheduler.timesteps.to("cuda")
        timesteps = timesteps.to("cuda")

        step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]

        sigma = sigmas[step_indices].flatten()
        while len(sigma.shape) < n_dim:
            sigma = sigma.unsqueeze(-1)
        return sigma

In [None]:
def compute_time_ids(original_size, crops_coords_top_left):
        # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
        target_size = (256, 256)
        add_time_ids = list(original_size + crops_coords_top_left + target_size)
        add_time_ids = torch.tensor([add_time_ids])
        add_time_ids = add_time_ids.to("cuda", dtype = torch.float16)
        return add_time_ids

In [None]:
for epoch in range(num_epochs):
    unet.train()
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}")
    epoch_loss = 0.0  # Initialize epoch loss

    for step, batch in enumerate(progress_bar):
        pixel_values = batch["pixel_values"].to(dtype=torch.float16, device = "cuda")
        text = batch["raw"]
        # Forward pass
        tokens = tokenize_prompt(tokenizer, text)
        model_input = vae.encode(pixel_values).latent_dist.sample()
        model_input = model_input * vae.config.scaling_factor

        # Sample noise that we'll add to the latents
        noise = torch.randn_like(model_input)
        bsz = model_input.shape[0]
        indices = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,))
        timesteps = noise_scheduler.timesteps[indices].to(device=model_input.device)
        noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
        sigmas = get_sigmas(timesteps, len(noisy_model_input.shape), noisy_model_input.dtype)
        inp_noisy_latents = noise_scheduler.precondition_inputs(noisy_model_input, sigmas)
        add_time_ids = torch.cat(
                    [
                        compute_time_ids(original_size=s, crops_coords_top_left=c)
                        for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])
                    ]
                )
        elems_to_repeat_text_embeds = 1
        unet_added_conditions = {"time_ids": add_time_ids}
        prompt_embeds, pooled_prompt_embeds = encode_prompt(
                        text_encoders=[text_encoder],
                        tokenizers=None,
                        prompt=None,
                        text_input_ids_list=[tokens],
                    )
        unet_added_conditions.update(
                        {"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)}
                    )
        prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
        model_pred = unet(
                        inp_noisy_latents,
                        timesteps,
                        prompt_embeds_input,
                        added_cond_kwargs=unet_added_conditions,
                        return_dict=False,
                    )[0]

        weighting = None
        model_pred = noise_scheduler.precondition_outputs(noisy_model_input, model_pred, sigmas)
        target = model_input
        loss = torch.mean(
                            (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(
                                target.shape[0], -1
                            ),
                            1,
                        )
        loss = loss.mean()
        # Backward pass
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()  # Accumulate batch loss for epoch
        progress_bar.set_postfix(loss=loss.item())

    # Average epoch loss
    epoch_loss /= len(train_dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.6f}")

    # Save the model checkpoint if needed based on validation loss

    # Step learning rate scheduler after each epoch
    lr_scheduler.step()

    # Save the model checkpoint
    pipe.unet.save_pretrained(f"model_checkpoint_{epoch + 1}")

print("Training complete.")