# Project 01 : Generating Tattoos from description

In [1]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets import load_dataset
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, DDPMScheduler
from peft import LoraConfig, get_peft_model
from transformers import CLIPTextModel, CLIPTokenizer
from torch.optim import AdamW
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
"""
Check device availability of GPU
"""
# device = "cuda" if torch.cuda.is_available() else "cpu" # if you are using colab
device = "mps" if torch.mps.is_available() else "cpu" # If running in mac
print(f"Using device: {device}")

Using device: mps


In [3]:
"""
Load dataset from Hugging Face
"""
dataset = load_dataset("Drozdik/tattoo_v0")
train_dataset = dataset["train"]

In [4]:
"""
Define image preprocessing
"""
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize images
    transforms.ToTensor(),          # Convert to tensor
    transforms.Normalize([0.5], [0.5])  # Normalize to [-1, 1]
])

In [5]:
""" 
Function to preprocess dataset for DataLoader
"""

def collate_fn(batch):
    images = [transform(sample["image"]) for sample in batch]
    captions = ["Tattoo of " + sample["text"]
                for sample in batch]  # Modify captions here
    return {
        "pixel_values": torch.stack(images),  # Convert to tensor
        "text": captions
    }

In [None]:
""" 
DataLoader with collate_fn
""" 

train_dataloader = DataLoader(
    train_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)

In [None]:
""" 
Load pre-trained Stable Diffusion model
"""

# model_id = "prompthero/openjourney"
model_id = "runwayml/stable-diffusion-v1-5"
pipeline = StableDiffusionPipeline.from_pretrained(
    model_id, torch_dtype=torch.float32).to(device)

Loading pipeline components...: 100%|██████████| 7/7 [00:00<00:00, 11.71it/s]


In [None]:
"""  
Apply LoRa for easy finetune with minimal computational power
"""
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=["to_q", "to_v"]
)

In [9]:
pipeline.unet = get_peft_model(pipeline.unet, lora_config)
pipeline.unet.print_trainable_parameters()

trainable params: 797,184 || all params: 860,318,148 || trainable%: 0.0927


In [10]:
"""  
Setting up optimizer
"""
optimizer = AdamW(pipeline.unet.parameters(), lr=5e-5)

In [11]:
""" 
Tokenizer and Text Encoder
""" 

tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(
    model_id, subfolder="text_encoder").to(device)

In [12]:
def encode_captions(captions):
    inputs = tokenizer(
        captions,
        padding="max_length",
        max_length=tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt"
    )
    return text_encoder(inputs.input_ids.to(device))[0]

In [13]:
noise_scheduler = DDPMScheduler.from_pretrained(
    model_id, subfolder="scheduler")

In [None]:
""" 
Encode images into latent space using VAE
""" 

vae = pipeline.vae.to(device)

In [15]:
def encode_images_to_latent_space(images):
    # Ensure input images are in float32
    images = images.to(device)
    # Move images to latent space (4 channels)
    with torch.no_grad():  # No gradients for VAE encoding
        latents = vae.encode(images).latent_dist.sample()
        latents = latents * 0.18215  # Scaling factor used in Stable Diffusion
    return latents

In [16]:
accumulation_steps = 4

In [None]:
""" 
Training loop
"""
# Training loop with progress bars
pipeline.unet.train()
for epoch in range(5): 
    optimizer.zero_grad()
    #dataloader with tqdm
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}", leave=False)
    for i, batch in enumerate(progress_bar):
        images = batch["pixel_values"].to(device)
        captions = batch["text"]
        # Encode images
        latents = encode_images_to_latent_space(images)
        # Encode captions
        encoder_hidden_states = encode_captions(captions)
        # Sample noise
        noise = torch.randn_like(latents)
        bsz = latents.shape[0]
        timesteps = torch.randint(
            0, noise_scheduler.config.num_train_timesteps, (bsz,), device=device).long()
        # Add noise
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
        # Predict the noise residual
        noise_pred = pipeline.unet(
            noisy_latents, timesteps, encoder_hidden_states).sample
        # Compute loss
        loss = F.mse_loss(noise_pred, noise)
        loss = loss / accumulation_steps  # Normalize loss
        # Backward pass
        loss.backward()
        # Gradient Accumulation
        if (i + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        # Update progress
        progress_bar.set_postfix(loss=loss.item())
    print(f"Epoch {epoch+1}: Loss = {loss.item() * accumulation_steps}")

    # Save LoRA weights as .pth after each epoch
    lora_state_dict = {name: param for name,
                       param in pipeline.unet.named_parameters() if param.requires_grad}
    torch.save(lora_state_dict,
               f"Project_01/training_weights/lora_finetuned_weights_epoch_{epoch+1}.pth")
print("Training complete!")

Epoch 1:  42%|████▏     | 1850/4370 [16:33<22:58,  1.83it/s, loss=0.00119] 

In [None]:
""" 
Generate tattoo from text
"""

prompt = "A elephant head drawn with sacred geometric patterns"
# prompt = " A skull with digital glitch distortions and neon streaks."
# prompt = "A futuristic mask with neon symbols floating around it."
# prompt = "Tatoo of A delicate, face mask with celestial engravings and soft glowing edges."

image = pipeline(prompt).images[0]

# Save or display the image
image.save("finetuned_generated_tattoo.png")
image.show()