# Set up

In [None]:
# !sudo apt install libcairo2-dev pkg-config python3-dev # uncomment this if you're on linux
!pip install -r ./requirements.txt

## Loading Dataset

### Loading the DeepSVG Dataset

Use this cell if ./pretrained/hierarchical_ordered.pth.tar doesn't exist. Downloaded files should be moved to ./pretrained.

In [None]:
!chmod u+x ./pretrained/download.sh
!./pretrained/download.sh

Use this cell if you need to download the dataset. Downloaded files should be moved to ./dataset.

In [None]:
!chmod u+x ./dataset/download.sh
!./dataset/download.sh

### VAE

In [None]:
from configs.deepsvg.hierarchical_ordered import Config
from deepsvg import utils
import torch

pretrained_path = "./pretrained/hierarchical_ordered.pth.tar"
device = torch.device("cuda:0"if torch.cuda.is_available() else "cpu")

cfg = Config()
vae_model = cfg.make_model().to(device)
utils.load_model(pretrained_path, vae_model)
vae_model.eval()

In [4]:
import torch
from deepsvg.utils.utils import batchify
from deepsvg.difflib.tensor import SVGTensor
from deepsvg.svglib.svg import SVG
from deepsvg.svglib.geom import Bbox

def encode(data, model):
    model_args = batchify((data[key] for key in cfg.model_args), device)
    with torch.no_grad():
        z = model(*model_args, encode_mode=True)
        return z.squeeze(dim=0).squeeze(dim=0)

def decode(z, model, do_display=True, return_svg=False, return_png=False):
    commands_y, args_y = model.greedy_sample(z=z)
    tensor_pred = SVGTensor.from_cmd_args(commands_y[0].cpu(), args_y[0].cpu())
    svg_path_sample = SVG.from_tensor(tensor_pred.data, viewbox=Bbox(256), allow_empty=True).normalize().split_paths().set_color("random")

    if return_svg:
        return svg_path_sample

    return svg_path_sample.draw(do_display=do_display, return_png=return_png)

## Setting up the Dataloader

### Dataloader

creating the dataloader using the encoded dataset

In [5]:
from deepsvg.svgtensor_dataset import load_dataset
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

dataset = load_dataset(cfg) # the DeepSVG dataset as {'commands': [...], 'args': [...]}


def dataloader_with_transformed_dataset(batch_n: int, length: int = None):
    encoded_dataset_with_labels = []
    data_len = length if length else len(dataset)

    for i in range(data_len):
        xy = dataset.get(i, model_args=['commands', 'args', 'label'])
        label = xy.pop('label')
        encoded_dataset_with_labels.append([encode(xy, vae_model), label])

    dataset_size = len(encoded_dataset_with_labels)
    batch_size = batch_n
    validation_split = .2
    shuffle_dataset = True
    random_seed= 42

    # Creating data indices for training and validation splits:

    indices = list(range(dataset_size))
    split = int(np.floor(validation_split * dataset_size))
    if shuffle_dataset :
        np.random.seed(random_seed)
        np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]

    # Creating PT data samplers and loaders:
    train_sampler = SubsetRandomSampler(train_indices)
    valid_sampler = SubsetRandomSampler(val_indices)

    train_loader = DataLoader(encoded_dataset_with_labels, batch_size=batch_size, sampler=train_sampler, drop_last=True,)
    validation_loader = DataLoader(encoded_dataset_with_labels, batch_size=batch_size, sampler=valid_sampler, drop_last=True,)

    return train_loader, validation_loader

In [21]:
def num_classes(dataloader):
    all_classes = set()

    for x, y in dataloader:
          all_classes.update(set(y.numpy()))

    return len(all_classes)

# Model

In [None]:
from diffusion import create_diffusion
from svgfusion import DiT

def create_model(predict_xstart=True, dropout=0.1, n_classes=56, depth=28, learn_sigma=True, num_heads=16):

    model = DiT(class_dropout_prob=dropout, num_classes=n_classes, depth=depth, learn_sigma=learn_sigma, num_heads=num_heads)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    model.to(device)
    diffusion = create_diffusion(timestep_respacing="", predict_xstart=predict_xstart)  # default: 1000 steps, linear noise schedule

    model.train()  # important! This enables embedding dropout for classifier-free guidance
    
    return model, diffusion

