lora finetuning code

In [2]:
import sys
import os
import torch

# Add the directory containing model_pipeline to sys.path
#sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..","CatVTON")))
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..","CatVTON")))

from model.pipeline import CatVTONPipeline
from peft import LoraConfig, get_peft_model

#from diffusers import StableDiffusionPipeline
# # 1. Stable Diffusion 1.5 모델 로드
# model_id = "runwayml/stable-diffusion-v1-5"
# pipe = StableDiffusionPipeline.from_pretrained(
#     model_id, torch_dtype=torch.float16
# ).to("cuda")

# catvton
base_ckpt = "booksforcharlie/stable-diffusion-inpainting"
attn_ckpt = "zhengchong/CatVTON"
attn_ckpt_version = "vitonhd"

pipe = CatVTONPipeline(
    attn_ckpt_version,
    attn_ckpt,
    base_ckpt, 
    weight_dtype=torch.bfloat16, 
    device="cuda",
    skip_safety_checks=True
)


# 2. LoRA Config 설정 (Cross-Attention에 적용)
lora_config = LoraConfig(
    r=16, #8                     # LoRA Rank
    lora_alpha=32,             # Scaling factor
    target_modules=["to_q", "to_k", "to_v"],  
    lora_dropout=0.1,
)

# 3. LoRA 모델 적용
pipe.unet = get_peft_model(pipe.unet, lora_config)

TypeError: __init__() got an unexpected keyword argument 'skip_safety_checks'

In [None]:
# UNET_TARGET_MODULES = [
#     "to_q",
#     "to_k",
#     "to_v",
#     "proj",
#     "proj_in",
#     "proj_out",
#     "conv",
#     "conv1",
#     "conv2",
#     "conv_shortcut",
#     "to_out.0",
#     "time_emb_proj",
#     "ff.net.2",
# ]

In [None]:
from torchvision import transforms
from PIL import Image
import os

# ✅ 데이터 경로 설정
data_dir = "dataset/"
image_files = [f for f in os.listdir(data_dir) if f.endswith(".jpg")]

# ✅ 이미지 변환 설정 (Stable Diffusion 입력 크기 512x512)
transform = transforms.Compose([
    transforms.Resize((512, 384)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

# ✅ 데이터셋 준비
train_dataset = []
for img_file in image_files:
    image_path = os.path.join(data_dir, img_file)
    image = Image.open(image_path).convert("RGB")
    image = transform(image)

    # 이미지 파일명 기반으로 캡션 생성
    caption = "A person sitting in a wheelchair" if "wheelchair" in img_file else "A random object"

    train_dataset.append({"image": image, "caption": caption})

In [None]:
from torch.utils.data import DataLoader
import torch.optim as optim

# ✅ 데이터 로더 설정
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)

# ✅ Optimizer 설정
optimizer = optim.AdamW(pipe.unet.parameters(), lr=1e-4)

# ✅ Fine-tuning 함수
def train_lora(pipe, dataloader, epochs=10, batch_size=1):
    pipe.unet.train()

    for epoch in range(epochs):
        for batch in dataloader:
            image = batch["image"].to("cuda")
            caption = batch["caption"]

            # ✅ UNet Forward Pass (LoRA 적용된 상태)
            noise = torch.randn_like(image)  # 가우시안 노이즈 추가
            noisy_image = image + 0.1 * noise
            output = pipe.unet(noisy_image)

            # ✅ 손실 계산 (MSE Loss 사용)
            loss = ((output - image) ** 2).mean()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print(f"Epoch {epoch+1}, Loss: {loss.item()}")

# ✅ LoRA Fine-tuning 실행
train_lora(pipe, train_loader, epochs=20, batch_size=1)

In [None]:
# ✅ LoRA 가중치 저장
pipe.unet.save_pretrained("lora_sd1.5_finetuned")

# ✅ Inference 테스트
prompt = "A person sitting in a wheelchair, cinematic lighting, high detail"
image = pipe(prompt, height=512, width=384).images[0]

# ✅ 결과 저장
image.save("wheelchair_lora_result.png")