lora finetuning code

In [None]:
import sys
import os
import torch
from tqdm import tqdm
from PIL import Image, ImageFilter

# Add the directory containing model_pipeline to sys.path
#sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..","CatVTON")))
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..","CatVTON")))

from model.pipeline import CatVTONPipeline
from peft import LoraConfig, get_peft_model

#from diffusers import StableDiffusionPipeline
# # 1. Stable Diffusion 1.5 모델 로드
# model_id = "runwayml/stable-diffusion-v1-5"
# pipe = StableDiffusionPipeline.from_pretrained(
#     model_id, torch_dtype=torch.float16
# ).to("cuda")

# catvton
base_ckpt = "booksforcharlie/stable-diffusion-inpainting"
attn_ckpt = "zhengchong/CatVTON"
attn_ckpt_version = "mix"

pipe = CatVTONPipeline(
    base_ckpt, 
    attn_ckpt,
    attn_ckpt_version,
    weight_dtype=torch.float32, 
    device="cuda",
    skip_safety_check=True
)


# 2. LoRA Config 설정 (Cross-Attention에 적용)
lora_config = LoraConfig(
    r=16, #8                     # LoRA Rank
    lora_alpha=32,             # Scaling factor
    target_modules=["to_q", "to_k", "to_v"],  
    lora_dropout=0.1,
)

# 3. LoRA 모델 적용
pipe.unet = get_peft_model(pipe.unet, lora_config)

An error occurred while trying to fetch booksforcharlie/stable-diffusion-inpainting: booksforcharlie/stable-diffusion-inpainting does not appear to have a file named diffusion_pytorch_model.safetensors.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
Fetching 12 files: 100%|██████████| 12/12 [00:00<00:00, 12285.00it/s]


Downloaded zhengchong/CatVTON to C:\Users\010\.cache\huggingface\hub\models--zhengchong--CatVTON\snapshots\2969fcf85fe62f2036605716f0b56f0b81d01d79


### 커스텀 데이터셋 정의

In [None]:
# UNET_TARGET_MODULES = [
#     "to_q",
#     "to_k",
#     "to_v",
#     "proj",
#     "proj_in",
#     "proj_out",
#     "conv",
#     "conv1",
#     "conv2",
#     "conv_shortcut",
#     "to_out.0",
#     "time_emb_proj",
#     "ff.net.2",
# ]

In [None]:
from torchvision import transforms
from PIL import Image
import os

# 데이터 경로 설정
data_dir = "dataset/"
image_files = [f for f in os.listdir(data_dir) if f.endswith(".jpg")]

