In [1]:
# Basic Imports
import pathlib
from pathlib import Path
import os
import sys
from typing import Any, Sequence, Optional, Tuple, Iterator, Dict, Callable, Union
import json
import time
from tqdm.auto import tqdm
import numpy as np
from copy import copy
from glob import glob
from collections import defaultdict
import matplotlib.pyplot as plt

# Changing fonts to be latex typesetting
from matplotlib import rcParams
rcParams['mathtext.fontset'] = 'dejavuserif'
rcParams['font.family'] = 'serif'

# JAX/Flax
import jax
import jax.numpy as jnp
from jax import random
jax.config.update('jax_platform_name', 'cpu')

import flax
from flax import linen as nn
from flax.training import train_state, checkpoints
from flax.serialization import (
    to_state_dict, msgpack_serialize, from_bytes
)
import optax

# Logging with Tensorboard or Weights and Biases
# from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger

# For ODESolver
from scipy import integrate

# PyTorch for Dataloaders
import torch
import torch.utils.data as data
from torch.utils.data import TensorDataset, DataLoader, Dataset
from torchvision.datasets import MNIST, CIFAR10
from torchvision import transforms

# Wandb 
import wandb
wandb.login()
import pprint

# Path to import created files
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path+"/src")

# Import created functions
import make_dataset as mkds
import visualization as vis
import flax_trn_loop as trn
# import NN_model as nnm
# import observable_data as od


from numpy.random import default_rng
rng = default_rng(seed=42)

2023-03-27 16:26:01.327452: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/nvidia/hpc_sdk/Linux_x86_64/22.5/math_libs/11.7/lib64:/opt/nvidia/hpc_sdk/Linux_x86_64/22.5/cuda/11.7/extras/CUPTI/lib64:/opt/nvidia/hpc_sdk/Linux_x86_64/22.5/cuda/11.7/extras/Debugger/lib64:/opt/nvidia/hpc_sdk/Linux_x86_64/22.5/cuda/11.7/nvvm/lib64:/opt/nvidia/hpc_sdk/Linux_x86_64/22.5/cuda/11.7/lib64:/opt/cray/pe/papi/7.0.0.1/lib64:/opt/cray/pe/gcc/11.2.0/snos/lib64:/opt/cray/libfabric/1.15.2.0/lib64
2023-03-27 16:26:01.327557: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/nvidia/hpc_sdk/Linux_x86_64/22.5/math_libs/11.7/lib64:/o

# 0. Create the model

In [2]:
class MLP(nn.Module):
    """
    Simple MLP model for testing PFGM.

    Due to it's simplicity we use @nn.compact instead of setup
    """
    hidden_dims: Sequence[int]
    output_dim: int

    @nn.compact
    def __call__(self, x, **kwargs):
        for dim in self.hidden_dims:
            x = nn.Dense(dim)(x)
            x = nn.relu(x)
        x = nn.Dense(self.output_dim)(x)
        return x

In [3]:
# # View model layers
# mlp = MLP(hidden_dims=config.hidden_dims,
#           output_dim=config.output_dim)

# print(mlp.tabulate(jax.random.PRNGKey(config.jax_seed), jnp.ones((1, 785))))

# 1. Create the `train_state` that will be passed between updates

In [4]:
def init_train_state(model: Any,
                     random_key: Any,
                     shape: tuple,
                     learning_rate: int) -> train_state.TrainState:
    """
    Function to initialize the TrainState dataclass, which represents
    the entire training state, including step number, parameters, and 
    optimizer state. This is useful because we no longer need to
    initialize the model again and again with new variables, we just 
    update the "state" of the mdoel and pass this as inputs to functions.

    Args:
    -----
        model: nn.Module    
            The model that we want to train.
        random_key: jax.random.PRNGKey()
            Used to trigger the initialization functions, which generate
            the initial set of parameters that the model will use.
        shape: tuple
            Shape of the batch of data that will be input into the model.
            This is used to trigger shape inference, which is where the model
            figures out by itself what the correct size the weights should be
            when they see the inputs.
        learning_rate: int
            How large of a step the optimizer should take.

    Returns:
    --------
        train_state.TrainState:
            A utility class for handling parameter and gradient updates. 
    """
    # Initialize the model
    variables = model.init(random_key, jnp.ones(shape))

    # Create the optimizer
    optimizer = optax.adam(learning_rate) # TODO update this to be user defined

    # Create a state
    return train_state.TrainState.create(apply_fn=model.apply,
                                         tx=optimizer,
                                         params=variables['params'])

# 2. Create function to compute the metrics for training

