In [1]:
!pip install diffusers --upgrade
!pip install invisible_watermark transformers accelerate safetensors

Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable




In [2]:
from diffusers import StableDiffusionPipeline
import transformers

import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.transforms import RandAugment
from IPython.core.debugger import set_trace
import os

import random
import numpy as np
import matplotlib.pyplot as plt
import math
from PIL import Image

In [3]:
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the model to use off huggingface
model_id = "CompVis/stable-diffusion-v1-4"

# Path to directory containing images of the subject we want to use dreambooth on
dataset_path = '/home/ahm247/dreambooth/dataset/dog6'

# Path to our 200 photos of our prior found online. For another class generate the data
classes_path = '/home/ahm247/dreambooth/class-images'

# Prior and fine-tuning prompts
prior_prompt = 'A photo of a dog'
id_prompt = 'A photo of a mytoken dog'

In [4]:
# Define Dataset class for prior images
class CustomImageDataset(torch.utils.data.Dataset):
    def __init__(self, directory, transform=None):
        self.directory = directory 
        self.transform = transform
        self.image_paths = [os.path.join(directory, filename) for filename in os.listdir(directory)]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
    
        if self.transform:
            image = self.transform(image)

        return image


In [5]:
finetuned_pipe = StableDiffusionPipeline.from_pretrained(model_id, 
                                                torch_dtype=torch.float32,
                                                use_safetensors=True,
                                                variant="fp16",
                                                safety_checker = None,
                                                requires_safety_checker = False)
finetuned_pipe.to(device)

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

StableDiffusionPipeline {
  "_class_name": "StableDiffusionPipeline",
  "_diffusers_version": "0.27.2",
  "_name_or_path": "CompVis/stable-diffusion-v1-4",
  "feature_extractor": [
    "transformers",
    "CLIPImageProcessor"
  ],
  "image_encoder": [
    null,
    null
  ],
  "requires_safety_checker": false,
  "safety_checker": [
    null,
    null
  ],
  "scheduler": [
    "diffusers",
    "PNDMScheduler"
  ],
  "text_encoder": [
    "transformers",
    "CLIPTextModel"
  ],
  "tokenizer": [
    "transformers",
    "CLIPTokenizer"
  ],
  "unet": [
    "diffusers",
    "UNet2DConditionModel"
  ],
  "vae": [
    "diffusers",
    "AutoencoderKL"
  ]
}

In [6]:
def ram_used():
    memory_bytes = torch.cuda.memory_allocated()
    memory_megabytes = memory_bytes / (1024 ** 2)  # Convert bytes to MB
    print(f"Model is using approximately {memory_megabytes:.2f} MB of GPU memory.")
    
def memory_usage(tensor):
    element_size = tensor.element_size()  # Returns the size in bytes of each element
    num_elements = tensor.numel()  # Returns the total number of elements in the tensor
    total_memory_bytes = num_elements * element_size
    total_memory_mb = total_memory_bytes / (1024 ** 2)  # Convert bytes to megabytes

    print(f"Total memory usage of tensor: {total_memory_mb} MB")
    
ram_used()

Model is using approximately 4067.40 MB of GPU memory.


In [7]:
prior_token = finetuned_pipe.tokenizer(prior_prompt, return_tensors='pt').to(device)
id_token = finetuned_pipe.tokenizer(id_prompt, return_tensors='pt').to(device)

prior_input_ids = prior_token['input_ids']
prior_attention_masks = prior_token['attention_mask']

with torch.no_grad():
    prior_encoder_hidden = finetuned_pipe.text_encoder(input_ids=prior_input_ids, attention_mask=prior_attention_masks)
    
id_input_ids = id_token['input_ids']
id_attention_masks = id_token['attention_mask']

with torch.no_grad(): 
    id_encoder_hidden = finetuned_pipe.text_encoder(input_ids=id_input_ids, attention_mask=id_attention_masks)

del prior_token
del prior_input_ids
del prior_attention_masks

del id_token
del id_input_ids
del id_attention_masks

torch.cuda.empty_cache()


