<a href="https://colab.research.google.com/github/automubashir/text-to-icon/blob/main/icons_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# 🔥 FULL AI ICON GENERATOR: Scraping → Training → Gradio UI (PNG/SVG + Icon Font Support)
# Run in Google Colab

import os
import torch
import requests
import subprocess
import xml.etree.ElementTree as ET
from PIL import Image
from io import BytesIO
from svglib.svglib import svg2rlg # Corrected import
from reportlab.graphics import renderPM
from pydub import AudioSegment
import gradio as gr
from diffusers import StableDiffusionPipeline, DDIMScheduler
from diffusers import UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer
from peft import LoraConfig, get_peft_model
import shutil
import time
import zipfile
import json
import re

# ================================
# 1. CONFIGURATION
# ================================

# Set your project paths
PROJECT_DIR = "/content/icon-generator"
DATASET_DIR = os.path.join(PROJECT_DIR, "dataset")
SVG_DIR = os.path.join(DATASET_DIR, "svg")
PNG_DIR = os.path.join(DATASET_DIR, "png")
CAPTIONS_FILE = os.path.join(DATASET_DIR, "captions.jsonl")
MODEL_OUTPUT_DIR = os.path.join(PROJECT_DIR, "trained_model")
GRADIO_SHARE = True  # Share UI publicly

# Create directories
os.makedirs(SVG_DIR, exist_ok=True)
os.makedirs(PNG_DIR, exist_ok=True)
os.makedirs(DATASET_DIR, exist_ok=True)
os.makedirs(MODEL_OUTPUT_DIR, exist_ok=True)

# Model settings
BASE_MODEL = "runwayml/stable-diffusion-1-5"
RESOLUTION = 128
TRAIN_BATCH_SIZE = 4
NUM_EPOCHS = 10
LEARNING_RATE = 1e-4
LORA_RANK = 32

# Icon font settings
FONT_OUTPUT_DIR = os.path.join(PROJECT_DIR, "icon_font")
os.makedirs(FONT_OUTPUT_DIR, exist_ok=True)

# ================================
# 2. DOWNLOAD & SCRAPE ICONS (Material Design Icons)
# ================================

def download_material_icons():
    print("🔽 Downloading Material Design Icons...")
    if not os.path.exists("/content/material-design-icons"):
        subprocess.run([
            "git", "clone", "--depth=1",
            "https://github.com/google/material-design-icons.git"
        ], check=True)
    print("✅ Downloaded!")
    return "/content/material-design-icons/src"

ICON_SRC_DIR = download_material_icons()

# ================================
# 3. PROCESS SVG → PNG + CAPTIONS
# ================================

def clean_name(name):
    return re.sub(r'[-_]+', ' ', name).title()

def process_icons():
    print("🔧 Processing SVGs to PNG and generating captions...")
    captions = []
    count = 0

    for root, _, files in os.walk(ICON_SRC_DIR):
        for file in files:
            if file.endswith(".svg"):
                svg_path = os.path.join(root, file)
                icon_name = clean_name(file.replace(".svg", ""))
                category = os.path.basename(root)

                # Skip if too many processed
                if count > 500:  # limit for demo
                    break

                try:
                    # Read SVG
                    with open(svg_path, "r") as f:
                        svg_content = f.read()

                    # Save SVG
                    svg_dest = os.path.join(SVG_DIR, f"{count:04d}.svg")
                    with open(svg_dest, "w") as f:
                        f.write(svg_content)

                    # Convert SVG to PNG
                    # Parse SVG content into an XML element tree
                    drawing = svg2rlg(svg_path)
                    img = renderPM.drawToPIL(drawing)
                    img = img.convert("RGBA")
                    img = img.resize((RESOLUTION, RESOLUTION), Image.LANCZOS)

                    # Extract black/dark parts (for clean icons)
                    r, g, b, a = img.split()
                    bg = Image.new("RGBA", img.size, (255, 255, 255))
                    bg.paste(img, mask=a)
                    bg = bg.convert("L")  # grayscale
                    bg = bg.point(lambda x: 0 if x < 200 else 255, mode='1')
                    bg = bg.convert("L").point(lambda x: 255 - x)  # invert

                    png_img = Image.new("RGBA", (RESOLUTION, RESOLUTION), (0, 0, 0, 0))
                    color_layer = Image.new("RGBA", (RESOLUTION, RESOLUTION), (0, 0, 0, 255))
                    png_img.paste(color_layer, mask=Image.fromarray(bg))

                    png_dest = os.path.join(PNG_DIR, f"{count:04d}.png")
                    png_img.save(png_dest, "PNG")

                    # Create prompt
                    prompt = f"{icon_name} icon, {category}, flat vector style"
                    captions.append({"file": f"{count:04d}.png", "text": prompt})
                    count += 1
                except Exception as e:
                    print(f"Error processing {file}: {e}")
                    continue

    # Save captions
    with open(CAPTIONS_FILE, "w") as f:
        for item in captions:
            f.write(json.dumps(item) + "\n")

    print(f"✅ Processed {len(captions)} icons!")