# Training

### Train Utils

In [16]:
from pathlib import Path
from datetime import datetime
from deepsvg.svglib.utils import to_gif
import IPython.display as ipd
import cairosvg
import io
import os
from PIL import Image

def draw(svg_obj, fill=False, file_path=None, do_display=True, return_png=False,
         with_points=False, with_handles=False, with_bboxes=False, with_markers=False, color_firstlast=False,
         with_moves=True, width=600, height=600):
    if file_path is not None:
        _, file_extension = os.path.splitext(file_path)
        if file_extension == ".svg":
            svg_obj.save_svg(file_path)
        elif file_extension == ".png":
            svg_obj.save_png(file_path)
        else:
            raise ValueError(f"Unsupported file_path extension {file_extension}")

    svg_str = svg_obj.to_str(fill=fill, with_points=with_points, with_handles=with_handles, with_bboxes=with_bboxes,
                              with_markers=with_markers, color_firstlast=color_firstlast, with_moves=with_moves)

    if do_display:
        ipd.display(ipd.SVG(svg_str))

    if return_png:
        if file_path is None:
            img_data = cairosvg.svg2png(bytestring=svg_str, output_width=width, output_height=height)
            return Image.open(io.BytesIO(img_data))
        else:
            _, file_extension = os.path.splitext(file_path)

            if file_extension == ".svg":
                img_data = cairosvg.svg2png(url=file_path)
                return Image.open(io.BytesIO(img_data))
            else:
                return Image.open(file_path)

def log_training(epoch_number: int, loss: float, timestep: int = None):
    Path("./artifacts/").mkdir(parents=True, exist_ok=True)

    now = datetime.now()
    current_time = now.strftime("%H:%M:%S")

    f = open("artifacts/log.txt", "a")

    if timestep: f.write(f"{current_time} Epoch {epoch_number}: {loss} for timestep {timestep} \n\n")
    else: f.write(f"{current_time} Epoch {epoch_number}: {loss} \n\n")

    f.close()


def sample_from_diffusion(diffusion, model, class_labels, x_t=None, normalization_factor=0.7, display_gif=False, cfg_scale=4):

    img_list = []

    # Create sampling noise:
    n = len(class_labels)
    z = torch.randn(n, 1, 256, device=device) if not x_t else x_t # z = torch.randn(1, 1, 256, device=device)
    y = torch.tensor(class_labels, device=device)

    # Setup classifier-free guidance:
    z = torch.cat([z, z], 0)
    y_null = torch.tensor([n] * n, device=device) # [1]
    y = torch.cat([y, y_null], 0)
    model_kwargs = dict(y=y, cfg_scale=cfg_scale)

    # Sample images:
    if display_gif:
      final_sample = None
      for sample in  diffusion.p_sample_loop_progressive(
          model.forward_with_cfg, z.shape, z, clip_denoised=False,
          model_kwargs=model_kwargs, progress=True, device=device
      ):
        samples, _ = sample["sample"].chunk(2, dim=0)  # Remove null class samples
        # samples = samples * normalization_factor
        sample_svg = decode((samples.unsqueeze(dim=0) / samples.std()) * normalization_factor,
                               vae_model, return_svg=True, do_display=False) #  * normalization_factor
        sample_png = draw(sample_svg, width=1200, height=1200, do_display=False, return_png=True)
        img_list.append(sample_png)
        final_sample = sample

      to_gif(img_list[::2])
      return final_sample

    else:
      samples = diffusion.p_sample_loop(
          model.forward_with_cfg, z.shape, z, clip_denoised=False,
          model_kwargs=model_kwargs, progress=True, device=device
      )

      samples, _ = samples.chunk(2, dim=0)  # Remove null class samples
      # samples = samples * normalization_factor
      decode((samples.unsqueeze(dim=0) / samples.std()) * normalization_factor, vae_model,) # * normalization_factor

      return samples

### Training Loop

In [None]:
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau

