In [1]:
import os
import argparse
import numpy as np
from tqdm import tqdm
from typing import Any

from dataloader import load_data
from RotNet import rotnet_constructor
from PredNet import prednet_constructor

import jax
import jax.numpy as jnp

import flax.linen as nn
from flax import traverse_util
from flax.core.frozen_dict import freeze
from flax.training import train_state, checkpoints


import optax

import matplotlib.pyplot as plt

2022-12-07 22:46:20.270441: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-07 22:46:20.270496: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory


In [2]:
# Define cifar10 image shape
CIFAR10_INPUT_SHAPE = (1, 32, 32, 3)

class TrainState(train_state.TrainState):
    batch_stats: Any

In [3]:
def cross_entropy_loss_(logits, labels, num_classes=10):
    """
    Define loss: https://flax.readthedocs.io/en/latest/getting_started.html#define-loss
    """
    labels_onehot = jax.nn.one_hot(labels, num_classes=num_classes)
    return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()
cross_entropy_loss = jax.jit(cross_entropy_loss_, static_argnums=2)

def compute_metrics_(logits, labels, num_classes):
    """
    Metric computation: https://flax.readthedocs.io/en/latest/getting_started.html#metric-computation
    """
    loss = cross_entropy_loss(logits=logits, labels=labels, num_classes=num_classes)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    metrics = {"loss": loss, "accuracy": accuracy}
    return metrics
compute_metrics = jax.jit(compute_metrics_, static_argnums=2)

def train_batch_(state, images, labels, num_classes=10):
    """
    Training step: https://flax.readthedocs.io/en/latest/getting_started.html#training-step
    """
    def loss_fn(params):
        logits, updates = state.apply_fn(
            {"params": params, "batch_stats": state.batch_stats}, images, mutable=["batch_stats"], train=True
        )
        loss = cross_entropy_loss(logits=logits, labels=labels, num_classes=num_classes)
        return loss, (logits, updates)

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, (logits, updates)), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    state = state.replace(batch_stats=updates["batch_stats"])
    metrics = compute_metrics(logits=logits, labels=labels, num_classes=num_classes)
    return state, metrics
train_batch = jax.jit(train_batch_, static_argnums=3)

def train_epoch(state, dataloader, num_classes=10):
    """
    Train function: https://flax.readthedocs.io/en/latest/getting_started.html#train-function
    """
    batch_metrics = []
    for images, labels in dataloader:
        state, metrics = train_batch(state, images, labels, num_classes=num_classes)
        batch_metrics.append(metrics)
    batch_metrics_np = jax.device_get(batch_metrics)
    epoch_metrics_np = {
        k: np.mean([metrics[k] for metrics in batch_metrics_np]) for k in batch_metrics_np[0]
    }
    return state, epoch_metrics_np

def eval_batch_(state, images, labels, num_classes=10):
    """
    Evaluation step: https://flax.readthedocs.io/en/latest/getting_started.html#evaluation-step
    """
    logits = state.apply_fn(
        {"params": state.params, "batch_stats": state.batch_stats}, images, mutable=False, train=False
    )
    return compute_metrics(logits=logits, labels=labels, num_classes=num_classes)
eval_batch = jax.jit(eval_batch_, static_argnums=3)

def eval_model(state, dataloader, num_classes=10):
    """
    Eval function: https://flax.readthedocs.io/en/latest/getting_started.html#eval-function
    """
    batch_metrics = []
    for images, labels in dataloader:
        metrics = eval_batch(state, images, labels, num_classes=num_classes)
        batch_metrics.append(metrics)
    batch_metrics_np = jax.device_get(batch_metrics)
    validation_metrics_np = {
        k: np.mean([metrics[k] for metrics in batch_metrics_np]) for k in batch_metrics_np[0]
    }
    return validation_metrics_np["loss"], validation_metrics_np["accuracy"]


#### Fill in the TODO below. Note that we use the Batch Norm here and need to get the batch_norm params separately.

*HINT: Read the link in the code below for more information*

In [None]:
def create_train_state(rng, model, learning_rate, momentum):
    """
    Create train state: https://flax.readthedocs.io/en/latest/getting_started.html#create-train-state
    """
    variables = model.init(rng, jnp.ones(CIFAR10_INPUT_SHAPE, dtype=model.dtype), train=False)
    # TODO: get params and batch_state from variables
    # hint: Check the documentation of model.init, and find out how to use it
    params, batch_stats = ... , ...
    tx = optax.sgd(learning_rate, momentum)
    state = TrainState.create(apply_fn=model.apply, params=params, tx=tx, batch_stats=batch_stats)
    return state, variables

Define your parameters here:

