In [1]:
import os
from typing import Any, Sequence, Optional, Tuple, Iterator, Dict, Callable, Union, Type, List
import numpy as np
import matplotlib.pyplot as plt
import pprint

# Changing fonts to be latex typesetting
from matplotlib import rcParams
rcParams['mathtext.fontset'] = 'dejavuserif'
rcParams['font.family'] = 'serif'

# JAX/Flax
import jax
import jax.numpy as jnp
from jax import random
import wandb

from visualization import visualize as vis
from data import data_functions as df
from models import model_architecture as march
from models import train_model as trm
from models import generate_model as gen

In [2]:
sweep_config = {
    'method': 'random',
    'metric': {
        'name': 'Loss',
        'goal': 'minimize'
    },
    'parameters': {
        'optimizer': {
            'value': 'adam'
        },
        'latent_dim': {
            'values': [64, 128, 256, 512, 1024]
        },
        'epochs': {
            'value': 50
        },
        'batch_size': {
            'values': [128, 256, 512, 1024]
        },
        'learning_rate': {
            'values': [1e-2, 1e-3, 1e-4, 1e-5]
        },
        'N': {
            'value': 2
        },
        'D': {
            'values': [128, 1024, 2048, 4096]
        },
        'std_data': {
            'value': 0.5
        },
        'size': {
            'values': [2000, 4000, 6000]
        },
        'gen_data_key': {
            'value': 21
        },
        'train_model_key': {
            'value': 47
        }
    }
}

pprint.pprint(sweep_config)

{'method': 'random',
 'metric': {'goal': 'minimize', 'name': 'Loss'},
 'parameters': {'D': {'values': [128, 1024, 2048, 4096]},
                'N': {'value': 2},
                'batch_size': {'values': [128, 256, 512, 1024]},
                'epochs': {'value': 50},
                'gen_data_key': {'value': 21},
                'latent_dim': {'values': [64, 128, 256, 512, 1024]},
                'learning_rate': {'values': [0.01, 0.001, 0.0001, 1e-05]},
                'optimizer': {'value': 'adam'},
                'size': {'values': [2000, 4000, 6000]},
                'std_data': {'value': 0.5},
                'train_model_key': {'value': 47}}}


In [3]:
def train_sweep(config=None):
    # Create random PRNG keys for training
    init_rng = random.PRNGKey(0)
    rng, subkey1, subkey2, subkey3 = random.split(init_rng, num=4)
    
    # Initialize a new wandb run
    with wandb.init(config=config, dir=os.path.join(os.path.expanduser('~'),
                                                    'PFGMPP/saved_models/toy')):
        # If called by wandb.agent, as below,
        # this config will be set by Sweep Controller
        config = wandb.config
        
        # Generate dataset of 2D Gaussians
        X = df.generate_data(config['gen_data_key'], config['size'])
        
        # Make the data suitable for a JAX Dataloader
        train_dataset = df.JaxDataset(X=X)
        train_loader = df.NumpyLoader(dataset=train_dataset,
                                      batch_size=config['batch_size'],
                                      shuffle=True)
        batch = next(iter(train_loader))
        
        # Get the rng & model instantiated
        model = march.ScoreNet(dim=config['N'],
                               latent_dim=config['latent_dim'],
                               std_data=config['std_data'])
        
        # Sample the noise distribution
        rnd_normal = random.normal(subkey2, shape=(batch.shape[0], 1))
        t = jnp.exp(rnd_normal * 1.2 - 1.2)  
    
    
        # Initialize the models state
        state = trm.init_train_state(model=model,
                                     random_key=subkey3,
                                     x_shape=batch.shape,
                                     t_shape=t.shape,
                                     learning_rate=config['learning_rate'])
        
        model, state = trm.train_model_sweep(train_loader,
                                             model,
                                             state,
                                             config)

In [None]:
sweep_id = wandb.sweep(sweep_config, project="toy_pfgmpp")
wandb.agent(sweep_id, train_sweep, count=20)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


Create sweep with ID: 0200zm40
Sweep URL: https://wandb.ai/mdowicz/toy_pfgmpp/sweeps/0200zm40


[34m[1mwandb[0m: Agent Starting Run: ob43pt5u with config:
[34m[1mwandb[0m: 	D: 1024
[34m[1mwandb[0m: 	N: 2
[34m[1mwandb[0m: 	batch_size: 512
[34m[1mwandb[0m: 	epochs: 50
[34m[1mwandb[0m: 	gen_data_key: 21
[34m[1mwandb[0m: 	latent_dim: 128
[34m[1mwandb[0m: 	learning_rate: 0.001
[34m[1mwandb[0m: 	optimizer: adam
[34m[1mwandb[0m: 	size: 6000
[34m[1mwandb[0m: 	std_data: 0.5
[34m[1mwandb[0m: 	train_model_key: 47
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmdowicz[0m. Use [1m`wandb login --relogin`[0m to force relogin


 26%|██▌       | 13/50 [02:39<07:27, 12.10s/it]