# Textual Inversion Style Transfer Test - Kaggle

Test style transfer từ ảnh COCO với Textual Inversion embedding đã train


## Setup


In [None]:
import os
import torch

if not torch.cuda.is_available():
    print("WARNING: No GPU detected!")
else:
    print(f"GPU: {torch.cuda.get_device_name(0)}")


In [None]:
import os
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
from diffusers import StableDiffusionImg2ImgPipeline
import torch

os.environ["XFORMERS_DISABLED"] = "1"

try:
    import xformers
    USE_XFORMERS = True
except:
    USE_XFORMERS = False

PLACEHOLDER_TOKEN = "<sks_style>"
MIXED_PRECISION = "fp16"
STRENGTH = 0.5
GUIDANCE = 7.5

TI_EMBEDDING_PATHS = {
    "sks_style": "/kaggle/input/your-ti-dataset/sks_style_embeddings/sks_style_embedding_fp32.pt",
}

COCO_IMAGE_PATHS = [
    "/kaggle/input/coco-2017-dataset/coco2017/val2017",
    "/kaggle/input/coco2017/val2017",
]

for style_name, embedding_path in TI_EMBEDDING_PATHS.items():
    if os.path.exists(embedding_path):
        print(f"{style_name}: {embedding_path}")
    else:
        print(f"{style_name}: {embedding_path} (not found)")

OUTPUT_DIR = Path("/kaggle/working/ti_inference_samples")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
print(f"Output: {OUTPUT_DIR}")


In [None]:
coco_image_dir = None
for path in COCO_IMAGE_PATHS:
    if os.path.exists(path):
        coco_image_dir = Path(path)
        print(f"Found COCO images: {coco_image_dir}")
        break

if coco_image_dir is None:
    print("COCO dataset not found")
    coco_image_dir = Path("/kaggle/input/coco2017/val2017")

image_files = list(coco_image_dir.glob("*.jpg"))[:5]
if len(image_files) == 0:
    image_files = list(coco_image_dir.glob("*.png"))[:5]

print(f"Found {len(image_files)} images")
for img_path in image_files:
    print(f"  {img_path.name}")

coco_images = []
for img_path in image_files:
    img = Image.open(img_path).convert("RGB")
    img = img.resize((512, 512))
    coco_images.append(img)
    print(f"Loaded: {img_path.name} ({img.size})")


In [None]:
print("Loading baseline model...")
baseline_pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16 if MIXED_PRECISION == "fp16" else torch.float32,
    safety_checker=None,
    requires_safety_checker=False,
)
if torch.cuda.is_available():
    baseline_pipeline = baseline_pipeline.to("cuda")
    if not USE_XFORMERS:
        baseline_pipeline.enable_attention_slicing()
print("Baseline model loaded")


In [None]:
def load_textual_inversion_embedding(pipeline, embedding_path, placeholder_token):
    embedding_data = torch.load(embedding_path, map_location="cpu")
    
    if isinstance(embedding_data, dict):
        embedding = embedding_data["embedding"]
        token = embedding_data.get("placeholder_token", placeholder_token)
    else:
        embedding = embedding_data
        token = placeholder_token
    
    tokenizer = pipeline.tokenizer
    text_encoder = pipeline.text_encoder
    
    num_added = tokenizer.add_tokens(token)
    if num_added == 0:
        print(f"Token {token} already exists")
    
    placeholder_token_id = tokenizer.convert_tokens_to_ids(token)
    
    text_encoder.resize_token_embeddings(len(tokenizer))
    embedding_layer = text_encoder.get_input_embeddings()
    
    if embedding.dim() == 1:
        embedding = embedding.unsqueeze(0)
    
    with torch.no_grad():
        embedding_layer.weight[placeholder_token_id] = embedding.squeeze(0).to(
            embedding_layer.weight.dtype
        )
    
    print(f"Loaded embedding for token: {token}")
    return token


In [None]:
print("Loading Textual Inversion models...")
ti_pipelines = {}

for style_name, embedding_path in TI_EMBEDDING_PATHS.items():
    if not os.path.exists(embedding_path):
        print(f"Embedding not found: {embedding_path}")
        continue
    
    print(f"\nLoading {style_name} from {embedding_path}...")
    try:
        pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(
            "runwayml/stable-diffusion-v1-5",
            torch_dtype=torch.float16 if MIXED_PRECISION == "fp16" else torch.float32,
            safety_checker=None,
            requires_safety_checker=False,
        )
        
        if torch.cuda.is_available():
            pipeline = pipeline.to("cuda")
            if not USE_XFORMERS:
                pipeline.enable_attention_slicing()
        
        token = load_textual_inversion_embedding(pipeline, embedding_path, PLACEHOLDER_TOKEN)
        ti_pipelines[style_name] = {"pipeline": pipeline, "token": token}
        print(f"{style_name} loaded successfully")
    except Exception as e:
        print(f"Error loading {style_name}: {e}")
        import traceback
        traceback.print_exc()


In [None]:
print("Generating baseline results...")
baseline_results = []
baseline_prompt = "a painting"

for i, coco_img in enumerate(coco_images):
    print(f"  Baseline {i+1}/{len(coco_images)}")
    result = baseline_pipeline(
        prompt=baseline_prompt,
        image=coco_img,
        strength=STRENGTH,
        num_inference_steps=50,
        guidance_scale=GUIDANCE,
    ).images[0]
    baseline_results.append(result)
    result.save(OUTPUT_DIR / f"baseline_transfer_{i+1}.png")

print(f"Saved {len(baseline_results)} baseline transfers")


In [None]:
all_results = {}

for style_name, style_data in ti_pipelines.items():
    pipeline = style_data["pipeline"]
    token = style_data["token"]
    
    print(f"\nTransferring {style_name}...")
    style_results = []
    style_prompt = f"a painting in {token} style"
    
    for i, coco_img in enumerate(coco_images):
        print(f"  Image {i+1}/{len(coco_images)}")
        result = pipeline(
            prompt=style_prompt,
            image=coco_img,
            strength=STRENGTH,
            num_inference_steps=50,
            guidance_scale=GUIDANCE,
        ).images[0]
        style_results.append(result)
        result.save(OUTPUT_DIR / f"{style_name}_transfer_{i+1}.png")
    
    all_results[style_name] = style_results
    print(f"Saved {len(style_results)} transfers")


In [None]:
num_images = len(coco_images)
num_styles = len(all_results)

fig, axes = plt.subplots(
    num_styles + 2,
    num_images,
    figsize=(4 * num_images, 4 * (num_styles + 2))
)

for col in range(num_images):
    axes[0, col].imshow(coco_images[col])
    axes[0, col].set_title(f"Original\n{col+1}", fontsize=9)
    axes[0, col].axis('off')
    
    axes[1, col].imshow(baseline_results[col])
    axes[1, col].set_title(f"Baseline\n{col+1}", fontsize=9)
    axes[1, col].axis('off')

for row, style_name in enumerate(all_results.keys(), 2):
    for col in range(num_images):
        axes[row, col].imshow(all_results[style_name][col])
        axes[row, col].set_title(f"{style_name}\n{col+1}", fontsize=9)
        axes[row, col].axis('off')

plt.tight_layout()
plt.savefig(OUTPUT_DIR / "style_transfer_comparison.png", dpi=150, bbox_inches='tight')
plt.show()
print(f"Saved: {OUTPUT_DIR / 'style_transfer_comparison.png'}")