In [None]:
class Args:
    def __getitem__(self, key):
        return getattr(self, key)
    
    def __setitem__(self, key, val):
        setattr(self, key, val)
        
    def __contains__(self, key):
        return hasattr(self, key)
    
    # Define What RotNet architecture to use.
    """
        Note the RotNet architecture is specified as rotnetX_featY.
        Here X is the number of CNN layers in the rotnet
        Here Y is the layer from which we want to extract the features for creating the final model

        NOTE: Y <= X. You CANNOT extract features from a layer that DOES NOT EXIST!
    """
    rotnet_arch : str = "rotnet3_feat3"  #@param {type: "string"}

    # Define What PredNet architecture to use.
    """
        Note the PredNet Classifier (Head) architecture is specified as prednetX.
        Here X refers to the number of convolutional layers to be added on top of the specified RotNet Backbone.

        NOTE: X=0 here refers to no CNN layers in the head and only one dense layer.
    """
    prednet_arch: str = "prednet3" #@param {type: "string"}

    # Define Directory to Save RotNet Checkpoints"
    rotnet_ckpt_dir: str = "./ckpts/rotnet" #@param {type: "string"}
    
    # Define Directory to Save PredNet Checkpoints
    prednet_ckpt_dir: str = "./ckpts/prednet" #@param {type: "string"}
    
    # -------------------------- RotNet Training Params -------------------------- #
    # Continue to Train RotNet from rotnet_ckpt_epoch
    rotnet_ckpt_epoch: int = 0 #@param {type: "integer"}

    # Train RotNet for rotnet_epochs in Total
    rotnet_epochs: int = 10 #@param {type: "integer"}
    
    # -------------------------- PredNet Training Params ------------------------- #
    # Continue to train PredNet from prednet_ckpt_epoch
    prednet_ckpt_epoch: int = 0 #@param {type: "integer"}
    
    # Train PredNet for prednet_epochs in Total
    prednet_epochs: int = 10 #@param {type: "integer"}

    # -------------------------- General Training Params ------------------------- #
    # Batch Size Per Process"
    batch_size: int = 128 #@param {type: "integer"}
    # Number of Data Loading Workers
    workers: int = 4 #@param {type: "integer"}
    # Learning Rate of the Optimizer
    lr: float = 1e-3 #@param {type: "float"}
    # Momentum of the Optimizer
    momentum: float = 0.9 #@param {type: "float"}
    # Print Model and Params Info
    verbose: bool = False #@param {type: "boolean"}

    # ------------------ Control Gradients and Type of Training ------------------ #
    # NOTE: To get an untrained RotNet set rotnet_epochs to 0
    # Disable Gradient Flow in RotNet if Set to True
    no_grad: bool = True #@param {type: "boolean"}


args = Args()

Preprocessing before training:

In [None]:
# ---------------------- Generate JAX Random Number Key ---------------------- #
rng = jax.random.PRNGKey(0)
print("Random Key Generated")

# -------------------------- Create the RotNet Model ------------------------- #
# Define network: https://flax.readthedocs.io/en/latest/getting_started.html#define-network
# TODO: In order to construct a RotNet, you need to fill in all the TODOs in RotNet.py. Check the above link for hints.
rotnet_model = rotnet_constructor(args.rotnet_arch)
print("Network Defined")
if args.verbose:
    print(nn.tabulate(rotnet_model, rng)(jnp.ones(CIFAR10_INPUT_SHAPE), False))

# ------------------------- Load the CIFAR10 Dataset ------------------------- #
# Loading data: https://flax.readthedocs.io/en/latest/getting_started.html#loading-data
# NOTE: Choose batch_size and workers based on system specs.
# NOTE: This dataloader requires pytorch to load the datset for convenience.
# TODO: In order to create the dataset, you will need to implement rotate_image function in utils.py. Check the above link for hints.
loaders = load_data(batch_size=args.batch_size, workers=args.workers)
train_loader, validation_loader, test_loader, rot_train_loader, rot_validation_loader, rot_test_loader = loaders
print("Data Loaded")

# --- Create the Train State Abstraction (see documentation in link below) --- #
# Create train state: https://flax.readthedocs.io/en/latest/getting_started.html#create-train-state
rotnet_state, rotnet_variables = create_train_state(rng, rotnet_model, args.lr, args.momentum)
print("Train State Created")

# ----------------- Specify the Directory to Save Checkpoints ---------------- #
rotnet_ckpt_dir = args.rotnet_ckpt_dir
if not os.path.exists(rotnet_ckpt_dir):
    os.makedirs(rotnet_ckpt_dir)
    print("RotNet Checkpoint Directory Created")
else:
    print("RotNet Checkpoint Directory Found")

Train a RotNet:

