In [1]:
!pip install transformers datasets torchvision
!pip install "diffusers[torch]" 


[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m
Collecting diffusers[torch]
  Downloading diffusers-0.34.0-py3-none-any.whl.metadata (20 kB)
Collecting huggingface-hub>=0.27.0 (from diffusers[torch])
  Downloading huggingface_hub-0.34.1-py3-none-any.whl.metadata (14 kB)
Collecting accelerate>=0.31.0 (from diffusers[torch])
  Downloading accelerate-1.0.1-py3-none-any.whl.metadata (19 kB)
Collecting hf-xet<2.0.0,>=1.1.3 (from huggingface-hub>=0.27.0->diffusers[torch])
  Downloading hf_xet-1.1.5-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (879 bytes)
Downloading accelerate-1.0.1-py3-none-any.whl (330 kB)
Downloading huggingface_hub-0.34.1-py3-none-any.whl (558 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m558.8/558.8 kB[0m [31m63

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
from diffusers import StableDiffusionPipeline
from transformers import CLIPTokenizer
from datasets import load_dataset
from tqdm import tqdm



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load the full WikiArt dataset from HugGAN
dataset = load_dataset("huggan/wikiart", split="train",streaming=True)

# Filter by artist ID with tqdm
artist_id = 22
filtered = []

# Limit tqdm bar to something large but finite (streaming doesn't give length)
for sample in tqdm(dataset, desc=f"Filtering artist_id={artist_id}", total=100_000):
    if sample["artist"] == artist_id:
        filtered.append(sample)


Filtering artist_id=22:  81%|████████▏ | 81444/100000 [15:06<03:26, 89.80it/s]  


In [None]:
# --- get subject images ---
sub_dir = "subject_images"
os.makedirs(sub_dir, exist_ok=True)

for idx, sample in enumerate(filtered):
    image = sample["image"]  # already a PIL Image
    assert isinstance(image, Image.Image), f"Item {idx} is not a PIL image"

    # Optional: encode metadata into filename if desired
    artist = sample.get("artist", "unknown")
    genre = sample.get("genre", "unknown")
    style = sample.get("style", "unknown")

    filename = f"img_{idx:05d}_artist{artist}_genre{genre}_style{style}.jpg"
    filepath = os.path.join(sub_dir, filename)

    image.save(filepath)

print(f"✅ Saved {len(filtered)} images to '{sub_dir}/'")

In [2]:
device = "cuda"

In [3]:
# ----- Load pretrained Stable Diffusion -----
model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device)
tokenizer: CLIPTokenizer = pipe.tokenizer


Loading pipeline components...: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]


In [66]:
# ----- Add a new token -----
new_token = "[skaz]"
class_token = "artwork"
num_added = tokenizer.add_tokens([new_token])
pipe.text_encoder.resize_token_embeddings(len(tokenizer))
new_token_id = tokenizer.convert_tokens_to_ids(new_token)

# we use class token in order to avoid drift, as mentioned in the dreambooth paper.
class_token_id = tokenizer.convert_tokens_to_ids(class_token)

In [67]:
num_added

1

In [68]:
# ----- fine-tune new token embedding -----
embedding_layer = pipe.text_encoder.get_input_embeddings()

# Freeze all text encoder params
for param in pipe.text_encoder.parameters():
    param.requires_grad = False

# Unfreeze the embedding weights
embedding_layer.weight.requires_grad = True

optimizer = torch.optim.Adam([
    embedding_layer.weight
], lr=5e-4)


In [69]:
transform = transforms.Compose([
    transforms.Resize(512),
    transforms.CenterCrop(512),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

In [70]:
class SimpleDreamBoothDataset(Dataset):
    def __init__(self, subject_dir, prior_dir, transform):
        # build full paths, and only keep files
        self.subject_images = [
            os.path.join(subject_dir, f)
            for f in os.listdir(subject_dir)
            if os.path.isfile(os.path.join(subject_dir, f))
        ]
        self.prior_images = [
            os.path.join(prior_dir, f)
            for f in os.listdir(prior_dir)
            if os.path.isfile(os.path.join(prior_dir, f))
        ]
        self.transform = transform

    def __len__(self):
        # so we cycle through the smaller set repeatedly
        return max(len(self.subject_images), len(self.prior_images))

    def __getitem__(self, idx):
        # wrap around
        subj_path  = self.subject_images[idx % len(self.subject_images)]
        prior_path = self.prior_images[idx % len(self.prior_images)]

        # open & ensure RGB
        img_subj  = Image.open(subj_path).convert("RGB")
        img_prior = Image.open(prior_path).convert("RGB")

        return self.transform(img_subj), self.transform(img_prior)


In [71]:
class_prompt = "A painting in the style of Van Gogh"
num_images = 400
pri_dir = "prior_images"
os.makedirs(pri_dir, exist_ok=True)

In [9]:
for i in range(num_images):
    if i % 10 == 0:
        print(f"generating prior image #{i}")
    image = pipe(class_prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
    image.save(os.path.join(pri_dir, f"class_image_{i:03}.png"))

100%|██████████| 50/50 [00:00<00:00, 59.93it/s]


In [72]:
dataset = SimpleDreamBoothDataset("subject_images", "prior_images",transform=transform)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

In [73]:
from torch.cuda.amp import autocast
import torch.nn.functional as F

# ----- Training loop -----
num_epochs = 1000
lmbda = 1.0  # prior preservation weight

for step, (x_subj, x_prior) in enumerate(dataloader):
    if step >= num_epochs:
        break

    # Move to device
    x_subj = x_subj.to(device)
    x_prior = x_prior.to(device)

    # --- Tokenize prompts ---
    subj_prompt  = f"a photo of {new_token} {class_token}"
    prior_prompt = f"a photo of {class_token}"
    subj_ids  = tokenizer(subj_prompt,  return_tensors="pt").input_ids.to(device)
    prior_ids = tokenizer(prior_prompt, return_tensors="pt").input_ids.to(device)

    # --- Text embeddings (frozen) ---
    with torch.no_grad():
        subj_embeds  = pipe.text_encoder(subj_ids)[0]
        prior_embeds = pipe.text_encoder(prior_ids)[0]

    # --- Encode RGB images → 4-channel latents (float16) ---
    with torch.no_grad():
        x_subj_fp16  = x_subj.to(torch.float16)
        x_prior_fp16 = x_prior.to(torch.float16)

        latents_subj  = pipe.vae.encode(x_subj_fp16).latent_dist.sample()
        latents_prior = pipe.vae.encode(x_prior_fp16).latent_dist.sample()

        latents_subj  *= pipe.vae.config.scaling_factor
        latents_prior *= pipe.vae.config.scaling_factor

    # --- Noise injection ---
    noise = torch.randn_like(latents_subj)
    batch_size = latents_subj.shape[0]
    t = torch.randint(
        0,
        pipe.scheduler.config.num_train_timesteps,
        (batch_size,),
        device=device,
    ).long()

    alphas = pipe.scheduler.alphas_cumprod.to(device)
    alpha_t = alphas[t].view(-1, 1, 1, 1).sqrt()
    sigma_t = (1 - alphas[t]).view(-1, 1, 1, 1).sqrt()

    noisy_subj  = alpha_t * latents_subj  + sigma_t * noise
    noisy_prior = alpha_t * latents_prior + sigma_t * noise

    # --- Forward + Loss (mixed precision) ---
    with autocast():
        pred_subj  = pipe.unet(noisy_subj,  t, encoder_hidden_states=subj_embeds).sample
        pred_prior = pipe.unet(noisy_prior, t, encoder_hidden_states=prior_embeds).sample

        loss_subj  = F.mse_loss(pred_subj,  noise)
        loss_prior = F.mse_loss(pred_prior, noise)
        loss = loss_subj + lmbda * loss_prior

    # --- Backprop & optimize ---
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if step % 10 == 0:
        print(
            f"[{step:04d}] "
            f"Loss: {loss.item():.4f} | "
            f"Subj: {loss_subj.item():.4f} | "
            f"Prior: {loss_prior.item():.4f}"
        )


[0000] Loss: 0.7971 | Subj: 0.3815 | Prior: 0.4156
[0010] Loss: 0.0782 | Subj: 0.0374 | Prior: 0.0408
[0020] Loss: 0.3964 | Subj: 0.1696 | Prior: 0.2267
[0030] Loss: 1.0207 | Subj: 0.4731 | Prior: 0.5476
[0040] Loss: 0.8250 | Subj: 0.3836 | Prior: 0.4414
[0050] Loss: 0.8168 | Subj: 0.4479 | Prior: 0.3689
[0060] Loss: 0.3136 | Subj: 0.1276 | Prior: 0.1860
[0070] Loss: 0.0263 | Subj: 0.0103 | Prior: 0.0161
[0080] Loss: 0.0100 | Subj: 0.0041 | Prior: 0.0059
[0090] Loss: 1.0049 | Subj: 0.5022 | Prior: 0.5027
[0100] Loss: 0.0644 | Subj: 0.0260 | Prior: 0.0384
[0110] Loss: 0.8375 | Subj: 0.4725 | Prior: 0.3650
[0120] Loss: 0.0106 | Subj: 0.0104 | Prior: 0.0003
[0130] Loss: 0.2777 | Subj: 0.1185 | Prior: 0.1592
[0140] Loss: 0.8461 | Subj: 0.4029 | Prior: 0.4433
[0150] Loss: 0.1992 | Subj: 0.0805 | Prior: 0.1187
[0160] Loss: 1.0240 | Subj: 0.5834 | Prior: 0.4406
[0170] Loss: 0.3257 | Subj: 0.1894 | Prior: 0.1364
[0180] Loss: 0.7059 | Subj: 0.3700 | Prior: 0.3360
[0190] Loss: 0.2950 | Subj: 0.1

In [74]:
def generate_skax_images(prompt, num_images=1, guidance_scale=7.5, num_inference_steps=100, height=512, width=512):
    output = pipe(
        prompt,
        height=height,
        width=width,
        num_images_per_prompt=num_images,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
    )
    return output.images

In [93]:
prompt = "[skaz] drawing with a grass near the lake and people setting next to it drawn like [skaz]"
imgs = generate_skax_images(prompt,num_images=5)

100%|██████████| 100/100 [00:04<00:00, 23.43it/s]


In [92]:
for idx, img in enumerate(imgs):
    img.save(f"skaz_{prompt.replace('[skaz]' ,'skaz').replace(' ','_')}_{idx}.png")


In [63]:
# save fine-tuned dreambooth model.
output_dir = "dreambooth_skax_finetuned"
pipe.save_pretrained(output_dir)