# Simple Icons LoRA 训练

在 Google Colab 上训练 SDXL LoRA 模型

## 第一步：检查 GPU

In [None]:
!nvidia-smi

import torch
print(f"CUDA 可用: {torch.cuda.is_available()}")
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else '无'}")

## 第二步：安装依赖

In [None]:
!pip install -q ai-toolkit accelerate bitsandbytes
!pip install -q transformers diffusers peft
!pip install -q dataslots wget

## 第三步：挂载 Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# 创建工作目录
!mkdir -p /content/drive/MyDrive/lora-training

## 第四步：准备数据

In [None]:
# 下载 Simple Icons 数据
!mkdir -p /content/dataset
%cd /content/dataset

import requests
import json

url = 'https://raw.githubusercontent.com/simple-icons/simple-icons/develop/data/simple-icons.json'
response = requests.get(url)
icons = response.json()

print(f'总图标数: {len(icons)}')

# 只取前 200 个
icons = icons[:200]
print(f'将下载: {len(icons)} 个')

In [None]:
# 下载 SVG
import os
from pathlib import Path
from tqdm import tqdm

output_dir = '/content/dataset/simple_icons'
Path(output_dir).mkdir(parents=True, exist_ok=True)

for icon in tqdm(icons):
    title = icon.get('title', '')
    slug = icon.get('slug', '') or title.lower().replace(' ', '').replace('-', '')
    
    if not slug:
        continue
    
    svg_url = f'https://raw.githubusercontent.com/simple-icons/simple-icons/develop/icons/{slug}.svg'
    
    try:
        resp = requests.get(svg_url, timeout=10)
        if resp.status_code == 200:
            with open(f'{output_dir}/{slug}.svg', 'w') as f:
                f.write(resp.text)
            
            prompts = [
                f'minimalist tech logo of {title}, geometric shape, flat design, single color',
                f'{title} icon, simple geometric logo, minimal design'
            ]
            with open(f'{output_dir}/{slug}.txt', 'w') as f:
                f.write('\n'.join(prompts))
    except:
        pass

In [None]:
# SVG 转 PNG
!pip install -q cairosvg

import cairosvg

for svg_file in Path(output_dir).glob('*.svg'):
    png_path = svg_file.with_suffix('.png')
    if not png_path.exists():
        cairosvg.svg2png(url=str(svg_file), write_to=str(png_path), output_width=512, output_height=512)

## 第五步：开始训练

In [None]:
# 创建配置文件
config = '''model:
  type: sdxl
  pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0
  dtype: bf16

dataset:
  data_dir: /content/dataset/simple_icons
  resolution: 512

lora:
  r: 32
  lora_alpha: 16

training:
  num_train_epochs: 5
  batch_size: 2
  learning_rate: 1e-4
  gradient_accumulation_steps: 4

output:
  output_dir: /content/drive/MyDrive/lora-training/output
  model_name: simple-icons-v1'''

with open('/content/config.yaml', 'w') as f:
    f.write(config)

print('配置完成')

In [None]:
# 开始训练！
!python -m ai_toolkit.train --config /content/config.yaml

## 第六步：测试

In [None]:
from diffusers import DiffusionPipeline
import torch

pipe = DiffusionPipeline.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=torch.float16)
pipe = pipe.to('cuda')

pipe.load_lora_weights('/content/drive/MyDrive/lora-training/output/simple-icons-v1')

image = pipe('minimalist tech logo of React, geometric shape', num_inference_steps=30).images[0]
image