In [None]:
# -------------------- Load Existing Checkpoint of RotNet -------------------- #
if args.rotnet_ckpt_epoch > 0:
    # TODO: Load pre-trained model if required
    # hint: Check out flax.training.checkpoints.restore_checkpoint function
    rotnet_state = ...
    print("RotNet Checkpoint Loaded")

# ----------------------------- Train the RotNet ----------------------------- #
print("Starting RotNet Training Loop")

train_acc = []
train_loss = []
valid_acc = []
valid_loss = []
test_acc =[]


for epoch in tqdm(range(args.rotnet_ckpt_epoch + 1, args.rotnet_epochs + 1)):
    # ------------------------------- Training Step ------------------------------ #
    # Training step: https://flax.readthedocs.io/en/latest/getting_started.html#training-step
    # TODO: Use train_epoch defined above to train a RotNet 
    rotnet_state, train_epoch_metrics = ...

    # Print training metrics every epoch
    print(
        f"train epoch: {epoch}, \
        loss: {train_epoch_metrics['loss']:.4f}, \
        accuracy:{train_epoch_metrics['accuracy']*100:.2f}%"
    )

    # ------------------------------ Evaluation Step ----------------------------- #
    # Evaluation step: https://flax.readthedocs.io/en/latest/getting_started.html#evaluation-step
    # TODO: Use eval_model defined aboove to get validation loss and accuracy
    validation_loss, validation_accuracy = ...
    
    # Print validation metrics every epoch
    print(f"validation loss: {validation_loss:.4f}, validation accuracy:{validation_accuracy*100:.2f}%\n")

    # ---------------------------- Saving Checkpoints ---------------------------- #
    # ---- https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html --- #
    checkpoints.save_checkpoint(
        ckpt_dir=rotnet_ckpt_dir, target=rotnet_state, step=epoch, overwrite=True, keep=args.rotnet_epochs
    )

    train_acc.append(train_epoch_metrics['accuracy'])
    train_loss.append(train_epoch_metrics['loss'])
    valid_acc.append(validation_accuracy)
    valid_loss.append(validation_loss)

    # Print test metrics every nth epoch
    if epoch % 5 == 0:
        _, test_accuracy = eval_model(rotnet_state, rot_test_loader, num_classes=4)
        print("====================")
        print(f"test_accuracy: {test_accuracy*100:.2f}%")
        print("====================")

In [None]:
plt.plot(np.arange(len(train_acc)), train_acc, color = 'blue')
plt.plot(np.arange(len(valid_acc)), valid_acc, color = 'green')
plt.title("Training and Validation Accuracy vs epochs")
plt.xlabel("number of epochs")
plt.ylabel("Accuracy")
plt.legend(["Training_Accuracy", "Validation_Accuracy"])

In [None]:
plt.plot(np.arange(len(train_loss)), train_loss, color = 'blue')
plt.plot(np.arange(len(valid_loss)), valid_loss, color = 'green')
plt.title("Training and Validation Loss vs epochs")
plt.xlabel("number of epochs")
plt.ylabel("Loss")
plt.legend(["Training_Loss", "Validation_Loss"])

## Model Surgery:

Transfer the rotnet as a backbone into the PredNet

In [None]:
# ---- https://flax.readthedocs.io/en/latest/guides/transfer_learning.html --- #
# ----------------------------- Extract Backbone ----------------------------- #
def extract_submodule(model):
    feature_extractor = model.features.clone()
    # TODO: extract the variables from the feature_extractor
    # hint: checkout https://flax.readthedocs.io/en/latest/guides/transfer_learning.html#extracting-a-submodule
    variables = ...
    return feature_extractor, variables

backbone_model, backbone_model_variables = nn.apply(extract_submodule, rotnet_model)(rotnet_variables)

# ------------------------- Create the Prednet Model ------------------------- #
# TODO: Construct your PredNet. Please take a look at the TODOs in PredNet.py
# hint: Check out how the RotNet was defined above
prednet_model = ...

# ----------------------- Extract Variables and Params ----------------------- #
prednet_variables   = prednet_model.init(rng, jnp.ones(CIFAR10_INPUT_SHAPE), train=False)
prednet_params      = prednet_variables['params']
prednet_batch_stats = prednet_variables['batch_stats']

# --------------------- Transfer the Backbone Parameters --------------------- #
prednet_params              = prednet_params.unfreeze()
prednet_params['backbone']  = backbone_model_variables['params']
prednet_params              = freeze(prednet_params)

if not args.no_grad:
    prednet_batch_stats              = prednet_batch_stats.unfreeze()
    prednet_batch_stats['backbone']  = backbone_model_variables['batch_stats']
    prednet_batch_stats              = freeze(prednet_batch_stats)

