In [1]:

!pip install -q --upgrade https://storage.googleapis.com/jax-releases/`nvcc -V | sed -En "s/.* release ([0-9]*)\.([0-9]*),.*/cuda\1\2/p"`/jaxlib-0.1.42-`python3 -V | sed -En "s/Python ([0-9]*)\.([0-9]*).*/cp\1\2/p"`-none-linux_x86_64.whl jax
!pip install -q git+https://github.com/google/flax.git@dev-setup
!unzip -q simpsons_faces.zip


[K     |████████████████████████████████| 67.8MB 47kB/s 
[K     |████████████████████████████████| 286kB 4.9MB/s 
[?25h  Building wheel for jax (setup.py) ... [?25l[?25hdone
  Building wheel for flax (setup.py) ... [?25l[?25hdone


In [2]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

gpu


In [3]:

import jax
import flax

import numpy as onp
import jax.numpy as jnp
import csv
import tensorflow as tf
import tensorflow_datasets as tfds




In [4]:

class CNN(flax.nn.Module):
  def apply(self, x):
    x = flax.nn.Conv(x, features=128, kernel_size=(3, 3))
    x = flax.nn.relu(x)
    x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = flax.nn.Conv(x, features=128, kernel_size=(3, 3))
    x = flax.nn.relu(x)
    x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = flax.nn.Conv(x, features=64, kernel_size=(3, 3))
    x = flax.nn.relu(x)
    x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = flax.nn.Conv(x, features=32, kernel_size=(3, 3))
    x = flax.nn.relu(x)
    x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = flax.nn.Conv(x, features=16, kernel_size=(3, 3))
    x = flax.nn.relu(x)
    x = flax.nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))
    x = flax.nn.Dense(x, features=256)
    x = flax.nn.relu(x)
    x = flax.nn.Dense(x, features=64)
    x = flax.nn.relu(x)
    x = flax.nn.Dense(x, features=4)
    x = flax.nn.log_softmax(x)
    return x

@jax.vmap
def cross_entropy_loss(logits, label):
  return -logits[label]

def compute_metrics(logits, labels):
  loss = jnp.mean(cross_entropy_loss(logits, labels))
  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
  return {'loss': loss, 'accuracy': accuracy}

@jax.jit
def train_step(optimizer, batch):
  def loss_fn(model):
    logits = model(batch[0])
    loss = jnp.mean(cross_entropy_loss(
        logits, batch[1]))
    return loss
  grad = jax.grad(loss_fn)(optimizer.target)
  optimizer = optimizer.apply_gradient(grad)
  return optimizer


@jax.jit
def eval(model, eval_ds):
  logits = model(eval_ds['image'])
  return compute_metrics(logits, eval_ds['label'])

#nachteil tf noch nötig!
def create_dataset(mode):
  if mode is tf.estimator.ModeKeys.EVAL:
      ds = tf.data.TextLineDataset('labels_test.csv')
      ds = ds.map(lambda x: map_resize(x, mode))
  elif mode is tf.estimator.ModeKeys.TRAIN:
      ds = tf.data.TextLineDataset('labels_train.csv')
      ds = ds.shuffle(2500) # Buffer_size is bigger than dataset size to get a uniform shuffle
      ds = ds.map(lambda x: map_resize(x, mode))
      ds = ds.batch(32)
  return ds

def map_resize(img, mode):
  img, label = tf.io.decode_csv(img, record_defaults=['',-1], field_delim=',')
  img = tf.io.read_file(img)
  img = tf.image.decode_jpeg(img, channels=3)
  img = (2 * img / 255) - 1 
  label = int(label)
  return img, label

def train():

  train_ds = create_dataset(tf.estimator.ModeKeys.TRAIN)
  test_ds = create_dataset(tf.estimator.ModeKeys.EVAL)
  
  test_ds = test_ds.prefetch(tf.data.experimental.AUTOTUNE)
  #test_ds is one giant batch
  test_ds = test_ds.batch(1000)
  #test ds is a feature dictonary!
  test_ds = tf.compat.v1.data.experimental.get_single_element(test_ds)
  test_ds = tfds.as_numpy(test_ds)
  test_ds = {'image': test_ds[0].astype(jnp.float32), 'label': test_ds[1].astype(jnp.int32)}

  _, initial_params = CNN.init_by_shape(jax.random.PRNGKey(0), [((1, 160, 120, 3), jnp.float32)])

  model = flax.nn.Model(CNN, initial_params)

  optimizer = flax.optim.Momentum(learning_rate=0.01, beta=0.9, weight_decay=0.0005).create(model)

  for epoch in range(50):
    for batch in tfds.as_numpy(train_ds):
      optimizer = train_step(optimizer, batch)

    metrics = eval(optimizer.target, test_ds)

    print('eval epoch: %d, loss: %.4f, accuracy: %.2f' % (epoch+1,metrics['loss'], metrics['accuracy'] * 100))

train()

eval epoch: 1, loss: 1.2812, accuracy: 44.86
eval epoch: 2, loss: 0.9695, accuracy: 56.91
eval epoch: 3, loss: 0.7705, accuracy: 69.77
eval epoch: 4, loss: 0.8001, accuracy: 66.24
eval epoch: 5, loss: 0.7520, accuracy: 69.94
eval epoch: 6, loss: 0.6220, accuracy: 77.33
eval epoch: 7, loss: 0.5647, accuracy: 78.94
eval epoch: 8, loss: 0.5618, accuracy: 78.94
eval epoch: 9, loss: 0.5280, accuracy: 79.42
eval epoch: 10, loss: 0.4943, accuracy: 81.35
eval epoch: 11, loss: 0.5699, accuracy: 80.39
eval epoch: 12, loss: 0.6678, accuracy: 78.30
eval epoch: 13, loss: 0.5361, accuracy: 79.10
eval epoch: 14, loss: 0.5533, accuracy: 82.32
eval epoch: 15, loss: 0.4565, accuracy: 85.53
eval epoch: 16, loss: 0.7185, accuracy: 82.80
eval epoch: 17, loss: 0.4648, accuracy: 84.89
eval epoch: 18, loss: 0.6321, accuracy: 84.41
eval epoch: 19, loss: 0.5234, accuracy: 85.37
eval epoch: 20, loss: 1.0434, accuracy: 74.44
eval epoch: 21, loss: 0.5883, accuracy: 86.01
eval epoch: 22, loss: 0.6007, accuracy: 85.