# Physically Recurrent Neural Networks - Demo notebook

This notebook provides a minimal demonstration of training a PRNN.
It is assumed that the reader is familiar with the torch-based implementation.

This Jax-based code should result in >10x faster training times, as demonstrated by the first example code block, which is very similar to the example in the torch-based code.

The final two code blocks show learning curves comparing two different decoder layers.

## Load packages, get datasets

In [None]:
%load_ext autoreload
%autoreload 2
# %matplotlib widget

import os
import random
import numpy as np
import jax
jax.config.update('jax_platform_name', 'cpu')  # force jax to use cpu; often faster.
import time
import zipfile

import matplotlib.pyplot as plt

from urllib.request import urlretrieve

from trainer import Trainer
from utils import StressStrainDataset

from prnn import create_prnn_model

# Download and unzip the same dataset as for the torch-based prnn
if not os.path.isdir('datasets'):
    print('Downloading and unzipping datasets...')
    urlretrieve('https://surfdrive.surf.nl/files/index.php/s/OcSDq0zNqkVbvIO/download', 'datasets.zip')
    zip_file = zipfile.ZipFile('datasets.zip')
    zip_file.extractall('.')

## Example code

### Train a PRNN model on 18 simple paths for 20 epochs, to compare the computational speed to the torch-based approach

In [None]:
# Setting some hyperparameters
settings = {
    'data_path': 'datasets/canonical.data',
    'seq_length': 60,
    'train_batch_size': 3,
    'decoder_type': 'SoftLayer',

    'input_norm': False,
    'output_norm': False,

    'mat_points': 2,
    'max_epochs': 20,
    'learning_rate': 1e-1,
    'patience': 50,                     # early stopping epochs
    'interval': 1,                      # interval for which we compute validation loss (once every x epochs)

    'feature_dim': 3,
    'verbose': True,
    'seed': 42
}

# Setup Keys and Seed
key = jax.random.PRNGKey(settings['seed'])
np.random.seed(settings['seed']) # For NumPy-based shuffling if used
random.seed(settings['seed']) # For Python random if used


dataset = StressStrainDataset(settings['data_path'], [0,1,2], [3,4,5], seq_length=settings['seq_length'])
all_samples = dataset.get_all_batches()

# the prnn parameters and material properties are explicitly passed around
prnn, params, material = create_prnn_model(random_key=key, n_matpts=settings['mat_points'], decoder_type=settings['decoder_type'])

# Initialize trainer class
train_handler = Trainer(prnn, params, material=material, random_key=key, **settings)

# Use all samples for training & validation for demonstration (effectively disabling early stopping)
start_time = time.time()
train_handler.train(all_samples, all_samples, **settings)
total_time = time.time() - start_time
print(f"Training for 20 epochs: {total_time:.2f} seconds")


### Train thirty PRNNs from scratch and plot a learning curve
This will take approximately 10-30 minutes.

In [None]:
base_settings = {
    'data_path': 'datasets/gpCurves.data',
    'seq_length': 60,
    'train_samples': 80,
    'train_batch_size': 4,
    'decoder_type': 'SoftLayer',

    'input_norm': False,   # Note: keep false. Normalization has not yet been consistently implemented for computing losses.
    'output_norm': False,

    'mat_points': 2,

    'max_epochs': 10000,
    'learning_rate': 1e-2,              # Note that setting a constant lr is not ideal
    'patience': 50,                     # early stopping epochs
    'interval': 1,                      # interval for which we compute validation loss (once every x epochs)

    'feature_dim': 3,
    'verbose': False,
    'seed': 42
}

# Number of curves
ncurves = [1, 2, 3, 4, 5, 6, 8, 10, 12, 16]
nmodels = 3

# Split dataset manually
dataset = StressStrainDataset(base_settings['data_path'], [0,1,2], [3,4,5], seq_length=base_settings['seq_length'])
all_samples = dataset.get_all_batches()
num_samples = len(dataset)
all_indices = np.arange(num_samples)
train_indices = all_indices[:40]
val_indices = all_indices[40:70]
test_indices = all_indices[70:]
val_dataset = all_samples[val_indices]
test_dataset = all_samples[test_indices]

