In [1]:
import os
import argparse
import numpy as np
from tqdm import tqdm
from typing import Any

from dataloader import load_data
from RotNet import rotnet_constructor
from PredNet import prednet_constructor

import jax
import jax.numpy as jnp

import flax.linen as nn
from flax import traverse_util
from flax.core.frozen_dict import freeze
from flax.training import train_state, checkpoints


import optax

  from .autonotebook import tqdm as notebook_tqdm
2022-12-07 20:44:41.252877: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/ros/humble/opt/rviz_ogre_vendor/lib:/opt/ros/humble/lib/x86_64-linux-gnu:/opt/ros/humble/lib
2022-12-07 20:44:41.252933: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/ros/humble/opt/rviz_ogre_vendor/lib:/opt/ros/humble/lib/x86_64-linux-gnu:/opt/ros/humble/lib


In [2]:
# Define cifar10 image shape
CIFAR10_INPUT_SHAPE = (1, 32, 32, 3)

class TrainState(train_state.TrainState):
    batch_stats: Any

def cross_entropy_loss_(logits, labels, num_classes=10):
    """
    Define loss: https://flax.readthedocs.io/en/latest/getting_started.html#define-loss
    """
    labels_onehot = jax.nn.one_hot(labels, num_classes=num_classes)
    return optax.softmax_cross_entropy(logits=logits, labels=labels_onehot).mean()
cross_entropy_loss = jax.jit(cross_entropy_loss_, static_argnums=2)

def compute_metrics_(logits, labels, num_classes):
    """
    Metric computation: https://flax.readthedocs.io/en/latest/getting_started.html#metric-computation
    """
    loss = cross_entropy_loss(logits=logits, labels=labels, num_classes=num_classes)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    metrics = {"loss": loss, "accuracy": accuracy}
    return metrics
compute_metrics = jax.jit(compute_metrics_, static_argnums=2)

def create_train_state(rng, model, learning_rate, momentum):
    """
    Create train state: https://flax.readthedocs.io/en/latest/getting_started.html#create-train-state
    """
    variables = model.init(rng, jnp.ones(CIFAR10_INPUT_SHAPE, dtype=model.dtype), train=False)
    # TODO: get params and batch_state from varaibles
    # hint: Check the documentation of model.init, and find out how to use it
    params, batch_stats = ... , ...
    tx = optax.sgd(learning_rate, momentum)
    state = TrainState.create(apply_fn=model.apply, params=params, tx=tx, batch_stats=batch_stats)
    return state, variables

def train_batch_(state, images, labels, num_classes=10):
    """
    Training step: https://flax.readthedocs.io/en/latest/getting_started.html#training-step
    """
    def loss_fn(params):
        logits, updates = state.apply_fn(
            {"params": params, "batch_stats": state.batch_stats}, images, mutable=["batch_stats"], train=True
        )
        loss = cross_entropy_loss(logits=logits, labels=labels, num_classes=num_classes)
        return loss, (logits, updates)

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, (logits, updates)), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    state = state.replace(batch_stats=updates["batch_stats"])
    metrics = compute_metrics(logits=logits, labels=labels, num_classes=num_classes)
    return state, metrics
train_batch = jax.jit(train_batch_, static_argnums=3)

def train_epoch(state, dataloader, num_classes=10):
    """
    Train function: https://flax.readthedocs.io/en/latest/getting_started.html#train-function
    """
    batch_metrics = []
    for images, labels in dataloader:
        state, metrics = train_batch(state, images, labels, num_classes=num_classes)
        batch_metrics.append(metrics)
    batch_metrics_np = jax.device_get(batch_metrics)
    epoch_metrics_np = {
        k: np.mean([metrics[k] for metrics in batch_metrics_np]) for k in batch_metrics_np[0]
    }
    return state, epoch_metrics_np

def eval_batch_(state, images, labels, num_classes=10):
    """
    Evaluation step: https://flax.readthedocs.io/en/latest/getting_started.html#evaluation-step
    """
    logits = state.apply_fn(
        {"params": state.params, "batch_stats": state.batch_stats}, images, mutable=False, train=False
    )
    return compute_metrics(logits=logits, labels=labels, num_classes=num_classes)