In [5]:
def compute_metrics(*, pred, labels):
    """
    Function that computes metrics that will be logged
    during training
    """
    # Calculate the MSE loss
    loss = ((pred - labels) ** 2).mean()

    # Calculate the R^2 score
    residual = jnp.sum(jnp.square(labels - pred))
    total = jnp.sum(jnp.square(labels - jnp.mean(labels)))
    r2_score = 1 - (residual / total)

    # Save these metrics into a dict
    metrics = {
        'loss': loss,
        'r2': r2_score
    }

    return metrics


def accumulate_metrics(metrics):
    """
    Function that accumulates all the metrics for each batch and 
    accumulates/calculates the metrics for each epoch.
    """
    metrics = jax.device_get(metrics)
    return {
        k: np.mean([metric[k] for metric in metrics])
        for k in metrics[0]
    }

# 3. Create a training step that does a gradient update for a single batch of data

In [6]:
@jax.jit
def train_step(state: train_state.TrainState,
               batch: list):
    """
    Function to run training on one batch of data.
    """
    image, label = batch

    def loss_fn(params: dict):
        """
        Simple MSE loss as described in the PFGM paper.
        """
        pred = state.apply_fn({'params': params}, image)
        loss = ((pred - label) ** 2).mean()
        return loss, pred

    def r_squared(params):
        """
        Function to calculate the coefficient of determination or 
        R^2, which quantifies how well the regression model fits 
        the observed data. Or more formally, it is a statistical
        measure that represents the proportion of variance in the
        dependent variable that is explained by the independent 
        variable(s) in a regression model. R^2 ranges from 0 to 1, 
        with a higher value indicating a better fit. 

        An R^2 of 0 means that the regression model does not explain
        any of the variability in the dependent variable, while an
        R^2 of 1 indicates that the regression model explains all of
        the variability in the dependent model.
        """
        pred = state.apply_fn({'params': params}, image)
        residual = jnp.sum(jnp.square(label - pred))
        total = jnp.sum(jnp.square(label - jnp.mean(label)))
        r2_score = 1 - (residual / total)
        return r2_score

    gradient_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, pred), grads = gradient_fn(state.params)
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(pred=pred, labels=label)
    return state, metrics

@jax.jit
def eval_step(state, batch):
    image, label = batch
    pred = state.apply_fn({'params': state.params}, image)
    return compute_metrics(pred=pred, labels=label)

# 4. Create functions to save and restore the model

In [7]:
def save_checkpoint_wandb(ckpt_path, state, epoch):
    with open(ckpt_path, "wb") as outfile:
        outfile.write(msgpack_serialize(to_state_dict(state)))
    artifact = wandb.Artifact(
        f'{wandb.run.name}-checkpoint', type='dataset'
    )
    artifact.add_file(ckpt_path)
    wandb.log_artifact(artifact, aliases=["latest", f"epoch_{epoch}"])
    
def load_checkpoint_wandb(ckpt_file, state):
    artifact = wandb.use_artifact(
        f'{wandb.run.name}-checkpoint:latest'
    )
    artifact_dir = artifact.download()
    ckpt_path = os.path.join(artifact_dir, ckpt_file)
    with open(ckpt_path, "rb") as data_file:
        byte_data = data_file.read()
    return from_bytes(state, byte_data)

## 4.1 Necessary things to get the training data

In [8]:
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import torch.utils.data as data


class PerturbMNIST(Dataset):
    """
    Simple dataset class that stores the data and targets as NumPy arrays.
    
    Args:
    -----
        data: np.ndarray
            The perturbed input data.
        targets: np.ndarray
            The empirical field that generated the perturbed data.
    """
    def __init__(self, data: np.ndarray, targets: np.ndarray):
        self.data = data
        self.targets = targets
        
    def __len__(self):
        """
        Returns the total number of samples in the dataset.
        """
        return len(self.targets)
    
    def __getitem__(self, idx: int):
        """
        Returns the i-th sample and corresponding target in the dataset.
        
        Args:
        -----
            idx: int
                The index of the sample to return.
                
        Returns:
        --------
            tuple: A tuple containing the sample and target.
        """
        sample = self.data[idx]
        target = self.targets[idx]
        return sample, target
    
    
def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)

def create_data_loaders(*datasets : Sequence[data.Dataset],
                        train : Union[bool, Sequence[bool]] = True,
                        batch_size : int = 128,
                        num_workers : int = 4,
                        seed : int = 42):
    """
    Creates data loaders used in JAX for a set of datasets.

    Args:
      datasets: Datasets for which data loaders are created.
      train: Sequence indicating which datasets are used for
        training and which not. If single bool, the same value
        is used for all datasets.
      batch_size: Batch size to use in the data loaders.
      num_workers: Number of workers for each dataset.
      seed: Seed to initialize the workers and shuffling with.
    """
    loaders = []
    if not isinstance(train, (list, tuple)):
        train = [train for _ in datasets]
    for dataset, is_train in zip(datasets, train):
        loader = torch.utils.data.DataLoader(dataset,
                                 batch_size=batch_size,
                                 shuffle=is_train,
                                 drop_last=is_train,
                                 collate_fn=numpy_collate,
                                 num_workers=num_workers,
                                 persistent_workers=is_train,
                                 generator=torch.Generator().manual_seed(seed))
        loaders.append(loader)
    return loaders

