In [2]:
%%writefile src/data_setup.py
"""
todo
"""
import os
import numpy as onp
import jax
import jax.numpy as jnp


def create_dataloaders(
    train_dir: str, 
    train_labels_dir: str, 
    test_dir: str, 
    test_labels_dir: str, 
    batch_size: int, 
    ):
    """Creates training and testing DataLoaders.


    Args:
    train_dir: Path to training directory.
    test_dir: Path to testing directory.

    batch_size: Number of samples per batch in each of the DataLoaders.


    Returns:

    """
    # Use ImageFolder to create dataset(s)
    train_data = onp.load(train_dir)
    test_data = onp.load(test_dir)

    # Get class names
    train_labels = onp.load(train_labels_dir)
    test_labels = onp.load(test_labels_dir)

    train_labels = jnp.squeeze(jax.nn.one_hot((train_labels), num_classes=10))
    test_labels = jnp.squeeze(jax.nn.one_hot((test_labels), num_classes=10))


    # Turn data into data loaders
    num_train = train_data.shape[0]
    num_complete_batches, leftover = divmod(num_train, batch_size)
    num_batches = num_complete_batches + bool(leftover)

    def train_data_stream():
        rng = npr.RandomState(0)
        while True:
            perm = rng.permutation(num_train)
            for i in range(num_batches):
            batch_idx = perm[i * batch_size:(i + 1) * batch_size]
            yield train_images[batch_idx], train_labels[batch_idx]
        return train_data_stream    
    train_dataloader = train_data_stream()

    def valid_data_stream(): #todo batch?
        while True:
            yield test_data, test_labels
        return valid_data_stream
    test_dataloader = valid_data_stream

    return train_dataloader, test_dataloader, train_labels

Overwriting src/data_setup.py


In [3]:
# !pip install -U jax jaxlib
# import jax

In [4]:
%%writefile src/model_builder.py
"""
Contains JAX (stax) model code to instantiate an MLP model.
"""
import jax
from jax import numpy as np
from jax import random

import neural_tangents as nt
from neural_tangents import stax


def MLP_stax(input_shape, hidden_units, output_shape, batch_size):
    nn_init, nn_apply, _ = stax.serial(
                            stax.Dense(hidden_size),
                            stax.Relu(),
                            stax.Dense(output_shape)
                            )

    rng = random.PRNGKey(0)
    in_shape = (batch_size,) + input_shape
    out_shape, params = init_fun(rng, in_shape)

    assert out_shape == (batch_size, output_shape), f"Output shape is {out_shape}, but should be {(batch_size, output_shape)}"
    
    return nn_init, nn_apply

Overwriting src/model_builder.py


In [5]:
%%writefile src/jax_extras.py

@jit
def train_step(step, opt_state,  batch_data, loss_fn):
    """Implements train step.
    
    Args:
        step: Integer representing the step index
        opt_state: Current state of the optimizer
        batch_data: A batch of data (images and labels)
    Returns:
        Batch loss, batch accuracy, updated optimizer state
    """
    params = get_params(opt_state)
    batch_loss, batch_gradients = value_and_grad(loss_fn)(params, batch_data)
    batch_accuracy = calculate_accuracy(params, batch_data)
    return batch_loss, batch_accuracy, opt_update(step, batch_gradients, opt_state)

@jit
def test_step(opt_state, batch_data):
    """Implements train step.

    Args:
        opt_state: Current state of the optimizer
        batch_data: A batch of data (images and labels)
    Returns:
        Batch loss, batch accuracy
    """
    params = get_params(opt_state)
    batch_loss = loss_fn(params, batch_data)
    batch_accuracy = calculate_accuracy(params, batch_data)
    return batch_loss, batch_accuracy


def calculate_accuracy(params, batch_data):
    """Implements accuracy metric.
    
    Args:
        params: Parameters of the network
        batch_data: A batch of data (images and labels)
    Returns:
        Accuracy for the current batch
    """
    inputs, targets = batch_data
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(nn_apply(params, inputs), axis=1)
    return jnp.mean(predicted_class == target_class)

def cross_loss_fn(params, batch_data):
    """Implements cross-entropy loss function.
    
    Args:
        params: Parameters of the network
        batch_data: A batch of data (images and labels)
    Returns:
        Loss calculated for the current batch
    """
    inputs, targets = batch_data
    preds = nn_apply(params, inputs)
    return -jnp.mean(jnp.sum(log_softmax(preds) * targets, axis=1))


Overwriting src/jax_extras.py


In [6]:
%%writefile src/engine.py
"""
Contains functions for training and testing a JAX model.
"""

from tqdm.auto import tqdm
from typing import Dict, List, Tuple


def JAX_train_step(model: callable, # could be also "params: Tuple"
               loss_fn: callable, 
               optimizer: callable,
               dataloader,
               opt_state) -> Tuple:
  """Trains a JAX model for a single epoch.

    ...........

  Args:
    model: 
    dataloader: 
    loss_fn: 
    optimizer: 
    device: A target device to compute on (e.g. "cuda" or "cpu").

  Returns:

  """


  # Setup train loss and train accuracy values
  train_loss, train_acc = 0, 0

  #opt_init, opt_update, get_params = optimizers.adam(step_size=learning_r)

  # Loop through data loader data batches
  for batch, (X, y) in enumerate(dataloader):

      # Forward pass
      y_pred = model(X)


      loss_value, acc, opt_state = train_step(step, opt_state, batch, loss_fn)

      train_loss += loss #.item()
      train_acc += acc


  # Adjust metrics to get average loss and accuracy per batch 
  train_loss = train_loss / len(dataloader)
  train_acc = train_acc / len(dataloader)
  return train_loss, train_acc