# 이미지 변환 설정 (Stable Diffusion 입력 크기 512x512)
transform = transforms.Compose([
    transforms.Resize((512, 384)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

# 데이터셋 준비
train_dataset = []
for img_file in image_files:
    image_path = os.path.join(data_dir, img_file)
    image = Image.open(image_path).convert("RGB")
    image = transform(image)

    # 이미지 파일명 기반으로 캡션 생성
    caption = "A person sitting in a wheelchair, cinematic lighting, high detail" if "wheelchair" in img_file else "A random object"

    train_dataset.append({"image": image, "caption": caption})

In [None]:
from torch.utils.data import DataLoader
import torch.optim as optim

# 데이터 로더 설정
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)

# Optimizer 설정
optimizer = optim.AdamW(pipe.unet.parameters(), lr=1e-5)

# Fine-tuning 함수
def train_lora(pipe, dataloader, epochs=10, batch_size=32):
    pipe.unet.train()

    for epoch in range(epochs):
        for batch in dataloader:
            image = batch["image"].to("cuda")
            caption = batch["caption"]

            # ✅ UNet Forward Pass (LoRA 적용된 상태)
            noise = torch.randn_like(image)  # 가우시안 노이즈 추가
            noisy_image = image + 0.1 * noise
            output = pipe.unet(noisy_image)

            # ✅ 손실 계산 (MSE Loss 사용)
            loss = ((output - image) ** 2).mean()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print(f"Epoch {epoch+1}, Loss: {loss.item()}")

# LoRA Fine-tuning 실행
train_lora(pipe, train_loader, epochs=20, batch_size=1)

In [None]:
# LoRA 가중치 저장
pipe.unet.save_pretrained("lora_sd1.5_finetuned")

# Inference 테스트
#prompt = "A person sitting in a wheelchair, cinematic lighting, high detail"
#image = pipe(prompt, height=512, width=384).images[0]

# 결과 저장
image.save("wheelchair_lora_result.png")

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import argparse
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from diffusers.image_processor import VaeImageProcessor
from PIL import Image

# PEFT 라이브러리를 통해 LoRA 모듈 적용
try:
    from peft import get_peft_model, LoraConfig, TaskType
except ImportError:
    raise ImportError("Please install peft library: pip install peft")

# latent diffusion 관련 라이브러리 (예: diffusers의 DDPMScheduler)
try:
    from diffusers import DDPMScheduler
except ImportError:
    raise ImportError("Please install diffusers library: pip install diffusers")

# CatVTONPipeline 임포트 (여러분의 프로젝트 구조에 맞게 수정)
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..","CatVTON")))
from model.pipeline import CatVTONPipeline


def parse_args():
    parser = argparse.ArgumentParser(description="LoRA Fine-tuning for Latent Diffusion based CatVTON")
    parser.add_argument("--data_root_path", type=str, required=True, help="Path to the training dataset.")
    parser.add_argument("--output_dir", type=str, default="output", help="Directory to save checkpoints.")
    parser.add_argument("--num_epochs", type=int, default=10, help="Number of training epochs.")
    parser.add_argument("--batch_size", type=int, default=8, help="Batch size for training.")
    parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate.")
    parser.add_argument("--lora_rank", type=int, default=4, help="LoRA rank parameter.")
    parser.add_argument("--seed", type=int, default=555, help="Random seed for reproducibility.")
    # latent diffusion 관련 파라미터
    parser.add_argument("--num_train_timesteps", type=int, default=1000, help="Number of diffusion steps for training.")
    args = parser.parse_args()
    return args


class TrainDataset(Dataset):
    def __init__(self, args):
        self.args = args
        self.vae_processor = VaeImageProcessor(vae_scale_factor=8)
        self.mask_processor = VaeImageProcessor(
            vae_scale_factor=8, 
            do_normalize=False, 
            do_binarize=True, 
            do_convert_grayscale=True
        )
        self.data = self.load_data()

    def load_data(self):
        return []

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data = self.data[idx]
        person, cloth, mask = [Image.open(data[key]) for key in ['person', 'cloth', 'mask']]
        return {
            'index': idx,
            'person_name': data['person_name'],
            'person': self.vae_processor.preprocess(person, self.args.height, self.args.width)[0],
            'cloth': self.vae_processor.preprocess(cloth, self.args.height, self.args.width)[0],
            'mask': self.mask_processor.preprocess(mask, self.args.height, self.args.width)[0]
        }

class VITONHDTestDataset(TrainDataset):
    def load_data(self):
        pair_txt = os.path.join(self.args.data_root_path, 'test_pairs_unpaired.txt')
        assert os.path.exists(pair_txt), f"File {pair_txt} does not exist."
        with open(pair_txt, 'r') as f:
            lines = f.readlines()
        self.args.data_root_path = os.path.join(self.args.data_root_path, "test")
        output_dir = os.path.join(
            self.args.output_dir, 
            "vitonhd", 
            'unpaired' if not self.args.eval_pair else 'paired'
        )
        data = []
        for line in lines:
            person_img, cloth_img = line.strip().split(" ")
            if os.path.exists(os.path.join(output_dir, person_img)):
                continue
            if self.args.eval_pair:
                cloth_img = person_img
            data.append({
                'person_name': person_img,
                'person': os.path.join(self.args.data_root_path, 'image', person_img),
                'cloth': os.path.join(self.args.data_root_path, 'cloth', cloth_img),
                'mask': os.path.join(self.args.data_root_path, 'agnostic-mask', person_img.replace('.jpg', '_mask.png')),
            })
        return data

class VITONHDTrainDataset(TrainDataset):
    def load_data(self):
        pair_txt = os.path.join(self.args.data_root_path, 'train_pairs.txt')
        assert os.path.exists(pair_txt), f"File {pair_txt} does not exist."
        with open(pair_txt, 'r') as f:
            lines = f.readlines()
        self.args.data_root_path = os.path.join(self.args.data_root_path, "train")
        output_dir = os.path.join(
            self.args.output_dir, 
            "vitonhd", 
            'unpaired' if not self.args.eval_pair else 'paired'
        )
        data = []
        for line in lines:
            person_img, cloth_img = line.strip().split(" ")
            if os.path.exists(os.path.join(output_dir, person_img)):
                continue
            if self.args.eval_pair:
                cloth_img = person_img
            data.append({
                'person_name': person_img,
                'person': os.path.join(self.args.data_root_path, 'image', person_img),
                'cloth': os.path.join(self.args.data_root_path, 'cloth', cloth_img),
                'mask': os.path.join(self.args.data_root_path, 'agnostic-mask', person_img.replace('.jpg', '_mask.png')),
            })
        return data


def main():
    args = parse_args()
    torch.manual_seed(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 1. 파이프라인과 모델 초기화 (여러분의 환경에 맞게 수정)
    base_ckpt = "booksforcharlie/stable-diffusion-inpainting"
    attn_ckpt = "zhengchong/CatVTON"
    attn_ckpt_version = "mix"

    pipeline = CatVTONPipeline(
        base_ckpt, 
        attn_ckpt,
        attn_ckpt_version,
        weight_dtype=torch.bfloat16, 
        device="cuda",
        skip_safety_check=True
    )

    # fine-tuning 대상 모델 (예: UNet) 추출
    model = pipeline.model  
    model.to(device)

    # 2. LoRA 설정 및 적용
    lora_config = LoraConfig(
        # task_type=TaskType.TEXT_TO_IMAGE,  # 상황에 맞게 변경
        r=args.lora_rank,
        lora_alpha=6,
        lora_dropout=0.1,
    )
    model = get_peft_model(model, lora_config)
    print("LoRA 적용 완료. 현재 학습 파라미터 수:",
          sum(p.numel() for p in model.parameters() if p.requires_grad))

    # 3. 노이즈 스케줄러 (DDPM) 초기화
    scheduler = DDPMScheduler(num_train_timesteps=args.num_train_timesteps)
    
    # 4. 옵티마이저 및 손실 함수 정의
    optimizer = optim.AdamW(model.parameters(), lr=args.lr)
    loss_fn = nn.MSELoss()  # diffusion 학습에서는 주로 예측한 노이즈와 실제 노이즈 간의 MSE를 사용

    # 5. 데이터셋 및 DataLoader 정의
    transform = transforms.Compose([
        transforms.Resize((512, 384)),  # 모델 입력 해상도에 맞게 조정
        transforms.ToTensor(),
    ])
    dataset = VITONHDTrainDataset(args.data_root_path, transform=transform)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4)

    model.train()
    for epoch in range(args.num_epochs):
        total_loss = 0.0
        for batch in dataloader:
            # 원본 이미지 (예: latent로 변환하기 전 이미지)
            images = batch.to(device)

            # 5.1. VAE encoder를 통해 latent 표현을 구하는 경우
            # latent = vae_encoder(images)  
            # 여기서는 간단히 images를 latent로 사용한다고 가정
            latents = images

            # 5.2. 노이즈 샘플링
            noise = torch.randn_like(latents)

            # 5.3. diffusion 스케줄러로부터 random timestep 선택
            timesteps = torch.randint(0, args.num_train_timesteps, (latents.shape[0],), device=device).long()

            # 5.4. 노이즈를 latent에 추가 (scheduler의 add_noise 함수 사용)
            noisy_latents = scheduler.add_noise(latents, noise, timesteps)

            optimizer.zero_grad()
            # 5.5. 모델에 noisy_latents와 timestep 정보를 입력하여 노이즈 예측
            # 모델 구조에 따라 timestep 입력 방식은 달라질 수 있음.
            # 예시: predicted_noise = model(noisy_latents, timesteps=timesteps)
            predicted_noise = model(noisy_latents, timesteps=timesteps)

            # 5.6. 손실 계산: 예측한 노이즈와 실제 샘플링한 노이즈의 MSE
            loss = loss_fn(predicted_noise, noise)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch [{epoch+1}/{args.num_epochs}], Loss: {avg_loss:.4f}")

        # 체크포인트 저장
        os.makedirs(args.output_dir, exist_ok=True)
        checkpoint_path = os.path.join(args.output_dir, f"model_epoch_{epoch+1}.pt")
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Checkpoint 저장: {checkpoint_path}")

    print("학습 완료.")


if __name__ == "__main__":
    main()