In [1]:
import jax.numpy as jnp

from jax import random, vmap, pmap, local_devices, local_device_count

from models import VariationalAutoencoder
from input_pipeline import MNISTVAEDataLoader

import wandb
import tensorflow_datasets as tfds
from tqdm.auto import tqdm, trange
import matplotlib.pyplot as plt
%matplotlib inline

from configs.default import get_config

print(local_devices())

[CpuDevice(id=0)]


In [2]:
# Create model from config
config = get_config()
model = VariationalAutoencoder(config)


[3m                                  VAE Summary                                   [0m
┏━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath         [0m[1m [0m┃[1m [0m[1mmodule     [0m[1m [0m┃[1m [0m[1minputs       [0m[1m [0m┃[1m [0m[1moutputs      [0m[1m [0m┃[1m [0m[1mparams        [0m[1m [0m┃
┡━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━┩
│               │ VAE         │ -             │ [2mfloat32[0m[1,78… │                │
│               │             │ [2mfloat32[0m[1,78… │               │                │
│               │             │ - [2mfloat32[0m[16] │               │                │
├───────────────┼─────────────┼───────────────┼───────────────┼────────────────┤
│ encoder       │ MlpEncoder  │ -             │ -             │                │
│               │             │ [2mfloat32[0m[1,78… │ [2mfloat32[0m[1,16] │                │
│               │    

In [3]:
# Create wandb instance
wandb_config = config.wandb
wandb.init(project=wandb_config.project,
            name=wandb_config.name,
            config=dict(config))

[34m[1mwandb[0m: Currently logged in as: [33mpgp[0m ([33mjaxpi[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
# Evaluation step
def eval_step(config, model, batch):

    params = model.state.params
    log_dict = {}

    if config.logging.log_losses:
        kl_loss, recon_loss = model.eval_losses(params, batch)
        kl_loss = kl_loss.mean()
        recon_loss = recon_loss.mean()
        log_dict['kl_loss'] = kl_loss
        log_dict['recon_loss'] = recon_loss

    return log_dict

In [None]:
vae_loader = MNISTVAEDataLoader(
    batch_size=config.training.batch_size,
    num_mc_samples=config.training.num_mc_samples,
    latent_dim=config.eps_dim,
    seed=config.seed
)

# Get info about the dataset
_, ds_info = vae_loader._load_dataset()
steps_per_epoch = ds_info.splits['train'].num_examples // config.training.batch_size

# Get data iterator
rng_key = random.PRNGKey(config.seed)
train_ds = vae_loader.get_iterator(rng_key)

print("Starting training...")
total_steps = steps_per_epoch * config.training.num_epochs

with tqdm(total=total_steps, desc="Training") as pbar:
    global_step = 0  # Add this global step counter
    
    for epoch in range(config.training.num_epochs):        
        for step in range(steps_per_epoch):
            # Get next batch: (images, labels, eps)
            images, labels, eps = next(train_ds)
            batch = (images, eps)
            model.state = model.step(model.state, batch)
            
            # Use global_step for wandb logging
            if step % config.logging.log_every_steps == 0:
                log_dict = eval_step(config, model, batch)
                
                # Log with global_step instead of step
                wandb.log(log_dict, global_step)
                
                # Update the progress bar
                pbar.set_postfix({
                    'kl_loss': f"{log_dict['kl_loss']:.4f}", 
                    'recon_loss': f"{log_dict['recon_loss']:.4f}",
                    'epoch': epoch + 1
                })
                
                # You can also add epoch information to wandb logs
                log_dict['epoch'] = epoch + 1
            
            # Update progress bar and increment global step
            pbar.update(1)
            global_step += 1

wandb.finish()

JAX running on 1 devices
Global batch size: 32
Per-device batch size: 32
Starting training...


Training:   0%|          | 0/37500 [00:00<?, ?it/s]



KeyboardInterrupt: 