# 5. Create a training loop function

In [9]:
def train_and_evaluate(batchsize, state, epochs, ckpt_dir, prng):    

    # Load the partitioned datasets
    training = mkds.load_data(data_dir='saved_data/MNIST/raw/partitioned',
                                         data_file='partitioned_training_set.pkl')

    val = mkds.load_data(data_dir='saved_data/MNIST/raw/partitioned',
                                         data_file='partitioned_val_set.pkl')

    test = mkds.load_data(data_dir='saved_data/MNIST/raw/partitioned',
                                         data_file='partitioned_test_set.pkl')


    # Put them into Pytorch Dataset Objects
    training = PerturbMNIST(training[0], training[1])
    val = PerturbMNIST(val[0], val[1])
    testing = PerturbMNIST(test[0], test[1])
    
    # Instantiate the pytorch data loaders
    train_dl, val_dl, test_dl = create_data_loaders(training, val, testing,
                                                    train=[True, False, False],
                                                    batch_size=batchsize)

    for epoch in tqdm(range(1, epochs+1)):
        best_val_loss = 1e6

        # =========== Training =========== #
        train_batch_metrics = []
        for cnt, batch in enumerate(train_dl):
            # Instantiate the perturbed data
            data, targets = batch
            data = mkds.reshape_with_channel_dim(data)
            perturbed_batch = mkds.empirical_field(data, prng)
            # Do one train step with perturbed data
            state, metrics = train_step(state, perturbed_batch)
            train_batch_metrics.append(metrics)
        train_batch_metrics = accumulate_metrics(train_batch_metrics)
        print(
            'TRAIN (%d/%d): Loss: %.4f, r2: %.2f' % (
                epoch, epochs, train_batch_metrics['loss'], 
                train_batch_metrics['r2'])
        )
                
        # =========== Validation =========== #
        val_batch_metrics = []
        for cnt, batch in enumerate(val_dl):
            # Instantiate the imgs
            data, targets = batch
            data = mkds.reshape_with_channel_dim(data)
            # Perturb the data
            perturbed_batch = mkds.empirical_field(data, prng)
            metrics = eval_step(state, perturbed_batch)
            val_batch_metrics.append(metrics)
        val_batch_metrics = accumulate_metrics(val_batch_metrics)
        print(
            'Val (%d/%d): Loss: %.4f, r2: %.2f' % (
                epoch, epochs, val_batch_metrics['loss'], 
                val_batch_metrics['r2'])
        )
        
        wandb.log({
            "Train Loss": train_batch_metrics['loss'],
            "Train r2": train_batch_metrics['r2'],
            "Validation Loss": val_batch_metrics['loss'],
            "Validation r2": val_batch_metrics['r2']
        }, step=epoch)
        
        if val_batch_metrics['loss'] < best_val_loss:
            save_checkpoint_wandb("checkpoint.msgpack", state, epoch)
            
    restored_state = load_checkpoint_wandb("checkpoint.msgpack", state)
    test_batch_metrics = []
    for cnt, batch in enumerate(test_dl):
            # Instantiate the imgs
            data, targets = batch
            data = mkds.reshape_with_channel_dim(data)
            # Perturb the data
            perturbed_batch = mkds.empirical_field(data, prng)
            metrics = eval_step(state, perturbed_batch)
            test_batch_metrics.append(metrics)
        
    test_batch_metrics = accumulate_metrics(test_batch_metrics)
    print(
        'Test: Loss: %.4f, r2: %.2f' % (
            test_batch_metrics['loss'],
            test_batch_metrics['r2']
        )
    )
    
    wandb.log({
        "Test Loss": test_batch_metrics['loss'],
        "Test r2": test_batch_metrics['r2']
    })

    # Save best state
    checkpoints.save_checkpoint(ckpt_dir, target=restored_state, step=None)
    return state, restored_state


# 6. Training

In [10]:
wandb.init(project='MLP_PFGM')

config = wandb.config
config.jax_seed = 42
config.batch_size = 256
config.learning_rate = 1e-4
config.epochs = 12
config.hidden_dims = [1570, 3140, 1570]
config.output_dim = 785

In [11]:
# Instantiate the model with random weights
mlp = MLP(hidden_dims=config.hidden_dims,
          output_dim=config.output_dim)

print(mlp.tabulate(jax.random.PRNGKey(config.jax_seed), jnp.ones((1, 785))))


