In [None]:
%env CUDA_VISIBLE_DEVICES=0
%load_ext autoreload
%autoreload
from utils import svg_string_to_tensor, svg_to_tensor
from torchvision.utils import make_grid, save_image
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import torch
import os
import yaml
from models import VSQ, VQ_SVG_Stage2
from utils import get_side_by_side_reconstruction, map_wand_config
from dataset import VSQDataset, VQDataset, VQDataModule
from tokenizer import VQTokenizer
from svg_fixing import min_dist_fix_global, min_dist_fix
from glob import glob
print(torch.cuda.is_available())

import os
import time
import torch
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm

from models.decoder import SketchDecoder
from deepsvg.difflib.tensor import SVGTensor
from deepsvg.svglib.svg import SVG
from deepsvg.svglib.geom import Bbox
from transformers import AutoTokenizer

# im2vec
from dataset import GenericRasterizedSVGDataset
from models.vector_vae_nlayers import VectorVAEnLayers
import pydiffvg
from typing import List

os.environ["TOKENIZERS_PARALLELISM"] = "false"
device = "cuda" if torch.cuda.is_available() else "cpu"

# Setup

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_model_from_basepath(basepath, device="cpu"):
    """
    returns model, ds, config
    """
    config = yaml.load(open(os.path.join(basepath, 'config.yaml'), 'r'), Loader=yaml.FullLoader)
    config["data_params"]["max_shapes_per_svg"] = 2000
    config["data_params"]["train_batch_size"] = 2
    config["data_params"]["val_batch_size"] = 2
    model = VSQ(**config["model_params"]).to(device)
    all_ckpts = glob(os.path.join(basepath, "checkpoints", "*.ckpt"))
    # sort by date
    latest_ckpt_path = sorted(all_ckpts, key=os.path.getmtime)[-1]
    state_dict = torch.load(latest_ckpt_path, map_location=device)["state_dict"]
    try:
        model.load_state_dict(state_dict)
    except:
        model.load_state_dict({k.replace("model.", ""): v for k, v in state_dict.items()})
    ds = VSQDataset(**config["data_params"], train=False)
    model = model.eval()
    return model, ds, config

def load_stage2_model_from_basepath(vsq_model, basepath, device="cpu"):
    """
    returns model, ds, config
    """
    config = yaml.load(open(os.path.join(basepath, 'config.yaml'), 'r'), Loader=yaml.FullLoader)
    config = map_wand_config(config)
    config["data_params"]["train_batch_size"] = 2
    config["data_params"]["val_batch_size"] = 2

    tokenizer = VQTokenizer(vsq_model, 
                        config["data_params"].get("grid_size") or config["data_params"].get("width"), 
                        config['stage1_params']["num_codes_per_shape"], 
                        config["model_params"]["text_encoder_str"],
                        lseg = config["stage1_params"]["lseg"], 
                        device = device,
                        max_text_token_length=config["data_params"].get("max_text_token_length") or 50)

    model = VQ_SVG_Stage2(tokenizer, **config["model_params"], device=device)

    text_only_tokenizer = VQTokenizer(vsq_model, 
                                  config["data_params"].get("grid_size") or config["data_params"].get("width"), 
                                  config['stage1_params']["num_codes_per_shape"], 
                                  config["model_params"]["text_encoder_str"], 
                                  use_text_encoder_only=True, 
                                  lseg=config["stage1_params"]["lseg"],
                                  codebook_size=tokenizer.codebook_size,
                                  max_text_token_length=config["data_params"].get("max_text_token_length") or 50,)
    dm = VQDataModule(tokenizer=text_only_tokenizer,
                    **config["data_params"], 
                    context_length=config['model_params']['max_seq_len'],
                    train=False)
    dm.setup()
    
    all_ckpts = glob(os.path.join(basepath, "checkpoints", "*.ckpt"))
    all_ckpts = [ckpt for ckpt in all_ckpts if not "last" in ckpt]
    # sort by date
    latest_ckpt_path = sorted(all_ckpts, key=os.path.getmtime)[-1]
    print(f"loading from {latest_ckpt_path}")
    state_dict = torch.load(latest_ckpt_path, map_location=device)["state_dict"]
    try:
        model.load_state_dict(state_dict)
        print("loaded weights successfully")
    except:
        model.load_state_dict({k.replace("model.", "", 1) if k.startswith("model.") else k: v for k, v in state_dict.items()})
        print("loaded weights successfully")
    model = model.eval()
    return model, dm, config