seed = base_settings['seed']
test_losses = np.zeros((len(ncurves), nmodels))

for run_i in range(nmodels):
    for ncurve in ncurves:
        print(f"\n--- Running with {ncurve} training samples, run {run_i + 1}/{nmodels} ---")
        # Setup Keys and Seed
        seed += 1
        key = jax.random.PRNGKey(seed)
        np.random.seed(seed) # For NumPy-based shuffling if used
        random.seed(seed) # For Python random if used

        # Copy and modify settings for this run
        settings = base_settings.copy()
        settings['train_samples'] = ncurve
        settings['train_batch_size'] = min(max(1, ncurve // 4), 4)        # minimum=1, maximum=4, else num_samples // 4
        settings['seed'] = seed

        # Create random training subset
        cur_indices = train_indices.copy()
        np.random.shuffle(cur_indices)
        cur_indices = cur_indices[:settings['train_samples']]
        train_dataset = all_samples[cur_indices]

        # the prnn parameters and material properties are explicitly passed around
        prnn, params, material = create_prnn_model(random_key=key, n_matpts=settings['mat_points'], decoder_type=settings['decoder_type'])
        
        # Initialize trainer class
        train_handler = Trainer(prnn, params, material=material, random_key=key, **settings)

        # Train the model
        train_handler.train(train_dataset, val_dataset, **settings)
        print(f"Model trained for {train_handler._epoch} epochs.")

        # Evaluate the test loss (Note, this uses the L2 norm, the torch-based notebook used the L1 norm)
        test_losses[ncurves.index(ncurve), run_i] = train_handler._eval_step_jit(train_handler._state, test_dataset, material)

plt.figure()
ncurves_arr = np.array(ncurves)
ncurves_arr = np.repeat(ncurves_arr[:, np.newaxis], nmodels, axis=1)
plt.scatter(ncurves_arr,test_losses)
plt.ylabel('L2 Loss')
plt.xlabel('n curves')
plt.show()


### Repeat the learning curve experiment with a custom sparse decoder layer
With a custom decoder layer that connects component-wise, and where the weights sum to one, we should be able to use even fewer training samples to obtain good performance when using more material points.

In [None]:
new_base_settings = base_settings.copy()
new_base_settings['decoder_type'] = 'SparseNormLayer'

seed = new_base_settings['seed']
test_losses_sparse_norm = np.zeros((len(ncurves), nmodels))

for run_i in range(nmodels):
    for ncurve in ncurves:
        print(f"\n--- Running with {ncurve} training samples, run {run_i + 1}/{nmodels} ---")
        # Setup Keys and Seed
        seed += 1
        key = jax.random.PRNGKey(seed)
        np.random.seed(seed) # For NumPy-based shuffling if used
        random.seed(seed) # For Python random if used

        # Copy and modify settings for this run
        settings = new_base_settings.copy()
        settings['train_samples'] = ncurve
        settings['train_batch_size'] = min(max(1, ncurve // 4), 4)        # minimum=1, maximum=4, else num_samples // 4
        settings['seed'] = seed

        # Create random training subset
        cur_indices = train_indices.copy()
        np.random.shuffle(cur_indices)
        cur_indices = cur_indices[:settings['train_samples']]
        train_dataset = all_samples[cur_indices]

        # Create and train the model
        prnn, params, material = create_prnn_model(random_key=key, n_matpts=settings['mat_points'], decoder_type=settings['decoder_type'])
        train_handler = Trainer(prnn, params, material=material, random_key=key, **settings)
        train_handler.train(train_dataset, val_dataset, **settings)
        print(f"Model trained for {train_handler._epoch} epochs.")

        test_losses_sparse_norm[ncurves.index(ncurve), run_i] = train_handler._eval_step_jit(train_handler._state, test_dataset, material)

plt.figure()
ncurves_arr = np.array(ncurves)
ncurves_arr = np.repeat(ncurves_arr[:, np.newaxis], nmodels, axis=1)
plt.scatter(ncurves_arr,test_losses, label='Base')
plt.scatter(ncurves_arr,test_losses_sparse_norm, label='New decoder')
plt.ylabel('L2 Loss')
plt.xlabel('n curves')
plt.show()