<a href="https://colab.research.google.com/github/samiraabnar/Gift/blob/main/notebooks/noisy_two_moon.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title General Imports

import time
import warnings

import math
import numpy as np
import matplotlib.pyplot as plt

from sklearn import cluster, datasets, mixture
from sklearn.neighbors import kneighbors_graph
from sklearn.preprocessing import StandardScaler
from itertools import cycle, islice
import pandas as pd 

colors = np.array(list(islice(cycle(['#377eb8', '#ff7f00', '#4daf4a',
                                             '#f781bf', '#a65628', '#984ea3',
                                             '#999999', '#e41a1c', '#dede00']),
                                      int(2 + 1))))

In [None]:
#@title Make the two moons dataset

moons = datasets.make_moons(n_samples=10000, random_state=8, noise=0.05)
M, m_labels = moons
M = StandardScaler().fit_transform(M)

angle = (90. / 180.) * math.pi
trans_mat = np.array([[math.cos(angle), math.sin(angle)],[-math.sin(angle), math.cos(angle)]])

translated_M = M.dot(trans_mat)


In [None]:
#@title Ploting utils

import altair as alt
from vega_datasets import data
alt.data_transformers.disable_max_rows()


def plot(x, y, label, title, width=200):
  moon_df = pd.DataFrame.from_dict({'x': x, 
                                    'y': y,
                                    'label': label})


  moon = alt.Chart(moon_df, title=title, height=width, width=width).mark_point(size=2).encode(
      x=alt.X('x', axis=alt.Axis(labels=False, ticks=False), title=''),
      y=alt.Y('y', axis=alt.Axis(labels=False, ticks=False), title=''),
      color=alt.Color('label:N', legend=None),
      shape=alt.Shape('label:N', legend=None),
  )

  return moon

In [None]:
#@title Plot source and target
moon = plot(M[:,0], M[:,1], m_labels, 'Source', width=200)
trainslated_moon = plot(translated_M[:,0], translated_M[:,1], m_labels, 'Target', width=200)

moon | trainslated_moon 

In [None]:
#@title Install flax
!pip install -q ml-collections git+https://github.com/google/flax