# Setting up the datasets/dataloaders
transform = transforms.Compose([
    transforms.Resize((512,512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

prior_dataset = CustomImageDataset(directory=classes_path, transform=transform)
id_dataset = CustomImageDataset(directory=dataset_path, transform=transform)
prior_dataloader = torch.utils.data.DataLoader(prior_dataset, batch_size=1, shuffle=True)
id_dataloader = torch.utils.data.DataLoader(id_dataset, batch_size=1, shuffle=True)

In [8]:
ram_used()

Model is using approximately 4075.57 MB of GPU memory.


In [21]:
# NUM_EPOCHS = 10
# accumulation_steps = 4
# finetuned_pipe.unet.train()
# optimizer = optim.AdamW(finetuned_pipe.unet.parameters(), 
#                         lr=5e-6,
#                         betas=(0.9,0.999),
#                         weight_decay=1e-2,
#                         eps=1e-08)
# # optimizer = optim.SGD(finetuned_pipe.unet.parameters(), 
# #                       lr=5e-6, 
# #                       weight_decay=1e-2,
# #                       momentum=0.9)
# mse_loss = nn.MSELoss()
# max_timesteps = finetuned_pipe.scheduler.num_train_timesteps


# for epoch in range(NUM_EPOCHS):
#     total_loss = 0
#     num_batches = 0

#     for i, prior_images in enumerate(prior_dataloader):
        
#         prior_images = prior_images.to(torch.float32).to(device)
#         id_batch = next(iter(id_dataloader))
#         id_images = random.choice(id_batch).unsqueeze(0).to(torch.float32).to(device)

#         prior_latent = finetuned_pipe.vae.encode(prior_images).latent_dist.sample()
#         prior_latent *= 0.18215
#         noisy_prior_latent = prior_latent + torch.randn_like(prior_latent)

#         id_latent = finetuned_pipe.vae.encode(id_images).latent_dist.sample()
#         id_latent *= 0.18215
#         noisy_id_latent = id_latent + torch.randn_like(id_latent)

#         denoised_prior_latent = finetuned_pipe.unet(noisy_prior_latent, timestep=max_timesteps, encoder_hidden_states=prior_encoder_hidden.last_hidden_state)      
#         denoised_id_latent = finetuned_pipe.unet(noisy_id_latent, timestep=max_timesteps, encoder_hidden_states=id_encoder_hidden.last_hidden_state)

#         # Calculate loss
#         loss_pr = mse_loss(denoised_prior_latent.sample, prior_latent)
#         loss_id = mse_loss(denoised_id_latent.sample, id_latent)
#         loss = loss_id + loss_pr
#         loss /= accumulation_steps
#         loss.backward()
        
#         # Free up memory again
#         del denoised_id_latent
#         del noisy_id_latent
#         del id_latent
#         del denoised_prior_latent
#         del noisy_prior_latent
#         del prior_latent
#         torch.cuda.empty_cache()
        
#         if (i + 1) % accumulation_steps == 0: 
#             optimizer.step() 
#             optimizer.zero_grad()
            
#         total_loss += loss.item() * accumulation_steps
#         num_batches += 1
        
        
#     average_loss = total_loss / num_batches
#     print(f"Epoch {epoch + 1}, Loss: {loss.item()}")
to_pil = transforms.ToPILImage()

In [10]:
NUM_EPOCHS = 10
accumulation_steps = 4
finetuned_pipe.unet.train()
optimizer = optim.AdamW(finetuned_pipe.unet.parameters(), 
                        lr=5e-6,
                        betas=(0.9,0.999),
                        weight_decay=1e-2,
                        eps=1e-08)
scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
mse_loss = nn.MSELoss()
max_timesteps = finetuned_pipe.scheduler.num_train_timesteps

for epoch in range(NUM_EPOCHS):
    
    final_image = finetuned_pipe('A photo of a mytoken dog swimming').images[0]
    image_output_directory = './my_images'
    if not os.path.exists(image_output_directory):
        os.makedirs(image_output_directory)
    name = 'mytoken_dog_swimming' + str(epoch) + '.jpg'
    # Save the image in the directory
    image_path = os.path.join(image_output_directory, name)  # You can change the filename extension according to the image format
    final_image.save(image_path)
    
    total_loss = 0
    num_batches = 0

    for i, prior_images in enumerate(prior_dataloader):
        prior_images = prior_images.to(torch.float32).to(device)
        id_batch = next(iter(id_dataloader))
        id_images = random.choice(id_batch).unsqueeze(0).to(torch.float32).to(device)
        
        # Add gradual noise
        noise_level = torch.rand(1, device=device)
        
        with torch.no_grad():
            prior_latent = finetuned_pipe.vae.encode(prior_images).latent_dist.sample()
            prior_latent *= 0.18215
            noisy_prior_latent = prior_latent + noise_level * torch.randn_like(prior_latent)

            id_latent = finetuned_pipe.vae.encode(id_images).latent_dist.sample()
            id_latent *= 0.18215
            noisy_id_latent = id_latent + noise_level * torch.randn_like(id_latent)

        # Sample a timestep
        timestep = torch.randint(0, max_timesteps, (1,), device=device).long()

        # Forward pass
        denoised_prior_latent = finetuned_pipe.unet(noisy_prior_latent, timestep=timestep, encoder_hidden_states=prior_encoder_hidden.last_hidden_state).sample
        denoised_id_latent = finetuned_pipe.unet(noisy_id_latent, timestep=timestep, encoder_hidden_states=id_encoder_hidden.last_hidden_state).sample

        # Calculate loss
        loss_pr = mse_loss(denoised_prior_latent, prior_latent)
        loss_id = mse_loss(denoised_id_latent, id_latent)
        loss = (loss_id + loss_pr) / accumulation_steps
        loss.backward()

        # Free up memory
        del denoised_id_latent
        del noisy_id_latent
        del id_latent
        del denoised_prior_latent
        del noisy_prior_latent
        del prior_latent
        torch.cuda.empty_cache()

        if (i + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        total_loss += loss.item() * accumulation_steps
        num_batches += 1

    average_loss = total_loss / num_batches
    print(f"Epoch {epoch + 1}, Loss: {average_loss:.4f}")

  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)


  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 1, Loss: 1.6384


  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 2, Loss: 1.0690


  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 3, Loss: 0.7942


  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 4, Loss: 0.6318


  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 5, Loss: 0.5180


  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 6, Loss: 0.4005


  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 7, Loss: 0.3640


  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 8, Loss: 0.3013


  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 9, Loss: 0.2894


  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 10, Loss: 0.2574


In [13]:
# Save model weights
saved_name = 'main_dog.pth'
torch.save(finetuned_pipe.unet.state_dict(), os.path.join('.','checkpoints', saved_name))

different_prompts = [
    'A photo of a mytoken dog swimming',
    'A photo of a mytoken dog in a city',
    'A photo of a mytoken dog wearing sunglasses',
    'A photo of a mytoken dog in Paris'
]
# Save image from fine-tuned model
for prompt in different_prompts: 
    final_image = finetuned_pipe(prompt).images[0]
    image_output_directory = './final_images'
    if not os.path.exists(image_output_directory):
        os.makedirs(image_output_directory)

    # Save the image in the directory
    image_path = os.path.join(image_output_directory, prompt + '.jpg')  # You can change the filename extension according to the image format
    final_image.save(image_path)

print('Images saved')

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

Images saved


In [31]:
for i, prior_images in enumerate(prior_dataloader):
    prior_images = prior_images.to(torch.float32).to(device)
    id_batch = next(iter(id_dataloader))
    id_images = random.choice(id_batch).unsqueeze(0).to(torch.float32).to(device)

    # Add gradual noise
    noise_level = torch.rand(1, device=device)

    with torch.no_grad():
        prior_latent = finetuned_pipe.vae.encode(prior_images).latent_dist.sample()
        prior_latent *= 0.18215
        noisy_prior_latent = prior_latent + noise_level * torch.randn_like(prior_latent)

    # Sample a timestep
    timestep = torch.randint(0, max_timesteps, (1,), device=device).long()

    # Forward pass
    denoised_prior_latent = finetuned_pipe.unet(noisy_prior_latent, timestep=timestep, encoder_hidden_states=prior_encoder_hidden.last_hidden_state).sample

    with torch.no_grad():
        reconstructed_image = finetuned_pipe.vae.decode(denoised_prior_latent).sample.squeeze(0)
        prior_image = finetuned_pipe.vae.decode(prior_latent).sample.squeeze(0)
        
    reconstructed_image = to_pil(reconstructed_image)
    prior_image = to_pil(prior_image)
        
    image_output_directory = './final_images'
    if not os.path.exists(image_output_directory):
        os.makedirs(image_output_directory)

    # Save the image in the directory
    prior_path = os.path.join(image_output_directory, 'prior_image.jpg')  # You can change the filename extension according to the image format
    reconstructed_path = os.path.join(image_output_directory, 'reconstructed_image.jpg')
    reconstructed_image.save(reconstructed_path)
    prior_image.save(prior_path)
    break
