In [26]:
!pip install -q diffusers==0.30.0 transformers accelerate safetensors datasets
!pip install -q torchvision

# 1. Chuẩn bị subset COCO + caption chứa token mới

In [27]:
import json, random
from pathlib import Path
from collections import defaultdict
from PIL import Image
from torchvision import transforms

BASE_INPUT = Path("/kaggle/input")

# 1. Tìm file captions_train2017.json và thư mục train2017
caption_files = list(BASE_INPUT.rglob("captions_train2017.json"))
train_dirs    = list(BASE_INPUT.rglob("train2017"))

print("Found captions files:", caption_files)
print("Found train dirs    :", train_dirs)

if not caption_files:
    raise FileNotFoundError("Không tìm thấy captions_train2017.json trong /kaggle/input")
if not train_dirs:
    raise FileNotFoundError("Không tìm thấy thư mục train2017 trong /kaggle/input")

CAPTION_FILE  = caption_files[0]       
TRAIN_IMG_DIR = train_dirs[0]

print("Using CAPTION_FILE :", CAPTION_FILE)
print("Using TRAIN_IMG_DIR:", TRAIN_IMG_DIR)

# 2. Đọc annotation COCO
with open(CAPTION_FILE, "r") as f:
    coco_caps = json.load(f)

images      = {img["id"]: img for img in coco_caps["images"]}
annotations = coco_caps["annotations"]

imgid2caps = defaultdict(list)
for ann in annotations:
    imgid2caps[ann["image_id"]].append(ann["caption"])

# 3. Chọn random N ảnh làm instance images cho style <sks_style>
N_IMAGES = 20  
all_image_ids = list(imgid2caps.keys())
random.seed(42)
selected_ids = random.sample(all_image_ids, N_IMAGES)

# 4. Tạo metadata.jsonl + ảnh 512x512
WORK_DIR   = Path("/kaggle/working/sks_style_data")
IMG_OUT_DIR = WORK_DIR / "images"
IMG_OUT_DIR.mkdir(parents=True, exist_ok=True)

metadata_path = WORK_DIR / "metadata.jsonl"

resize_512 = transforms.Compose([
    transforms.Resize((512, 512)),
])

with open(metadata_path, "w", encoding="utf-8") as fw:
    for img_id in selected_ids:
        img_info  = images[img_id]
        file_name = img_info["file_name"]          
        src_path  = TRAIN_IMG_DIR / file_name

        if not src_path.exists():
            continue

        # caption COCO gốc
        caps = imgid2caps[img_id]
        base_caption = caps[0]

        # chèn token style
        text = f"{base_caption}, in <sks_style> style"

        # Resize về 512x512 và lưu
        image = Image.open(src_path).convert("RGB")
        image = resize_512(image)
        out_name = f"{img_id}.png"
        out_path = IMG_OUT_DIR / out_name
        image.save(out_path)

        rec = {
            "file_name": str(out_path.relative_to(WORK_DIR)),
            "text": text
        }
        fw.write(json.dumps(rec) + "\n")

print("✅ Done. Created subset at:", WORK_DIR)
print(" - Images:", IMG_OUT_DIR)
print(" - Metadata:", metadata_path)


Found captions files: [PosixPath('/kaggle/input/d/awsaf49/coco-2017-dataset/coco2017/annotations/captions_train2017.json')]
Found train dirs    : [PosixPath('/kaggle/input/d/awsaf49/coco-2017-dataset/coco2017/train2017')]
Using CAPTION_FILE : /kaggle/input/d/awsaf49/coco-2017-dataset/coco2017/annotations/captions_train2017.json
Using TRAIN_IMG_DIR: /kaggle/input/d/awsaf49/coco-2017-dataset/coco2017/train2017
✅ Done. Created subset at: /kaggle/working/sks_style_data
 - Images: /kaggle/working/sks_style_data/images
 - Metadata: /kaggle/working/sks_style_data/metadata.jsonl


# 2. Dataset class cho Textual Inversion (dùng metadata.jsonl)

In [29]:
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image

