In [1]:
from typing import Callable, Tuple, Any, Dict, List
from absl import logging
from functools import partial
from tqdm import tqdm
from dataclasses import dataclass

import os
import time
import numpy as np
import jax
import jax.numpy as jnp
from jax.random import PRNGKey as jkey
from chex import Scalar, Array, PRNGKey, Shape
import flax
from flax import linen as nn
from flax.training.train_state import TrainState as RawTrainState
from flax.training.checkpoints import restore_checkpoint
import optax
import matplotlib.pyplot as plt
import tensorflow as tf

from training_cnn import *
from architectures import *

logging.set_verbosity(logging.WARN)


SEED = 42



In [2]:
final_state, metrices, elapsed_time = train_and_eval(
    seed=42,
    epochs=10,
    batch_size=32,
    create_state_fun=create_GAPCNN,
    lr=0.001,
    momentum=0.9,
    ds_chunk_size=0.1,
    log_every=1,
    checkpoint_dir=os.path.join("checkpoints/gap_cnn"),
)
print(f"Total training time: {elapsed_time:.3f}")

epoch:  1, train_loss: 2.2817, train_accuracy: 22.76, test_loss: 2.4850, test_accuracy: 10.10
epoch:  2, train_loss: 1.8716, train_accuracy: 31.09, test_loss: 2.3741, test_accuracy: 13.70
epoch:  3, train_loss: 1.7111, train_accuracy: 36.24, test_loss: 1.8696, test_accuracy: 32.60
epoch:  4, train_loss: 1.5939, train_accuracy: 41.09, test_loss: 1.7909, test_accuracy: 34.50
epoch:  5, train_loss: 1.5263, train_accuracy: 43.67, test_loss: 1.6881, test_accuracy: 36.80
epoch:  6, train_loss: 1.4837, train_accuracy: 45.09, test_loss: 1.4637, test_accuracy: 49.30
epoch:  7, train_loss: 1.4396, train_accuracy: 46.73, test_loss: 1.4469, test_accuracy: 46.20
epoch:  8, train_loss: 1.3861, train_accuracy: 48.58, test_loss: 1.6482, test_accuracy: 41.00
epoch:  9, train_loss: 1.3572, train_accuracy: 51.08, test_loss: 1.5712, test_accuracy: 45.00
epoch: 10, train_loss: 1.3217, train_accuracy: 52.38, test_loss: 1.4965, test_accuracy: 45.50
Total training time: 86.821


In [14]:
doggo = plt.imread("labrador_32x24.jpg")
doggo_batch = jnp.expand_dims(doggo, axis=0)

In [20]:
doggo_prediction = logits = GAPCNN().apply(
    {'params': final_state.params, 'batch_stats': final_state.batch_stats},
    doggo_batch,
    training=False,
    rngs={'dropout': jax.random.PRNGKey(42)}
)
doggo_prediction

DeviceArray([[ 0.00648331, -1.2280122 ,  2.0572402 ,  1.1882851 ,
               1.7229375 ,  1.4963266 ,  1.3262426 ,  0.52195233,
              -1.4348559 , -1.0805315 ]], dtype=float32)