def JAX_test_step(model: callable, # could be also "params: Tuple"
              loss_fn: callable, 
              dataloader) -> Tuple:
  test_loss, test_acc = 0, 0
  for batch, (X, y) in enumerate(dataloader):

    test_loss_value, test_acc = test_step(opt_state, batch)

    test_loss += test_loss_value
    test_acc += test_acc

  # Adjust metrics to get average loss and accuracy per batch 
  test_loss = test_loss / len(dataloader)
  test_acc = test_acc / len(dataloader)
  return test_loss, test_acc

def train(model: callable, 
          optimizer: callable,
          loss_fn: callable,
          epochs: int,
          train_dataloader,
          test_dataloader 
          ) -> Dict[str, List]:

    
  """T

  Args:

  Returns:
 
  """
  # Create empty results dictionary
  results = {"train_loss": [],
      "train_acc": [],
      "test_loss": [],
      "test_acc": []
  }

  # Loop through training and testing steps for a number of epochs
  for epoch in tqdm(range(epochs)):
      train_loss, train_acc = JAX_train_step(model=model,
                                          dataloader=train_dataloader,
                                          loss_fn=loss_fn,
                                          optimizer=optimizer
                                          )
      test_loss, test_acc = JAX_test_step(model=model,
          dataloader=test_dataloader,
          loss_fn=loss_fn
          )

      # Print out what's happening
      print(
          f"Epoch: {epoch+1} | "
          f"train_loss: {train_loss:.4f} | "
          f"train_acc: {train_acc:.4f} | "
          f"test_loss: {test_loss:.4f} | "
          f"test_acc: {test_acc:.4f}"
      )

      # Update results dictionary
      results["train_loss"].append(train_loss)
      results["train_acc"].append(train_acc)
      results["test_loss"].append(test_loss)
      results["test_acc"].append(test_acc)

  # Return the filled results at the end of the epochs
  return results

Overwriting src/engine.py


In [10]:
%%writefile src/utils.py
"""
Contains various utility functions for PyTorch model training and saving.
"""
from flax import serialization
from pathlib import Path
import os

def save_model(model_params,
               target_dir: str,
               model_name: str):
    """Saves JAX model parameters to a file in a target directory.

        Args:
        model_params: JAX model parameters to save.
        target_dir: A directory for saving the model to.
        model_name: A filename for the saved model. Should include
            either ".pkl" or ".params" as the file extension.

        Example usage:
        save_model(model_params, "models", "model_name.pkl")
    """
    # Create target directory
    target_dir_path = Path(target_dir)
    target_dir_path.mkdir(parents=True,
                            exist_ok=True)

    # Create model save path
    assert model_name.endswith(".pkl") or model_name.endswith(".params"), "model_name should end with '.pkl' or '.params'"
    model_save_path = target_dir_path / model_name

    # Save the model state_dict()
    print(f"[INFO] Saving model to: {model_save_path}")
    with open(model_save_path, "wb") as f:
        serialization.to_bytes(model_params, f)



Overwriting src/utils.py


In [9]:
%%writefile src/train.py
"""
Trains a PyTorch image classification model using device-agnostic code.
"""

import os
import torch
import data_setup, engine, model_builder, utils
from src.jax_extras import cross_loss_fn

from torchvision import transforms

# Setup hyperparameters
NUM_EPOCHS = 5
BATCH_SIZE = 32
HIDDEN_UNITS = 10
LEARNING_RATE = 0.001
INPUT_SHAPE = 40

# Setup directories
train_dir = "datasets/X_train.npy"
train_labels_dir = "datasets/y_train.npy"
test_dir = "datasets/X_test.npy"
test_labels_dir = "datasets/y_test.npy"



# Create DataLoaders with help from data_setup.py
train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(
    train_dir=train_dir,
    train_labels_dir=train_labels_dir,
    test_dir=test_dir,
    test_labels_dir=test_labels_dir,
    batch_size=BATCH_SIZE
)

# Create model with help from model_builder.py
nn_init, model = model_builder.MLP_stax(
    input_shape=INPUT_SHAPE,
    hidden_units=HIDDEN_UNITS,
    output_shape=len(class_names),
    batch_size = BATCH_SIZE
)

# Set loss and optimizer
loss_fn = cross_loss_fn
optimizer = opt_init, opt_update, get_params = optimizers.momentum(LEARNING_RATE, 0.9)

# Start training with help from engine.py
engine.train(model=model,
             train_dataloader=train_dataloader,
             test_dataloader=test_dataloader,
             loss_fn=loss_fn,
             optimizer=optimizer,
             epochs=NUM_EPOCHS
             )

# Save the model with help from utils.py
utils.save_model(model=model,
                 target_dir="models",
                 model_name="05_going_modular_script_mode_tinyvgg_model.pth")

Overwriting src/train.py


In [None]:
  # Initialize model
  opt_init, opt_update, get_params = optimizer