[?25l[K     |███▊                            | 10kB 22.0MB/s eta 0:00:01[K     |███████▍                        | 20kB 20.7MB/s eta 0:00:01[K     |███████████                     | 30kB 11.0MB/s eta 0:00:01[K     |██████████████▉                 | 40kB 8.9MB/s eta 0:00:01[K     |██████████████████▌             | 51kB 5.5MB/s eta 0:00:01[K     |██████████████████████▏         | 61kB 6.5MB/s eta 0:00:01[K     |█████████████████████████▉      | 71kB 6.4MB/s eta 0:00:01[K     |█████████████████████████████▋  | 81kB 6.3MB/s eta 0:00:01[K     |████████████████████████████████| 92kB 4.5MB/s 
[K     |████████████████████████████████| 102kB 5.4MB/s 
[K     |████████████████████████████████| 61kB 5.5MB/s 
[?25h  Building wheel for flax (setup.py) ... [?25l[?25hdone


In [None]:
#@title Modelling and Training Util
from absl import logging
from flax import optim
from flax.metrics import tensorboard
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
import tensorflow_datasets as tfds
from jax import numpy as jnp, random, lax
from flax import linen as nn
from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union
from flax.linen import Module, compact
from pprint import pprint
from flax.linen import initializers
import functools

logging.set_verbosity(logging.INFO)


class MLP(Module):

  @compact
  def __call__(self, x, input_key='inputs', return_hidden=False):
    hidden_reps = [x]
    dense0 = nn.Dense(64)
    dense1 = nn.Dense(2)
    if input_key == 'inputs':
      x = dense0(x)
      x = nn.relu(x)
      hidden_reps.append(x)

    if return_hidden:
      return dense1(x), hidden_reps
    else:
      return dense1(x)


def get_initial_params(key):
  init_val = jnp.ones((1, 2), jnp.float32)
  initial_params = MLP().init(key, init_val)['params']
  return initial_params


def create_optimizer(params, learning_rate, beta):
  optimizer_def = optim.Adam(learning_rate=learning_rate)
  optimizer = optimizer_def.create(params)
  return optimizer


def onehot(labels, num_classes=2):
  x = (labels[..., None] == jnp.arange(num_classes)[None])
  return x.astype(jnp.float32)


def cross_entropy_loss(logits, labels, weights):

  if len(labels.shape) == 1:
    labels = onehot(labels)
  if weights is None:
    loss = -jnp.sum(jnp.sum(labels * jax.nn.softmax(logits, axis=-1), axis=-1))
    normalizer = logits.shape[0]
  else:
    loss = -jnp.sum(jnp.sum(labels * jax.nn.softmax(logits, axis=-1), axis=-1) * weights)
    normalizer = weights.sum()
  
  return loss / normalizer

  
def compute_metrics(logits, labels, weights):
  loss = cross_entropy_loss(logits, labels, weights)
  if weights is None:
    corrects = jnp.sum(jnp.argmax(logits, -1) == labels)
    normalizer = logits.shape[0]
  else:
    corrects = jnp.sum((jnp.argmax(logits, -1) == labels)*weights)
    normalizer = weights.sum()

  accuracy = corrects / normalizer
  metrics = {
      'loss': loss,
      'accuracy': accuracy,
  }
  return metrics


@jax.jit
def train_step(optimizer, batch, weight_decay=0.01):
  """Train for a single step."""
  def loss_fn(params):
    logits = MLP().apply({'params': params}, 
                               batch['image'], 
                               input_key='inputs')
    loss = cross_entropy_loss(logits, batch['label'], batch.get('weight'))
   
    weight_penalty_params = jax.tree_leaves(params)
    weight_l2 = sum([jnp.sum(x ** 2)
                     for x in weight_penalty_params
                     if x.ndim > 1])
    weight_penalty = weight_decay * 0.5 * weight_l2

    loss = loss + weight_penalty

    return loss, logits

  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (_, logits), grad = grad_fn(optimizer.target)
  optimizer = optimizer.apply_gradient(grad)
  metrics = compute_metrics(logits, batch['label'], batch.get('weight'))
  return optimizer, metrics

@jax.jit
def self_train_step(optimizer, batch, weight_decay=0.01):
  """Train for a single step."""
  def loss_fn(params):
    logits = MLP().apply({'params': params}, batch['image'], input_key='inputs')
    loss = cross_entropy_loss(logits, batch['predicted_label'], batch.get('weight'))

    weight_penalty_params = jax.tree_leaves(params)
    weight_l2 = sum([jnp.sum(x ** 2)
                     for x in weight_penalty_params
                     if x.ndim > 1])
    weight_penalty = weight_decay * 0.5 * weight_l2

    loss = loss + weight_penalty

    return loss, logits

  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (_, logits), grad = grad_fn(optimizer.target)
  optimizer = optimizer.apply_gradient(grad)
  metrics = compute_metrics(logits, batch['label'], batch.get('weight'))

  return optimizer, metrics

@jax.jit
def gift_train_step(rng, optimizer, teacher_params, batch, test_batch, conf=0.1, weight_decay=0.01):
  """Train for a single step."""
  def loss_fn(params):
    _, reps1 = MLP().apply({'params': params}, batch['image'],
                                 input_key='inputs',
                                 return_hidden=True)
    _, reps2 = MLP().apply({'params': teacher_params}, test_batch['image'],
                                input_key='inputs',
                                return_hidden=True)
    reps1 = reps1[-1]
    reps2 = jax.random.permutation(rng, reps2[-1])

    new_batch = {}
    new_batch['image'] = (1. - lmbda) * reps1 + lmbda * reps2
    teacher_logits= MLP().apply({'params': teacher_params}, new_batch['image'],
                              input_key='hidden')
    teacher_logits = jax.lax.stop_gradient(teacher_logits)
    new_batch['predicted_label'] = jnp.argmax(teacher_logits, axis=-1)
    new_batch['label'] = jnp.argmax(teacher_logits, axis=-1)
    confidence = jnp.max(teacher_logits, axis=-1) - jnp.min(teacher_logits, axis=-1)
    threshold = jax.numpy.quantile(confidence, conf)
    new_batch['weight'] = jnp.float32(confidence > threshold)

    logits = MLP().apply({'params': params}, new_batch['image'], input_key='hidden')
    loss = cross_entropy_loss(logits, 
                              new_batch['predicted_label'], 
                              new_batch.get('weight'))
    
    weight_penalty_params = jax.tree_leaves(params)
    weight_l2 = sum([jnp.sum(x ** 2)
                     for x in weight_penalty_params
                     if x.ndim > 1])
    weight_penalty = weight_decay * 0.5 * weight_l2

    loss = loss + weight_penalty
    return loss, (logits, new_batch)


  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (_, (logits, new_batch)), grad = grad_fn(optimizer.target)
  optimizer = optimizer.apply_gradient(grad)
  metrics = compute_metrics(logits, new_batch['label'], new_batch.get('weight'))
  return optimizer, metrics

@jax.jit
def eval_step(params, batch):
  logits = MLP().apply({'params': params}, batch['image'], input_key='inputs')
  return compute_metrics(logits, batch['label'], batch.get('weight'))


def train_epoch(optimizer, train_ds, batch_size, epoch, rng, weight_decay=0.01):
  """Train for a single epoch."""
  train_ds_size = len(train_ds['image'])
  steps_per_epoch = train_ds_size // batch_size

  perms = jax.random.permutation(rng, len(train_ds['image']))
  perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
  perms = perms.reshape((steps_per_epoch, batch_size))
  batch_metrics = []
  for perm in perms:
    batch = {k: v[perm, ...] for k, v in train_ds.items()}
    optimizer, metrics = train_step(optimizer, batch, weight_decay)
    batch_metrics.append(metrics)

  # compute mean of metrics across each batch in epoch.
  batch_metrics_np = jax.device_get(batch_metrics)
  epoch_metrics_np = {
      k: np.mean([metrics[k] for metrics in batch_metrics_np])
      for k in batch_metrics_np[0]}

  logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
               epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100)

  return optimizer, epoch_metrics_np

def self_train_epoch(optimizer, train_ds, batch_size, epoch, rng, 
                     conf,
                     weight_decay=0.01,
                     steps_per_epoch=None):
  """Train for a single epoch."""
  train_ds_size = len(train_ds['image'])
  steps_per_epoch = steps_per_epoch or train_ds_size // batch_size

  perms = jax.random.permutation(rng, len(train_ds['image']))
  perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
  perms = perms.reshape((steps_per_epoch, batch_size))
  batch_metrics = []
  teacher_params = jax.tree_map(lambda x: x, optimizer.target)
  for perm in perms:
    batch = {k: v[perm, ...] for k, v in train_ds.items()}
    teacher_logits = MLP().apply({'params': teacher_params}, batch['image'],
                                 input_key='inputs')

    batch['predicted_label'] = jnp.argmax(teacher_logits, 
                                axis=-1)
    confidence = jnp.max(teacher_logits, axis=-1) - jnp.min(teacher_logits, axis=-1)
    threshold = jax.numpy.quantile(confidence, conf)
    batch['weight'] = jnp.float32(confidence > threshold)
    optimizer, metrics = self_train_step(optimizer, batch, weight_decay)
    batch_metrics.append(metrics)

  # compute mean of metrics across each batch in epoch.
  batch_metrics_np = jax.device_get(batch_metrics)
  epoch_metrics_np = {
      k: np.mean([metrics[k] for metrics in batch_metrics_np])
      for k in batch_metrics_np[0]}

  logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
               epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100)

  return optimizer, epoch_metrics_np


def gift_train_epoch(optimizer, train_ds, test_ds, lmbda, conf, batch_size, 
                     epoch, rng, weight_decay=0.01,
                     steps_per_epoch=None):
  """Train for a single epoch."""
  train_ds_size = len(train_ds['image'])
  steps_per_epoch = steps_per_epoch or train_ds_size // batch_size

  perms = jax.random.permutation(rng, len(train_ds['image']))
  perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
  perms = perms.reshape((steps_per_epoch, batch_size))

  test_perms = jax.random.permutation(rng, len(test_ds['image']))
  test_perms = test_perms[:steps_per_epoch * batch_size]  # skip incomplete batch
  test_perms = test_perms.reshape((steps_per_epoch, batch_size))

  batch_metrics = []
  teacher_params = jax.tree_map(lambda x: x, optimizer.target)
  for perm in perms:
    batch = {k: v[perm, ...] for k, v in train_ds.items()}
    test_batch = {k: v[perm, ...] for k, v in train_ds.items()}
    
    optimizer, metrics = gift_train_step(rng, optimizer, teacher_params,
                                         batch, test_batch, conf, weight_decay)
    batch_metrics.append(metrics)

  # compute mean of metrics across each batch in epoch.
  batch_metrics_np = jax.device_get(batch_metrics)
  epoch_metrics_np = {
      k: np.mean([metrics[k] for metrics in batch_metrics_np])
      for k in batch_metrics_np[0]}

  logging.info('train epoch: %d, loss: %.4f, accuracy: %.2f', epoch,
               epoch_metrics_np['loss'], epoch_metrics_np['accuracy'] * 100)

  return optimizer, epoch_metrics_np


def eval_model(params, test_ds):
  metrics = eval_step(params, test_ds)
  metrics = jax.device_get(metrics)
  summary = jax.tree_map(lambda x: x.item(), metrics)
  return summary['loss'], summary['accuracy']


def get_datasets(X, labels, angle=0):

  train_ds = {}
  test_ds = {}
  trans_mat = np.array([[math.cos(angle), math.sin(angle)],[-math.sin(angle), math.cos(angle)]])
  X = X.dot(trans_mat)
  train_ds['image'] = jnp.float32(X)
  train_ds['label'] = jnp.int64(labels)
  test_ds = train_ds
  return train_ds, test_ds


def train_and_evaluate(config: ml_collections.ConfigDict, train_ds, test_ds, 
                       workdir: str,
                       optimizer=None,
                       steps_per_epoch=None):
  """Execute model training and evaluation loop.
  Args:
    config: Hyperparameter configuration for training and evaluation.
    workdir: Directory where the tensorboard summaries are written to.
  Returns:
    The trained optimizer.
  """
  rng = jax.random.PRNGKey(0)

  summary_writer = tensorboard.SummaryWriter(workdir)
  summary_writer.hparams(dict(config))

  rng, init_rng = jax.random.split(rng)
  if optimizer is None:
    params = get_initial_params(init_rng)
    optimizer =  create_optimizer(
        params, config.learning_rate, config.momentum)

    train_fn = functools.partial(train_epoch,
                                 weight_decay=config.weight_decay)
  else:
    train_fn = functools.partial(self_train_epoch, 
                                 conf=config.conf,
                                 weight_decay=config.weight_decay,
                                 steps_per_epoch=steps_per_epoch)

  for epoch in range(1, config.num_epochs + 1):
    rng, input_rng = jax.random.split(rng)
    optimizer, train_metrics = train_fn(
        optimizer, train_ds, config.batch_size, epoch, input_rng)
    loss, accuracy = eval_model(optimizer.target, test_ds)

    logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f',
                 epoch, loss, accuracy * 100)

    summary_writer.scalar('train_loss', train_metrics['loss'], epoch)
    summary_writer.scalar('train_accuracy', train_metrics['accuracy'], epoch)
    summary_writer.scalar('eval_loss', loss, epoch)
    summary_writer.scalar('eval_accuracy', accuracy, epoch)

  summary_writer.flush()
  return optimizer



In [None]:
#@title GIFT

def gift_train_and_evaluate(config: ml_collections.ConfigDict, train_ds, test_ds, 
                            lmbda,
                            workdir: str,
                            optimizer=None,
                            steps_per_epoch=None):
  """Execute model training and evaluation loop.
  Args:
    config: Hyperparameter configuration for training and evaluation.
    workdir: Directory where the tensorboard summaries are written to.
  Returns:
    The trained optimizer.
  """
  rng = jax.random.PRNGKey(0)

  summary_writer = tensorboard.SummaryWriter(workdir)
  summary_writer.hparams(dict(config))

  rng, init_rng = jax.random.split(rng)
  if optimizer is None:
    params = get_initial_params(init_rng)
    optimizer =  create_optimizer(
        params, config.learning_rate, config.momentum)

    train_fn = train_epoch
  else:
    train_fn = functools.partial(gift_train_epoch, 
                                 weight_decay=config.weight_decay,
                                 steps_per_epoch=steps_per_epoch)

  for epoch in range(1, config.num_epochs + 1):
    rng, input_rng = jax.random.split(rng)
    optimizer, train_metrics = train_fn(
        optimizer, train_ds,test_ds, lmbda, config.conf, config.batch_size, epoch, input_rng)
    loss, accuracy = eval_model(optimizer.target, test_ds)

    logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f',
                 epoch, loss, accuracy * 100)

    summary_writer.scalar('train_loss', train_metrics['loss'], epoch)
    summary_writer.scalar('train_accuracy', train_metrics['accuracy'], epoch)
    summary_writer.scalar('eval_loss', loss, epoch)
    summary_writer.scalar('eval_accuracy', accuracy, epoch)

  summary_writer.flush()
  return optimizer

In [None]:
#@title Get train and test datasets

train_ds, _ = get_datasets(M, m_labels)
_, test_ds = get_datasets(M, m_labels, angle=(math.pi/2))


  lax._check_user_dtype_supported(dtype, "array")


In [None]:
if 'google.colab' in str(get_ipython()):
  %load_ext tensorboard
  %tensorboard --logdir=.

<IPython.core.display.Javascript object>

In [None]:
#@title Train base model:
config = ml_collections.ConfigDict()
config.learning_rate = 0.0002
config.momentum = 0.9
config.batch_size = 128
config.num_epochs = 100
config.weight_decay = 0.01
config.conf = 0.
optimizer = train_and_evaluate(config, train_ds, test_ds, workdir=f'./models/base')


INFO:absl:train epoch: 1, loss: -0.4960, accuracy: 47.77
INFO:absl:eval epoch: 1, loss: -0.5622, accuracy: 76.57
INFO:absl:train epoch: 2, loss: -0.6449, accuracy: 76.87
INFO:absl:eval epoch: 2, loss: -0.5476, accuracy: 69.91
INFO:absl:train epoch: 3, loss: -0.7333, accuracy: 84.76
INFO:absl:eval epoch: 3, loss: -0.5279, accuracy: 64.51
INFO:absl:train epoch: 4, loss: -0.7789, accuracy: 86.52
INFO:absl:eval epoch: 4, loss: -0.5142, accuracy: 60.37
INFO:absl:train epoch: 5, loss: -0.8044, accuracy: 87.22
INFO:absl:eval epoch: 5, loss: -0.5049, accuracy: 58.34
INFO:absl:train epoch: 6, loss: -0.8199, accuracy: 87.65
INFO:absl:eval epoch: 6, loss: -0.4985, accuracy: 56.95
INFO:absl:train epoch: 7, loss: -0.8306, accuracy: 88.03
INFO:absl:eval epoch: 7, loss: -0.4943, accuracy: 56.01
INFO:absl:train epoch: 8, loss: -0.8385, accuracy: 88.33
INFO:absl:eval epoch: 8, loss: -0.4908, accuracy: 55.13
INFO:absl:train epoch: 9, loss: -0.8449, accuracy: 88.60
INFO:absl:eval epoch: 9, loss: -0.4882,

In [None]:
logits = MLP().apply({'params': optimizer.target}, 
                               test_ds['image'], 
                               input_key='inputs') 
pred_labels = jax.device_get(jnp.argmax(logits, axis=-1))
x, l = jax.device_get(test_ds)['image'], jax.device_get(test_ds)['label']

predicted_base = plot(x[:,0], x[:,1], pred_labels, 'No Adaptation',width=200)
predicted_ground_truth = plot(x[:,0], x[:,1], l, 'Ground Truth', width=200)

predicted_base

In [None]:
#@title Gradual Self-training (with ground truth intermediate steps)
new_optimizer = optimizer.replace()
total_steps = 1000
iters = 20
config.conf = 0.4
config.learning_rate = 0.0001
config.num_epochs = 3
for i in range(1, iters+1):
  angel = (math.pi/(iters*2))*i
  print('angel:', angel)
  rotated_train_ds, rotated_test_ds = get_datasets(M, m_labels, angle=angel)
  x, y = jax.device_get(rotated_train_ds)['image'], jax.device_get(rotated_train_ds)['label']
  new_optimizer = train_and_evaluate(config, rotated_train_ds, test_ds, workdir=f'./models/gs', 
                                    optimizer=new_optimizer,
                                    steps_per_epoch=total_steps//iters)

angel: 0.07853981633974483


  lax._check_user_dtype_supported(dtype, "array")
INFO:absl:train epoch: 1, loss: -0.9906, accuracy: 100.00
INFO:absl:eval epoch: 1, loss: -0.4780, accuracy: 54.06
INFO:absl:train epoch: 2, loss: -0.9919, accuracy: 100.00
INFO:absl:eval epoch: 2, loss: -0.4798, accuracy: 53.94
INFO:absl:train epoch: 3, loss: -0.9920, accuracy: 100.00
INFO:absl:eval epoch: 3, loss: -0.4819, accuracy: 53.93
  lax._check_user_dtype_supported(dtype, "array")


angel: 0.15707963267948966


INFO:absl:train epoch: 1, loss: -0.9916, accuracy: 100.00
INFO:absl:eval epoch: 1, loss: -0.4855, accuracy: 54.02
INFO:absl:train epoch: 2, loss: -0.9922, accuracy: 100.00
INFO:absl:eval epoch: 2, loss: -0.4898, accuracy: 54.34
INFO:absl:train epoch: 3, loss: -0.9917, accuracy: 99.97
INFO:absl:eval epoch: 3, loss: -0.4945, accuracy: 54.64


angel: 0.23561944901923448


INFO:absl:train epoch: 1, loss: -0.9914, accuracy: 100.00
INFO:absl:eval epoch: 1, loss: -0.5007, accuracy: 55.22
INFO:absl:train epoch: 2, loss: -0.9918, accuracy: 99.97
INFO:absl:eval epoch: 2, loss: -0.5074, accuracy: 55.83
INFO:absl:train epoch: 3, loss: -0.9916, accuracy: 99.97
INFO:absl:eval epoch: 3, loss: -0.5138, accuracy: 56.36


angel: 0.3141592653589793


INFO:absl:train epoch: 1, loss: -0.9911, accuracy: 99.97
INFO:absl:eval epoch: 1, loss: -0.5218, accuracy: 57.05
INFO:absl:train epoch: 2, loss: -0.9915, accuracy: 99.95
INFO:absl:eval epoch: 2, loss: -0.5302, accuracy: 57.87
INFO:absl:train epoch: 3, loss: -0.9915, accuracy: 99.97
INFO:absl:eval epoch: 3, loss: -0.5382, accuracy: 58.58


angel: 0.39269908169872414


INFO:absl:train epoch: 1, loss: -0.9907, accuracy: 99.95
INFO:absl:eval epoch: 1, loss: -0.5477, accuracy: 59.44
INFO:absl:train epoch: 2, loss: -0.9910, accuracy: 99.90
INFO:absl:eval epoch: 2, loss: -0.5574, accuracy: 60.44
INFO:absl:train epoch: 3, loss: -0.9912, accuracy: 99.95
INFO:absl:eval epoch: 3, loss: -0.5666, accuracy: 61.25


angel: 0.47123889803846897


INFO:absl:train epoch: 1, loss: -0.9907, accuracy: 99.95
INFO:absl:eval epoch: 1, loss: -0.5768, accuracy: 62.33
INFO:absl:train epoch: 2, loss: -0.9910, accuracy: 99.90
INFO:absl:eval epoch: 2, loss: -0.5874, accuracy: 63.54
INFO:absl:train epoch: 3, loss: -0.9911, accuracy: 99.92
INFO:absl:eval epoch: 3, loss: -0.5972, accuracy: 64.44


angel: 0.5497787143782138


INFO:absl:train epoch: 1, loss: -0.9908, accuracy: 99.95
INFO:absl:eval epoch: 1, loss: -0.6083, accuracy: 65.64
INFO:absl:train epoch: 2, loss: -0.9909, accuracy: 99.87
INFO:absl:eval epoch: 2, loss: -0.6195, accuracy: 66.88
INFO:absl:train epoch: 3, loss: -0.9910, accuracy: 99.90
INFO:absl:eval epoch: 3, loss: -0.6300, accuracy: 68.07


angel: 0.6283185307179586


INFO:absl:train epoch: 1, loss: -0.9910, accuracy: 99.95
INFO:absl:eval epoch: 1, loss: -0.6410, accuracy: 69.07
INFO:absl:train epoch: 2, loss: -0.9908, accuracy: 99.84
INFO:absl:eval epoch: 2, loss: -0.6519, accuracy: 69.96
INFO:absl:train epoch: 3, loss: -0.9912, accuracy: 99.90
INFO:absl:eval epoch: 3, loss: -0.6620, accuracy: 70.76


angel: 0.7068583470577035


INFO:absl:train epoch: 1, loss: -0.9908, accuracy: 99.90
INFO:absl:eval epoch: 1, loss: -0.6725, accuracy: 71.49
INFO:absl:train epoch: 2, loss: -0.9904, accuracy: 99.77
INFO:absl:eval epoch: 2, loss: -0.6829, accuracy: 72.27
INFO:absl:train epoch: 3, loss: -0.9912, accuracy: 99.87
INFO:absl:eval epoch: 3, loss: -0.6926, accuracy: 72.87


angel: 0.7853981633974483


INFO:absl:train epoch: 1, loss: -0.9911, accuracy: 99.90
INFO:absl:eval epoch: 1, loss: -0.7030, accuracy: 73.87
INFO:absl:train epoch: 2, loss: -0.9906, accuracy: 99.77
INFO:absl:eval epoch: 2, loss: -0.7130, accuracy: 74.42
INFO:absl:train epoch: 3, loss: -0.9914, accuracy: 99.87
INFO:absl:eval epoch: 3, loss: -0.7218, accuracy: 75.02


angel: 0.8639379797371931


INFO:absl:train epoch: 1, loss: -0.9914, accuracy: 99.90
INFO:absl:eval epoch: 1, loss: -0.7312, accuracy: 75.77
INFO:absl:train epoch: 2, loss: -0.9909, accuracy: 99.77
INFO:absl:eval epoch: 2, loss: -0.7401, accuracy: 76.49
INFO:absl:train epoch: 3, loss: -0.9916, accuracy: 99.87
INFO:absl:eval epoch: 3, loss: -0.7481, accuracy: 77.12


angel: 0.9424777960769379


INFO:absl:train epoch: 1, loss: -0.9914, accuracy: 99.87
INFO:absl:eval epoch: 1, loss: -0.7567, accuracy: 77.75
INFO:absl:train epoch: 2, loss: -0.9908, accuracy: 99.74
INFO:absl:eval epoch: 2, loss: -0.7642, accuracy: 78.39
INFO:absl:train epoch: 3, loss: -0.9913, accuracy: 99.82
INFO:absl:eval epoch: 3, loss: -0.7716, accuracy: 79.15


angel: 1.0210176124166828


INFO:absl:train epoch: 1, loss: -0.9916, accuracy: 99.87
INFO:absl:eval epoch: 1, loss: -0.7790, accuracy: 79.89
INFO:absl:train epoch: 2, loss: -0.9915, accuracy: 99.79
INFO:absl:eval epoch: 2, loss: -0.7856, accuracy: 80.42
INFO:absl:train epoch: 3, loss: -0.9915, accuracy: 99.82
INFO:absl:eval epoch: 3, loss: -0.7923, accuracy: 80.98


angel: 1.0995574287564276


INFO:absl:train epoch: 1, loss: -0.9920, accuracy: 99.90
INFO:absl:eval epoch: 1, loss: -0.7990, accuracy: 81.85
INFO:absl:train epoch: 2, loss: -0.9913, accuracy: 99.77
INFO:absl:eval epoch: 2, loss: -0.8049, accuracy: 82.38
INFO:absl:train epoch: 3, loss: -0.9916, accuracy: 99.82
INFO:absl:eval epoch: 3, loss: -0.8107, accuracy: 82.77


angel: 1.1780972450961724


INFO:absl:train epoch: 1, loss: -0.9918, accuracy: 99.87
INFO:absl:eval epoch: 1, loss: -0.8167, accuracy: 83.33
INFO:absl:train epoch: 2, loss: -0.9916, accuracy: 99.79
INFO:absl:eval epoch: 2, loss: -0.8216, accuracy: 83.80
INFO:absl:train epoch: 3, loss: -0.9919, accuracy: 99.84
INFO:absl:eval epoch: 3, loss: -0.8264, accuracy: 84.32


angel: 1.2566370614359172


INFO:absl:train epoch: 1, loss: -0.9919, accuracy: 99.87
INFO:absl:eval epoch: 1, loss: -0.8316, accuracy: 84.87
INFO:absl:train epoch: 2, loss: -0.9922, accuracy: 99.84
INFO:absl:eval epoch: 2, loss: -0.8357, accuracy: 85.31
INFO:absl:train epoch: 3, loss: -0.9924, accuracy: 99.90
INFO:absl:eval epoch: 3, loss: -0.8397, accuracy: 85.98


angel: 1.335176877775662


INFO:absl:train epoch: 1, loss: -0.9922, accuracy: 99.90
INFO:absl:eval epoch: 1, loss: -0.8440, accuracy: 86.66
INFO:absl:train epoch: 2, loss: -0.9924, accuracy: 99.87
INFO:absl:eval epoch: 2, loss: -0.8471, accuracy: 87.19
INFO:absl:train epoch: 3, loss: -0.9924, accuracy: 99.90
INFO:absl:eval epoch: 3, loss: -0.8501, accuracy: 87.67


angel: 1.413716694115407


INFO:absl:train epoch: 1, loss: -0.9925, accuracy: 99.92
INFO:absl:eval epoch: 1, loss: -0.8534, accuracy: 88.23
INFO:absl:train epoch: 2, loss: -0.9922, accuracy: 99.84
INFO:absl:eval epoch: 2, loss: -0.8554, accuracy: 88.46
INFO:absl:train epoch: 3, loss: -0.9925, accuracy: 99.90
INFO:absl:eval epoch: 3, loss: -0.8574, accuracy: 88.78


angel: 1.4922565104551517


INFO:absl:train epoch: 1, loss: -0.9925, accuracy: 99.92
INFO:absl:eval epoch: 1, loss: -0.8595, accuracy: 88.87
INFO:absl:train epoch: 2, loss: -0.9922, accuracy: 99.84
INFO:absl:eval epoch: 2, loss: -0.8605, accuracy: 88.83
INFO:absl:train epoch: 3, loss: -0.9922, accuracy: 99.87
INFO:absl:eval epoch: 3, loss: -0.8615, accuracy: 88.70


angel: 1.5707963267948966


INFO:absl:train epoch: 1, loss: -0.9923, accuracy: 99.90
INFO:absl:eval epoch: 1, loss: -0.8623, accuracy: 88.57
INFO:absl:train epoch: 2, loss: -0.9922, accuracy: 99.84
INFO:absl:eval epoch: 2, loss: -0.8623, accuracy: 88.33
INFO:absl:train epoch: 3, loss: -0.9923, accuracy: 99.87
INFO:absl:eval epoch: 3, loss: -0.8623, accuracy: 88.05


In [None]:
logits = MLP().apply({'params': new_optimizer.target}, 
                               test_ds['image'], 
                               input_key='inputs') 
pred_labels = jax.device_get(jnp.argmax(logits, axis=-1))
x, l = jax.device_get(test_ds)['image'], jax.device_get(test_ds)['label']

predicted_gs = plot(x[:,0], x[:,1], pred_labels, 'Gradual Self-training', width=200)

predicted_gs

In [None]:
#@title Direct Self-training
new_optimizer = optimizer.replace()
config.learning_rate = 0.0001
config.num_epochs = 1
config.conf = 0.4

new_optimizer = train_and_evaluate(config, test_ds, test_ds, workdir=f'./models/direct', 
                                  optimizer=new_optimizer,
                                  steps_per_epoch=None)

INFO:absl:train epoch: 1, loss: -0.3735, accuracy: 40.61
INFO:absl:eval epoch: 1, loss: -0.4791, accuracy: 52.13


In [None]:
logits = MLP().apply({'params': new_optimizer.target}, 
                               test_ds['image'], 
                               input_key='inputs') 
pred_labels = jax.device_get(jnp.argmax(logits, axis=-1))
x, l = jax.device_get(test_ds)['image'], jax.device_get(test_ds)['label']

predicted_direct = plot(x[:,0], x[:,1], pred_labels, 'Self-training', width=200)

predicted_direct

In [None]:
#@title Iterative Self-training
new_optimizer = optimizer.replace()
total_steps = 1000
iters = 20
config.conf = 0.4
config.weight_decay = 0.01
config.learning_rate = 0.0001
config.epoch = 3
for i in range(1, iters+1):
  new_optimizer = train_and_evaluate(config, test_ds, test_ds, workdir=f'./models/directp', 
                                    optimizer=new_optimizer,
                                    steps_per_epoch=total_steps//iters)

INFO:absl:train epoch: 1, loss: -0.3722, accuracy: 40.70
INFO:absl:eval epoch: 1, loss: -0.4723, accuracy: 51.88
INFO:absl:train epoch: 1, loss: -0.3607, accuracy: 38.34
INFO:absl:eval epoch: 1, loss: -0.4621, accuracy: 49.88
INFO:absl:train epoch: 1, loss: -0.3430, accuracy: 35.97
INFO:absl:eval epoch: 1, loss: -0.4469, accuracy: 47.94
INFO:absl:train epoch: 1, loss: -0.3204, accuracy: 33.27
INFO:absl:eval epoch: 1, loss: -0.4292, accuracy: 45.78
INFO:absl:train epoch: 1, loss: -0.2986, accuracy: 30.81
INFO:absl:eval epoch: 1, loss: -0.4110, accuracy: 43.52
INFO:absl:train epoch: 1, loss: -0.2754, accuracy: 28.23
INFO:absl:eval epoch: 1, loss: -0.3927, accuracy: 41.43
INFO:absl:train epoch: 1, loss: -0.2553, accuracy: 26.05
INFO:absl:eval epoch: 1, loss: -0.3754, accuracy: 39.26
INFO:absl:train epoch: 1, loss: -0.2341, accuracy: 23.77
INFO:absl:eval epoch: 1, loss: -0.3589, accuracy: 37.18
INFO:absl:train epoch: 1, loss: -0.2140, accuracy: 21.58
INFO:absl:eval epoch: 1, loss: -0.3431,

In [None]:
logits = MLP().apply({'params': new_optimizer.target}, 
                               test_ds['image'], 
                               input_key='inputs') 
pred_labels = jax.device_get(jnp.argmax(logits, axis=-1))
x, l = jax.device_get(test_ds)['image'], jax.device_get(test_ds)['label']

predicted_iter = plot(x[:,0], x[:,1], pred_labels, 'Iterative Self-Training',width=200)

predicted_iter

In [None]:
#@title GIFT (hidden layer)

new_optimizer = optimizer.replace()
total_steps = 1000
iters = 20
config.conf = 0.4
config.weight_decay = 0.03
config.learning_rate = 0.0002
for i in range(1, iters+1):
  lmbda = (1.0/iters)*i
  print(lmbda)
  config.num_epochs = 4
  new_optimizer = gift_train_and_evaluate(config, train_ds, test_ds, lmbda,
                                          workdir=f'./models/gift_hidden', 
                                          optimizer=new_optimizer,
                                          steps_per_epoch=total_steps//iters)



0.05


INFO:absl:train epoch: 1, loss: -0.9865, accuracy: 100.00
INFO:absl:eval epoch: 1, loss: -0.4671, accuracy: 53.88
INFO:absl:train epoch: 2, loss: -0.9828, accuracy: 100.00
INFO:absl:eval epoch: 2, loss: -0.4680, accuracy: 53.94
INFO:absl:train epoch: 3, loss: -0.9809, accuracy: 100.00
INFO:absl:eval epoch: 3, loss: -0.4731, accuracy: 54.38
INFO:absl:train epoch: 4, loss: -0.9801, accuracy: 100.00
INFO:absl:eval epoch: 4, loss: -0.4797, accuracy: 54.89


0.1


INFO:absl:train epoch: 1, loss: -0.9806, accuracy: 100.00
INFO:absl:eval epoch: 1, loss: -0.4864, accuracy: 55.60
INFO:absl:train epoch: 2, loss: -0.9805, accuracy: 100.00
INFO:absl:eval epoch: 2, loss: -0.4937, accuracy: 56.21
INFO:absl:train epoch: 3, loss: -0.9797, accuracy: 100.00
INFO:absl:eval epoch: 3, loss: -0.5019, accuracy: 57.06
INFO:absl:train epoch: 4, loss: -0.9796, accuracy: 100.00
INFO:absl:eval epoch: 4, loss: -0.5099, accuracy: 57.81


0.15000000000000002


INFO:absl:train epoch: 1, loss: -0.9805, accuracy: 100.00
INFO:absl:eval epoch: 1, loss: -0.5172, accuracy: 58.54
INFO:absl:train epoch: 2, loss: -0.9805, accuracy: 100.00
INFO:absl:eval epoch: 2, loss: -0.5239, accuracy: 59.25
INFO:absl:train epoch: 3, loss: -0.9799, accuracy: 100.00
INFO:absl:eval epoch: 3, loss: -0.5318, accuracy: 59.94
INFO:absl:train epoch: 4, loss: -0.9800, accuracy: 100.00
INFO:absl:eval epoch: 4, loss: -0.5402, accuracy: 60.77


0.2


INFO:absl:train epoch: 1, loss: -0.9808, accuracy: 100.00
INFO:absl:eval epoch: 1, loss: -0.5469, accuracy: 61.56
INFO:absl:train epoch: 2, loss: -0.9809, accuracy: 100.00
INFO:absl:eval epoch: 2, loss: -0.5526, accuracy: 62.14
INFO:absl:train epoch: 3, loss: -0.9803, accuracy: 100.00
INFO:absl:eval epoch: 3, loss: -0.5598, accuracy: 62.96
INFO:absl:train epoch: 4, loss: -0.9805, accuracy: 100.00
INFO:absl:eval epoch: 4, loss: -0.5674, accuracy: 63.77


0.25


INFO:absl:train epoch: 1, loss: -0.9812, accuracy: 100.00
INFO:absl:eval epoch: 1, loss: -0.5740, accuracy: 64.42
INFO:absl:train epoch: 2, loss: -0.9812, accuracy: 100.00
INFO:absl:eval epoch: 2, loss: -0.5794, accuracy: 65.07
INFO:absl:train epoch: 3, loss: -0.9807, accuracy: 100.00
INFO:absl:eval epoch: 3, loss: -0.5859, accuracy: 65.74
INFO:absl:train epoch: 4, loss: -0.9810, accuracy: 100.00
INFO:absl:eval epoch: 4, loss: -0.5922, accuracy: 66.47


0.30000000000000004


INFO:absl:train epoch: 1, loss: -0.9816, accuracy: 100.00
INFO:absl:eval epoch: 1, loss: -0.5972, accuracy: 66.92
INFO:absl:train epoch: 2, loss: -0.9816, accuracy: 100.00
INFO:absl:eval epoch: 2, loss: -0.6018, accuracy: 67.50
INFO:absl:train epoch: 3, loss: -0.9811, accuracy: 100.00
INFO:absl:eval epoch: 3, loss: -0.6066, accuracy: 67.91
INFO:absl:train epoch: 4, loss: -0.9814, accuracy: 100.00
INFO:absl:eval epoch: 4, loss: -0.6111, accuracy: 68.20


0.35000000000000003


INFO:absl:train epoch: 1, loss: -0.9820, accuracy: 100.00
INFO:absl:eval epoch: 1, loss: -0.6155, accuracy: 68.56
INFO:absl:train epoch: 2, loss: -0.9819, accuracy: 100.00
INFO:absl:eval epoch: 2, loss: -0.6190, accuracy: 68.83
INFO:absl:train epoch: 3, loss: -0.9814, accuracy: 100.00
INFO:absl:eval epoch: 3, loss: -0.6222, accuracy: 69.09
INFO:absl:train epoch: 4, loss: -0.9817, accuracy: 100.00
INFO:absl:eval epoch: 4, loss: -0.6245, accuracy: 69.25


0.4


INFO:absl:train epoch: 1, loss: -0.9822, accuracy: 100.00
INFO:absl:eval epoch: 1, loss: -0.6268, accuracy: 69.49
INFO:absl:train epoch: 2, loss: -0.9821, accuracy: 100.00
INFO:absl:eval epoch: 2, loss: -0.6290, accuracy: 69.68
INFO:absl:train epoch: 3, loss: -0.9816, accuracy: 100.00
INFO:absl:eval epoch: 3, loss: -0.6309, accuracy: 69.84
INFO:absl:train epoch: 4, loss: -0.9819, accuracy: 100.00
INFO:absl:eval epoch: 4, loss: -0.6317, accuracy: 69.78


0.45


INFO:absl:train epoch: 1, loss: -0.9824, accuracy: 100.00
INFO:absl:eval epoch: 1, loss: -0.6327, accuracy: 69.88
INFO:absl:train epoch: 2, loss: -0.9822, accuracy: 100.00
INFO:absl:eval epoch: 2, loss: -0.6342, accuracy: 69.97
INFO:absl:train epoch: 3, loss: -0.9817, accuracy: 100.00
INFO:absl:eval epoch: 3, loss: -0.6359, accuracy: 70.04
INFO:absl:train epoch: 4, loss: -0.9821, accuracy: 100.00
INFO:absl:eval epoch: 4, loss: -0.6359, accuracy: 70.02


0.5


INFO:absl:train epoch: 1, loss: -0.9825, accuracy: 100.00
INFO:absl:eval epoch: 1, loss: -0.6363, accuracy: 70.07
INFO:absl:train epoch: 2, loss: -0.9823, accuracy: 100.00
INFO:absl:eval epoch: 2, loss: -0.6372, accuracy: 70.12
INFO:absl:train epoch: 3, loss: -0.9818, accuracy: 100.00
INFO:absl:eval epoch: 3, loss: -0.6384, accuracy: 70.21
INFO:absl:train epoch: 4, loss: -0.9822, accuracy: 100.00
INFO:absl:eval epoch: 4, loss: -0.6377, accuracy: 70.11


0.55


INFO:absl:train epoch: 1, loss: -0.9826, accuracy: 100.00
INFO:absl:eval epoch: 1, loss: -0.6379, accuracy: 70.17
INFO:absl:train epoch: 2, loss: -0.9824, accuracy: 100.00
INFO:absl:eval epoch: 2, loss: -0.6389, accuracy: 70.28
INFO:absl:train epoch: 3, loss: -0.9819, accuracy: 100.00
INFO:absl:eval epoch: 3, loss: -0.6400, accuracy: 70.32
INFO:absl:train epoch: 4, loss: -0.9822, accuracy: 100.00
INFO:absl:eval epoch: 4, loss: -0.6390, accuracy: 70.22


0.6000000000000001


INFO:absl:train epoch: 1, loss: -0.9826, accuracy: 100.00
INFO:absl:eval epoch: 1, loss: -0.6390, accuracy: 70.28
INFO:absl:train epoch: 2, loss: -0.9824, accuracy: 100.00
INFO:absl:eval epoch: 2, loss: -0.6397, accuracy: 70.39
INFO:absl:train epoch: 3, loss: -0.9820, accuracy: 100.00
INFO:absl:eval epoch: 3, loss: -0.6409, accuracy: 70.44
INFO:absl:train epoch: 4, loss: -0.9823, accuracy: 100.00
INFO:absl:eval epoch: 4, loss: -0.6396, accuracy: 70.22


0.65


INFO:absl:train epoch: 1, loss: -0.9827, accuracy: 100.00
INFO:absl:eval epoch: 1, loss: -0.6395, accuracy: 70.32
INFO:absl:train epoch: 2, loss: -0.9825, accuracy: 100.00
INFO:absl:eval epoch: 2, loss: -0.6400, accuracy: 70.36
INFO:absl:train epoch: 3, loss: -0.9820, accuracy: 100.00
INFO:absl:eval epoch: 3, loss: -0.6410, accuracy: 70.43
INFO:absl:train epoch: 4, loss: -0.9824, accuracy: 100.00
INFO:absl:eval epoch: 4, loss: -0.6396, accuracy: 70.20


0.7000000000000001


INFO:absl:train epoch: 1, loss: -0.9827, accuracy: 100.00
INFO:absl:eval epoch: 1, loss: -0.6396, accuracy: 70.33
INFO:absl:train epoch: 2, loss: -0.9825, accuracy: 100.00
INFO:absl:eval epoch: 2, loss: -0.6401, accuracy: 70.39
INFO:absl:train epoch: 3, loss: -0.9821, accuracy: 100.00
INFO:absl:eval epoch: 3, loss: -0.6409, accuracy: 70.40
INFO:absl:train epoch: 4, loss: -0.9824, accuracy: 100.00
INFO:absl:eval epoch: 4, loss: -0.6394, accuracy: 70.19


0.75


INFO:absl:train epoch: 1, loss: -0.9827, accuracy: 100.00
INFO:absl:eval epoch: 1, loss: -0.6395, accuracy: 70.29
INFO:absl:train epoch: 2, loss: -0.9826, accuracy: 100.00
INFO:absl:eval epoch: 2, loss: -0.6399, accuracy: 70.36
INFO:absl:train epoch: 3, loss: -0.9821, accuracy: 100.00
INFO:absl:eval epoch: 3, loss: -0.6408, accuracy: 70.36
INFO:absl:train epoch: 4, loss: -0.9824, accuracy: 100.00
INFO:absl:eval epoch: 4, loss: -0.6392, accuracy: 70.16


0.8


INFO:absl:train epoch: 1, loss: -0.9828, accuracy: 100.00
INFO:absl:eval epoch: 1, loss: -0.6392, accuracy: 70.22
INFO:absl:train epoch: 2, loss: -0.9826, accuracy: 100.00
INFO:absl:eval epoch: 2, loss: -0.6396, accuracy: 70.33
INFO:absl:train epoch: 3, loss: -0.9821, accuracy: 100.00
INFO:absl:eval epoch: 3, loss: -0.6405, accuracy: 70.32
INFO:absl:train epoch: 4, loss: -0.9824, accuracy: 100.00
INFO:absl:eval epoch: 4, loss: -0.6389, accuracy: 70.14


0.8500000000000001


INFO:absl:train epoch: 1, loss: -0.9828, accuracy: 100.00
INFO:absl:eval epoch: 1, loss: -0.6388, accuracy: 70.19
INFO:absl:train epoch: 2, loss: -0.9826, accuracy: 100.00
INFO:absl:eval epoch: 2, loss: -0.6392, accuracy: 70.24
INFO:absl:train epoch: 3, loss: -0.9821, accuracy: 100.00
INFO:absl:eval epoch: 3, loss: -0.6403, accuracy: 70.27
INFO:absl:train epoch: 4, loss: -0.9824, accuracy: 100.00
INFO:absl:eval epoch: 4, loss: -0.6387, accuracy: 70.13


0.9


INFO:absl:train epoch: 1, loss: -0.9828, accuracy: 100.00
INFO:absl:eval epoch: 1, loss: -0.6386, accuracy: 70.19
INFO:absl:train epoch: 2, loss: -0.9826, accuracy: 100.00
INFO:absl:eval epoch: 2, loss: -0.6390, accuracy: 70.21
INFO:absl:train epoch: 3, loss: -0.9822, accuracy: 100.00
INFO:absl:eval epoch: 3, loss: -0.6401, accuracy: 70.25
INFO:absl:train epoch: 4, loss: -0.9825, accuracy: 100.00
INFO:absl:eval epoch: 4, loss: -0.6386, accuracy: 70.12


0.9500000000000001


INFO:absl:train epoch: 1, loss: -0.9828, accuracy: 100.00
INFO:absl:eval epoch: 1, loss: -0.6385, accuracy: 70.19
INFO:absl:train epoch: 2, loss: -0.9826, accuracy: 100.00
INFO:absl:eval epoch: 2, loss: -0.6389, accuracy: 70.20
INFO:absl:train epoch: 3, loss: -0.9822, accuracy: 100.00
INFO:absl:eval epoch: 3, loss: -0.6400, accuracy: 70.25
INFO:absl:train epoch: 4, loss: -0.9825, accuracy: 100.00
INFO:absl:eval epoch: 4, loss: -0.6386, accuracy: 70.12


1.0


INFO:absl:train epoch: 1, loss: -0.9828, accuracy: 100.00
INFO:absl:eval epoch: 1, loss: -0.6384, accuracy: 70.19
INFO:absl:train epoch: 2, loss: -0.9826, accuracy: 100.00
INFO:absl:eval epoch: 2, loss: -0.6389, accuracy: 70.20
INFO:absl:train epoch: 3, loss: -0.9822, accuracy: 100.00
INFO:absl:eval epoch: 3, loss: -0.6400, accuracy: 70.27
INFO:absl:train epoch: 4, loss: -0.9825, accuracy: 100.00
INFO:absl:eval epoch: 4, loss: -0.6385, accuracy: 70.12


In [None]:
logits = MLP().apply({'params': new_optimizer.target}, 
                               test_ds['image'], 
                               input_key='inputs') 
pred_labels = jax.device_get(jnp.argmax(logits, axis=-1))
x, l = jax.device_get(test_ds)['image'], jax.device_get(test_ds)['label']

predicted_gift = plot(x[:,0], x[:,1], pred_labels, 'GIFT', width=200)

predicted_gift

In [None]:

(moon & trainslated_moon) | (predicted_base | predicted_gs) & (predicted_iter | predicted_gift)