def raster_svg(pixels: np.ndarray):
    # deepcopy
    pixels = pixels.copy()
    try:
        pixels -= 6  # 3 END_TOKEN + 1 SVG_END + 2 CAUSAL_TOKEN

        svg_tensors = []
        path_tensor = []
        for i, pix in enumerate(pixels):
            # COMMAND = 0
            # START_POS = [1, 3)
            # CONTROL1 = [3, 5)
            # CONTROL2 = [5, 7)
            # END_POS = [7, 9)
            if pix[0] == -3:  # Move
                cmd_tensor = np.zeros(9)
                cmd_tensor[0] = 0
                cmd_tensor[7:9] = pixels[i+2]
                start_pos = pixels[i+1]
                end_pos = pixels[i+2]
                if np.all(start_pos == end_pos) and path_tensor:
                    svg_tensors.append(torch.tensor(path_tensor))
                    path_tensor = []
                path_tensor.append(cmd_tensor.tolist())
            elif pix[0] == -2:  # Line
                cmd_tensor = np.zeros(9)
                cmd_tensor[0] = 1
                cmd_tensor[7:9] = pixels[i+1]
                path_tensor.append(cmd_tensor.tolist())
            elif pix[0] == -1:  # Curve
                cmd_tensor = np.zeros(9)
                cmd_tensor[0] = 2
                cmd_tensor[3:5] = pixels[i+1]
                cmd_tensor[5:7] = pixels[i+2]
                cmd_tensor[7:9] = pixels[i+3]
                path_tensor.append(cmd_tensor.tolist())
        append_t = torch.tensor(path_tensor)
        if append_t.size(0) > 0:
            svg_tensors.append(append_t)
        return [svg_tensors]
    except Exception as error_msg:  
        print(error_msg, pixels)
        assert False, "error in raster_svg"

@torch.no_grad()
def iconshop_gen_to_svg(sample, bbox=200):
    gen_data = []
    svgs = []

    for sample_pixel in sample:
        gen_data += raster_svg(sample_pixel)

    for index, data in enumerate(gen_data):
        print("decoding svg", index)
        paths = []

        for d in data:
            path = SVGTensor.from_data(d)
            path = SVG.from_tensor(path.data, viewbox=Bbox(bbox))
            path.fill_(True)
            paths.append(path)
        path_groups = paths[0].svg_path_groups
        for k in range(1, len(paths)):
            path_groups.extend(paths[k].svg_path_groups)
        svg = SVG(path_groups, viewbox=Bbox(bbox))
        svgs.append(svg)

    return svgs

@torch.no_grad()
def sample_iconshop_from_text(text:str, model, tokenizer, n_samples=1) -> list[SVG]:
    encoded_dict = tokenizer(
            text,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=50,
            add_special_tokens=True,
            return_token_type_ids=False,  # for RoBERTa
        )
    batched_text_tokens = torch.stack([encoded_dict["input_ids"].squeeze()]).to(device)
    sample = model.sample(n_samples=n_samples, text=batched_text_tokens.repeat(n_samples, 1))
    return iconshop_gen_to_svg(sample)

def pretty_print_trainable_params(model, only_trainable:bool=False):
    print(f"Num trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.2f}M")
    if not only_trainable:
        print(f"Num params: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M")

def map_wand_config(config):
    new_config = {}
    for k, v in config.items():
        if not "wandb" in k:
            if "value" in v and isinstance(v, dict):
                new_config[k] = v["value"]
            else:
                new_config[k] = v
    return new_config

