# Simple Icons LoRA 训练

使用 Hugging Face 官方 diffusers 库训练 SDXL LoRA

## 第一步：检查 GPU

In [None]:
!nvidia-smi

import torch
print(f"CUDA 可用: {torch.cuda.is_available()}")
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else '无'}")

## 第二步：安装依赖

In [None]:
# 安装 diffusers 开发版（训练脚本需要）
!pip install -q git+https://github.com/huggingface/diffusers.git
!pip install -q transformers accelerate peft bitsandbytes
!pip install -q torch torchvision safetensors
!pip install -q Pillow tqdm

## 第三步：挂载 Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# 创建工作目录
!mkdir -p /content/drive/MyDrive/lora-training/output

## 第四步：下载训练脚本

In [None]:
# 下载 Hugging Face 官方 SDXL LoRA 训练脚本
!wget -q https://github.com/huggingface/diffusers/raw/main/examples/text_to_image/train_text_to_image_lora_sdxl.py -O /content/train_lora.py
print('训练脚本下载完成')

## 第五步：准备数据（仅需运行一次）

In [None]:
# 检查是否已有数据
from pathlib import Path

output_dir = Path('/content/dataset/simple_icons')
png_count = len(list(output_dir.glob('*.png'))) if output_dir.exists() else 0

if png_count >= 50:
    print(f'数据已存在，共 {png_count} 个 PNG 文件，跳过下载')
    icons = []
else:
    print('数据不存在或不足，开始下载...')
    import requests
    import json
    
    !mkdir -p /content/dataset
    %cd /content/dataset
    
    url = 'https://raw.githubusercontent.com/simple-icons/simple-icons/develop/data/simple-icons.json'
    response = requests.get(url)
    icons = response.json()[:200]
    print(f'将下载: {len(icons)} 个图标')

In [None]:
# 下载 SVG 并创建 metadata.jsonl（仅在需要时执行）
if len(icons) > 0:
    import os
    import json
    from tqdm import tqdm
    
    output_dir = Path('/content/dataset/simple_icons')
    output_dir.mkdir(parents=True, exist_ok=True)
    
    metadata = []
    
    for icon in tqdm(icons):
        title = icon.get('title', '')
        slug = icon.get('slug', '') or title.lower().replace(' ', '').replace('-', '')
        
        if not slug:
            continue
        
        svg_url = f'https://raw.githubusercontent.com/simple-icons/simple-icons/develop/icons/{slug}.svg'
        
        try:
            resp = requests.get(svg_url, timeout=10)
            if resp.status_code == 200:
                # 保存 SVG
                with open(f'{output_dir}/{slug}.svg', 'w') as f:
                    f.write(resp.text)
                
                # 准备 metadata（之后转为 PNG 文件名）
                prompt = f'minimalist tech logo of {title}, geometric shape, flat design, single color'
                metadata.append({
                    'file_name': f'{slug}.png',
                    'text': prompt
                })
        except:
            pass
    
    # 保存 metadata.jsonl
    with open(f'{output_dir}/metadata.jsonl', 'w') as f:
        for item in metadata:
            f.write(json.dumps(item) + '\n')
    
    print(f'下载完成，共 {len(list(output_dir.glob("*.svg")))} 个 SVG')
    print(f'metadata.jsonl 包含 {len(metadata)} 条记录')
else:
    print('跳过 SVG 下载')

In [None]:
# SVG 转 PNG（仅在需要时执行）
if len(icons) > 0:
    !pip install -q cairosvg
    import cairosvg
    
    for svg_file in output_dir.glob('*.svg'):
        png_path = svg_file.with_suffix('.png')
        if not png_path.exists():
            cairosvg.svg2png(url=str(svg_file), write_to=str(png_path), output_width=512, output_height=512)
    
    print(f'转换完成，共 {len(list(output_dir.glob("*.png")))} 个 PNG')
else:
    print('跳过 PNG 转换')

## 第六步：开始训练

In [None]:
# 开始训练（针对 Colab T4 GPU 优化）
# 首先检查/创建 metadata.jsonl
import json
from pathlib import Path
import os

output_dir = Path('/content/dataset/simple_icons')
metadata_file = output_dir / 'metadata.jsonl'

if not metadata_file.exists():
    print("创建 metadata.jsonl...")
    metadata = []
    for txt_file in output_dir.glob('*.txt'):
        png_name = txt_file.with_suffix('.png').name
        with open(txt_file) as f:
            text = f.read().strip().split('\n')[0]
        metadata.append({'file_name': png_name, 'text': text})
    
    with open(metadata_file, 'w') as f:
        for item in metadata:
            f.write(json.dumps(item) + '\n')
    print(f"创建了 metadata.jsonl，共 {len(metadata)} 条记录")

# 检查数据集格式
from datasets import load_dataset

print("检查数据集格式...")
if os.path.exists('/content/dataset/simple_icons/metadata.jsonl'):
    dataset = load_dataset('imagefolder', data_files={'train': '/content/dataset/simple_icons/**'}, split='train')
    print(f"数据集列名: {dataset.column_names}")
    print(f"数据集大小: {len(dataset)}")
else:
    print("错误: metadata.jsonl 不存在!")

# 开始训练
!accelerate launch --mixed_precision="fp16" /content/train_lora.py \
  --pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \
  --train_data_dir=/content/dataset/simple_icons \
  --caption_column="text" \
  --resolution=512 \
  --center_crop \
  --random_flip \
  --train_batch_size=1 \
  --gradient_accumulation_steps=8 \
  --gradient_checkpointing \
  --use_8bit_adam \
  --num_train_epochs=5 \
  --learning_rate=1e-4 \
  --max_grad_norm=1.0 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=100 \
  --output_dir=/content/drive/MyDrive/lora-training/output \
  --checkpointing_steps=500 \
  --rank=16 \
  --seed=42

## 第七步：测试

In [None]:
from diffusers import DiffusionPipeline
import torch

# 加载基础模型
pipe = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16
)
pipe = pipe.to("cuda")

# 加载 LoRA
pipe.load_lora_weights("/content/drive/MyDrive/lora-training/output/pytorch_lora_weights.safetensors")

# 生成图片
prompt = "minimalist tech logo of React, geometric shape, flat design, single color"
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]

# 显示图片
image