process_icons()

# ================================
# 4. TRAIN LoRA MODEL ON ICONS
# ================================

def train_lora():
    print("🏋️ Starting LoRA training...")

    from diffusers import AutoencoderKL, UNet2DConditionModel
    from diffusers.optimization import get_scheduler
    from torch.utils.data import Dataset, DataLoader
    import accelerate
    from tqdm import tqdm
    import numpy as np # Moved import inside the function

    class IconDataset(Dataset):
        def __init__(self, data_file, tokenizer, img_dir):
            self.items = []
            with open(data_file, 'r') as f:
                for line in f:
                    self.items.append(json.loads(line))
            self.tokenizer = tokenizer
            self.img_dir = img_dir

        def __len__(self):
            return len(self.items)

        def __getitem__(self, idx):
            item = self.items[idx]
            image = Image.open(os.path.join(self.img_dir, item['file'])).convert("RGB")
            image = image.resize((RESOLUTION, RESOLUTION))
            image = torch.tensor(np.array(image)).permute(2, 0, 1).float() / 127.5 - 1.0
            text = self.tokenizer(
                item['text'],
                max_length=77,
                padding="max_length",
                truncation=True,
                return_tensors="pt"
            ).input_ids[0]
            return {"input_ids": text, "pixel_values": image}


    accelerator = accelerate.Accelerator(mixed_precision="fp16", gradient_accumulation_steps=1)
    weight_dtype = torch.float16

    # Load models
    tokenizer = CLIPTokenizer.from_pretrained(BASE_MODEL, subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained(BASE_MODEL, subfolder="text_encoder").to(accelerator.device, dtype=weight_dtype)
    vae = AutoencoderKL.from_pretrained(BASE_MODEL, subfolder="vae").to(accelerator.device, dtype=weight_dtype)
    unet = UNet2DConditionModel.from_pretrained(BASE_MODEL, subfolder="unet").to(accelerator.device, dtype=weight_dtype)

    # LoRA
    lora_config = LoraConfig(
        r=LORA_RANK,
        lora_alpha=16,
        target_modules=["to_q", "to_v", "to_k", "to_out.0"],
        lora_dropout=0.0,
        bias="none",
        modules_to_save=[],
    )
    unet = get_peft_model(unet, lora_config)

    # Dataset
    dataset = IconDataset(CAPTIONS_FILE, tokenizer, PNG_DIR)
    dataloader = DataLoader(dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True)

    # Optimizer
    optimizer = torch.optim.AdamW(unet.parameters(), lr=LEARNING_RATE)

    # Prepare
    unet, optimizer, dataloader = accelerator.prepare(unet, optimizer, dataloader)

    # Training loop
    total_steps = 0
    progress_bar = tqdm(range(len(dataloader) * NUM_EPOCHS), desc="Training")
    for epoch in range(NUM_EPOCHS):
        unet.train()
        for batch in dataloader:
            with accelerator.accumulate(unet):
                latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample() * 0.18215
                noise = torch.randn_like(latents)
                bsz = latents.shape[0]
                timesteps = torch.randint(0, 1000, (bsz,), device=latents.device)
                noisy_latents = noise + torch.sqrt(timesteps.float().view(-1,1,1,1)/1000) * latents

                encoder_hidden_states = text_encoder(batch["input_ids"])[0]
                noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
                loss = torch.nn.functional.mse_loss(noise_pred, noise)

                accelerator.backward(loss)
                optimizer.step()
                optimizer.zero_grad()

                progress_bar.update(1)
                total_steps += 1

    # Save model
    unet.save_pretrained(MODEL_OUTPUT_DIR)
    print(f"✅ Model saved to {MODEL_OUTPUT_DIR}")

# Uncomment to train (takes 10-20 mins on Colab)
# train_lora()

# ================================
# 5. GRADIO UI FOR GENERATION
# ================================

def load_trained_pipeline():
    if os.path.exists(MODEL_OUTPUT_DIR):
        print("🔁 Loading fine-tuned model...")
        pipe = StableDiffusionPipeline.from_pretrained(
            BASE_MODEL,
            torch_dtype=torch.float16
        )
        pipe.unet.load_adapter(MODEL_OUTPUT_DIR)
    else:
        print("🆕 Using base model (not fine-tuned yet)...")
        pipe = StableDiffusionPipeline.from_pretrained(BASE_MODEL, torch_dtype=torch.float16)
    pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
    pipe = pipe.to("cuda")
    return pipe

pipe = load_trained_pipeline()

def generate_icon(prompt, output_format="png", negative_prompt="text, numbers, complex background"):
    # Generate image
    output = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        width=RESOLUTION,
        height=RESOLUTION,
        num_inference_steps=30,
        guidance_scale=7.0
    )
    image = output.images[0]

    # Save as PNG
    png_path = os.path.join(PROJECT_DIR, "output_icon.png")
    image.save(png_path)

    # Convert to SVG (simplified: using threshold + contour)
    svg_path = os.path.join(PROJECT_DIR, "output_icon.svg")
    img = image.convert("L")
    img = img.point(lambda x: 0 if x < 128 else 255, mode='1')

    # Simple SVG (black shape)
    w, h = img.size
    pixels = list(img.getdata())
    with open(svg_path, "w") as f:
        f.write(f'<svg xmlns="http://www.w3.org/2000/svg" width="{w}" height="{h}" viewBox="0 0 {w} {h}">\n')
        f.write('<path d="')
        for y in range(h):
            for x in range(w):
                if pixels[y * w + x] == 0:
                    f.write(f"M{x},{y}h1v1h-1z")
        f.write('" fill="black"/>\n</svg>')

    if output_format == "png":
        return png_path
    elif output_format == "svg":
        return svg_path

