# Setup

In [None]:
# !sudo apt install libcairo2-dev pkg-config python3-dev # uncomment this if you're on linux
!pip install -r ./requirements.txt

## Loading Dataset

### Loading the DeepSVG Dataset

Use this cell if ./pretrained/hierarchical_ordered.pth.tar doesn't exist. Downloaded files should be moved to ./pretrained.

In [None]:
!chmod u+x ./pretrained/download.sh
!./pretrained/download.sh

Use this cell if you need to download the dataset. Downloaded files should be moved to ./dataset.

In [None]:
!chmod u+x ./dataset/download.sh
!./dataset/download.sh

# Defining the Model Training

## Autoencoder

In [None]:
import torch
from configs.hierarchical_ordered import Config
from deepsvg import utils

pretrained_path = "./pretrained/hierarchical_ordered.pth.tar"

device = torch.device("cuda:0"if torch.cuda.is_available() else "cpu")

cfg = Config()
vae_model = cfg.make_model().to(device)
utils.load_model(pretrained_path, vae_model)
vae_model.eval()

## Model

In [5]:
from diffusion import create_diffusion
from svgfusion import DiT

def create_model(predict_xstart=True, dropout=0.1, n_classes=56, depth=28, learn_sigma=True, num_heads=16):

    model = DiT(class_dropout_prob=dropout, num_classes=n_classes, depth=depth, learn_sigma=learn_sigma, num_heads=num_heads)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    model.to(device)
    diffusion = create_diffusion(timestep_respacing="", predict_xstart=predict_xstart, 
                                 learn_sigma=learn_sigma)  # default: 1000 steps, linear noise schedule

    model.train()  # important! This enables embedding dropout for classifier-free guidance
    
    return model, diffusion

## Training

In [6]:
import wandb
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from dataset.dataset import num_classes, dataloader_with_transformed_dataset

def train():
    config_defaults = {
            'optimizer': 'adam',
            'predict_xstart': True,
            'learn_sigma': True,
            'use_scheduler': True,
            'num_heads': 16,
            'depth': 28,
            'dropout': 0.1,
            'epochs': 100,
            'learning_rate': 0.001,
            'batch_size': 100,
    }
    wandb.init(config=config_defaults)
    config = wandb.config
    
    train_dataloader, valid_dataloader = dataloader_with_transformed_dataset(vae_model, cfg, batch_n=config.batch_size, length=1000)

    model, diffusion = create_model(dropout=config.dropout, predict_xstart=config.predict_xstart,
                                    n_classes=num_classes(train_dataloader), depth=config.depth, 
                                    learn_sigma=config.learn_sigma, num_heads=config.num_heads)
    
    magical_number = 0.7128
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    optimizer = AdamW(model.parameters(), lr=config.learning_rate, weight_decay=0)
    scheduler = ReduceLROnPlateau(optimizer, 'min', patience=5)
    
    for epoch in range(config.epochs):
        avg_loss = 0
        for x, y in train_dataloader:
            x = x.to(device)
            y = y.to(device)
    
            x = x.squeeze().unsqueeze(dim=1)
            x = x / magical_number # mean of std's of latents
    
            model_kwargs = dict(y=y)
    
            t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=device)
    
            loss_dict = diffusion.training_losses(model, x, t, model_kwargs)
            loss = loss_dict["loss"].mean()
    
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            avg_loss += loss.item()
            
            wandb.log({"batch_loss": loss.item()})     
        
        if config.use_scheduler: scheduler.step(avg_loss / len(train_dataloader))
        wandb.log({"loss": avg_loss / len(train_dataloader), "epoch": epoch, 'learning_rate': optimizer.param_groups[0]['lr']}) 

A single training run for debugging.

In [None]:
train()
wandb.finish()

# Defining the Sweep

## Config

In [None]:
import wandb

wandb.login()

In [8]:
sweep_config = {
    'method': 'random',
    'metric': {
        'name': 'loss',
        'goal': 'minimize'   
    },
    'parameters': {
        'optimizer': {
            'values': ['adam']
        },
        'predict_xstart': {
            'value': True,
        },
        'learn_sigma':{
            'value': True, # [True, False]
        },
        'use_scheduler':{
            'values': [True, False],
        },
        'num_heads': {
            'values': [16, 32, 64, 128, 256]
        },
        'depth': {
            'distribution': 'int_uniform',
            'min': 28,
            'max': 100, 
        },
        'dropout': {
              'values': [0.3, 0.4, 0.5]
        },
        'epochs': {
            'value': 100
        },
        'learning_rate': {
            # a flat distribution between 0.01 and 0.0001
            'distribution': 'uniform',
            'min': 0.0001,
            'max': 0.01
        },
        'batch_size': {
            # integers between 32 and 256
            # with evenly-distributed logarithms 
            'distribution': 'q_log_uniform_values',
            'q': 8,
            'min': 16,
            'max': 128,
        }
    },
}

## Sweep

In [None]:
import wandb

sweep_id = wandb.sweep(sweep_config, project="svgfusion-sweep")

In [None]:
import wandb

wandb.agent(sweep_id, train, count=100)