[3m                                  MLP Summary                                   [0m
┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath   [0m[1m [0m┃[1m [0m[1mmodule[0m[1m [0m┃[1m [0m[1minputs         [0m[1m [0m┃[1m [0m[1moutputs        [0m[1m [0m┃[1m [0m[1mparams               [0m[1m [0m┃
┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩
│         │ MLP    │ [2mfloat32[0m[1,785]  │ [2mfloat32[0m[1,785]  │                       │
├─────────┼────────┼─────────────────┼─────────────────┼───────────────────────┤
│ Dense_0 │ Dense  │ [2mfloat32[0m[1,785]  │ [2mfloat32[0m[1,1570] │ bias: [2mfloat32[0m[1570]   │
│         │        │                 │                 │ kernel:               │
│         │        │                 │                 │ [2mfloat32[0m[785,1570]     │
│         │        │                 │                 │                       │
│         │  

In [12]:
# Instantiate the initial train_state object
state = init_train_state(model=mlp, 
                         random_key=jax.random.PRNGKey(config.jax_seed), 
                         shape=(1, 785), 
                         learning_rate=config.learning_rate)


In [None]:
prng = jax.random.PRNGKey(21)
state, best_state = train_and_evaluate(config.batch_size, 
                                       state, 
                                       config.epochs, 
                                       ckpt_dir='saved_models/',
                                       prng=default_rng(seed=np.asarray(prng)))

  0%|          | 0/12 [00:00<?, ?it/s]

TRAIN (1/12): Loss: 13.0952, r2: -12.11
Val (1/12): Loss: 1.0020, r2: -0.00
TRAIN (2/12): Loss: 1.0008, r2: -0.00
Val (2/12): Loss: 1.0001, r2: -0.00
TRAIN (3/12): Loss: 0.9996, r2: -0.00
Val (3/12): Loss: 0.9995, r2: -0.00
TRAIN (4/12): Loss: 0.9991, r2: -0.00
Val (4/12): Loss: 0.9992, r2: -0.00
TRAIN (5/12): Loss: 0.9988, r2: -0.00
Val (5/12): Loss: 0.9990, r2: -0.00
TRAIN (6/12): Loss: 0.9987, r2: -0.00
Val (6/12): Loss: 0.9990, r2: -0.00


In [None]:
restored_state = checkpoints.restore_checkpoint(ckpt_dir='saved_models/', target=state)

assert jax.tree_util.tree_all(jax.tree_map(lambda x, y: (x == y).all(), best_state.params, restored_state.params))

In [None]:
# Load the perturbed datasets
partitioned_training = mkds.load_data(data_dir='saved_data/MNIST/raw/partitioned',
                                     data_file='partitioned_training_set.pkl')

In [None]:
data, targets = partitioned_training[0], partitioned_training[1]

data.shape

In [None]:
# Load the perturbed datasets
perturbed_training = mkds.load_data(data_dir='saved_data/MNIST/raw/partitioned',
                                     data_file='partitioned_training_set.pkl')

perturbed_val = mkds.load_data(data_dir='saved_data/MNIST/raw/partitioned',
                                     data_file='partitioned_val_set.pkl')

perturbed_test = mkds.load_data(data_dir='saved_data/MNIST/raw/partitioned',
                                     data_file='partitioned_test_set.pkl')


# Put them into Pytorch Dataset Objects
training = PerturbMNIST(perturbed_training[0], perturbed_training[1])
val = PerturbMNIST(perturbed_val[0], perturbed_val[1])
testing = PerturbMNIST(perturbed_test[0], perturbed_test[1])

# Instantiate the pytorch data loaders
train_dl, val_dl, test_dl = create_data_loaders(training, val, testing,
                                                train=[True, False, False],
                                                batch_size=128)

test_batch = next(iter(test_dl))
data_batch = test_batch[0]
target_batch = test_batch[1]

pred_batch = restored_state.apply_fn({'params': restored_state.params}, data_batch)

In [None]:
fig, ax = plt.subplots(1,2, figsize=(10,8))


ax[0].quiver(data_batch[:, 2], data_batch[:,-1],
             target_batch[:, 2], target_batch[:,-1]);
# This quiver plot highlights the selected pixel in the left panel. Allowing the audience to see how 
# this one pixel is mapped into the N+1 dimension above the hyperplane.
ax[0].set_title(f'True Poisson field')
ax[0].set_xlabel(f'Perturbed value')
ax[0].set_ylabel(f'Z value')

ax[1].quiver(data_batch[:, 2], data_batch[:,-1],
             pred_batch[:, 2], pred_batch[:,-1]);
# This quiver plot highlights the selected pixel in the left panel. Allowing the audience to see how 
# this one pixel is mapped into the N+1 dimension above the hyperplane.
ax[1].set_title(f'Predicted Poisson field')
ax[1].set_xlabel(f'Perturbed value')
ax[1].set_ylabel(f'Z value')
plt.show()