eval_batch = jax.jit(eval_batch_, static_argnums=3)

def eval_model(state, dataloader, num_classes=10):
    """
    Eval function: https://flax.readthedocs.io/en/latest/getting_started.html#eval-function
    """
    batch_metrics = []
    for images, labels in dataloader:
        metrics = eval_batch(state, images, labels, num_classes=num_classes)
        batch_metrics.append(metrics)
    batch_metrics_np = jax.device_get(batch_metrics)
    validation_metrics_np = {
        k: np.mean([metrics[k] for metrics in batch_metrics_np]) for k in batch_metrics_np[0]
    }
    return validation_metrics_np["loss"], validation_metrics_np["accuracy"]


Define your parameters here:

In [3]:
class Args:
    def __getitem__(self, key):
        return getattr(self, key)
    
    def __setitem__(self, key, val):
        setattr(self, key, val)
        
    def __contains__(self, key):
        return hasattr(self, key)
    
    # Define What RotNet architecture to use. 
    rotnet_arch : str = "rotnet3_feat3" #@param["rotnet3_feat3"] 
    # Define What PredNet architecture to use. 
    prednet_arch: str = "prednet3" #@param["rotnet3_feat3"]
    # Define Directory to Save RotNet Checkpoints"
    rotnet_ckpt_dir: str = "./ckpts/ro state = TrainState.create(apply_fn=model.apply, params=params, tx=tx, batch_stats=batch_stats)tnet" #@param["rotnet3_feat3"]
    # Define Directory to Save PredNet Checkpoints
    prednet_ckpt_dir: str = "./ckpts/prednet" #@param["rotnet3_feat3"]
    # Continue to Train RotNet from rotnet_ckpt_epoch
    rotnet_ckpt_epoch: int = 0 #@param {type: "integer"}
    # Continue to train PredNet from prednet_ckpt_epoch
    prednet_ckpt_epoch: int = 0 #@param {type: "integer"}
    # Train RotNet for rotnet_epochs in Total
    rotnet_epochs: int = 10 #@param {type: "integer"}
    # Train PredNet for prednet_epochs in Total
    prednet_epochs: int = 10 #@param {type: "integer"}
    # Disable Gradient Flow in RotNet if Set to True
    no_grad: bool = True #@param {type: "boolean"}
    # Batch Size Per Process"
    batch_size: int = 128 #@param {type: "integer"}
    # Number of Data Loading Workers
    workers: int = 4 #@param {type: "integer"}
    # Learning Rate of the Optimizer
    lr: float = 1e-3 #@param {type: "float"}
    # Momentum of the Optimizer
    momentum: float = 0.9 #@param {type: "float"}
    # Print Model and Params Info
    verbose: bool = False #@param {type: "boolean"}
    
args = Args()

Preprocessing before training:

In [26]:
# ---------------------- Generate JAX Random Number Key ---------------------- #
rng = jax.random.PRNGKey(0)
print("Random Key Generated")

# -------------------------- Create the RotNet Model ------------------------- #
# Define network: https://flax.readthedocs.io/en/latest/getting_started.html#define-network
# TODO: In order to construct a RotNet, you need to fill in all the TODOs in RotNet.py. Check the above link for hints.
rotnet_model = rotnet_constructor(args.rotnet_arch)
print("Network Defined")
if args.verbose:
    print(nn.tabulate(rotnet_model, rng)(jnp.ones(CIFAR10_INPUT_SHAPE), False))

# ------------------------- Load the CIFAR10 Dataset ------------------------- #
# Loading data: https://flax.readthedocs.io/en/latest/getting_started.html#loading-data
# NOTE: Choose batch_size and workers based on system specs.
# NOTE: This dataloader requires pytorch to load the datset for convenience.
# TODO: In order to create the dataset, you will need to implement rotate_image function in utils.py. Chech the above link for hints.
loaders = load_data(batch_size=args.batch_size, workers=args.workers)
train_loader, validation_loader, test_loader, rot_train_loader, rot_validation_loader, rot_test_loader = loaders
print("Data Loaded")