class COCOTIStyleDataset(Dataset):
    def __init__(self, data_root, tokenizer, size=512):
        self.data_root = Path(data_root)
        self.tokenizer = tokenizer
        self.size = size
        
        self.records = []
        meta_path = self.data_root / "metadata.jsonl"
        with open(meta_path, "r", encoding="utf-8") as f:
            for line in f:
                self.records.append(json.loads(line.strip()))
        
        self.image_transform = transforms.Compose([
            transforms.Resize((size, size), interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])  # [-1, 1]
        ])
    
    def __len__(self):
        return len(self.records)
    
    def __getitem__(self, idx):
        rec = self.records[idx]
        img_path = self.data_root / rec["file_name"]
        caption = rec["text"]
        
        image = Image.open(img_path).convert("RGB")
        image = self.image_transform(image)
        
        tokenized = self.tokenizer(
            caption,
            padding="max_length",
            truncation=True,
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt"
        )
        
        return {
            "pixel_values": image,
            "input_ids": tokenized.input_ids[0],
            "attention_mask": tokenized.attention_mask[0]
        }


# 3. Load Stable Diffusion v1.5 và thêm token <sks_style>

In [30]:
import torch
from diffusers import StableDiffusionPipeline, DDPMScheduler

device = "cuda"

model_id = "runwayml/stable-diffusion-v1-5"

pipe = StableDiffusionPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    safety_checker=None
).to(device)

tokenizer = pipe.tokenizer
text_encoder = pipe.text_encoder
vae = pipe.vae
unet = pipe.unet
noise_scheduler = pipe.scheduler

text_encoder.to(torch.float32)
emb_layer = text_encoder.get_input_embeddings()
emb_layer.weight.data = emb_layer.weight.data.to(torch.float32)

# 1. Thêm token mới
placeholder_token = "<sks_style>"
num_added = tokenizer.add_tokens(placeholder_token)
if num_added == 0:
    print("Token đã tồn tại.")
placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token)

# 2. Resize embedding layer
text_encoder.resize_token_embeddings(len(tokenizer))
embedding_layer = text_encoder.get_input_embeddings()

# 3. Khởi tạo embedding mới từ token "painting" (hoặc "style")
init_token = "painting"
init_token_id = tokenizer.encode(init_token, add_special_tokens=False)[0]
with torch.no_grad():
    token_embeds[placeholder_token_id] = token_embeds[init_token_id].clone()

# 4. Freeze toàn bộ model trừ embedding của placeholder token
vae.requires_grad_(False)
unet.requires_grad_(False)
text_encoder.requires_grad_(False)

embedding_layer.weight.requires_grad_(True)



Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


Parameter containing:
tensor([[-0.0012,  0.0368,  0.0221,  ...,  0.0158,  0.0046, -0.0219],
        [ 0.0152,  0.0262, -0.0132,  ..., -0.0037,  0.0002,  0.0121],
        [-0.0154, -0.0131,  0.0065,  ..., -0.0206, -0.0139, -0.0025],
        ...,
        [ 0.0011,  0.0032,  0.0003,  ..., -0.0018,  0.0003,  0.0019],
        [ 0.0012,  0.0077, -0.0011,  ..., -0.0015,  0.0009,  0.0052],
        [-0.0010, -0.0005,  0.0006,  ..., -0.0002, -0.0002,  0.0006]],
       device='cuda:0', requires_grad=True)

# 4. Vòng lặp train Textual Inversion 

In [31]:
from torch.optim import AdamW
from torch.utils.data import DataLoader
import torch.nn.utils as nn_utils

train_dataset = COCOTIStyleDataset(WORK_DIR, tokenizer, size=512)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)

max_train_steps = 400
learning_rate = 5e-5        
grad_accum = 4

optimizer = AdamW([emb_layer.weight], lr=learning_rate)

global_step = 0
accum_steps = 0
loss_history = []

