# Set up

In [None]:
# !sudo apt install libcairo2-dev pkg-config python3-dev # uncomment this if you're on linux
!pip install -r ./requirements.txt

## Loading Dataset

### Loading the DeepSVG Dataset

Use this cell if ./pretrained/hierarchical_ordered.pth.tar doesn't exist. Downloaded files should be moved to ./pretrained.

In [None]:
!chmod u+x ./pretrained/download.sh
!./pretrained/download.sh

Use this cell if you need to download the dataset. Downloaded files should be moved to ./dataset.

In [None]:
!chmod u+x ./dataset/download.sh
!./dataset/download.sh

### VAE

In [None]:
from configs.hierarchical_ordered import Config
from deepsvg import utils
import torch

pretrained_path = "./pretrained/hierarchical_ordered.pth.tar"
device = torch.device("cuda:0"if torch.cuda.is_available() else "cpu")

cfg = Config()
vae_model = cfg.make_model().to(device)
utils.load_model(pretrained_path, vae_model)
vae_model.eval()

# Model

In [None]:
from diffusion import create_diffusion
from svgfusion import DiT

def create_model(predict_xstart=True, dropout=0.1, n_classes=56, depth=28, learn_sigma=True, num_heads=16):

    model = DiT(class_dropout_prob=dropout, num_classes=n_classes, depth=depth, learn_sigma=learn_sigma, num_heads=num_heads)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    model.to(device)
    diffusion = create_diffusion(timestep_respacing="", predict_xstart=predict_xstart,
                                 learn_sigma=learn_sigma)  # default: 1000 steps, linear noise schedule

    model.train()  # important! This enables embedding dropout for classifier-free guidance

    return model, diffusion

# Training

### Training Loop

In [None]:
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from dataset.dataset import num_classes, dataloader_with_transformed_dataset
from utils import log_training

def train(epochs=100, learning_rate=0.0001, batch_size=10, n_samples=1000, use_scheduler=True, dropout=0.1, predict_xstart=True, depth=28, learn_sigma=True, num_heads=16, magical_number=0.7128):
    train_dataloader, valid_dataloader = dataloader_with_transformed_dataset(vae_model, cfg, batch_n=batch_size, length=n_samples)

    model, diffusion = create_model(dropout=dropout, predict_xstart=predict_xstart,
                                    n_classes=num_classes(train_dataloader), depth=depth,
                                    learn_sigma=learn_sigma, num_heads=num_heads)

    magical_number = 0.7128
    device = "cuda" if torch.cuda.is_available() else "cpu"

    optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0)
    scheduler = ReduceLROnPlateau(optimizer, 'min', patience=5)

    for epoch in range(epochs):
        avg_loss = 0
        for x, y in train_dataloader:
            x = x.to(device)
            y = y.to(device)

            x = x.squeeze().unsqueeze(dim=1)
            x = x / magical_number # mean of std's of latents

            model_kwargs = dict(y=y)

            t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=device)

            loss_dict = diffusion.training_losses(model, x, t, model_kwargs)
            loss = loss_dict["loss"].mean()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            avg_loss += loss.item()

        if use_scheduler: scheduler.step(avg_loss / len(train_dataloader))
        print(optimizer.param_groups[0]['lr'])
        log_training(epoch, avg_loss / len(train_dataloader))

    return model, optimizer, diffusion, scheduler, num_classes(train_dataloader)

In [None]:
config = {
        'predict_xstart': True,
        'learn_sigma': True,
        'use_scheduler': True,
        'num_heads': 16,
        'depth': 28,
        'dropout': 0.1,
        'epochs': 100,
        'learning_rate': 0.0001,
        'batch_size': 100,
        'n_samples': None, # None = all of the samples
        'magical_number': 0.7128, # mean of std's of latents
}

model, optimizer, diffusion, scheduler, n_classes = train(**config)

# Saving/Loading the Model

In [None]:
from torch.optim import AdamW
from pathlib import Path

def save_model(model, optimizer, diffusion, scheduler, n_classes, config):
    export_dir = './models'

    Path(export_dir).mkdir(parents=True, exist_ok=True)

    # will save everything unless this turns out to be heavy on memory
    checkpoint = {
      "model": model.state_dict(),
      "opt": optimizer.state_dict(),
      "diffusion": diffusion,
      "scheduler": scheduler,
      "num_classes": n_classes,
      "config": config,
    }
    exported_model_path = f"{export_dir}/predict_{'x0' if config['predict_xstart'] else 'noise'}_{config['epochs']}.pt"
    torch.save(checkpoint, exported_model_path)


def load_model(model_path, device, for_training=True):
    state = torch.load(model_path, map_location=device)

    model = DiT(num_classes=state['num_classes']).to(device)
    model.load_state_dict(state['model'])

    optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=0)
    optimizer.load_state_dict(state['opt'])

    if not for_training:
      model.eval()
      return model, state['diffusion'], state['config']
    else:
      return model, optimizer, state['diffusion'], state['scheduler'], state['config']

In [None]:
save_model(model, optimizer, diffusion, scheduler, n_classes, config)