# Launch Gradio
demo = gr.Interface(
    fn=generate_icon,
    inputs=[
        gr.Textbox(value="home icon, flat design", label="Prompt"),
        gr.Radio(["png", "svg"], value="png", label="Output Format"),
    ],
    outputs=gr.Image(type="filepath", label="Generated Icon"),
    title="🎨 AI Icon & Icon Font Generator",
    description="Generate icons from text. Outputs PNG or simplified SVG.",
    allow_flagging="never"
)

# ================================
# 6. ICON FONT GENERATION (via SVG → Font)
# ================================

def create_icon_font():
    try:
        from fontTools.ttLib import TTFont
        import fontforge
    except:
        !apt-get update && apt-get install -y fontforge python3-fontforge
        import fontforge

    font = fontforge.font()
    font.fontname = "AIIconFont"
    font.fullname = "AI Generated Icon Font"
    font.familyname = "AIIconFont"

    codepoint = 0xE001
    for svg_file in os.listdir(SVG_DIR)[:100]:  # limit to 100
        file_path = os.path.join(SVG_DIR, svg_file)
        glyph = font.createChar(codepoint)
        glyph.importOutlines(file_path)
        glyph.left_side_bearing = 50
        glyph.right_side_bearing = 50
        codepoint += 1

    font_path = os.path.join(FONT_OUTPUT_DIR, "AIIconFont.ttf")
    font.generate(font_path)
    print(f"✅ Icon font saved to {font_path}")
    return font_path

print("✅ Setup complete. Use the button below to generate an icon font.")
gr.Interface(lambda: create_icon_font(), inputs=None, outputs="file", title="Generate Icon Font").launch(share=False)

# ================================
# 7. LAUNCH GRADIO UI
# ================================

print("🎉 Starting Gradio UI...")
demo.launch(share=GRADIO_SHARE)

🔽 Downloading Material Design Icons...
