# Simple Icons LoRA 训练

使用 Hugging Face 官方 diffusers 库训练 SDXL LoRA

## 第一步：检查 GPU

In [None]:
!nvidia-smi

import torch
print(f"CUDA 可用: {torch.cuda.is_available()}")
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else '无'}")

## 第二步：安装依赖

In [None]:
!pip install -q diffusers transformers accelerate peft bitsandbytes
!pip install -q torch torchvision safetensors
!pip install -q Pillow tqdm

## 第三步：挂载 Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# 创建工作目录
!mkdir -p /content/drive/MyDrive/lora-training/output

## 第四步：下载训练脚本

In [None]:
# 下载 Hugging Face 官方 SDXL LoRA 训练脚本
!wget -q https://github.com/huggingface/diffusers/raw/main/examples/text_to_image/train_text_to_image_lora_sdxl.py -O /content/train_lora.py
print('训练脚本下载完成')

## 第五步：准备数据

In [None]:
# 下载 Simple Icons 数据
!mkdir -p /content/dataset
%cd /content/dataset

import requests
import json

url = 'https://raw.githubusercontent.com/simple-icons/simple-icons/develop/data/simple-icons.json'
response = requests.get(url)
icons = response.json()

print(f'总图标数: {len(icons)}')

# 只取前 200 个
icons = icons[:200]
print(f'将下载: {len(icons)} 个')

In [None]:
# 下载 SVG
import os
from pathlib import Path
from tqdm import tqdm

output_dir = '/content/dataset/simple_icons'
Path(output_dir).mkdir(parents=True, exist_ok=True)

for icon in tqdm(icons):
    title = icon.get('title', '')
    slug = icon.get('slug', '') or title.lower().replace(' ', '').replace('-', '')
    
    if not slug:
        continue
    
    svg_url = f'https://raw.githubusercontent.com/simple-icons/simple-icons/develop/icons/{slug}.svg'
    
    try:
        resp = requests.get(svg_url, timeout=10)
        if resp.status_code == 200:
            with open(f'{output_dir}/{slug}.svg', 'w') as f:
                f.write(resp.text)
            
            prompt = f'minimalist tech logo of {title}, geometric shape, flat design, single color'
            with open(f'{output_dir}/{slug}.txt', 'w') as f:
                f.write(prompt)
    except:
        pass

print(f'下载完成，共 {len(list(Path(output_dir).glob("*.svg")))} 个 SVG')

In [None]:
# SVG 转 PNG
!pip install -q cairosvg

import cairosvg

for svg_file in Path(output_dir).glob('*.svg'):
    png_path = svg_file.with_suffix('.png')
    if not png_path.exists():
        cairosvg.svg2png(url=str(svg_file), write_to=str(png_path), output_width=512, output_height=512)

print(f'转换完成，共 {len(list(Path(output_dir).glob("*.png")))} 个 PNG')

## 第六步：开始训练

In [None]:
# 开始训练
!accelerate launch --mixed_precision="bf16" /content/train_lora.py \
  --pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \
  --train_data_dir=/content/dataset/simple_icons \
  --resolution=512 \
  --center_crop \
  --random_flip \
  --train_batch_size=2 \
  --gradient_accumulation_steps=4 \
  --num_train_epochs=5 \
  --learning_rate=1e-4 \
  --max_grad_norm=1.0 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=100 \
  --output_dir=/content/drive/MyDrive/lora-training/output \
  --save_steps=500 \
  --logging_steps=10 \
  --rank=32 \
  --seed=42

## 第七步：测试

In [None]:
from diffusers import DiffusionPipeline
import torch

# 加载基础模型
pipe = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16
)
pipe = pipe.to("cuda")

# 加载 LoRA
pipe.load_lora_weights("/content/drive/MyDrive/lora-training/output/pytorch_lora_weights.safetensors")

# 生成图片
prompt = "minimalist tech logo of React, geometric shape, flat design, single color"
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]

# 显示图片
image