In [1]:
import numpy as np
from numpy.random import default_rng

import jax
import jax.numpy as jnp
import flax
from flax import linen as nn
from flax.training import train_state, checkpoints

import optax 
import make_dataset as mkds
from typing import Any, Sequence, Optional, Tuple, Iterator, Dict, Callable, Union
from tqdm.auto import tqdm

import os 

import make_dataset as mkds


2023-03-16 13:45:40.250066: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/nvidia/hpc_sdk/Linux_x86_64/22.5/math_libs/11.7/lib64:/opt/nvidia/hpc_sdk/Linux_x86_64/22.5/cuda/11.7/extras/CUPTI/lib64:/opt/nvidia/hpc_sdk/Linux_x86_64/22.5/cuda/11.7/extras/Debugger/lib64:/opt/nvidia/hpc_sdk/Linux_x86_64/22.5/cuda/11.7/nvvm/lib64:/opt/nvidia/hpc_sdk/Linux_x86_64/22.5/cuda/11.7/lib64:/opt/cray/pe/papi/7.0.0.1/lib64:/opt/cray/pe/gcc/11.2.0/snos/lib64:/opt/cray/libfabric/1.15.2.0/lib64
2023-03-16 13:45:40.250148: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/nvidia/hpc_sdk/Linux_x86_64/22.5/math_libs/11.7/lib64:/o

In [None]:
train_dl, val_dl, test_dl = mkds.load_dataloaders(batch_size=128)

# 1. Load the data

# 2. Make training pipeline

In [3]:
from flax.training import train_state, checkpoints
from flax.serialization import (
    to_state_dict, msgpack_serialize, from_bytes
)

## 1. Define Network

In [4]:
class MLP(nn.Module):
    """
    Simple MLP model for testing PFGM.

    Due to it's simplicity we use @nn.compact instead of setup
    """
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(360)(x)
        x = nn.silu(x)
        x = nn.Dense(785)(x)
        return x

In [5]:
# View model layers
mlp = MLP()

print(mlp.tabulate(jax.random.PRNGKey(0), jnp.ones((1, 785))))


