# svgllm SFT on Colab (LLaVA 1.5 7B)

This notebook mounts Google Drive, installs dependencies, clones the repo, and runs a small SFT job using data stored on Drive.


In [None]:
# 1) Runtime: set GPU to A100 in Colab UI: Runtime > Change runtime type > GPU > A100
from google.colab import drive  # type: ignore
import os
from glob import glob

drive.mount('/content/drive', force_remount=True)

COLAB_ROOT = '/content'
REPO_DIR = f'{COLAB_ROOT}/svgllm'
DRIVE_ROOT = f'{COLAB_ROOT}/drive/MyDrive'
DATA_DIR = f'{DRIVE_ROOT}/WikipediaSVG'  # customize if needed
OUTPUT_DIR = f'{DRIVE_ROOT}/runs/sft-llava'

print('Drive mounted. Data dir:', DATA_DIR)

In [None]:
# 2) Install dependencies and update repo (clone if missing)
import os
import subprocess
from pathlib import Path

# Clone if absent; otherwise fetch the latest from main
repo_url = "https://github.com/JacobAsmuth/svgllm"
if not Path(REPO_DIR).exists():
  subprocess.run(["pip", "install", "-U", "uv"], check=True)
  subprocess.run(["git", "clone", repo_url, REPO_DIR], check=True)
else:
  subprocess.run(["git", "-C", REPO_DIR, "fetch", "origin"], check=True)
  subprocess.run(["git", "-C", REPO_DIR, "checkout", "main"], check=True)
  subprocess.run(["git", "-C", REPO_DIR, "reset", "--hard", "origin/main"], check=True)

os.chdir(REPO_DIR)
# Bring in dependencies via uv
subprocess.run(["uv", "sync"], check=True)

In [None]:
# 3) Quick dataset sanity check: render preview grid via uv
!uv run python -m scripts.sanity_check --data-dir $DATA_DIR --limit 8 --size 256 --out /content/svg_preview.png
from IPython.display import Image as DispImage, display
try:
  display(DispImage('/content/svg_preview.png'))
except Exception as e:
  print('Preview not available:', e)


In [None]:
# 4) Dry-run collator shapes
!uv run python -m scripts.train_sft_llava --dry-run --data-dir $DATA_DIR --max-items 2 --batch-size 1


In [None]:
# 5) Train (small)
!uv run python -m scripts.train_sft_llava \
  --data-dir $DATA_DIR \
  --max-items 512 \
  --batch-size 1 \
  --output-dir $OUTPUT_DIR


In [None]:
# 6) Load fine-tuned checkpoint and generate SVG for a sample
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
from svgllm.renderer import render_svg_to_rgb
from PIL import Image
import numpy as np

ckpt_dir = OUTPUT_DIR  # fine-tuned output dir from training cell
processor = AutoProcessor.from_pretrained(ckpt_dir)
model = LlavaForConditionalGeneration.from_pretrained(
    ckpt_dir,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
).to(0 if torch.cuda.is_available() else 'cpu')

# Pick a sample (first) from the dataset previewed earlier
if 'ds' not in globals():
  from svgllm.data.svg_dataset import SvgSftDataset
  ds = SvgSftDataset(DATA_DIR, image_size=(256, 256), max_items=8)

sample = ds[0]
image = sample.image

messages = [
    {"role": "user", "content": [
        {"type": "text", "text": "Reproduce this image as an SVG."},
        {"type": "image"},
    ]},
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(images=image, text=prompt, return_tensors='pt')
if torch.cuda.is_available():
  inputs = {k: v.to(0) for k, v in inputs.items()}

gen_ids = model.generate(
    **inputs,
    max_new_tokens=512,
    do_sample=True,
    temperature=0.7,
    top_p=0.9,
)
input_len = inputs["input_ids"].shape[1]
svg_text = processor.tokenizer.decode(gen_ids[0][input_len:], skip_special_tokens=True)
print(svg_text[:1000])


In [None]:
# 7) Render the generated SVG next to the target
from IPython.display import display

target_np = np.array(image.convert('RGB'))
try:
  pred_np = render_svg_to_rgb(svg_text, size=(256, 256))
except Exception as e:
  print('Render failed:', e)
  pred_np = np.zeros_like(target_np)

print('Target (left) vs Generated (right)')
display(Image.fromarray(np.hstack([target_np, pred_np])))
