In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
from typing import Any, Sequence, Optional, Tuple, Iterator, Dict, Callable, Union, Type, List
import numpy as np
import matplotlib.pyplot as plt
import pprint

# Changing fonts to be latex typesetting
from matplotlib import rcParams
rcParams['mathtext.fontset'] = 'dejavuserif'
rcParams['font.family'] = 'serif'

# JAX/Flax
import jax
import jax.numpy as jnp
from jax import random
from flax import linen as nn
import flax
from flax import linen as nn
from flax.training import train_state, orbax_utils
import optax
import orbax
import wandb
wandb_dir = os.path.join(os.path.expanduser('~'), "PFGMPP")
os.environ["WANDB_DIR"] = os.path.abspath(wandb_dir)
from tqdm import tqdm

from visualization import visualize as vis
from data import data_functions as df
from models import model_architecture as march
from models import train_model as trm
from models import generate_model as gen

In [2]:
class MLP(nn.Module):
    initial_feature: int = 16
    out_feature: int = 2
    std_data: float = 0.5
    # embedding_dim: int = 32

    def setup(self):
        # Initial dense layer
        self.inc = nn.Dense(self.initial_feature) # 2->64
        # Activation
        self.act = nn.relu
        # Final Output Layer
        self.outc = nn.Dense(self.out_feature)

    def __call__(self, x, t):
        # Preconditioning terms
        t = t.squeeze()
        c_out = t * self.std_data / jnp.sqrt(self.std_data**2 + t**2)
        c_skip = self.std_data**2 / (self.std_data**2 + t**2)

        # Sampling the noise and embedding the noise via positional encoding
        t = jnp.log(t.flatten()) / 4.
        t = self.pos_encoding(t, self.initial_feature)
        x_orig = x

        # First layer + noise embedding
        x_1 = self.inc(x) + t
        x_2 = self.act(x)
        # Second layer
        x_3 = self.inc(x_2)
        x_4 = self.act(x)
        # Third layer
        x_5 = self.inc(x_4)
        x_6 = self.act(x)
        # Output layer
        output = self.outc(x_6)

        # Reshape c_out & c_skip to match dimensions for broadcasting
        c_out = jnp.reshape(c_out, (-1,1))
        c_skip = jnp.reshape(c_skip, (-1,1))
        return c_out * output + c_skip * x_orig

    def pos_encoding(self, t, channels):
        t = jnp.expand_dims(t, axis=-1)  # Add an additional dimension to t
        inv_freq = 1.0 / (10000 ** (jnp.arange(0, channels, 2).astype(jnp.float32) / channels))
        pos_enc_a = jnp.sin(t * inv_freq)
        pos_enc_b = jnp.cos(t * inv_freq)
        pos_enc = jnp.concatenate([pos_enc_a, pos_enc_b], axis=-1)
        return pos_enc

In [3]:
sweep_config = {
    'method': 'random',
    'metric': {
        'name': 'Train Loss',
        'goal': 'minimize'
    },
    'parameters': {
        'optimizer': {
            'value': 'adam'
        },
        # 'embedding_dim': {
        #     'values': [16, 32, 64, 128, 256]
        # },
        'epochs': {
            'value': 60
        },
        'batch_size': {
            'value': 64
        },
        'learning_rate': {
            'values': [1e-3, 1e-4]
        },
        'N': {
            'value': 2
        },
        'D': {
            'values': [10, 100, 1000, 10000]
        },
        'std_data': {
            'values': [0.5, 1]
        },
        'initial_feature': {
            'values': [16, 32, 64, 128, 256]
        },
        'out_feature': {
            'value': 2
        },
        'size': {
            'value': 2000
        },
        'gen_data_key': {
            'value': 21
        },
        'train_model_key': {
            'value': 47
        }
    }
}

pprint.pprint(sweep_config)

{'method': 'random',
 'metric': {'goal': 'minimize', 'name': 'Train Loss'},
 'parameters': {'D': {'values': [10, 100, 1000, 10000]},
                'N': {'value': 2},
                'batch_size': {'value': 64},
                'epochs': {'value': 60},
                'gen_data_key': {'value': 21},
                'initial_feature': {'values': [16, 32, 64, 128, 256]},
                'learning_rate': {'values': [0.001, 0.0001]},
                'optimizer': {'value': 'adam'},
                'out_feature': {'value': 2},
                'size': {'value': 2000},
                'std_data': {'values': [0.5, 1]},
                'train_model_key': {'value': 47}}}