# --- Create the Train State Abstraction (see documentation in link below) --- #
# Create train state: https://flax.readthedocs.io/en/latest/getting_started.html#create-train-state
rotnet_state, rotnet_variables = create_train_state(rng, rotnet_model, args.lr, args.momentum)
print("Train State Created")

# ----------------- Specify the Directory to Save Checkpoints ---------------- #
rotnet_ckpt_dir = args.rotnet_ckpt_dir
if not os.path.exists(rotnet_ckpt_dir):
    os.makedirs(rotnet_ckpt_dir)
    print("RotNet Checkpoint Directory Created")
else:
    print("RotNet Checkpoint Directory Found")

Random Key Generated
Network Defined
Files already downloaded and verified
Files already downloaded and verified
Data Loaded
Train State Created
RotNet Checkpoint Directory Created


Train a RotNet:

In [27]:
# -------------------- Load Existing Checkpoint of RotNet -------------------- #
if args.rotnet_ckpt_epoch > 0:
    # TODO: Load pre-trained model if required
    # hint: Check out flax.training.checkpoints.restore_checkpoint function
    rotnet_state = ...
    print("RotNet Checkpoint Loaded")

# ----------------------------- Train the RotNet ----------------------------- #
print("Starting RotNet Training Loop")
for epoch in tqdm(range(args.rotnet_ckpt_epoch + 1, args.rotnet_epochs + 1)):
    # ------------------------------- Training Step ------------------------------ #
    # Training step: https://flax.readthedocs.io/en/latest/getting_started.html#training-step
    # TODO: Use train_epoch defined aboove to train a RotNet 
    rotnet_state, train_epoch_metrics = ...

    # Print training metrics every epoch
    print(
        f"train epoch: {epoch}, \
        loss: {train_epoch_metrics['loss']:.4f}, \
        accuracy:{train_epoch_metrics['accuracy']*100:.2f}%"
    )

    # ------------------------------ Evaluation Step ----------------------------- #
    # Evaluation step: https://flax.readthedocs.io/en/latest/getting_started.html#evaluation-step
    # TODO: Use eval_model defined aboove to get validation loss and accuracy
    validation_loss, validation_accuracy = ...
    
    # Print validation metrics every epoch
    print(f"validation loss: {validation_loss:.4f}, validation accuracy:{validation_accuracy*100:.2f}%")

    # ---------------------------- Saving Checkpoints ---------------------------- #
    # ---- https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html --- #
    checkpoints.save_checkpoint(
        ckpt_dir=rotnet_ckpt_dir, target=rotnet_state, step=epoch, overwrite=True, keep=args.rotnet_epochs
    )

    # Print test metrics every nth epoch
    if epoch % 5 == 0:
        _, test_accuracy = eval_model(rotnet_state, rot_test_loader, num_classes=4)
        print("====================")
        print(f"test_accuracy: {test_accuracy*100:.2f}%")
        print("====================")

Starting RotNet Training Loop


  0%|          | 0/10 [00:00<?, ?it/s]

train epoch: 1,         loss: 1.1226,         accuracy:57.21%


 10%|█         | 1/10 [00:26<03:59, 26.64s/it]

validation loss: 0.9133, validation accuracy:64.15%
train epoch: 2,         loss: 0.8422,         accuracy:66.74%


 20%|██        | 2/10 [00:51<03:24, 25.52s/it]

validation loss: 0.8266, validation accuracy:67.87%
train epoch: 3,         loss: 0.7403,         accuracy:71.15%


 30%|███       | 3/10 [01:16<02:56, 25.22s/it]

validation loss: 0.7856, validation accuracy:69.32%
train epoch: 4,         loss: 0.6743,         accuracy:73.94%


 40%|████      | 4/10 [01:41<02:30, 25.09s/it]

validation loss: 0.7866, validation accuracy:69.66%
train epoch: 5,         loss: 0.6203,         accuracy:76.10%
validation loss: 0.7553, validation accuracy:71.04%


 50%|█████     | 5/10 [02:10<02:13, 26.71s/it]

test_accuracy: 70.27%
train epoch: 6,         loss: 0.5753,         accuracy:78.03%


 60%|██████    | 6/10 [02:35<01:43, 25.97s/it]

