# LoRA Style Transfer Test - Kaggle

Test style transfer từ ảnh COCO với 5 style LoRA đã train: Action_painting, Analytical_Cubism, Contemporary_Realism, New_Realism, Synthetic_Cubism


## Setup


In [None]:
import os
import torch

if not torch.cuda.is_available():
    print("WARNING: No GPU detected!")
else:
    print(f"GPU: {torch.cuda.get_device_name(0)}")


In [None]:
import os
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
from diffusers import StableDiffusionImg2ImgPipeline
import torch

os.environ["XFORMERS_DISABLED"] = "1"

try:
    import xformers
    USE_XFORMERS = True
except:
    USE_XFORMERS = False

STYLES = [
    "Action_painting",
    "Analytical_Cubism",
    "Contemporary_Realism",
    "New_Realism",
    "Synthetic_Cubism",
]
MIXED_PRECISION = "fp16"
STRENGTH = 0.5
GUIDANCE = 7.5

LORA_DATASET_MAP = {
    "Action_painting": "/kaggle/input/dts-lora-actionpainting",
    "Analytical_Cubism": "/kaggle/input/dts-lora-analyticalcubism",
    "Contemporary_Realism": "/kaggle/input/dts-lora-contemporaryrealism",
    "New_Realism": "/kaggle/input/dts-lora-newrealism",
    "Synthetic_Cubism": "/kaggle/input/dts-lora-syntheticcubism",
}

COCO_IMAGE_PATHS = [
    "/kaggle/input/coco-2017-dataset/coco2017/val2017",
    "/kaggle/input/coco2017/val2017",
]

for style_name, dataset_path in LORA_DATASET_MAP.items():
    if os.path.exists(dataset_path):
        print(f"{style_name}: {dataset_path}")
    else:
        print(f"{style_name}: {dataset_path} (not found)")

OUTPUT_DIR = Path("/kaggle/working/lora_inference_samples")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
print(f"Output: {OUTPUT_DIR}")


In [None]:
coco_image_dir = None
for path in COCO_IMAGE_PATHS:
    if os.path.exists(path):
        coco_image_dir = Path(path)
        print(f"Found COCO images: {coco_image_dir}")
        break

if coco_image_dir is None:
    print("COCO dataset not found")
    coco_image_dir = Path("/kaggle/input/coco2017/val2017")

image_files = list(coco_image_dir.glob("*.jpg"))[:5]
if len(image_files) == 0:
    image_files = list(coco_image_dir.glob("*.png"))[:5]

print(f"Found {len(image_files)} images")
for img_path in image_files:
    print(f"  {img_path.name}")

coco_images = []
for img_path in image_files:
    img = Image.open(img_path).convert("RGB")
    img = img.resize((512, 512))
    coco_images.append(img)
    print(f"Loaded: {img_path.name} ({img.size})")


In [None]:
print("Loading baseline model...")
baseline_pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16 if MIXED_PRECISION == "fp16" else torch.float32,
    safety_checker=None,
    requires_safety_checker=False,
)
if torch.cuda.is_available():
    baseline_pipeline = baseline_pipeline.to("cuda")
    if not USE_XFORMERS:
        baseline_pipeline.enable_attention_slicing()
print("Baseline model loaded")


In [None]:
print("Loading LoRA models...")
style_pipelines = {}

for style_name in STYLES:
    if style_name not in LORA_DATASET_MAP:
        print(f"No dataset path mapped for {style_name}")
        continue
    
    dataset_base = Path(LORA_DATASET_MAP[style_name])
    lora_dir = dataset_base / "lora_models" / style_name
    lora_weights = lora_dir / "pytorch_lora_weights.safetensors"
    
    if not dataset_base.exists():
        print(f"Dataset not found: {dataset_base}")
        continue
    
    if not lora_weights.exists():
        print(f"LoRA weights not found: {lora_weights}")
        if lora_dir.parent.exists():
            print(f"Available models: {[d.name for d in lora_dir.parent.iterdir() if d.is_dir()]}")
        continue
    
    print(f"\nLoading {style_name} from {lora_dir}...")
    try:
        pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(
            "runwayml/stable-diffusion-v1-5",
            torch_dtype=torch.float16 if MIXED_PRECISION == "fp16" else torch.float32,
            safety_checker=None,
            requires_safety_checker=False,
        )
        pipeline.load_lora_weights(str(lora_dir))
        if torch.cuda.is_available():
            pipeline = pipeline.to("cuda")
            if not USE_XFORMERS:
                pipeline.enable_attention_slicing()
        style_pipelines[style_name] = pipeline
        print(f"{style_name} loaded successfully")
    except Exception as e:
        print(f"Error loading {style_name}: {e}")


In [None]:
print("\nPerforming style transfer...")
baseline_prompt = "a realistic depiction of the same scene"
style_prompts = {
    "Action_painting": "an action painting with energetic brush strokes of the scene",
    "Analytical_Cubism": "an analytical cubism interpretation of the scene",
    "Contemporary_Realism": "a contemporary realism painting of the scene",
    "New_Realism": "a new realism painting of the scene",
    "Synthetic_Cubism": "a synthetic cubism painting of the scene",
}

baseline_results = []
for i, coco_img in enumerate(coco_images):
    print(f"\nProcessing image {i+1}/{len(coco_images)}...")
    result = baseline_pipeline(
        prompt=baseline_prompt,
        image=coco_img,
        strength=STRENGTH,
        num_inference_steps=50,
        guidance_scale=GUIDANCE,
    ).images[0]
    baseline_results.append(result)
    result.save(OUTPUT_DIR / f"baseline_transfer_{i+1}.png")
print(f"Saved {len(baseline_results)} baseline transfers")


In [None]:
all_results = {}

for style_name, pipeline in style_pipelines.items():
    print(f"\nTransferring {style_name}...")
    style_results = []
    style_prompt = style_prompts.get(style_name, f"a {style_name.replace('_', ' ').lower()} painting of the scene")
    for i, coco_img in enumerate(coco_images):
        print(f"  Image {i+1}/{len(coco_images)}")
        result = pipeline(
            prompt=style_prompt,
            image=coco_img,
            strength=STRENGTH,
            num_inference_steps=50,
            guidance_scale=GUIDANCE,
        ).images[0]
        style_results.append(result)
        result.save(OUTPUT_DIR / f"{style_name}_transfer_{i+1}.png")
    all_results[style_name] = style_results
    print(f"Saved {len(style_results)} transfers")


In [None]:
num_images = len(coco_images)
fig, axes = plt.subplots(
    len(STYLES) + 2,
    num_images,
    figsize=(4 * num_images, 4 * (len(STYLES) + 2))
)

for col in range(num_images):
    axes[0, col].imshow(coco_images[col])
    axes[0, col].set_title(f"Original\n{col+1}", fontsize=9)
    axes[0, col].axis('off')
    
    axes[1, col].imshow(baseline_results[col])
    axes[1, col].set_title(f"Baseline\n{col+1}", fontsize=9)
    axes[1, col].axis('off')

for row, style_name in enumerate(STYLES, 2):
    if style_name in all_results:
        for col in range(num_images):
            axes[row, col].imshow(all_results[style_name][col])
            axes[row, col].set_title(f"{style_name}\n{col+1}", fontsize=9)
            axes[row, col].axis('off')

plt.tight_layout()
plt.savefig(OUTPUT_DIR / "style_transfer_comparison.png", dpi=150, bbox_inches='tight')
plt.show()
print(f"Saved: {OUTPUT_DIR / 'style_transfer_comparison.png'}")
