In [None]:
!git clone https://github.com/lllyasviel/ControlNet

In [None]:
# 토큰 77 없애기 전처리

from transformers import CLIPTokenizer
import pandas as pd

tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")

csv_path = "/content/train.csv"
df = pd.read_csv(csv_path)

def is_within_token_limit(text, max_tokens=77):
    tokens = tokenizer.tokenize(str(text))
    return len(tokens) <= max_tokens

filtered_df = df[df['caption'].apply(is_within_token_limit)].reset_index(drop=True)

filtered_df.to_csv("/content/train_filtered.csv", index=False)
print(f"전처리 완료: {len(df) - len(filtered_df)}개 삭제됨. 'train_filtered.csv' 저장됨")

In [None]:
# 데이터셋 처리! (우리 데이터셋과 맞게)

import pandas as pd
import json
import cv2
import numpy as np

from torch.utils.data import Dataset


class MyDataset(Dataset):
    def __init__(self, csv_path="/content/train_filtered.csv"):
        self.data = pd.read_csv(csv_path)

        self.data['input_img_path'] = self.data['input_img_path'].apply(lambda x: x if x.startswith('/') else '/content/' + x.lstrip('./'))
        self.data['gt_img_path'] = self.data['gt_img_path'].apply(lambda x: x if x.startswith('/') else '/content/' + x.lstrip('./'))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]

        source_path = row['input_img_path']
        target_path = row['gt_img_path']
        prompt = row['caption']

        source = cv2.imread(source_path)
        target = cv2.imread(target_path)
        source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB)
        target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB)


        # Normalize source images to [0, 1].
        source = source.astype(np.float32) / 255.0

        # Normalize target images to [-1, 1].
        target = (target.astype(np.float32) / 127.5) - 1.0

        return dict(jpg=target, txt=prompt, hint=source)


In [None]:
%cd /content/

dataset = MyDataset()
print(len(dataset))

item = dataset[1234]
jpg = item['jpg']
txt = item['txt']
hint = item['hint']
print(txt)
print(jpg.shape)
print(hint.shape)

In [None]:
!wget https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.ckpt -O /content/ControlNet/models/ema_pruned.ckpt

In [None]:

# 이 코드 실행 전
# /content/ControlNet/tool_add_control_sd21.py
# 이 파일에 들어가
# 29번째 줄
# pretrained_weights = torch.load(input_path)를
# pretrained_weights = torch.load(input_path, weights_only=False)


In [None]:
%cd /content/ControlNet/
!python tool_add_control_sd21.py ./models/ema_pruned.ckpt ./models/control_sd21_ini.ckpt
%cd /content/

In [None]:

# /content/ControlNet/cldm/cldm_hsv_clip.py
# 여기에 추가하기

import torch, torch.nn.functional as F, torchvision, open_clip
from cldm.cldm import ControlLDM      # ← 방금 올려준 ControlLDM 재사용
import kornia.color as kc