validation loss: 0.7674, validation accuracy:70.43%
train epoch: 7,         loss: 0.5348,         accuracy:79.62%


 70%|███████   | 7/10 [02:59<01:16, 25.54s/it]

validation loss: 0.7377, validation accuracy:71.85%
train epoch: 8,         loss: 0.5017,         accuracy:81.10%


 80%|████████  | 8/10 [03:24<00:50, 25.35s/it]

validation loss: 0.7988, validation accuracy:69.95%
train epoch: 9,         loss: 0.4677,         accuracy:82.44%


 90%|█████████ | 9/10 [03:49<00:25, 25.11s/it]

validation loss: 0.8497, validation accuracy:69.34%
train epoch: 10,         loss: 0.4431,         accuracy:83.47%
validation loss: 0.8121, validation accuracy:70.31%


100%|██████████| 10/10 [04:19<00:00, 25.94s/it]

test_accuracy: 70.45%





Preprocessing for training PredNet:

In [28]:
# ---- https://flax.readthedocs.io/en/latest/guides/transfer_learning.html --- #
# ----------------------------- Extract Backbone ----------------------------- #
def extract_submodule(model):
    feature_extractor = model.features.clone()
    # TODO: extract variable from model
    # hint: checkout https://flax.readthedocs.io/en/latest/guides/transfer_learning.html
    variables = ...
    return feature_extractor, variables

backbone_model, backbone_model_variables = nn.apply(extract_submodule, rotnet_model)(rotnet_variables)

# ------------------------- Create the Prednet Model ------------------------- #
# TODO: Construct your PredNet
# hint: Check out how the RotNet was defined above
prednet_model = ...

# ----------------------- Extract Variables and Params ----------------------- #
prednet_variables   = prednet_model.init(rng, jnp.ones(CIFAR10_INPUT_SHAPE), train=False)
prednet_params      = prednet_variables['params']
prednet_batch_stats = prednet_variables['batch_stats']

# --------------------- Transfer the Backbone Parameters --------------------- #
prednet_params              = prednet_params.unfreeze()
prednet_params['backbone']  = backbone_model_variables['params']
prednet_params              = freeze(prednet_params)

if not args.no_grad:
    prednet_batch_stats              = prednet_batch_stats.unfreeze()
    prednet_batch_stats['backbone']  = backbone_model_variables['batch_stats']
    prednet_batch_stats              = freeze(prednet_batch_stats)

# -------------------------- Define How to Backprop -------------------------- #
if args.no_grad:
    # TODO: Freeze layers with optax.multi_transform for the mode
    ############### Your Code Starts Here ###################
    # hint: Check out https://flax.readthedocs.io/en/latest/guides/transfer_learning.html
    partition_optimizers = {'trainable': optax.sgd(args.lr, args.momentum), 'frozen': optax.set_to_zero()}
    
    
    prednet_param_partitions = ...
    

    tx = optax.multi_transform(partition_optimizers, prednet_param_partitions)
    
    ############### Your Code Ends Here ###################

    # ---------------- Visualize param_partitions to double check ---------------- #
    if args.verbose:
        flat = list(traverse_util.flatten_dict(prednet_param_partitions).items())
        freeze(traverse_util.unflatten_dict(dict(flat[:2] + flat[-2:])))
        
else:
    tx = optax.sgd(args.lr, args.momentum)
    
# ---------------------- Create Train State for PredNet ---------------------- #
# TODO: Create Train State for PredNet
# hint: We also did it for RotNet.
prednet_state = ...

# ----------------- Specify the Directory to Save Checkpoints ---------------- #
prednet_ckpt_dir = args.prednet_ckpt_dir
if not os.path.exists(prednet_ckpt_dir):
    os.makedirs(prednet_ckpt_dir)
    print("PredNet Checkpoint Directory Created")
else:
    print("PredNet Checkpoint Directory Found")

PredNet Checkpoint Directory Created


Train a PredNet:

In [29]:
# -------------------- Load Existing Checkpoint of PredNet ------------------- #
if args.prednet_ckpt_epoch > 0:
    prednet_state = checkpoints.restore_checkpoint(
        ckpt_dir=prednet_ckpt_dir, target=prednet_state, step=args.prednet_ckpt_epoch
    )
    print("PredNet Checkpoint Loaded")

