In [None]:
# ================================================================
#  1. Imports
# ================================================================
import torch
from diffusers import StableDiffusionPipeline
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
from torchvision import transforms
from torch.utils.data import DataLoader
from PIL import Image

# ================================================================
#  2. Device
# ================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Torch version:", torch.__version__)

# ================================================================
#  3. Load Stable Diffusion pipeline
# ================================================================
pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16
).to(device)

tokenizer = pipe.tokenizer
print(" SD pipeline & tokenizer ready")

# ================================================================
#  4. LoRA config for UNet transformer blocks
# ================================================================
lora_config = LoraConfig(
    r=4,
    lora_alpha=16,
    target_modules=["to_q", "to_k", "to_v", "to_out.0"],
    lora_dropout=0.1,
    bias="none"
)

pipe.unet = get_peft_model(pipe.unet, lora_config)
print(" LoRA attached with:", lora_config.target_modules)

# ================================================================
#  5. Define transform for your dataset
# ================================================================
torch_transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor()
])

def transform(batch):
    batch["image"] = [
        torch_transform(img.convert("RGB")) for img in batch["image"]
    ]
    return batch

# ================================================================
#  6. Loop through each style folder
# ================================================================
styles = ["oil_painting", "mosaic", "crayon", "pencil_sketch", "watercolor"]

epochs = 20 
batch_size = 4 

for style in styles:
    print(f"\n Training LoRA for style: {style}")

    dataset = load_dataset(
        "imagefolder",
        data_dir=f"style_dataset/{style}"
    )["train"]

    dataset = dataset.with_transform(transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    optimizer = torch.optim.Adam(pipe.unet.parameters(), lr=1e-4)

    for epoch in range(epochs):
        print(f" Epoch {epoch + 1}/{epochs}")
        for step, batch in enumerate(dataloader):
            images = batch["image"].to(device)
            captions = [f"A {style} artwork"] * images.size(0)

            encoding = tokenizer(
                captions,
                padding="max_length",
                truncation=True,
                max_length=tokenizer.model_max_length,
                return_tensors="pt"
            )

            input_ids = encoding.input_ids.to(device)

    save_name = f"lora_{style}.pt"
    torch.save(pipe.unet.state_dict(), save_name)
    print(f" Saved LoRA for {style}: {save_name}")

    test_image = Image.open("test.jpg").convert("RGB").resize((512, 512))
    print(f" Finished {style}!")

print("\n All styles processed! Ready for style transfer tests.")


Torch version: 2.7.1+cu126


Loading pipeline components...: 100%|██████████| 7/7 [00:00<00:00, 12.38it/s]


 SD pipeline & tokenizer ready
 LoRA attached with: {'to_q', 'to_v', 'to_out.0', 'to_k'}

 Training LoRA for style: oil_painting
 Epoch 1/10
 Epoch 2/10
 Epoch 3/10
 Epoch 4/10
 Epoch 5/10
 Epoch 6/10
 Epoch 7/10
 Epoch 8/10
 Epoch 9/10
 Epoch 10/10
 Saved LoRA for oil_painting: lora_oil_painting.pt
 Finished oil_painting!

 Training LoRA for style: mosaic
 Epoch 1/10
 Epoch 2/10
 Epoch 3/10
 Epoch 4/10
 Epoch 5/10
 Epoch 6/10
 Epoch 7/10
 Epoch 8/10
 Epoch 9/10
 Epoch 10/10
 Saved LoRA for mosaic: lora_mosaic.pt
 Finished mosaic!

 Training LoRA for style: crayon
 Epoch 1/10
 Epoch 2/10
 Epoch 3/10
 Epoch 4/10
 Epoch 5/10
 Epoch 6/10
 Epoch 7/10
 Epoch 8/10
 Epoch 9/10
 Epoch 10/10
 Saved LoRA for crayon: lora_crayon.pt
 Finished crayon!

 Training LoRA for style: pencil_sketch
 Epoch 1/10
 Epoch 2/10
 Epoch 3/10
 Epoch 4/10
 Epoch 5/10
 Epoch 6/10
 Epoch 7/10
 Epoch 8/10
 Epoch 9/10
 Epoch 10/10
 Saved LoRA for pencil_sketch: lora_pencil_sketch.pt
 Finished pencil_sketch!

 Training 