In [4]:
def train_model_sweep(train_loader,
                      model,
                      state,
                      config,
                      dir_name,
                      project_name,
                      key_seed=47,
                      wandb_logging=True):
    """
    Train a machine learning model with optional Weights & Biases (wandb) logging.

    Parameters:
    -----------
    train_loader: 
        A data loader providing the training data.
    model: 
        The model to be trained.
    state: 
        The initial state of the model.
    config: dict
        A dictionary containing configuration parameters for training, such as learning rate, batch size etc.
    wandb_logging: bool
        If True, training progress is logged using wandb.
        Default is False.

    Returns:
    --------
        model: The trained model.
        state: The final state of the model after training.
    """
    # Initialize a var to hold the best test loss seen so far
    best_train_loss = float('inf')
    
    # Start the training loop
    for epoch in tqdm(range(config['epochs'])):
        # Initialize a list to store all batch-level metrics
        batch_metrics = []

        for batch in train_loader:
            var_data = 0.2 ** 2 + 1 ** 2
            sigma_data = jnp.sqrt(var_data)
            
            # Prepare the data
            batch = jax.device_put(batch)
            
            # Normalize the data such that it's std is 1
            batch /= sigma_data
            
            # Update the model
            train_step_jit = jax.jit(trm.train_step, static_argnums=(3,4,5))
            state, batch_loss = train_step_jit(state, batch, config['std_data'], config['D'], config['N'], config['train_model_key'])    

            # Store the batch-level metric in the list
            batch_metrics.append({'Train Loss': batch_loss})

        # Use accumulate_metrics to calculate average metrics for the epoch
        epoch_metrics = trm.accumulate_metrics(batch_metrics)
        

        # If the train loss for this epoch is better than the previous best,
        # save the model
        if epoch_metrics['Train Loss'] < best_train_loss:
            best_train_loss = epoch_metrics['Train Loss'] # Update the best train loss
            checkpt_dir = dir_name # dir where models are saved to
            trm.save_checkpoint(checkpt_dir, state, epoch, project_name)

        
        # If wandb logging is enabled, log metrics
        if wandb_logging:
            print('Epoch Metrics =', epoch_metrics)
            wandb.log(epoch_metrics)

    return model, state

def train_sweep(config=None, project_name='pfgmpp_test', dir_name='PFGMPP/saved_models/toy'):
    
    # Create random PRNG keys for training
    init_rng = random.PRNGKey(0)
    rng, subkey1, subkey2, subkey3 = random.split(init_rng, num=4)

    # Get the absolute path of the saved_models/toy directory
    dir_name = os.path.join(os.path.expanduser('~'), str(dir_name)) #FIX THIS HARDCODE
    
    # Initialize a new wandb run
    with wandb.init(config=config, dir=dir_name):
        # If called by wandb.agent, as below,
        # this config will be set by Sweep Controller
        config = wandb.config
        
        # Generate dataset of 2D Gaussians
        X = df.generate_data(config['gen_data_key'], config['size'])
        
        # Make the data suitable for a JAX Dataloader
        train_dataset = df.JaxDataset(X=X)
        train_loader = df.NumpyLoader(dataset=train_dataset,
                                      batch_size=config['batch_size'],
                                      shuffle=True)
        batch = next(iter(train_loader))
        
        # Get the rng & model instantiated
        model = MLP(initial_feature=config['initial_feature'],
                    std_data=config['std_data'],
                    out_feature=config['out_feature'])#,
                    #embedding_dim=config['embedding_dim'])
        
        # Sample the noise distribution
        rnd_normal = random.normal(subkey2, shape=(batch.shape[0], 1))
        t = jnp.exp(rnd_normal * 1.2 - 1.2)  
    
    
        # Initialize the models state
        state = trm.init_train_state(model=model,
                                     random_key=subkey3,
                                     x_shape=batch.shape,
                                     t_shape=t.shape,
                                     learning_rate=config['learning_rate'])
        
        model, state = train_model_sweep(train_loader,
                                         model,
                                         state,
                                         config,
                                         key_seed=47,
                                         wandb_logging=True,
                                         project_name=project_name,
                                         dir_name=dir_name)

        del state
        del model
        del train_dataset
        del train_loader
        del X
        del t 
        del batch
        wandb.finish()

In [5]:
sweep_id = wandb.sweep(sweep_config, project="simple_PFGM++")
wandb.agent(sweep_id, train_sweep, count=20)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


Create sweep with ID: 4s01fuhj
Sweep URL: https://wandb.ai/mdowicz/simple_PFGM%2B%2B/sweeps/4s01fuhj