# ----------------------------- Train the PredNet ---------------------------- #
print("Starting PredNet Training Loop")
for epoch in tqdm(range(args.prednet_ckpt_epoch + 1, args.prednet_epochs + 1)):
    # ------------------------------- Training Step ------------------------------ #
    # Training step: https://flax.readthedocs.io/en/latest/getting_started.html#training-step
    # TODO: Use train_epoch defined aboove to train a RotNet 
    prednet_state, train_epoch_metrics = ...

    # Print training metrics every epoch
    print(
        f"train epoch: {epoch}, \
        loss: {train_epoch_metrics['loss']:.4f}, \
        accuracy:{train_epoch_metrics['accuracy']*100:.2f}%"
    )

    # ------------------------------ Evaluation Step ----------------------------- #
    # Evaluation step: https://flax.readthedocs.io/en/latest/getting_started.html#evaluation-step
    # TODO: Use eval_model defined aboove to get validation loss and accuracy
    validation_loss, validation_accuracy = ...
    
    # Print validation metrics every epoch
    print(f"validation loss: {validation_loss:.4f}, validation accuracy:{validation_accuracy*100:.2f}%")

    # ---------------------------- Saving Checkpoints ---------------------------- #
    # ---- https://flax.readthedocs.io/en/latest/guides/use_checkpointing.html --- #
    checkpoints.save_checkpoint(
        ckpt_dir=prednet_ckpt_dir, target=prednet_state, step=epoch, overwrite=True, keep=args.prednet_epochs
    )

    # Print test metrics every nth epoch
    if epoch % 5 == 0:
        _, test_accuracy = eval_model(prednet_state, test_loader, num_classes=10)
        print("====================")
        print(f"test_accuracy: {test_accuracy*100:.2f}%")
        print("====================")

Starting PredNet Training Loop


  0%|          | 0/10 [00:00<?, ?it/s]

train epoch: 1,         loss: 1.8956,         accuracy:46.84%


 10%|█         | 1/10 [00:10<01:31, 10.21s/it]

validation loss: 1.7487, validation accuracy:44.49%
train epoch: 2,         loss: 0.7775,         accuracy:73.39%


 20%|██        | 2/10 [00:16<01:01,  7.66s/it]

validation loss: 1.3920, validation accuracy:56.07%
train epoch: 3,         loss: 0.3648,         accuracy:88.40%


 30%|███       | 3/10 [00:21<00:47,  6.86s/it]

validation loss: 1.2737, validation accuracy:59.16%
train epoch: 4,         loss: 0.1739,         accuracy:96.32%


 40%|████      | 4/10 [00:27<00:38,  6.49s/it]

validation loss: 1.3250, validation accuracy:59.86%
train epoch: 5,         loss: 0.0904,         accuracy:99.17%
validation loss: 1.2508, validation accuracy:61.31%


 50%|█████     | 5/10 [00:34<00:33,  6.63s/it]

test_accuracy: 61.96%
train epoch: 6,         loss: 0.0547,         accuracy:99.81%


 60%|██████    | 6/10 [00:40<00:25,  6.38s/it]

validation loss: 1.2069, validation accuracy:62.70%
train epoch: 7,         loss: 0.0382,         accuracy:99.95%


 70%|███████   | 7/10 [00:46<00:18,  6.22s/it]

validation loss: 1.2222, validation accuracy:63.71%
train epoch: 8,         loss: 0.0287,         accuracy:99.99%


 80%|████████  | 8/10 [00:52<00:12,  6.11s/it]

validation loss: 1.2276, validation accuracy:64.18%
train epoch: 9,         loss: 0.0238,         accuracy:100.00%


 90%|█████████ | 9/10 [00:58<00:06,  6.05s/it]

validation loss: 1.2417, validation accuracy:64.06%
train epoch: 10,         loss: 0.0202,         accuracy:100.00%
validation loss: 1.2510, validation accuracy:64.90%


100%|██████████| 10/10 [01:04<00:00,  6.49s/it]

test_accuracy: 64.38%