[3m                                  MLP Summary                                   [0m
┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath   [0m[1m [0m┃[1m [0m[1mmodule[0m[1m [0m┃[1m [0m[1minputs        [0m[1m [0m┃[1m [0m[1moutputs       [0m[1m [0m┃[1m [0m[1mparams                 [0m[1m [0m┃
┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━┩
│         │ MLP    │ [2mfloat32[0m[1,785] │ [2mfloat32[0m[1,785] │                         │
├─────────┼────────┼────────────────┼────────────────┼─────────────────────────┤
│ Dense_0 │ Dense  │ [2mfloat32[0m[1,785] │ [2mfloat32[0m[1,360] │ bias: [2mfloat32[0m[360]      │
│         │        │                │                │ kernel:                 │
│         │        │                │                │ [2mfloat32[0m[785,360]        │
│         │        │                │                │                         │
│         │  

In [6]:
train_dl, val_dl, test_dl = mkds.load_dataloaders(batch_size=128)

params = mlp.init(jax.random.PRNGKey(0), jnp.ones((1, 785)))['params']
data_batch = next(iter(train_dl))
pred = mlp.apply({'params': params}, data_batch[0])
print(pred.shape)

(128, 785)


## 2. Create a `TrainState`

In [7]:
def init_train_state(model: Any,
                     random_key: Any,
                     shape: tuple,
                     learning_rate: int) -> train_state.TrainState:
    """
    Function to initialize the TrainState dataclass, which represents
    the entire training state, including step number, parameters, and 
    optimizer state. This is useful because we no longer need to
    initialize the model again and again with new variables, we just 
    update the "state" of the mdoel and pass this as inputs to functions.

    Args:
    -----
        model: nn.Module    
            The model that we want to train.
        random_key: jax.random.PRNGKey()
            Used to trigger the initialization functions, which generate
            the initial set of parameters that the model will use.
        shape: tuple
            Shape of the batch of data that will be input into the model.
            This is used to trigger shape inference, which is where the model
            figures out by itself what the correct size the weights should be
            when they see the inputs.
        learning_rate: int
            How large of a step the optimizer should take.

    Returns:
    --------
        train_state.TrainState:
            A utility class for handling parameter and gradient updates. 
    """
    # Initialize the model
    variables = model.init(random_key, jnp.ones(shape))

    # Create the optimizer
    optimizer = optax.adam(learning_rate) # TODO update this to be user defined

    # Create a state
    return train_state.TrainState.create(apply_fn=model.apply,
                                         tx=optimizer,
                                         params=variables['params'])

In [8]:
init_rng = jax.random.PRNGKey(0)

learning_rate = 0.01
state = init_train_state(mlp, init_rng, (1, 785), learning_rate)
del init_rng  # Must not be used anymore.

## 3. Training Step

A function that:

- Evaluates the neural network given the parameters and a batch of input images with `TrainState.apply_fn` (which contains the `Module.apply` method (forward pass)).

- Computes the cross entropy loss, using the predefined `optax.l2_loss`. Note that this function expects integer labels, so there is no need to convert labels to onehot encoding.

- Evaluates the gradient of the loss function using jax.grad.

- Applies a pytree of gradients to the optimizer to update the model’s parameters.

Use JAX’s `@jit` decorator to trace the entire `train_step` function and just-in-time compile it with XLA into fused device operations that run faster and more efficiently on hardware accelerators.

In [9]:
import optax

@jax.jit
def train_step(state: train_state.TrainState,
               batch: list):
    """
    Function to run training on one batch of data.
    """
    data, targets = batch

    def loss_fn(params: dict):
        """
        Simple MSE loss as described in the PFGM paper.
        """
        pred = state.apply_fn({'params': params}, data)
        loss = optax.l2_loss(
            predictions=pred, targets=targets).mean()
        return loss, pred

    def r_squared(params):
        """
        Function to calculate the coefficient of determination or 
        R^2, which quantifies how well the regression model fits 
        the observed data. Or more formally, it is a statistical
        measure that represents the proportion of variance in the
        dependent variable that is explained by the independent 
        variable(s) in a regression model. R^2 ranges from 0 to 1, 
        with a higher value indicating a better fit. 

        An R^2 of 0 means that the regression model does not explain
        any of the variability in the dependent variable, while an
        R^2 of 1 indicates that the regression model explains all of
        the variability in the dependent model.
        """
        pred = state.apply_fn({'params': params}, data)
        residual = jnp.sum(jnp.square(targets - pred))
        total = jnp.sum(jnp.square(targets - jnp.mean(targets)))
        r2_score = 1 - (residual / total)
        return r2_score

    gradient_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, pred), grads = gradient_fn(state.params)
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(pred=pred, targets=targets)
    return state, metrics

@jax.jit
def eval_step(state, batch):
    data, targets = batch
    pred = state.apply_fn({'params': state.params}, data)
    return compute_metrics(pred=pred, targets=targets)


## 4. Metric Computation

In [10]:
def compute_metrics(*, pred, targets):
    """
    Function that computes metrics that will be logged
    during training
    """
    # Calculate the MSE loss
    loss = ((pred - targets) ** 2).mean()

    # Calculate the R^2 score
    residual = jnp.sum(jnp.square(targets - pred))
    total = jnp.sum(jnp.square(targets - jnp.mean(targets)))
    r2_score = 1 - (residual / total)

    # Save these metrics into a dict
    metrics = {
        'loss': loss,
        'r2': r2_score
    }

    return metrics

def accumulate_metrics(metrics):
    """
    Function that accumulates all the metrics for each batch and 
    accumulates/calculates the metrics for each epoch.
    """
    metrics = jax.device_get(metrics)
    return {
        k: np.mean([metric[k] for metric in metrics])
        for k in metrics[0]
    }

## 5. Initialize the `TrainState`

In [11]:
init_rng = jax.random.PRNGKey(0)

learning_rate = 0.01
state = init_train_state(mlp, init_rng, (1, 785), learning_rate)
del init_rng  # Must not be used anymore.

In [12]:
train_dl, val_dl, test_dl = mkds.load_dataloaders(batch_size=128)

train_batch_metrics = []
for cnt, batch in enumerate(train_dl):
    state, metrics = train_step(state, batch)
    train_batch_metrics.append(metrics)

## 6. Train and Evaluate

In [13]:
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader

class PerturbMNIST(Dataset):
    """
    Simple dataset class that stores the data and targets as NumPy arrays.
    
    Args:
    -----
        data: np.ndarray
            The perturbed input data.
        targets: np.ndarray
            The empirical field that generated the perturbed data.
    """
    def __init__(self, data: np.ndarray, targets: np.ndarray):
        self.data = data
        self.targets = targets
        
    def __len__(self):
        """
        Returns the total number of samples in the dataset.
        """
        return len(self.targets)
    
    def __getitem__(self, idx: int):
        """
        Returns the i-th sample and corresponding target in the dataset.
        
        Args:
        -----
            idx: int
                The index of the sample to return.
                
        Returns:
        --------
            tuple: A tuple containing the sample and target.
        """
        sample = self.data[idx]
        target = self.targets[idx]
        return sample, target

In [14]:
perturbed_training = mkds.load_data(data_dir='/pscratch/sd/m/mdowicz/PFGM_MNIST/saved_data/MNIST/perturbed/partitioned',
                                     data_file='partitioned_training_set.pkl')

perturbed_val = mkds.load_data(data_dir='/pscratch/sd/m/mdowicz/PFGM_MNIST/saved_data/MNIST/perturbed/partitioned',
                                     data_file='partitioned_val_set.pkl')

perturbed_test = mkds.load_data(data_dir='/pscratch/sd/m/mdowicz/PFGM_MNIST/saved_data/MNIST/perturbed/partitioned',
                                     data_file='partitioned_test_set.pkl')

In [15]:
epochs = 2

training = PerturbMNIST(perturbed_training[0], perturbed_training[1])
train_dl = DataLoader(training, 
                      collate_fn=mkds.numpy_collate,
                      batch_size=128) # create your dataloader

val = PerturbMNIST(perturbed_val[0], perturbed_val[1])
val_dl = DataLoader(val,
                    collate_fn=mkds.numpy_collate,
                    batch_size=128) # create your dataloader

testing = PerturbMNIST(perturbed_test[0], perturbed_test[1])
test_dl = DataLoader(testing, 
                     collate_fn=mkds.numpy_collate,
                     batch_size=128) # create your dataloader

In [22]:
for epoch in tqdm(range(1,epochs+1)):
    ### Training ###
    train_batch_metrics = []
    for cnt, batch in enumerate(train_dl):
        state, metrics = train_step(state, batch)
        train_batch_metrics.append(metrics)
    train_metrics = accumulate_metrics(train_batch_metrics)
    
    print(
        'TRAIN (%d/%d): Loss: %.4f, r2: %.2f' % (
            epoch, epochs, train_metrics['loss'], 
            train_metrics['r2'])
    )   
    
    ### Validation ###
    val_batch_metrics = []
    for cnt, batch in enumerate(val_dl):
        state, metrics = train_step(state, batch)
        val_batch_metrics.append(metrics)
    val_metrics = accumulate_metrics(val_batch_metrics)
    print(
        'Val (%d/%d): Loss: %.4f, r2: %.2f' % (
            epoch, epochs, val_metrics['loss'], 
            val_metrics['r2'])
    )

    ### Testing ###    
    test_batch_metrics = []
    for cnt, batch in enumerate(test_dl):
            metrics = eval_step(state, batch)
            test_batch_metrics.append(metrics)
        
    test_metrics = accumulate_metrics(test_batch_metrics)
    print(
        'Test: Loss: %.4f, r2: %.2f' % (
            test_metrics['loss'],
            test_metrics['r2']
        )
    )
    print()
    


  0%|          | 0/2 [00:00<?, ?it/s]

TRAIN (1/2): Loss: 63102800363520.0000, r2: -63184274718720.00
Val (1/2): Loss: 45787794898944.0000, r2: -45846980722688.00
Test: Loss: 142529786797883392.0000, r2: -142713448189394944.00

TRAIN (2/2): Loss: 48549882494976.0000, r2: -48612545396736.00
Val (2/2): Loss: 59777925251072.0000, r2: -59855154970624.00
Test: Loss: 134506830838628352.0000, r2: -134680184308629504.00



In [24]:
CHECKPOINT_PATH = '/pscratch/sd/m/mdowicz/PFGM_MNIST/saved_models'

In [27]:
checkpoints.save_checkpoint(ckpt_dir=CHECKPOINT_PATH, target=state, step=0)

'/pscratch/sd/m/mdowicz/PFGM_MNIST/saved_models/checkpoint_0'

In [28]:
restored_state = checkpoints.restore_checkpoint(ckpt_dir=CHECKPOINT_PATH, target=state)