def save_im2vec_points_to_svg(model:VectorVAEnLayers,
                            all_points:List, 
                            imsize, 
                            save_base_dir, 
                            filename):
    shape_groups = []
    shapes = []
    for k in range(len(all_points)):
        # Get point parameters from network
        points = all_points[k].cpu()#[self.sort_idx[k]]
        if points.ndim > 2:
            points = points.squeeze(0)
        points = points * imsize
        color = torch.cat([torch.tensor([0,0,0,1]),])
        num_ctrl_pts = torch.zeros(model.curves, dtype=torch.int32) + 2

        path = pydiffvg.Path(
            num_control_points=num_ctrl_pts, points=points,
            is_closed=True)

        shapes.append(path)
        path_group = pydiffvg.ShapeGroup(
            shape_ids=torch.tensor([len(shapes) - 1]),
            fill_color=None,
            stroke_color=color)
        shape_groups.append(path_group)
    pydiffvg.save_svg(f"{save_base_dir}/{filename}",
                        imsize, imsize, shapes, shape_groups)

# Inference IconShop

In [None]:
NUM_SAMPLE = 1
BS = 4
BBOX = 200
cfg = {
    'pix_len': 512,
    'text_len': 50,

    'tokenizer_name': 'google/bert_uncased_L-12_H-512_A-8',
    'word_emb_path': 'iconshop_checkpoints/word_embedding_512.pt',
    'pos_emb_path': None,
}


df = pd.read_csv(".data/stage2_split.csv")
df = df[df["split"] == "test"]

tokenizer = AutoTokenizer.from_pretrained(cfg['tokenizer_name'])

sketch_decoder = SketchDecoder(
    config={
        'hidden_dim': 1024,
        'embed_dim': 512, 
        'num_layers': 16, 
        'num_heads': 8,
        'dropout_rate': 0.1  
    },
    pix_len=cfg['pix_len'],
    text_len=cfg['text_len'],
    num_text_token=tokenizer.vocab_size,
    word_emb_path=cfg['word_emb_path'],
    pos_emb_path=cfg['pos_emb_path'],
)
sketch_decoder.load_state_dict(torch.load(os.path.join("iconshop_checkpoints","epoch_100", 'pytorch_model.bin')))
sketch_decoder = sketch_decoder.to(device).eval()

single text prompt

In [None]:
n_samples=1
text = "heart heart-shape hearts like love"
all_svgs = sample_iconshop_from_text(text, sketch_decoder, tokenizer, n_samples=n_samples)

all_svgs[0].draw()

dataset text prompts

In [None]:
all_svgs = []
all_texts = []
for i in range(25):
    row = df.sample(1)
    text = row.description.values[0]
    svgs = sample_iconshop_from_text(text, sketch_decoder, tokenizer, n_samples=1)
    all_svgs.extend(svgs)
    all_texts.extend([text] * len(svgs))

In [None]:
for i in range(len(all_svgs)):
    print(all_texts[i])
    all_svgs[i].draw()

# Inference $VSQ_l$

In [None]:
base_path = "results/VSQ_l"
model, ds, config = load_model_from_basepath(base_path, device=device)

In [None]:
from utils import drawing_to_tensor

drawings = []
grid_sizes = [56, 72, 100, 128, 200, 256]
idx = 3

for grid_size in grid_sizes:
    img, drawing = get_side_by_side_reconstruction(model, 
                                                   ds, 
                                                   idx = idx, 
                                                   device = device, 
                                                   dataset_name="glyphazzn", 
                                                   override_global_stroke_width=0.4, 
                                                   return_drawing=True, 
                                                   quantize_grid_size=grid_size)
    drawings.append(drawing)


plt.imshow(make_grid([drawing_to_tensor(drawings[0]), drawing_to_tensor(drawings[-1])]).permute(1,2,0))
plt.axis("off")
plt.show()

In [None]:
plt.imshow(img.permute(1,2,0))
plt.axis("off")
plt.show()

# Inference $TM_l$

In [None]:
# dataset returns: text_tokens, attention_mask, vq_tokens, vq_targets, torch.ones(1).to(text_tokens.device)*self.pad_token
stage2_base_path = "results/TM_l"

# load config to extract stage1 params
config = yaml.load(open(os.path.join(stage2_base_path, 'config.yaml'), 'r'), Loader=yaml.FullLoader)
config = map_wand_config(config)

