In [256]:
import os
import pathlib
import time
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import random
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_datasets as tfds
from tensorflow.keras.layers import Input
from LWTA.base import *
from LWTA.base_conv2d import LwtaClassifier as lwta_clf
from LWTA.bit_precision import compute_reduced_weights

In [2]:
learning_rate = 0.003
optimizer = keras.optimizers.SGD(learning_rate=learning_rate)  

meta_step_size = 0.25
inner_batch_size = 25
eval_batch_size = 25

meta_iters = 60000
eval_iters = 5
inner_iters = 4

eval_interval = 50
report_frequency = 50
checkpoint_freq = 1000
train_shots = 20
shots = 5 # 1 for 1-shot 5-way
classes = 5 # 5 for 1-shot 5-way
BMA = False
DETERMINISTIC = False

## Prepare the data

The [Omniglot dataset](https://github.com/brendenlake/omniglot/) is a dataset of 1,623
characters taken from 50 different alphabets, with 20 examples for each character.
The 20 samples for each character were drawn online via Amazon's Mechanical Turk. For the
few-shot learning task, `k` samples (or "shots") are drawn randomly from `n` randomly-chosen
classes. These `n` numerical values are used to create a new set of temporary labels to use
to test the model's ability to learn a new task given few examples. In other words, if you
are training on 5 classes, your new class labels will be either 0, 1, 2, 3, or 4.
Omniglot is a great dataset for this task since there are many different classes to draw
from, with a reasonable number of samples for each class.

In [4]:
class Dataset:
    # This class will facilitate the creation of a few-shot dataset
    # from the Omniglot dataset that can be sampled from quickly while also
    # allowing to create new labels at the same time.
    def __init__(self, training):
        # Download the tfrecord files containing the omniglot data and convert to a
        # dataset.
        split = "train" if training else "test"
        ds = tfds.load("omniglot", split=split, as_supervised=True, shuffle_files=False)

        # Iterate over the dataset to get each individual image and its class,
        # and put that data into a dictionary.
        self.data = {}
            
        def extraction(image, label):
            # This function will shrink the Omniglot images to the desired size,
            # scale pixel values and convert the RGB image to grayscale
            image = tf.image.convert_image_dtype(image, tf.float32)
            image = tf.image.rgb_to_grayscale(image)
            image = tf.image.resize(image, [28, 28])
            return image, label

        for image, label in ds.map(extraction):
            image = image.numpy()
            label = str(label.numpy())
            if label not in self.data:
                self.data[label] = []
            self.data[label].append(image)
            self.labels = list(self.data.keys())

    def get_mini_dataset(
        self, batch_size, repetitions, shots, num_classes, split=False
    ):
        temp_labels = np.zeros(shape=(num_classes * shots))
        temp_images = np.zeros(shape=(num_classes * shots, 28, 28, 1))
        if split:
            test_labels = np.zeros(shape=(num_classes))
            test_images = np.zeros(shape=(num_classes, 28, 28, 1))

        # Get a random subset of labels from the entire label set.
        label_subset = random.choices(self.labels, k=num_classes)
        for class_idx, class_obj in enumerate(label_subset):
            # Use enumerated index value as a temporary label for mini-batch in
            # few shot learning.
            temp_labels[class_idx * shots : (class_idx + 1) * shots] = class_idx
            # If creating a split dataset for testing, select an extra sample from each
            # label to create the test dataset.
            if split:
                test_labels[class_idx] = class_idx
                images_to_split = random.choices(
                    self.data[label_subset[class_idx]], k=shots + 1
                )
                test_images[class_idx] = images_to_split[-1]
                temp_images[
                    class_idx * shots : (class_idx + 1) * shots
                ] = images_to_split[:-1]
            else:
                # For each index in the randomly selected label_subset, sample the
                # necessary number of images.
                temp_images[
                    class_idx * shots : (class_idx + 1) * shots
                ] = random.choices(self.data[label_subset[class_idx]], k=shots)

        dataset = tf.data.Dataset.from_tensor_slices(
            (temp_images.astype(np.float32), temp_labels.astype(np.int32))
        )
        dataset = dataset.shuffle(100).batch(batch_size).repeat(repetitions)
        if split:
            return dataset, test_images, test_labels
        return dataset

import urllib3

urllib3.disable_warnings()  # Disable SSL warnings that may happen during download.
train_dataset = Dataset(training=True)
test_dataset = Dataset(training=False)

## Train the model (with dense_LWTA)

In [8]:
#try deterministic=False for point estimates and not distribution estimates
sb_class = LwtaClassifier(original_dim = [classes,1], tau = 5e-2, bma=BMA,
                          deterministic=DETERMINISTIC) 
                          
def train_func(x, train=True, activation="lwta"):
    return sb_class(x, train=train, activation=activation)

train = tf.function(train_func)

In [9]:
# Create checkpoint for resuming training 
ckpt = tf.train.Checkpoint(step=tf.Variable(0), optimizer=optimizer, net=sb_class)
ckpt_path = f"./checkpoints_omniglot_lwta_{shots}_shot_{classes}_way"
manager = tf.train.CheckpointManager(ckpt, ckpt_path, max_to_keep=3)

def checkpoint_manager(meta_iter):  
    if meta_iter == 0:
        if os.path.isdir(ckpt_path):
            ckpt.restore(manager.latest_checkpoint)
            print("\nRestored from {}\n".format(manager.latest_checkpoint))
        else:
            print("\nNone checkpoints found => Initializing from scratch.\n")
    elif meta_iter % checkpoint_freq == 0:
        save_path = manager.save()
        print("\nSaved checkpoint for step {}: {}\n".format(meta_iter, save_path))
        
    ckpt.step.assign_add(1)
    return ckpt

#save to extenrla file integer

In [10]:
!rm -r checkpoints_omniglot_lwta_1_shot_5_way
!rm -r checkpoints_omniglot_lwta_5_shot_5_way

rm: cannot remove 'checkpoints_omniglot_lwta_1_shot_5_way': No such file or directory


In [18]:
training = []
testing = []

start = time.time()

for meta_iter in range(int(ckpt.step), meta_iters):
    mini_dataset = train_dataset.get_mini_dataset(
        inner_batch_size, inner_iters, train_shots, classes
    )

    ckpt = checkpoint_manager(meta_iter)

    frac_done = meta_iter / meta_iters
    cur_meta_step_size = (1 - frac_done) * meta_step_size
    
    # Get a sample from the full dataset.
    if meta_iter > 0:
        old_vars = sb_class.get_weights()

    j = 0    
    for images, labels in mini_dataset:
        
        with tf.GradientTape() as tape:
            x1, x2, x3 = images.shape[0], images.shape[1], images.shape[2]
            images = tf.reshape(images, [x1, x2*x3])    
            #preds, _, _ = sb_class(images, train=True, activation="lwta")
            try:
                preds, _, _ = train(images, train=True, activation="lwta")
            except (UnboundLocalError, ValueError):
                train = tf.function(train_func)
                preds, _, _ = train(images, train=True, activation="lwta")
                
            if (j == 0) and (meta_iter == 0):
                old_vars = sb_class.get_weights()
            # we optimize the variational lower bound scaled by the number of data
            # points (so we can keep our intuitions about hyper-params such as the learning rate)
            #kl_loss = sum(sb_class.losses) / (x1 * x2)
            ce = keras.losses.sparse_categorical_crossentropy(labels, preds)
            #loss = ce + kl_loss
            loss = ce
           
        grads = tape.gradient(loss, sb_class.trainable_weights)
        optimizer.apply_gradients(zip(grads, sb_class.trainable_weights))
       
        j += 1
    
    new_vars = sb_class.get_weights()

    # Perform SGD for the meta step.
    for var in range(len(new_vars)):
        new_vars[var] = old_vars[var] + (
            (new_vars[var] - old_vars[var]) * cur_meta_step_size
        )
    # After the meta-learning step, reload the newly-trained weights into the model.
    sb_class.set_weights(new_vars)
   
    if meta_iter == 50:
        print("\n#### The first 50 iterations took {:.2f} secs ####".format(time.time()-start))
 
    # Evaluation loop
    if meta_iter % eval_interval == 0:
        accuracies = []
        for dataset in (train_dataset, test_dataset):
            # Sample a mini dataset from the full dataset.
            train_set, test_images, test_labels = dataset.get_mini_dataset(
                eval_batch_size, eval_iters, shots, classes, split=True
            )
            old_vars = sb_class.get_weights()
            # Train on the samples and get the resulting accuracies.
            for images, labels in train_set:
              
                with tf.GradientTape() as tape:
                    x1, x2, x3 = images.shape[0], images.shape[1], images.shape[2]
                    images = tf.reshape(images, [x1, x2*x3])
                    try:
                        preds, _, _ = train(images, train=True, activation="lwta")
                    except (UnboundLocalError, ValueError):
                        train = tf.function(train_func)
                        preds, _, _ = train(images, train=True, activation="lwta")
                    #kl_loss = sum(sb_class.losses) / (x1 * x2)
                    ce = keras.losses.sparse_categorical_crossentropy(labels, preds)
                    #loss = ce + kl_loss
                    loss = ce

                grads = tape.gradient(loss, sb_class.trainable_weights)
                optimizer.apply_gradients(zip(grads, sb_class.trainable_weights))
            
            x1, x2, x3 = test_images.shape[0], test_images.shape[1], test_images.shape[2]

            test_images = tf.reshape(test_images, [x1, x2*x3]) 
#             test_preds, _, _ = sb_class(test_images, train=False, activation="lwta") for bma=True
            try:
                test_preds, _, _ = train(test_images, train=False, activation="lwta")
            except (UnboundLocalError, ValueError):
                train = tf.function(train_func)
                test_preds, _, _ = train(test_images, train=False, activation="lwta")
                
            test_preds = tf.argmax(test_preds).numpy()
            num_correct = (test_preds == test_labels).sum()
            
            # Reset the weights after getting the evaluation accuracies.
            sb_class.set_weights(old_vars)
            accuracies.append(num_correct / classes)
       
            
        training.append(accuracies[0])
        testing.append(accuracies[1])
        
        if meta_iter % report_frequency == 0:
            # total num of params
            print(tf.reduce_sum([tf.reduce_prod(v.shape) for v in sb_class.get_weights()]))

            print("Iter = %d => train_acc = %.2f%% / test_acc = %.2f%%" % (meta_iter,
                  100*np.mean(training),100*np.mean(testing)))

end = time.time()
print("The training took {:.2f} secs".format(end-start))


None checkpoints found => Initializing from scratch.

tf.Tensor(54549, shape=(), dtype=int32)
Iter = 0 => train_acc = 20.00% / test_acc = 60.00%

#### The first 50 iterations took 19.49 secs ####
tf.Tensor(54549, shape=(), dtype=int32)
Iter = 50 => train_acc = 20.00% / test_acc = 40.00%


KeyboardInterrupt: 

In [7]:
print("Final training accuracy = {:.1f} %%".format(100*np.mean(training)))
print("Final testing accuracy = {:.1f} %%".format(100*np.mean(testing)))

Final training accuracy = 80.0 %%
Final testing accuracy = 66.0 %%


## Train the model (without LWTA)

In [39]:
def conv_bn(x):
    x = layers.Conv2D(filters=64, kernel_size=3, strides=2, padding="same")(x)
    x = layers.BatchNormalization()(x)
    return layers.ReLU()(x)

inputs = layers.Input(shape=(28, 28, 1))
x = conv_bn(inputs)
x = conv_bn(x)
x = conv_bn(x)
x = conv_bn(x)
x = layers.Flatten()(x)
outputs = layers.Dense(classes, activation="softmax")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
model.compile()

In [40]:
def model_func(x):
    return model(x)

train = tf.function(model_func)

In [41]:
# Create checkpoint for resuming training 
ckpt = tf.train.Checkpoint(step=tf.Variable(0), optimizer=optimizer, net=model)
ckpt_path = f"./checkpoints_omniglot_lwta_{shots}_shot_{classes}_way"
manager = tf.train.CheckpointManager(ckpt, ckpt_path, max_to_keep=3)

def checkpoint_manager(meta_iter):  
    if meta_iter == 0:
        if os.path.isdir(ckpt_path):
            ckpt.restore(manager.latest_checkpoint)
            print("\nRestored from {}\n".format(manager.latest_checkpoint))
        else:
            print("\nNone checkpoints found => Initializing from scratch.\n")
    elif meta_iter % checkpoint_freq == 0:
        save_path = manager.save()
        print("\nSaved checkpoint for step {}: {}\n".format(meta_iter, save_path))
        
    ckpt.step.assign_add(1)
    return ckpt

#save to extenrla file integer

In [42]:
!rm -r checkpoints_omniglot_lwta_1_shot_5_way
!rm -r checkpoints_omniglot_lwta_5_shot_5_way

rm: cannot remove 'checkpoints_omniglot_lwta_1_shot_5_way': No such file or directory
rm: cannot remove 'checkpoints_omniglot_lwta_5_shot_5_way': No such file or directory


In [43]:
training = []
testing = []
num_params = 0
train_acc = []
test_acc = []
train_loss = []

start = time.time()

for meta_iter in range(int(ckpt.step), meta_iters):   
    mini_dataset = train_dataset.get_mini_dataset(
        inner_batch_size, inner_iters, train_shots, classes
    )
    
    checkpoint_manager(meta_iter)
        
    frac_done = meta_iter / meta_iters
    cur_meta_step_size = (1 - frac_done) * meta_step_size
    # Temporarily save the weights from the model.
    old_vars = model.get_weights()
    # Get a sample from the full dataset.
    
    j = 0
    for images, labels in mini_dataset:
        with tf.GradientTape() as tape:
#             preds = model(images)
            try:
                preds = train(images)
            except (UnboundLocalError, ValueError):
                train = tf.function(model_func)
                preds = train(images)
            loss = keras.losses.sparse_categorical_crossentropy(labels, preds)
        grads = tape.gradient(loss, model.trainable_weights)

        optimizer.apply_gradients(zip(grads, model.trainable_weights))
        j += 1
    new_vars = model.get_weights()
    # Perform SGD for the meta step.
    for var in range(len(new_vars)):
        new_vars[var] = old_vars[var] + (
            (new_vars[var] - old_vars[var]) * cur_meta_step_size
        )
    # After the meta-learning step, reload the newly-trained weights into the model.
    model.set_weights(new_vars)
    if meta_iter == 50:
        print("\n#### The first 50 iterations took {:.2f} secs ####".format(time.time()-start))
    
    # Evaluation loop
    if meta_iter % 1 == 0:
        accuracies = []
        for dataset in (train_dataset, test_dataset):
            # Sample a mini dataset from the full dataset.
            train_set, test_images, test_labels = dataset.get_mini_dataset(
                eval_batch_size, eval_iters, shots, classes, split=True
            )
            old_vars = model.get_weights()
            # Train on the samples and get the resulting accuracies.
            for images, labels in train_set:
                with tf.GradientTape() as tape:
#                     preds = model(images)
                    try:
                        preds = train(images)
                    except (UnboundLocalError, ValueError):
                        train = tf.function(model_func)
                        preds = train(images)
                    loss = keras.losses.sparse_categorical_crossentropy(labels, preds)                          
                grads = tape.gradient(loss, model.trainable_weights)
                optimizer.apply_gradients(zip(grads, model.trainable_weights))
            
            try:
                test_preds = train(test_images)
            except (UnboundLocalError, ValueError):
                train = tf.function(model_func)
                test_preds = train(test_images)
#             test_preds = model.predict(test_images)
            
            train_loss_ = keras.losses.sparse_categorical_crossentropy(test_labels, test_preds) 
            test_preds = tf.argmax(test_preds).numpy()
                        
            num_correct = (test_preds == test_labels).sum()
           
            # Reset the weights after getting the evaluation accuracies.
            model.set_weights(old_vars)
            accuracies.append(num_correct / classes)
       
        training.append(accuracies[0])
        testing.append(accuracies[1])
        train_acc.append(100*np.mean(training))
        test_acc.append(100*np.mean(testing))
        train_loss.append(tf.reduce_mean(train_loss_))

        with open('print_metrics_folder/train_acc_reptile_no_lwta.txt', 'a+') as file:
            file.write("%f\n" % (100.0*np.mean(training)))

        with open('print_metrics_folder/train_loss_reptile_no_lwta.txt', 'a+') as file:
            file.write("%f\n" % (tf.reduce_mean(train_loss_).numpy()))

        with open('print_metrics_folder/test_acc_reptile_no_lwta.txt', 'a+') as file:
            file.write("%f\n" % (100.0*np.mean(testing)))
        
        if meta_iter % report_frequency == 0:
            print("Iter = %d => train_acc = %.2f%% / test_acc = %.2f%%" % (meta_iter,
                  100*np.mean(training),100*np.mean(testing)))

end = time.time()
print("The training took {:.2f} secs".format(end-start))


None checkpoints found => Initializing from scratch.

Iter = 0 => train_acc = 60.00% / test_acc = 20.00%

#### The first 50 iterations took 29.22 secs ####
Iter = 50 => train_acc = 50.20% / test_acc = 46.27%
Iter = 100 => train_acc = 58.61% / test_acc = 55.25%
Iter = 150 => train_acc = 66.23% / test_acc = 64.77%
Iter = 200 => train_acc = 71.14% / test_acc = 68.86%
Iter = 250 => train_acc = 74.34% / test_acc = 72.75%
Iter = 300 => train_acc = 75.81% / test_acc = 74.82%
Iter = 350 => train_acc = 77.21% / test_acc = 76.41%
Iter = 400 => train_acc = 77.76% / test_acc = 77.36%
Iter = 450 => train_acc = 78.94% / test_acc = 77.83%
Iter = 500 => train_acc = 79.84% / test_acc = 78.84%
Iter = 550 => train_acc = 80.44% / test_acc = 79.17%
Iter = 600 => train_acc = 81.13% / test_acc = 79.83%
Iter = 650 => train_acc = 81.41% / test_acc = 79.88%
Iter = 700 => train_acc = 81.88% / test_acc = 79.91%
Iter = 750 => train_acc = 82.24% / test_acc = 80.37%
Iter = 800 => train_acc = 82.25% / test_acc = 80.

Iter = 7300 => train_acc = 90.54% / test_acc = 87.19%
Iter = 7350 => train_acc = 90.55% / test_acc = 87.21%
Iter = 7400 => train_acc = 90.58% / test_acc = 87.25%
Iter = 7450 => train_acc = 90.62% / test_acc = 87.26%
Iter = 7500 => train_acc = 90.64% / test_acc = 87.27%
Iter = 7550 => train_acc = 90.67% / test_acc = 87.29%
Iter = 7600 => train_acc = 90.70% / test_acc = 87.30%
Iter = 7650 => train_acc = 90.73% / test_acc = 87.28%
Iter = 7700 => train_acc = 90.74% / test_acc = 87.27%
Iter = 7750 => train_acc = 90.77% / test_acc = 87.26%
Iter = 7800 => train_acc = 90.79% / test_acc = 87.29%
Iter = 7850 => train_acc = 90.80% / test_acc = 87.30%
Iter = 7900 => train_acc = 90.82% / test_acc = 87.33%
Iter = 7950 => train_acc = 90.85% / test_acc = 87.34%

Saved checkpoint for step 8000: ./checkpoints_omniglot_lwta_5_shot_5_way/ckpt-8

Iter = 8000 => train_acc = 90.88% / test_acc = 87.35%
Iter = 8050 => train_acc = 90.90% / test_acc = 87.39%
Iter = 8100 => train_acc = 90.94% / test_acc = 87.41%


KeyboardInterrupt: 

In [7]:
print("Final training accuracy = {:.1f} %%".format(100*np.mean(training)))
print("Final testing accuracy = {:.1f} %%".format(100*np.mean(testing)))

Final training accuracy = 83.8 %%
Final testing accuracy = 68.6 %%
