In [1]:
!pip install --quiet --upgrade diffusers transformers accelerate mediapy peft openai datasets
!pip install --quiet fsspec==2024.10.0

# Remove jax and jaxlib if present to prevent KeyArray errors
!pip uninstall -y jax jaxlib

import IPython
IPython.display.clear_output()
print("Dependencies installed. jax and jaxlib removed to avoid KeyArray errors. You can now continue with the next cells.")

Dependencies installed. jax and jaxlib removed to avoid KeyArray errors. You can now continue with the next cells.


In [2]:
import mediapy as media
import random
import sys
import torch
import time
import os
import requests
from PIL import Image
from io import BytesIO
from datasets import load_dataset
from torch import nn
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm

from diffusers import StableDiffusionPipeline, DDPMScheduler

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

import diffusers
import transformers
print("Diffusers version:", diffusers.__version__)
print("Transformers version:", transformers.__version__)

In [None]:
# Load the CIFAR-10 dataset and select only cat images (label=3)
dataset = load_dataset("cifar10", split="train")
print("Total samples in dataset:", len(dataset))

def is_cat(example):
    return example["label"] == 3  # cat images in CIFAR-10 have label=3

filtered = dataset.filter(is_cat)
print("Filtered cat images:", len(filtered))

filtered = filtered.select(range(min(3, len(filtered))))  # Take only 3 cat images
print("Final samples for training:", len(filtered))

images = []
for i in range(len(filtered)):
    # 'img' is already a PIL image, so we can use it directly
    img = filtered[i]["img"]  # This should be a PIL image (PngImageFile)
    images.append(img)
    print(f"Loaded cat image {i+1} from CIFAR-10")

print(f"Total training images loaded: {len(images)}")


In [None]:
placeholder_token = "<my_concept>"

class TextualInversionDataset(Dataset):
    def __init__(self, images, placeholder_token="<my_concept>", size=512):
        self.images = images
        self.placeholder_token = placeholder_token
        self.size = size

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx].resize((self.size, self.size), Image.BICUBIC)
        image = (torch.tensor(image).permute(2,0,1).float() / 255.0 - 0.5)/0.5
        prompt = self.placeholder_token
        return {"pixel_values": image, "prompt": prompt}

dataset = TextualInversionDataset(images, placeholder_token=placeholder_token)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)


In [None]:
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device)

# Change the safety_checker line to return a list for each image:
pipe.safety_checker = lambda images, clip_input: (images, [False]*len(images))

print("Pipeline loaded and safety checker modified.")

In [None]:
num_added_tokens = pipe.tokenizer.add_tokens(placeholder_token)
if num_added_tokens == 0:
    raise ValueError(f"The token {placeholder_token} already exists.")

text_encoder = pipe.text_encoder
text_encoder.resize_token_embeddings(len(pipe.tokenizer))

token_id = pipe.tokenizer.convert_tokens_to_ids(placeholder_token)
with torch.no_grad():
    token_embedding = text_encoder.get_input_embeddings().weight[token_id]
    torch.nn.init.normal_(token_embedding, mean=0.0, std=0.5)

print("Placeholder token added and embedding initialized.")


In [None]:
for name, param in text_encoder.named_parameters():
    param.requires_grad = False
text_encoder.get_input_embeddings().weight[token_id].requires_grad = True

optimizer = AdamW([text_encoder.get_input_embeddings().weight[token_id]], lr=5e-4)


In [None]:
class TextualInversionDataset(Dataset):
    def __init__(self, images, placeholder_token="<my_concept>", size=512):
        self.images = images
        self.placeholder_token = placeholder_token
        self.size = size

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx].resize((self.size, self.size), Image.BICUBIC)
        # Convert PIL image to NumPy array before creating a tensor
        image = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
        image = (image - 0.5) / 0.5
        prompt = self.placeholder_token
        return {"pixel_values": image, "prompt": prompt}


In [None]:
output_dir = "./textual_inversion_concept"
os.makedirs(output_dir, exist_ok=True)

learned_embedding = text_encoder.get_input_embeddings().weight[token_id].detach().cpu().numpy()
torch.save({"embedding": learned_embedding, "token": placeholder_token}, os.path.join(output_dir, "learned_embedding.pt"))

print("Learned embedding saved.")


In [None]:
prompt = """
cartoon loking, cute cat eating banana
"""

generator = torch.manual_seed(1234)
images = pipe(prompt=prompt, num_inference_steps=50, guidance_scale=7.5, generator=generator, height=512, width=512).images

media.show_images(images)
images[0].save("final_output.jpg")
print("final_output.jpg saved. Check above for the displayed image.")