# load VSQ
vsq_base_path = config["stage1_params"]["checkpoint_path"].split("checkpoints")[0]
vsq_model = load_model_from_basepath(vsq_base_path, device=device)[0]

stage2_model, stage2_dm, stage2_config = load_stage2_model_from_basepath(vsq_model, stage2_base_path, device=device)
stage2_val_dl = stage2_dm.val_dataloader()

from text

In [None]:
stage2_model._generate_from_text("camera picture", temperature=0.0, return_drawing=True, post_process=True, global_position_fixing=True, max_dist = 2.0, v2=False)

from dataset

In [None]:
generation_drawings = []
prompts = []
generations = []
stage2_ds = stage2_val_dl.dataset

for i in range(20):
    random_idx = np.random.randint(0, len(stage2_ds))
    text_tokens, attention_mask, vq_tokens, vq_targets, pad_tokens = stage2_ds[random_idx]

    text = stage2_model.tokenizer.decode_text(text_tokens)
    prompts.append(text)
    print(f"Doing '{text}' - {i+1}/{20}")

    vq_tokens = torch.tensor([1], device=device, dtype=torch.int64)
    text_tokens = text_tokens.unsqueeze(0).to(device)
    attention_mask = attention_mask.unsqueeze(0).to(device)
    vq_tokens = vq_tokens.unsqueeze(0).to(device)

    generation, reason = stage2_model.generate(text_tokens, attention_mask, vq_tokens, temperature=0.1,sampling_method="top_p", sampling_kwargs={"thres":0.5})
    generations.append(generation)
    gen_drawing = stage2_model.tokenizer._tokens_to_svg_drawing(generation, post_process=True, max_dist_frac=0.01)
    generation_drawings.append(gen_drawing)

print(prompts[0])
generation_drawings[0]

# Inference Im2Vec

Setup & load Model

In [None]:

im2vecsweep_base_config = {
    "base_path": "results/Im2Vec",
    "im2vec_model_path": "checkpoints/last-v2.ckpt",
    "im2vec_config_path": "config.yaml",
    "out_base_dir": "results/Im2Vec",
    "dataset": "icons",
}

class_name = "figr8_full"

selected_config = im2vecsweep_base_config
im2vecsweep_base_config["class_name"] = class_name

base_path = os.path.join(selected_config["base_path"], class_name)
im2vec_model_path = os.path.join(base_path, selected_config["im2vec_model_path"])
im2vec_config_path = os.path.join(base_path, selected_config["im2vec_config_path"])
dataset = selected_config["dataset"]
out_base_dir = os.path.join(selected_config["out_base_dir"], class_name)

with open(im2vec_config_path, "r") as f:
        try:
            im2vec_config = yaml.safe_load(f)
        except yaml.YAMLError as exc:
            print(exc)

im2vec_config = map_wand_config(im2vec_config)

im2vec_config["model_params"]["imsize"] = 128
im2vec_config["model_params"]["latent_dim"] = 256
im2vec_config["data_params"]["img_size"] = 128


# ds = GenericRasterizedSVGDataset(**im2vec_config["data_params"], train=None)
im2vec = VectorVAEnLayers(**im2vec_config["model_params"])
state_dict = torch.load(im2vec_model_path, map_location=device)["state_dict"]
try:
    im2vec.load_state_dict(state_dict)
except:
    im2vec.load_state_dict({k.replace("model.", ""): v for k, v in state_dict.items()})

im2vec = im2vec.eval().to(device)
im2vec.base_control_features = im2vec.base_control_features.to(device)

inference (no conditioning as Im2Vec is just a VAE)

In [None]:
samples_points = im2vec.multishape_sample(10, return_points=True, device=device)
idx = 0
save_im2vec_points_to_svg(im2vec, samples_points[idx], 72, "results/Im2Vec",f"im2vec_sample_{idx}.svg")

# adjust these settings for altering visual appearance
img = svg_to_tensor(f"im2vec_sample_{idx}.svg", new_stroke_color="black", new_fill_color="none", new_stroke_width=0.7, output_width=480)
plt.imshow(img.permute(1,2,0))
plt.axis("off")
plt.show()