# -------------------------- Define How to Backprop -------------------------- #
if args.no_grad:
    # TODO: Freeze layers with optax.multi_transform for the mode
    ############### Your Code Starts Here ###################
    # hint: Check out https://flax.readthedocs.io/en/latest/guides/transfer_learning.html#optax-multi-transform

    # Fill the 'frozen' value here:
    partition_optimizers = {'trainable': optax.sgd(args.lr, args.momentum), 'frozen': ... }
    
    # Fill in the assigning of params as frozen.
    prednet_param_partitions = ...
    

    tx = optax.multi_transform(partition_optimizers, prednet_param_partitions)
    
    ############### Your Code Ends Here ###################

    # ---------------- Visualize param_partitions to double check ---------------- #
    if args.verbose:
        flat = list(traverse_util.flatten_dict(prednet_param_partitions).items())
        freeze(traverse_util.unflatten_dict(dict(flat[:2] + flat[-2:])))
        
else:
    tx = optax.sgd(args.lr, args.momentum)
    
# ---------------------- Create Train State for PredNet ---------------------- #
# TODO: Create Train State for PredNet
# hint: We also did it for RotNet above.
prednet_state = ...

# ----------------- Specify the Directory to Save Checkpoints ---------------- #
prednet_ckpt_dir = args.prednet_ckpt_dir
if not os.path.exists(prednet_ckpt_dir):
    os.makedirs(prednet_ckpt_dir)
    print("PredNet Checkpoint Directory Created")
else:
    print("PredNet Checkpoint Directory Found")

Train a PredNet:

In [None]:
# -------------------- Load Existing Checkpoint of PredNet ------------------- #
if args.prednet_ckpt_epoch > 0:
    prednet_state = checkpoints.restore_checkpoint(
        ckpt_dir=prednet_ckpt_dir, target=prednet_state, step=args.prednet_ckpt_epoch
    )
    print("PredNet Checkpoint Loaded")

# ----------------------------- Train the PredNet ---------------------------- #
print("Starting PredNet Training Loop")

train_acc = []
train_loss = []
valid_acc = []
valid_loss = []
test_acc =[]


for epoch in tqdm(range(args.prednet_ckpt_epoch + 1, args.prednet_epochs + 1)):
    # ------------------------------- Training Step ------------------------------ #
    # Training step: https://flax.readthedocs.io/en/latest/getting_started.html#training-step
    # TODO: Use train_epoch defined above to train. 
    prednet_state, train_epoch_metrics = ...

    # Print training metrics every epoch
    print(
        f"train epoch: {epoch}, \
        loss: {train_epoch_metrics['loss']:.4f}, \
        accuracy:{train_epoch_metrics['accuracy']*100:.2f}%"
    )

    # ------------------------------ Evaluation Step ----------------------------- #
    # Evaluation step: https://flax.readthedocs.io/en/latest/getting_started.html#evaluation-step
    # TODO: Use eval_model defined above to get validation loss and accuracy
    validation_loss, validation_accuracy = ...
    
    # Print validation metrics every epoch
    print(f"validation loss: {validation_loss:.4f}, validation accuracy:{validation_accuracy*100:.2f}%\n")

    # ---------------------------- Saving Checkpoints ---------------------------- #
    # ---- https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html --- #
    checkpoints.save_checkpoint(
        ckpt_dir=prednet_ckpt_dir, target=prednet_state, step=epoch, overwrite=True, keep=args.prednet_epochs
    )

    train_acc.append(train_epoch_metrics['accuracy'])
    train_loss.append(train_epoch_metrics['loss'])
    valid_acc.append(validation_accuracy)
    valid_loss.append(validation_loss)
    # Print test metrics every nth epoch
    if epoch % 5 == 0:
        _, test_accuracy = eval_model(prednet_state, test_loader, num_classes=10)
        print("====================")
        print(f"test_accuracy: {test_accuracy*100:.2f}%")
        print("====================")

In [None]:
plt.plot(np.arange(len(train_acc)), train_acc, color = 'blue')
plt.plot(np.arange(len(valid_acc)), valid_acc, color = 'green')
plt.title("Training and Validation Accuracy vs epochs")
plt.xlabel("number of epochs")
plt.ylabel("Accuracy")
plt.legend(["Training_Accuracy", "Validation_Accuracy"])


In [None]:
plt.plot(np.arange(len(train_loss)), train_loss, color = 'blue')
plt.plot(np.arange(len(valid_loss)), valid_loss, color = 'green')
plt.title("Training and Validation Loss vs epochs")
plt.xlabel("number of epochs")
plt.ylabel("Loss")
plt.legend(["Training_Loss", "Validation_Loss"])