# Project 01 : Generating Tattoos from description

In [4]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets import load_dataset
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, DDPMScheduler
from peft import LoraConfig, get_peft_model
from transformers import CLIPTextModel, CLIPTokenizer
from torch.optim import AdamW
from tqdm.auto import tqdm  # Import tqdm for progress bars

In [5]:
"""
Check device availability of GPU
"""
# device = "cuda" if torch.cuda.is_available() else "cpu" # if you are using colab
device = "mps" if torch.mps.is_available() else "cpu" # If running in mac
print(f"Using device: {device}")

Using device: mps


In [None]:
"""
Load dataset from Hugging Face
"""
dataset = load_dataset("Drozdik/tattoo_v0")
train_dataset = dataset["train"]

In [None]:
"""
Define image preprocessing
"""
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize images
    transforms.ToTensor(),          # Convert to tensor
    transforms.Normalize([0.5], [0.5])  # Normalize to [-1, 1]
])

In [8]:
""" 
Function to preprocess dataset for DataLoader
"""

def collate_fn(batch):
    images = [transform(sample["image"]) for sample in batch]
    captions = ["Tattoo of " + sample["text"]
                for sample in batch]  # Modify captions here
    return {
        "pixel_values": torch.stack(images),  # Convert to tensor
        "text": captions
    }

In [9]:
# DataLoader with collate_fn
train_dataloader = DataLoader(
    train_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)

In [11]:
# Load pre-trained Stable Diffusion model

# model_id = "runwayml/stable-diffusion-v1-5"
# model_id = "prompthero/openjourney"
model_id = "SG161222/Realistic_Vision_V4.0"
pipeline = StableDiffusionPipeline.from_pretrained(
    model_id, torch_dtype=torch.float32).to(device)  # Use float32

ImportError: 
StableDiffusionPipeline requires the transformers library but it was not found in your environment. You can install it with pip: `pip
install transformers`
