In [1]:
# Basic Imports
from pathlib import Path
import os
import sys
from typing import Any, Sequence, Optional, Tuple, Iterator, Dict, Callable, Union
import json
import time
import numpy as np
from copy import copy
from glob import glob
from collections import defaultdict
import matplotlib.pyplot as plt

# Changing fonts to be latex typesetting
from matplotlib import rcParams
rcParams['mathtext.fontset'] = 'dejavuserif'
rcParams['font.family'] = 'serif'
from tqdm.auto import tqdm

# JAX/Flax
import jax
import jax.numpy as jnp
from jax import random
# jax.config.update('jax_platform_name', 'cpu')

import flax
from flax import linen as nn
from flax.training import train_state, checkpoints
from flax.serialization import (
    to_state_dict, msgpack_serialize, from_bytes
)
import optax

# For ODESolver
from scipy import integrate

# PyTorch for Dataloaders
import torch
import torch.utils.data as data
from torch.utils.data import TensorDataset, DataLoader, Dataset
from torchvision.datasets import MNIST, CIFAR10
from torchvision import transforms

# Wandb 
import wandb
wandb.login()
import pprint

# Import created functions
import make_dataset as mkds
import visualization as vis
# import NN_model as nnm
# import observable_data as od


from numpy.random import default_rng
prng = jax.random.PRNGKey(42)
rng = default_rng(seed=np.asarray(prng))

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmdowicz[0m. Use [1m`wandb login --relogin`[0m to force relogin


# 0. Model Pipeline

In [2]:
# Create the model
class MLP(nn.Module):
    """
    Simple MLP model for testing PFGM.

    Due to it's simplicity we use @nn.compact instead of setup
    """
    hidden_dims: Sequence[int]
    output_dim: int

    @nn.compact
    def __call__(self, x, **kwargs):
        for dim in self.hidden_dims:
            x = nn.Dense(dim)(x)
            x = nn.relu(x)
        x = nn.Dense(self.output_dim)(x)
        return x
    
def init_train_state(model: Any,
                     random_key: Any,
                     shape: tuple,
                     learning_rate: int) -> train_state.TrainState:
    """
    Function to initialize the TrainState dataclass, which represents
    the entire training state, including step number, parameters, and 
    optimizer state. This is useful because we no longer need to
    initialize the model again and again with new variables, we just 
    update the "state" of the mdoel and pass this as inputs to functions.

    Args:
    -----
        model: nn.Module    
            The model that we want to train.
        random_key: jax.random.PRNGKey()
            Used to trigger the initialization functions, which generate
            the initial set of parameters that the model will use.
        shape: tuple
            Shape of the batch of data that will be input into the model.
            This is used to trigger shape inference, which is where the model
            figures out by itself what the correct size the weights should be
            when they see the inputs.
        learning_rate: int
            How large of a step the optimizer should take.

    Returns:
    --------
        train_state.TrainState:
            A utility class for handling parameter and gradient updates. 
    """
    # Initialize the model
    variables = model.init(random_key, jnp.ones(shape))

    # Create the optimizer
    optimizer = optax.adam(learning_rate) # TODO update this to be user defined

    # Create a state
    return train_state.TrainState.create(apply_fn=model.apply,
                                         tx=optimizer,
                                         params=variables['params'])

def compute_metrics(*, pred, labels):
    """
    Function that computes metrics that will be logged
    during training
    """
    # Calculate the MSE loss
    loss = ((pred - labels) ** 2).mean()

    # Calculate the R^2 score
    residual = jnp.sum(jnp.square(labels - pred))
    total = jnp.sum(jnp.square(labels - jnp.mean(labels)))
    r2_score = 1 - (residual / total)

    # Save these metrics into a dict
    metrics = {
        'loss': loss,
        'r2': r2_score
    }

    return metrics


def accumulate_metrics(metrics):
    """
    Function that accumulates all the metrics for each batch and 
    accumulates/calculates the metrics for each epoch.
    """
    metrics = jax.device_get(metrics)
    return {
        k: np.mean([metric[k] for metric in metrics])
        for k in metrics[0]
    }

@jax.jit
def train_step(state: train_state.TrainState,
               batch: list):
    """
    Function to run training on one batch of data.
    """
    image, label = batch

    def loss_fn(params: dict):
        """
        Simple MSE loss as described in the PFGM paper.
        """
        pred = state.apply_fn({'params': params}, image)
        loss = ((pred - label) ** 2).mean()
        return loss, pred

    def r_squared(params):
        """
        Function to calculate the coefficient of determination or 
        R^2, which quantifies how well the regression model fits 
        the observed data. Or more formally, it is a statistical
        measure that represents the proportion of variance in the
        dependent variable that is explained by the independent 
        variable(s) in a regression model. R^2 ranges from 0 to 1, 
        with a higher value indicating a better fit. 

        An R^2 of 0 means that the regression model does not explain
        any of the variability in the dependent variable, while an
        R^2 of 1 indicates that the regression model explains all of
        the variability in the dependent model.
        """
        pred = state.apply_fn({'params': params}, image)
        residual = jnp.sum(jnp.square(label - pred))
        total = jnp.sum(jnp.square(label - jnp.mean(label)))
        r2_score = 1 - (residual / total)
        return r2_score

    gradient_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, pred), grads = gradient_fn(state.params)
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(pred=pred, labels=label)
    return state, metrics

@jax.jit
def eval_step(state, batch):
    image, label = batch
    pred = state.apply_fn({'params': state.params}, image)
    return compute_metrics(pred=pred, labels=label)


def save_checkpoint_wandb(ckpt_path, state, epoch):
    with open(ckpt_path, "wb") as outfile:
        outfile.write(msgpack_serialize(to_state_dict(state)))
    artifact = wandb.Artifact(
        f'{wandb.run.name}-checkpoint', type='dataset'
    )
    artifact.add_file(ckpt_path)
    wandb.log_artifact(artifact, aliases=["latest", f"epoch_{epoch}"])
    
def load_checkpoint_wandb(ckpt_file, state):
    artifact = wandb.use_artifact(
        f'{wandb.run.name}-checkpoint:latest'
    )
    artifact_dir = artifact.download()
    ckpt_path = os.path.join(artifact_dir, ckpt_file)
    with open(ckpt_path, "rb") as data_file:
        byte_data = data_file.read()
    return from_bytes(state, byte_data)

# 1. Creating perturbed dataset

In [3]:
class MNIST_DS(Dataset):
    """
    Simple dataset class that stores the data and targets as NumPy arrays.
    
    Args:
    -----
        data: np.ndarray
            The perturbed input data.
        targets: np.ndarray
            The empirical field that generated the perturbed data.
    """
    def __init__(self, data: np.ndarray, targets: np.ndarray):
        self.data = data
        self.targets = targets
        
    def __len__(self):
        """
        Returns the total number of samples in the dataset.
        """
        return len(self.targets)
    
    def __getitem__(self, idx: int):
        """
        Returns the i-th sample and corresponding target in the dataset.
        
        Args:
        -----
            idx: int
                The index of the sample to return.
                
        Returns:
        --------
            tuple: A tuple containing the sample and target.
        """
        sample = self.data[idx]
        target = self.targets[idx]
        return sample, target
    
    
def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)

def create_data_loaders(*datasets : Sequence[data.Dataset],
                        train : Union[bool, Sequence[bool]] = True,
                        batch_size : int = 128,
                        num_workers : int = 4,
                        seed : int = 42):
    """
    Creates data loaders used in JAX for a set of datasets.

    Args:
      datasets: Datasets for which data loaders are created.
      train: Sequence indicating which datasets are used for
        training and which not. If single bool, the same value
        is used for all datasets.
      batch_size: Batch size to use in the data loaders.
      num_workers: Number of workers for each dataset.
      seed: Seed to initialize the workers and shuffling with.
    """
    loaders = []
    if not isinstance(train, (list, tuple)):
        train = [train for _ in datasets]
    for dataset, is_train in zip(datasets, train):
        loader = torch.utils.data.DataLoader(dataset,
                                 batch_size=batch_size,
                                 shuffle=is_train,
                                 drop_last=is_train,
                                 collate_fn=numpy_collate,
                                 num_workers=num_workers,
                                 persistent_workers=is_train,
                                 generator=torch.Generator().manual_seed(seed))
        loaders.append(loader)
    return loaders

In [4]:
train, test = mkds.download_MNIST(download=False)
train_data = mkds.reshape_with_channel_dim(train.data.numpy())

# 2. Creating the data for training

In [5]:
def process_perturbed_data(dataset: np.ndarray, 
                           prng: jax.random.PRNGKey):
    """
    Function that perturbs the raw MNIST files (ie. train data file/test data file)
    and saves them into a tuple of (perturbed_data, empirical_field).
    
    It does this by calculating the number of passes needed to go through the entire
    dataset if we used a batchsize of 1000. The reason for the seemingly arbitrary choice
    of 1000, is that I'm running out of memory at when my batchsize is larger than 1000.
    
    Args:
    -----
        dataset: np.ndarray
            Either the raw MNIST training/testing dataset.
            NOTE: Dataset here refers to the pixel data only, we discard the labels because
                  in PFGM the NN learns the empirical field at each pixel location. One can
                  think of the empirical field at each pixel location as that pixels "label".
                  
        prng: jax.random.PRNGKey
            Source of randomness for generating random numbers  
            
    Returns:
    --------
        tuple of the perturbed data for each sample and the empirical field for each image
    """
    # Get the number of samples in the dataset
    num_samples = dataset.shape[0]
    # Calculate the number of passes needed to go through the entire dataset
    # in batches of 1000
    num_passes = int(np.ceil(num_samples / 1000))
    # Initialize empty list to store the output of the function
    outputs = []
    # Loop over the number passes needed to go through entire dataset
    for i in range(num_passes):
        # Calculate start and end indices for the current batch
        start_idx = i * 1000
        end_idx = min((i+1)*1000, num_samples)
        # Get the current batch from the dataset
        batch = dataset[start_idx:end_idx]
        # Get the perturbed data & empirical field for that batch
        output = mkds.empirical_field(batch, prng)
        outputs.append(output)
        del output
    # Concatenate outputs along the first axis
    return tuple(np.concatenate(o, axis=0) for o in zip(*outputs))

In [6]:
# perturbed_data = process_perturbed_data(train_data, prng=rng)

In [7]:
def partition_MNIST(root_dir: str = 'saved_data/MNIST/perturbed/partitioned/',
                    perturb_on: bool = True,
                    sigma: float = 0.2, 
                    tau: float = 0.06, 
                    M: int = 291,
                    download: bool = True,
                    validation_frac: float = 1/6):
    """
    Function to partition the raw/perturbed training/test data into training, validation, and test
    datasets. The split will be 50k, 10k, and 10k, where the validation set will be a random
    sampling without replacement from the raw training set.
    
    Args:
    ------
        root_dir: str
            Path to where the partitioned data should saved to.
        perturb_on: bool
            If True, partition the perturbed data consisting of the perturbed data & the empirical
            field. Else, partition the raw unperturbed MNIST data.
        download: bool
             If True, download the partition dataset to the root_dir. Else, simply return the partitioned
             dataset.
     
     Returns:
     --------
         training_set: np.ndarray
             The 50k perturbed samples that will be used for training the NN.
         val_set: np.ndarray
             The 10K perturbed samples that will be used for validating the NN. These 10k samples came
             from the raw MNIST training set. 
         test_set: np.ndarray
             The 10k samples that will be used to test the NN. This is the perturbed version of the raw 
             MNIST test set.
    """
    # Instantiate the MNIST data
    train, test = mkds.download_MNIST(download=False)
    
    # Give a channel dimension to the data
    train_data = mkds.reshape_with_channel_dim(train.data.numpy())
    test_data = mkds.reshape_with_channel_dim(test.data.numpy())
    
    # Perturb the data
    perturbed_train = process_perturbed_data(train_data, prng=rng)
    perturbed_test = process_perturbed_data(test_data, prng=rng)
    
    # Define the fraction of data to use for validation
    validation_frac = validation_frac

    # Get the number of samples in the training set
    num_samples = perturbed_train[0].shape[0]

    # Generate a random permutatioon of the sample indices
    permutation = np.random.permutation(num_samples)

    # Calculate the number of samples to use for validation
    num_validation_samples = int(num_samples * validation_frac)

    # Split the permutation into training and validation indices
    validation_indices = permutation[:num_validation_samples]
    training_indices = permutation[num_validation_samples:]

    # Split the training set into training and validation sets
    X_train_new = perturbed_train[0][training_indices]
    y_train_new = perturbed_train[1][training_indices]
    train_new = (X_train_new, y_train_new)

    X_val = perturbed_train[0][validation_indices]
    y_val = perturbed_train[1][validation_indices]
    val = (X_val, y_val)
    
    if download:
        # Save the data
        mkds.save_data(train_new, 
                   directory=root_dir,
                   filename='perturbed_training_set.pkl')

        mkds.save_data(val, 
                   directory=root_dir,
                   filename='perturbed_val_set.pkl')

        mkds.save_data(perturbed_test, 
                   directory=root_dir,
                   filename='perturbed_test_set.pkl')
    
    return train_new, val, test

In [8]:
# train, val, test = partition_MNIST(download=True)

In [9]:
def train_and_evaluate(batchsize, state, epochs, ckpt_dir, prng):    

    # Load the partitioned datasets
    training = mkds.load_data(data_dir='saved_data/MNIST/perturbed/partitioned',
                                         data_file='perturbed_training_set.pkl')

    val = mkds.load_data(data_dir='saved_data/MNIST/perturbed/partitioned',
                                         data_file='perturbed_val_set.pkl')

    test = mkds.load_data(data_dir='saved_data/MNIST/perturbed/partitioned',
                                         data_file='perturbed_test_set.pkl')

    print('data loaded')
    # Put them into Pytorch Dataset Objects
    training = MNIST_DS(training[0], training[1])
    val = MNIST_DS(val[0], val[1])
    testing = MNIST_DS(test[0], test[1])
    
    # Instantiate the pytorch data loaders
    train_dl = DataLoader(training, batch_size=batchsize, collate_fn=numpy_collate, shuffle=True)
    val_dl = DataLoader(val, batch_size=batchsize, collate_fn=numpy_collate, shuffle=False)
    test_dl = DataLoader(testing, batch_size=batchsize, collate_fn=numpy_collate, shuffle=False)

    for epoch in tqdm(range(1, epochs+1)):
        best_val_loss = 1e6

        # =========== Training =========== #
        train_batch_metrics = []
        for cnt, batch in enumerate(train_dl):
            # Instantiate the perturbed data
            data, targets = batch
            # data = mkds.reshape_with_channel_dim(data)
            # perturbed_batch = mkds.empirical_field(data, prng)
            # Do one train step with perturbed data
            state, metrics = train_step(state, batch)
            train_batch_metrics.append(metrics)
        train_batch_metrics = accumulate_metrics(train_batch_metrics)
        print(
            'TRAIN (%d/%d): Loss: %.4f, r2: %.2f' % (
                epoch, epochs, train_batch_metrics['loss'], 
                train_batch_metrics['r2'])
        )
                
        # =========== Validation =========== #
        val_batch_metrics = []
        for cnt, batch in enumerate(val_dl):
            # Instantiate the imgs
            data, targets = batch
            # data = mkds.reshape_with_channel_dim(data)
            # # Perturb the data
            # perturbed_batch = mkds.empirical_field(data, prng)
            metrics = eval_step(state, batch)
            val_batch_metrics.append(metrics)
        val_batch_metrics = accumulate_metrics(val_batch_metrics)
        print(
            'Val (%d/%d): Loss: %.4f, r2: %.2f' % (
                epoch, epochs, val_batch_metrics['loss'], 
                val_batch_metrics['r2'])
        )
        print()
        
        wandb.log({
            "Train Loss": train_batch_metrics['loss'],
            "Train r2": train_batch_metrics['r2'],
            "Validation Loss": val_batch_metrics['loss'],
            "Validation r2": val_batch_metrics['r2']
        }, step=epoch)
        
        if val_batch_metrics['loss'] < best_val_loss:
            save_checkpoint_wandb("checkpoint.msgpack", state, epoch)
            
    restored_state = load_checkpoint_wandb("checkpoint.msgpack", state)
    test_batch_metrics = []
    for cnt, batch in enumerate(test_dl):
            # Instantiate the imgs
            data, targets = batch
            # data = mkds.reshape_with_channel_dim(data)
            # # Perturb the data
            # perturbed_batch = mkds.empirical_field(data, prng)
            metrics = eval_step(state, batch)
            test_batch_metrics.append(metrics)
        
    test_batch_metrics = accumulate_metrics(test_batch_metrics)
    print(
        'Test: Loss: %.4f, r2: %.2f' % (
            test_batch_metrics['loss'],
            test_batch_metrics['r2']
        )
    )
    
    wandb.log({
        "Test Loss": test_batch_metrics['loss'],
        "Test r2": test_batch_metrics['r2']
    })

    # Save best state
    checkpoints.save_checkpoint(ckpt_dir, target=restored_state, step=None)
    return state, restored_state


In [10]:
wandb.init(project='MLP_PFGM')

config = wandb.config
config.jax_seed = 42
config.batch_size = 128
config.learning_rate = 3e-4
config.epochs = 50
config.hidden_dims = [1570, 3140, 3140, 1570]
config.output_dim = 785

# Instantiate the model with random weights
mlp = MLP(hidden_dims=config.hidden_dims,
          output_dim=config.output_dim)

print(mlp.tabulate(jax.random.PRNGKey(config.jax_seed), jnp.ones((1, 785))))


[3m                                  MLP Summary                                   [0m
┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath   [0m[1m [0m┃[1m [0m[1mmodule[0m[1m [0m┃[1m [0m[1minputs         [0m[1m [0m┃[1m [0m[1moutputs        [0m[1m [0m┃[1m [0m[1mparams               [0m[1m [0m┃
┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩
│         │ MLP    │ [2mfloat32[0m[1,785]  │ [2mfloat32[0m[1,785]  │                       │
├─────────┼────────┼─────────────────┼─────────────────┼───────────────────────┤
│ Dense_0 │ Dense  │ [2mfloat32[0m[1,785]  │ [2mfloat32[0m[1,1570] │ bias: [2mfloat32[0m[1570]   │
│         │        │                 │                 │ kernel:               │
│         │        │                 │                 │ [2mfloat32[0m[785,1570]     │
│         │        │                 │                 │                       │
│         │  

In [11]:
# Instantiate the initial train_state object
state = init_train_state(model=mlp, 
                         random_key=jax.random.PRNGKey(config.jax_seed), 
                         shape=(1, 785), 
                         learning_rate=config.learning_rate)

prng = jax.random.PRNGKey(21)
state, best_state = train_and_evaluate(config.batch_size, 
                                       state, 
                                       config.epochs, 
                                       ckpt_dir='saved_models/',
                                       prng=default_rng(seed=np.asarray(prng)))


data loaded


  0%|          | 0/50 [00:00<?, ?it/s]

TRAIN (1/50): Loss: 2.9721, r2: -1.98
Val (1/50): Loss: 0.9986, r2: 0.00

TRAIN (2/50): Loss: 0.9985, r2: 0.00
Val (2/50): Loss: 0.9985, r2: 0.00

TRAIN (3/50): Loss: 0.9984, r2: 0.00
Val (3/50): Loss: 0.9983, r2: 0.00

TRAIN (4/50): Loss: 0.9982, r2: 0.00
Val (4/50): Loss: 0.9982, r2: 0.00

TRAIN (5/50): Loss: 0.9981, r2: 0.00
Val (5/50): Loss: 0.9981, r2: 0.00

TRAIN (6/50): Loss: 0.9981, r2: 0.00
Val (6/50): Loss: 0.9981, r2: 0.00

TRAIN (7/50): Loss: 0.9980, r2: 0.00
Val (7/50): Loss: 0.9980, r2: 0.00

TRAIN (8/50): Loss: 0.9980, r2: 0.00
Val (8/50): Loss: 0.9980, r2: 0.00

TRAIN (9/50): Loss: 0.9979, r2: 0.00
Val (9/50): Loss: 0.9980, r2: 0.00

TRAIN (10/50): Loss: 0.9979, r2: 0.00
Val (10/50): Loss: 0.9980, r2: 0.00

TRAIN (11/50): Loss: 0.9979, r2: 0.00
Val (11/50): Loss: 0.9980, r2: 0.00

TRAIN (12/50): Loss: 0.9979, r2: 0.00
Val (12/50): Loss: 0.9980, r2: 0.00

TRAIN (13/50): Loss: 0.9979, r2: 0.00
Val (13/50): Loss: 0.9980, r2: 0.00

TRAIN (14/50): Loss: 0.9979, r2: 0.00
Val 

KeyboardInterrupt: 

In [None]:
restored_state = checkpoints.restore_checkpoint(ckpt_dir='saved_models/', target=state)