while global_step < max_train_steps:
    for batch in train_dataloader:
        if global_step >= max_train_steps:
            break

        pixel_values = batch["pixel_values"].to(device, dtype=torch.float16)
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        # 1. image -> latents (half, no grad)
        with torch.no_grad():
            latents = vae.encode(pixel_values).latent_dist.sample()
            latents = latents * 0.18215

        # 2. add noise
        noise = torch.randn_like(latents)
        timesteps = torch.randint(
            0, noise_scheduler.config.num_train_timesteps,
            (latents.shape[0],), device=device
        ).long()
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

        # 3. text encoder ở float32
        encoder_hidden_states = text_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )[0]                      # float32

        # cast ra dtype của UNet (half)
        encoder_hidden_states = encoder_hidden_states.to(unet.dtype)

        # 4. UNet dự đoán noise (half)
        model_pred = unet(
            noisy_latents,
            timesteps,
            encoder_hidden_states
        ).sample

        target = noise
        loss = torch.nn.functional.mse_loss(
            model_pred.float(), target.float(), reduction="mean"
        )

        if torch.isnan(loss) or torch.isinf(loss):
            print("Loss is NaN/Inf, stopping training.")
            break

        loss = loss / grad_accum
        loss.backward()
        accum_steps += 1

        if accum_steps % grad_accum == 0:
            # chỉ cho token <sks_style> cập nhật
            with torch.no_grad():
                grad = emb_layer.weight.grad
                mask = torch.ones(grad.shape[0], dtype=torch.bool, device=grad.device)
                mask[placeholder_token_id] = False
                grad[mask] = 0.0

            # CLIP grad để tránh nổ
            nn_utils.clip_grad_norm_([emb_layer.weight], max_norm=1.0)

            optimizer.step()
            optimizer.zero_grad(set_to_none=True)

            global_step += 1
            accum_steps = 0
            loss_history.append(loss.item() * grad_accum)
            if global_step % 50 == 0:
                print(f"Step {global_step}/{max_train_steps}, loss={loss_history[-1]:.4f}")

    if torch.isnan(loss) or torch.isinf(loss):
        break


Step 50/400, loss=0.0271
Step 150/400, loss=0.1555
Step 200/400, loss=0.0173
Step 250/400, loss=0.0320
Step 300/400, loss=0.2494
Step 350/400, loss=0.2249
Step 400/400, loss=0.1043


In [34]:
with torch.no_grad():
    emb_layer = text_encoder.get_input_embeddings()
    sks_emb = emb_layer.weight[placeholder_token_id]

    # norm trung bình của toàn vocab
    avg_norm = emb_layer.weight.detach().cpu().norm(dim=1).mean()
    print("avg vocab norm:", avg_norm.item())

    # scale embedding <sks_style> lên cùng “tầm” với vocab
    scale = avg_norm / sks_emb.norm()
    sks_emb_scaled = sks_emb * scale

    emb_layer.weight[placeholder_token_id] = sks_emb_scaled

# (tuỳ chọn) lưu lại embedding đã scale
import pathlib, torch
save_dir = pathlib.Path("/kaggle/working/sks_style_embeddings")
save_dir.mkdir(exist_ok=True, parents=True)

torch.save(
    {"placeholder_token": placeholder_token,
     "embedding": sks_emb_scaled.detach().cpu().unsqueeze(0)},
    save_dir / "sks_style_embedding_fp32_scaled.pt"
)


avg vocab norm: 0.3853093981742859


In [42]:
with torch.no_grad():
    emb = emb_layer.weight[placeholder_token_id].detach().cpu()

print("min:", emb.min().item(), "max:", emb.max().item())
print("has NaN:", torch.isnan(emb).any().item())
print("has Inf :", torch.isinf(emb).any().item())
print("norm:", emb.norm().item())


min: -0.08013193309307098 max: 0.06354348361492157
has NaN: False
has Inf : False
norm: 0.3853094279766083


# 5. Lưu checkpoint embedding

In [37]:
import pathlib, torch
save_dir = pathlib.Path("/kaggle/working/sks_style_embeddings")
save_dir.mkdir(exist_ok=True, parents=True)

torch.save(
    {"placeholder_token": placeholder_token,
     "embedding": emb.unsqueeze(0)},
    save_dir / "sks_style_embedding_fp32.pt"
)