def train(epochs=100, learning_rate=0.0001, batch_size=10, n_samples=1000, use_scheduler=True, dropout=0.1, predict_xstart=True, depth=28, learn_sigma=True, num_heads=16):    
    train_dataloader, valid_dataloader = dataloader_with_transformed_dataset(batch_n=batch_size, length=n_samples)

    model, diffusion = create_model(dropout=dropout, predict_xstart=predict_xstart,
                                    n_classes=num_classes(train_dataloader), depth=depth, 
                                    learn_sigma=learn_sigma, num_heads=num_heads)
    
    magical_number = 0.7128
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0)
    scheduler = ReduceLROnPlateau(optimizer, 'min', patience=5)
    
    for epoch in range(epochs):
        avg_loss = 0
        for x, y in train_dataloader:
            x = x.to(device)
            y = y.to(device)
    
            x = x.squeeze().unsqueeze(dim=1)
            x = x / magical_number # mean of std's of latents
    
            model_kwargs = dict(y=y)
    
            t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=device)
    
            loss_dict = diffusion.training_losses(model, x, t, model_kwargs)
            loss = loss_dict["loss"].mean()
    
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            avg_loss += loss.item()
                    
        if use_scheduler: scheduler.step(avg_loss / len(train_dataloader))
        print(optimizer.param_groups[0]['lr'])
        log_training(epoch, avg_loss / len(train_dataloader))

In [None]:
train()

### sampling

In [None]:
model.eval()  # important!

In [None]:
tmp_svg = decode((tmp.unsqueeze(dim=0) / tmp.std()) * 0.6, vae_model, return_svg=True, do_display=False)
draw(tmp_svg, width=1200, height=1200, return_png=True)
# for i in range(samples.shape[0]):
#   decode((samples[i].unsqueeze(dim=0).unsqueeze(dim=0) / samples.std()) * 0.7, vae_model, return_png=False, do_display=True)

In [None]:
tmp = sample_from_diffusion(diffusion=diffusion, model=model, class_labels=[4], normalization_factor=magical_number, display_gif=False)
# for i in range(10): sample_from_diffusion(diffusion=diffusion, model=model, class_labels=[0])`

In [None]:
# img_list = []
samples, _ = tmp["sample"].chunk(2, dim=0)  # Remove null class samples

test_svg = decode((samples.unsqueeze(dim=0) / samples.std()) * magical_number ,
                               vae_model, return_png=True, do_display=True)
# test_img1 = draw(test_svg, return_png=True)
# test_img2 = draw(test_svg, return_png=True)
# img_list.append(test_img1)
# img_list.append(test_img2)

In [None]:
for i in range(10):
  tmp = 1 - (1/(i+1))
  for i in range(samples.shape[0]):
    decode((samples[i].unsqueeze(dim=0).unsqueeze(dim=0) / samples.std()) * tmp, vae_model, return_png=False, do_display=True)

# Saving/Loading the Model

In [None]:
from torch.optim import AdamW
from copy import deepcopy

def save_model(model, ema, optimizer, diffusion, total_epochs, predict_noise=False):
    export_dir = './models'

    Path(export_dir).mkdir(parents=True, exist_ok=True)

    checkpoint = {
      "model": model.state_dict(),
      "ema": ema.state_dict(),
      "opt": optimizer.state_dict(),
      "diffusion": diffusion,
    }
    exported_model_path = f"{export_dir}/predict_{'noise' if predict_noise else 'x0'}_{total_epochs}.pt"
    torch.save(checkpoint, exported_model_path)


def load_model(model_path, num_classes, device, for_training=True, return_optimizer=False):
    model = DiT(num_classes=num_classes).to(device)
    ema = deepcopy(model).to(device)
    state = torch.load(model_path, map_location=device)

    ema.load_state_dict(state['ema'])
    model.load_state_dict(state['model'])

    optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=0)
    optimizer.load_state_dict(state['opt'])

    if not for_training:
      model.eval()
      return model, state['diffusion'], ema
    else:
      return model, optimizer, state['diffusion'], ema

In [None]:
save_model(model, ema, optimizer, diffusion, total_epochs)