class ControlLDM_HSV_CLIP(ControlLDM):
    def __init__(self, config_path, lambda_hsv=10., lambda_clip=5., **kwargs):
        super().__init__(**kwargs)                # 기존 ControlLDM 초기화
        self.lambda_hsv, self.lambda_clip = lambda_hsv, lambda_clip

        # CLIP
        self.clip_model, self.clip_preproc, _ = open_clip.create_model_and_transforms(
            "ViT-L-14", pretrained="openai", device="cpu")
        self.clip_model.eval().requires_grad_(False)
        self.tokenizer = open_clip.get_tokenizer("ViT-L-14")

    @staticmethod
    def _rgb2hsv(t):      # B 3 H W, 0~1
        return kc.rgb_to_hsv(t)

    def _init_hsv_clip(self, lambda_hsv=10., lambda_clip=5.):
        import open_clip, torchvision, torch
        self.lambda_hsv  = lambda_hsv
        self.lambda_clip = lambda_clip
        self.clip_model, self.clip_preproc, _ = open_clip.create_model_and_transforms(
            "ViT-L-14", pretrained="openai", device="cpu")
        self.clip_model.eval().requires_grad_(False)
        self.tokenizer = open_clip.get_tokenizer("ViT-L-14")
        self.rgb2hsv = kc.rgb_to_hsv

    # --------- 핵심 ----------
    # cldm/cldm_hsv_clip.py ─ training_step 교체
    def training_step(self, batch, batch_idx):
        # ── 1) 기본 입력 준비 ──────────────────────
        x, c = self.get_input(batch, self.first_stage_key)     # latent, cond
        t = torch.randint(0, self.num_timesteps, (x.size(0),), device=self.device).long()
        noise = torch.randn_like(x)

        x_noisy = self.q_sample(x_start=x, t=t, noise=noise)   # forward diffused
        noise_pred = self.apply_model(x_noisy, t, c)           # model prediction

        # ── 2) 기본 Noise-MSE (원래 Loss) ────────────
        loss_simple = F.mse_loss(noise_pred, noise)

        # ── 3) pred_x0 계산 & 디코딩 ─────────────────
        pred_x0 = self.predict_start_from_noise(x_noisy, t, noise_pred)  # latent clean
        rec = (self.decode_first_stage(pred_x0) + 1) / 2                 # 0~1
        if rec.shape[1] != 3:                                   # C가 2번째 축이 아니면
            rec = rec.permute(0, 3, 1, 2).contiguous()          # B C H W

        tgt = (batch["jpg"] + 1) / 2
        if tgt.shape[1] != 3:                                   # BHWC일 경우만
            tgt = tgt.permute(0, 3, 1, 2).contiguous()

        # ── 4) HSV-L1 ───────────────────────────────
        hsv_w = torch.tensor([3,1,1], device=self.device)[:, None, None]
        loss_hsv = (hsv_w * (self._rgb2hsv(rec) - self._rgb2hsv(tgt)).abs()).mean()

        # ── 5) CLIP Cosine ──────────────────────────
        rec_224 = F.interpolate(rec, size=224, mode="bilinear", align_corners=False)

        clip_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073],
                                device=self.device).view(1, 3, 1, 1)
        clip_std  = torch.tensor([0.26862954, 0.26130258, 0.27577711],
                                device=self.device).view(1, 3, 1, 1)

        img_clip = (rec_224 - clip_mean) / clip_std

        # ← 여기!  모델 weight 에 맞춰 dtype 캐스팅
        clip_dtype = self.clip_model.visual.conv1.weight.dtype
        img_clip   = img_clip.to(clip_dtype)

        img_emb = self.clip_model.encode_image(img_clip)

        txt_tok = self.tokenizer(batch["txt"]).to(self.device)
        txt_emb = self.clip_model.encode_text(txt_tok)
        loss_clip = 1 - F.cosine_similarity(img_emb, txt_emb).mean()



        # ── 6) 합산 & 로깅 ───────────────────────────
        loss = loss_simple + self.lambda_hsv*loss_hsv + self.lambda_clip*loss_clip
        self.log_dict(
            {"loss":loss, "mse":loss_simple, "hsv":loss_hsv, "clip":loss_clip},
            prog_bar=True, logger=True)
        return loss



In [None]:
import importlib, cldm.cldm_hsv_clip
importlib.reload(cldm.cldm_hsv_clip)

In [None]:
%cd /content/ControlNet
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from pytorch_lightning.callbacks import ModelCheckpoint
from cldm.model import create_model, load_state_dict
from cldm.cldm_hsv_clip import ControlLDM_HSV_CLIP

# ───── (1) 베이스 모델 그대로 만들기 ─────
cfg_path   = '/content/ControlNet/models/cldm_v21.yaml'
resume_ckpt= '/content/ControlNet/models/control_sd21_ini.ckpt'

base = create_model(cfg_path).cpu()                    # 기존 헬퍼!
base.load_state_dict(load_state_dict(resume_ckpt, 'cpu'))

# ───── (2) 클래스 교체 + 새 Loss 초기화 ─────
base.__class__ = ControlLDM_HSV_CLIP                   # 타입 갈아끼우기
base._init_hsv_clip(lambda_hsv=10, lambda_clip=5)      # CLIP 로드 & 파라미터 저장

# ───── (3) 나머지 기존 하이퍼파라미터 유지 ─────
base.learning_rate = 1e-5
base.sd_locked     = False
base.only_mid_control = False

model = base                                          # 가독성을 위해 alias

# ───── (4) DataLoader · Trainer 그대로 ─────
%cd /content/
dataset = MyDataset()
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0)

from cldm.logger import ImageLogger
logger = ImageLogger(batch_frequency=3000)
ckpt_cb= ModelCheckpoint(dirpath='/content/drive/MyDrive/SD2.1_training_RGB_hsv+clip',
                         filename='test-{epoch:02d}', save_last=True, every_n_epochs=1, save_top_k=-1)

trainer = pl.Trainer(accelerator="gpu", devices=1, precision=32,
                     max_epochs=5, callbacks=[logger, ckpt_cb])

In [None]:
# Train!
trainer.fit(model, dataloader)