[34m[1mwandb[0m: Agent Starting Run: c4eo9307 with config:
[34m[1mwandb[0m: 	D: 10
[34m[1mwandb[0m: 	N: 2
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	epochs: 60
[34m[1mwandb[0m: 	gen_data_key: 21
[34m[1mwandb[0m: 	initial_feature: 256
[34m[1mwandb[0m: 	learning_rate: 0.0001
[34m[1mwandb[0m: 	optimizer: adam
[34m[1mwandb[0m: 	out_feature: 2
[34m[1mwandb[0m: 	size: 2000
[34m[1mwandb[0m: 	std_data: 0.5
[34m[1mwandb[0m: 	train_model_key: 47
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmdowicz[0m. Use [1m`wandb login --relogin`[0m to force relogin


  2%|▏         | 1/60 [00:05<05:53,  5.99s/it]

Epoch Metrics = {'Train Loss': 152.40222}


  3%|▎         | 2/60 [00:10<04:43,  4.89s/it]

Epoch Metrics = {'Train Loss': 147.66794}


  5%|▌         | 3/60 [00:14<04:16,  4.51s/it]

Epoch Metrics = {'Train Loss': 148.5904}


  7%|▋         | 4/60 [00:18<04:02,  4.33s/it]

Epoch Metrics = {'Train Loss': 150.38896}


  8%|▊         | 5/60 [00:22<03:52,  4.22s/it]

Epoch Metrics = {'Train Loss': 147.97176}


 10%|█         | 6/60 [00:26<03:44,  4.15s/it]

Epoch Metrics = {'Train Loss': 147.68816}


 12%|█▏        | 7/60 [00:30<03:37,  4.11s/it]

Epoch Metrics = {'Train Loss': 147.97417}


 13%|█▎        | 8/60 [00:34<03:32,  4.08s/it]

Epoch Metrics = {'Train Loss': 152.9771}


 15%|█▌        | 9/60 [00:38<03:27,  4.07s/it]

Epoch Metrics = {'Train Loss': 150.0151}


 17%|█▋        | 10/60 [00:42<03:23,  4.06s/it]

Epoch Metrics = {'Train Loss': 147.89035}


 18%|█▊        | 11/60 [00:46<03:18,  4.06s/it]

Epoch Metrics = {'Train Loss': 147.88747}


 20%|██        | 12/60 [00:50<03:14,  4.05s/it]

Epoch Metrics = {'Train Loss': 143.82253}


 22%|██▏       | 13/60 [00:54<03:10,  4.05s/it]

Epoch Metrics = {'Train Loss': 145.91238}


 23%|██▎       | 14/60 [00:58<03:06,  4.05s/it]

Epoch Metrics = {'Train Loss': 147.59827}


[34m[1mwandb[0m: Ctrl + C detected. Stopping sweep.
 23%|██▎       | 14/60 [00:59<03:15,  4.24s/it]


# Old Model Architectures

In [6]:
# class MLP(nn.Module):
#     depth: int = 4
#     initial_feature: int = 64
#     output_feature: int = 2
#     std_data: float = 0.5
#     group_norm: bool = False

#     def setup(self):
#         # Encoder
#         self.down_blocks = [DoubleConvolution(self.initial_filters * 2**i) for i in range(self.depth)]
#         self.downsamples = [DownSample() for _ in range(self.depth)]
        
#         # Bottleneck 
#         self.bottleneck_block = DoubleConvolution(self.initial_filters * 2**self.depth)

#         # Decoder
#         self.up_samples = [UpSample(self.initial_filters * 2**(i-1)) for i in range(self.depth, 0, -1)]
#         self.up_blocks = [DoubleConvolution(self.initial_filters * 2**i) for i in range(self.depth-1, -1, -1)]

#         # Final Convolutional Layer
#         self.final_conv = nn.Conv(self.output_channels, 
#                                   kernel_size=(1, 1),
#                                   strides=(1, 1),
#                                   padding="SAME")

#     def __call__(self, x):
#         skip_connections = []
        
#         # Encoder path
#         for i in range(self.depth):
#             x = self.down_blocks[i](x)
#             skip_connections.append(x)
#             x = self.downsamples[i](x)
#             # print(f'Encoder{i+1} x.shape =', x.shape)

#         # Bottleneck
#         x = self.bottleneck_block(x)
#         # print('Bottleneck x.shape =', x.shape)

#         # Decoder path
#         for i in range(self.depth):
#             x = self.up_samples[i](x)
#             x = jnp.concatenate([x, skip_connections.pop()], axis=-1)
#             # print(f'Skip_connection{i+1} x.shape =', x.shape)
#             x = self.up_blocks[i](x)
#             # print(f'Decoder{i+1} x.shape =', x.shape)

#         # Final Convolution layer
#         x = self.final_conv(x)
#         # print('Final x.shape =', x.shape)

#         return x

In [7]:
# class DoubleDense(nn.Module):
#     features: int

#     def setup(self):
#         self.dense1 = nn.Dense(self.features)
#         self.act1 = nn.swish
#         self.dense2 = nn.Dense(self.features)
#         self.act2 = nn.swish

#     def __call__(self, x):
#         x = self.dense1(x)
#         x = self.act1(x)
#         x = self.dense2(x)
#         x = self.act2(x)
#         return x

# class Down(nn.Module):
#     features: int
a
#     def setup(self):
#         self.dense1 = DoubleDense(self.features)
#         self.emb_layer = nn.Dense(self.features)

#     def __call__(self, x, t):
#         x = self.dense1(x)
#         emb = self.emb_layer(t)
#         return x + emb


# class Up(nn.Module):
#     features: int

#     def setup(self):
#         self.dense1 = DoubleDense(self.features)
#         self.emb_layer = nn.Dense(self.features)

#     def __call__(self, x, skip_x, t):
#         x = jnp.concatenate([skip_x, x], axis=-1)
#         x = self.dense1(x)
#         emb = self.emb_layer(t)
#         return x + emb


# class UNet(nn.Module):
#     depth: int = 4
#     initial_feature: int = 64
#     out_feature: int = 2
#     std_data: float = 0.5
#     embedding_dim: int = 64

#     def setup(self):
#         # Initial dense layer
#         self.inc = DoubleDense(self.initial_feature) # 2->64

#         # Encoder Block (Downsampling)
#         for i in range(1, self.depth):
#             features = self.initial_feature * (2 ** i)
#             # print('Encoder features', features)
#             setattr(self, f'down{i}', Down(features))

#         # Bottleneck Layers
#         bottleneck_features = self.initial_feature * (2 ** (self.depth-1))
#         self.bot1 = DoubleDense(bottleneck_features) # 512->512
#         self.bot2 = DoubleDense(bottleneck_features * 2) # 512->1024
#         self.bot3 = DoubleDense(bottleneck_features) # 1024->512

#         # Decoder Block (Upsampling)
#         for i in reversed(range(0, self.depth-1)):
#             features = self.initial_feature * (2 ** i)
#             setattr(self, f'up{i}', Up(features))

#         # Final Output Layer
#         self.outc = nn.Dense(self.out_feature)

#     def __call__(self, x, t):
#         # Preconditioning terms
#         t = t.squeeze()
#         c_out = t * self.std_data / jnp.sqrt(self.std_data**2 + t**2)
#         c_skip = self.std_data**2 / (self.std_data**2 + t**2)

#         # Sampling the noise and embedding the noise via positional encoding
#         t = jnp.log(t.flatten()) / 4.
#         t = self.pos_encoding(t, self.embedding_dim)
#         x_orig = x

#         skip_connections = []

#         # Pass through the initial layer
#         x = self.inc(x) # 2 -> 64
#         skip_connections.append(x) # Store the output for skip connection

#         # Pass through the dynamic encoder layers
#         for i in range(1, self.depth):
#             x = getattr(self, f'down{i}')(x, t)
#             skip_connections.append(x) # Store the outputs for skip connections
            
#         # Pass through the bottleneck layers
#         x = self.bot1(x) # 256 -> 512
#         x = self.bot2(x) # 512 -> 1024
#         x = self.bot3(x) # 1024 -> 512

#         # Pass through the dynamic decoder (upsampling) layers
#         skip_connections.pop() # THIS IS FOR TESTING PURPOSES FOUND THERE WAS A DISCREPANCY
#                        # WITH THE NUMBER OF PARAMS BETWEEN STATIC AND DYNAMIC MODEL
#                        # DUE TO THE SKIP CONNECTIONS TAKING WRONG DIMENSIONAL DATA
#         for i in reversed(range(0, self.depth-1)):
#             skip_output = skip_connections.pop() # Retrieve last stored output
#             x = getattr(self, f'up{i}')(x, skip_output, t)

#         # Pass through the final output layer
#         output = self.outc(x) # 64 -> 2

#         # Reshape c_out & c_skip to match dimensions for broadcasting
#         c_out = jnp.reshape(c_out, (-1,1))
#         c_skip = jnp.reshape(c_skip, (-1,1))
#         return c_out * output + c_skip * x_orig

#     def pos_encoding(self, t, channels):
#         t = jnp.expand_dims(t, axis=-1)  # Add an additional dimension to t
#         inv_freq = 1.0 / (10000 ** (jnp.arange(0, channels, 2).astype(jnp.float32) / channels))
#         pos_enc_a = jnp.sin(t * inv_freq)
#         pos_enc_b = jnp.cos(t * inv_freq)
#         pos_enc = jnp.concatenate([pos_enc_a, pos_enc_b], axis=-1)
#         return pos_enc

NameError: name 'a' is not defined