# Semi-Supervised Learning Implementation Examples
Note that these examples are mostly to serve as an example of how a properly executed training loop should look. It does not include evaluation (which would simply be a call to the network), and it generally assumes the average parameter for most hyperparamters. 

A proper training protocol should include a hyperparamter sweep for whichever parameters requirement, and it should optimally include the __@tf.function__ decorator to speed up the training/inference loops whenever possible.

In [1]:
# This only needs to be run once (it takes some time to initialize the first time you use it)
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.datasets import mnist

# Only useful for GPU devices
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
for device in gpu_devices: tf.config.experimental.set_memory_growth(device, True)

In [2]:
# Rest of the required imports for this example
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

# HYPERPARAMETERS FOR DATASET
PERCENT_LABELED = 0.10

# HYPERPARAMETERS FOR NETWORK
DISTANCE_WEIGHT_MAX = 30
DISTANCE_MAX_WEIGHT_EPOCH = 10
TRAINING_EPOCHS = 50
BATCH_SIZE = 1024 # unlabeled batch_size determined proportionally!

# Initialization of Model Data
##### Keep in mind that that one of the central points of SSL is to train on an unsupervised dataset that is constantly being augmented. While the augmentations aren't carried out here, the Dataset class in TensorFlow does allow for online augmentations to take place. (See [this](https://www.tensorflow.org/tutorials/images/data_augmentation) link.)

In [3]:
# Loads in the MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Generate 32x32 image set for use in LeNet
x_train = np.pad(x_train, ((0,0),(2,2),(2,2))).reshape((-1, 32, 32, 1))
x_test = np.pad(x_test, ((0,0),(2,2),(2,2))).reshape((-1, 32, 32, 1))

# Normalize input values
x_train = x_train/255.
x_test = x_test/255.

# Separate dataset into labeled and unlabeled portions
sample_inds = np.arange(x_train.shape[0])
np.random.shuffle(sample_inds)
split_ind = int(PERCENT_LABELED*x_train.shape[0])
x_train_labeled = x_train[sample_inds[:split_ind],:,:]
y_train_labeled = y_train[sample_inds[:split_ind]].astype('float32')
x_train_unlabeled = x_train[sample_inds[split_ind:],:,:]

# Generate datasets with batching from samples
l_train_ds = tf.data.Dataset.from_tensor_slices((x_train_labeled, y_train_labeled))
ul_train_ds = tf.data.Dataset.from_tensor_slices(x_train_unlabeled)

l_train_ds = l_train_ds.shuffle(buffer_size=1024).batch(BATCH_SIZE)
num_batches = int((x_train_labeled.shape[0]+BATCH_SIZE-1)/BATCH_SIZE)
ul_train_ds = ul_train_ds.shuffle(buffer_size=1024).batch(int(x_train_unlabeled.shape[0]/num_batches), drop_remainder=True)

print("Size of labeled dataset: {}".format(x_train_labeled.shape[0]))
print("Size of unlabeled dataset: {}".format(x_train_unlabeled.shape[0]))

Size of labeled dataset: 6000
Size of unlabeled dataset: 54000


In [4]:
# Example model to use (the actual model doesn't matter so long as it has dropout)
def LeNet(dropout_prob = 0.5):
    # input
    xIn = keras.Input(shape=(32,32,1))
    
    # subsequent layers
    out = keras.layers.Conv2D(6, 5, activation='relu')(xIn)
    out = keras.layers.AveragePooling2D(2, 2)(out)
    out = keras.layers.Conv2D(16, 5, activation='relu')(out)
    out = keras.layers.AveragePooling2D(2, 2)(out)
    out = keras.layers.Flatten()(out)
    out = keras.layers.Dense(120, activation='relu')(out)
    out = keras.layers.Dropout(dropout_prob)(out)
    out = keras.layers.Dense(84, activation='relu')(out)
    out = keras.layers.Dropout(dropout_prob)(out)
    out = keras.layers.Dense(10, name='prelogits')(out)
    out = keras.layers.Activation('softmax', name='logits')(out)
    
    # Creates model
    mod = keras.Model(inputs=xIn, outputs=out)
    
    return mod

### PI Model

In [5]:
# Implement PI-Network Training
curModel = LeNet()
sup_loss_fn = keras.losses.SparseCategoricalCrossentropy()
optimizer = keras.optimizers.Adam(learning_rate=5e-3)
epochVar = tf.Variable(0)

# Metrics
train_acc_metric = keras.metrics.SparseCategoricalAccuracy()
unsup_loss_metric = keras.metrics.Mean()
sup_loss_metric = keras.metrics.Mean()
total_loss_metric = keras.metrics.Mean()

# Implement the values chosen for the weighting function
w_values = tf.constant(np.linspace(0, DISTANCE_WEIGHT_MAX, DISTANCE_MAX_WEIGHT_EPOCH), dtype='float32')

# Now set up weighing function
@tf.function
def weigh_fn(epoch):
    return w_values[tf.math.minimum(DISTANCE_MAX_WEIGHT_EPOCH-1, epoch)]

@tf.function
def train_step(x_sup, x_unsup, y_sup, epoch):
    # Calculate gradients while performing operations
    with tf.GradientTape() as tape:
        # Process the necessary values for each batch
        l_logits = curModel(x_sup, training=True)
        us_logits_1 = curModel(x_unsup, training=True)
        us_logits_2 = curModel(x_unsup, training=True)

        # Compute losses for each respective dataset
        sup_loss = sup_loss_fn(y_sup, l_logits)
        unsup_loss = weigh_fn(epoch) * tf.reduce_mean(keras.losses.MSE(us_logits_2, us_logits_1))
        total_loss = sup_loss + unsup_loss
        
        # Update relevant metrics
        train_acc_metric.update_state(y_sup, l_logits)
        unsup_loss_metric.update_state(unsup_loss)
        sup_loss_metric.update_state(sup_loss)
        total_loss_metric.update_state(total_loss)
        
    # Use tape to propagate gradients back to weights
    grads = tape.gradient(total_loss, curModel.trainable_weights)
    optimizer.apply_gradients(zip(grads, curModel.trainable_weights))

for epoch in range(TRAINING_EPOCHS):
    print("Epoch {}".format(epoch+1))
    
    # iterate accross the batches
    for step, (labeled_batch, unlabeled_batch) in enumerate(zip(l_train_ds, ul_train_ds)):
        labeled_bx, labeled_by = labeled_batch
        unsup_samples = tf.concat([labeled_bx, unlabeled_batch], 0)
        
        # Progress through a step of training
        train_step(labeled_bx, unsup_samples, labeled_by, epochVar)
        
        # Reset metrics and evaluate results
        t_acc = train_acc_metric.result()
        us_loss = unsup_loss_metric.result()
        s_loss = sup_loss_metric.result()
        t_loss = total_loss_metric.result()
        
        # Print epoch final statistics
        if(step == num_batches-1):
            print("Loss: {:.4f}, US_Loss: {:.4f}, S_Loss: {:.4f}, Acc: {:.2f}".format(t_loss, us_loss,
                                                                                      s_loss, t_acc))
        epochVar = epochVar + 1
    
    # Reset states
    train_acc_metric.reset_states()
    unsup_loss_metric.reset_states()
    total_loss_metric.reset_states()
    sup_loss_metric.reset_states()

Epoch 1
Loss: 2.1731, US_Loss: 0.0170, S_Loss: 2.1561, Acc: 0.24
Epoch 2
Loss: 1.7144, US_Loss: 0.1984, S_Loss: 1.5160, Acc: 0.56
Epoch 3
Loss: 1.3375, US_Loss: 0.3267, S_Loss: 1.0109, Acc: 0.72
Epoch 4
Loss: 1.0206, US_Loss: 0.3159, S_Loss: 0.7047, Acc: 0.80
Epoch 5
Loss: 0.8117, US_Loss: 0.2695, S_Loss: 0.5421, Acc: 0.85
Epoch 6
Loss: 0.6637, US_Loss: 0.2410, S_Loss: 0.4227, Acc: 0.89
Epoch 7
Loss: 0.5466, US_Loss: 0.2011, S_Loss: 0.3456, Acc: 0.91
Epoch 8
Loss: 0.4819, US_Loss: 0.1881, S_Loss: 0.2938, Acc: 0.93
Epoch 9
Loss: 0.4156, US_Loss: 0.1670, S_Loss: 0.2486, Acc: 0.94
Epoch 10
Loss: 0.3728, US_Loss: 0.1504, S_Loss: 0.2224, Acc: 0.94
Epoch 11
Loss: 0.3400, US_Loss: 0.1433, S_Loss: 0.1967, Acc: 0.95
Epoch 12
Loss: 0.3027, US_Loss: 0.1270, S_Loss: 0.1756, Acc: 0.95
Epoch 13
Loss: 0.2729, US_Loss: 0.1174, S_Loss: 0.1554, Acc: 0.96
Epoch 14
Loss: 0.2574, US_Loss: 0.1126, S_Loss: 0.1448, Acc: 0.96
Epoch 15
Loss: 0.2238, US_Loss: 0.1022, S_Loss: 0.1217, Acc: 0.97
Epoch 16
Loss: 0.21

# How Do We Implement Alternatives?
##### Note that no observations were made about temporal ensembling with minibatches. It's quite possible that the stochasticity introduced by minibatches may cause some future instability issues with certain problems...

The alternatives are basically the same! There are some small differences depending on what you implement though... For example, the following is the implementation of the temporal ensembling (it's not efficient, so don't train large networks like this!!!! A proper implementation would require managing resources a bit more than what is done here...)

### Temporal Ensembling

In [5]:
# Shuffling is disabled to allow for proper recording of outputs... Otherwise we would be required
# to record shuffled indices on every iteration (which means implementing our own shuffles!)
l_train_ds = tf.data.Dataset.from_tensor_slices((x_train_labeled, y_train_labeled))
ul_train_ds = tf.data.Dataset.from_tensor_slices(x_train_unlabeled)

l_train_ds = l_train_ds.batch(BATCH_SIZE)
num_batches = int((x_train_labeled.shape[0]+BATCH_SIZE-1)/BATCH_SIZE)
ul_train_ds = ul_train_ds.batch(int(x_train_unlabeled.shape[0]/num_batches), drop_remainder=True)

In [6]:
# Implement Temporal Ensembling...
curModel = LeNet()
sup_loss_fn = keras.losses.SparseCategoricalCrossentropy()
optimizer = keras.optimizers.Adam(learning_rate=5e-3)
epochVar = tf.Variable(0)
stepVar = tf.Variable(0)

# Metrics
train_acc_metric = keras.metrics.SparseCategoricalAccuracy()
unsup_loss_metric = keras.metrics.Mean()
sup_loss_metric = keras.metrics.Mean()
total_loss_metric = keras.metrics.Mean()

# Implement the values chosen for the weighting function
w_values = tf.constant(np.linspace(0, DISTANCE_WEIGHT_MAX, DISTANCE_MAX_WEIGHT_EPOCH), dtype='float32')

# Now set up weighing function
@tf.function
def weigh_fn(epoch):
    return w_values[tf.math.minimum(DISTANCE_MAX_WEIGHT_EPOCH-1, epoch)]

# Now set up EMA variables
EMA_RATE = tf.constant(0.5)
yEMA = tf.zeros((x_train.shape[0], 10))
totalSamples = tf.constant(x_train.shape[0])
totalBatch = tf.constant(int(x_train_unlabeled.shape[0]/num_batches) + BATCH_SIZE)
epLogits = tf.Variable(0., shape=tf.TensorShape(None))

# Now set up the EMA function
@tf.function
def EMA_fn(toUpdate, update, epoch):
    return (EMA_RATE*toUpdate + (1-EMA_RATE)*update)/(1-tf.math.pow(EMA_RATE,tf.cast(epoch, dtype='float32')))

@tf.function
def train_step(x_sup, x_unsup, y_sup, epoch, step):
    # Calculate gradients while performing operations
    with tf.GradientTape() as tape:
        # Process the necessary values for each batch
        us_logits = curModel(x_unsup, training=True)
        l_logits = us_logits[:x_sup.shape[0]]

        # Compute losses for each respective dataset
        sup_loss = sup_loss_fn(y_sup, l_logits)
        unsup_loss = weigh_fn(epoch) * tf.reduce_mean(keras.losses.MSE(us_logits, 
                                yEMA[totalBatch*step:tf.math.minimum(totalBatch*(step+1),totalSamples)]))
        total_loss = sup_loss + unsup_loss
        
        # Update relevant metrics
        train_acc_metric.update_state(y_sup, l_logits)
        unsup_loss_metric.update_state(unsup_loss)
        sup_loss_metric.update_state(sup_loss)
        total_loss_metric.update_state(total_loss)
        
    # Use tape to propagate gradients back to weights
    grads = tape.gradient(total_loss, curModel.trainable_weights)
    optimizer.apply_gradients(zip(grads, curModel.trainable_weights))
    
    return us_logits

for epoch in range(TRAINING_EPOCHS):
    print("Epoch {}".format(epoch+1))
    
    # iterate accross the batches
    for step, (labeled_batch, unlabeled_batch) in enumerate(zip(l_train_ds, ul_train_ds)):
        labeled_bx, labeled_by = labeled_batch
        unsup_samples = tf.concat([labeled_bx, unlabeled_batch], 0)
        
        # Progress through a step of training
        tAppend = train_step(labeled_bx, unsup_samples, labeled_by, epochVar, stepVar)
        if(step == 0):
            epLogits = tAppend
        else:
            epLogits = tf.concat([epLogits, tAppend], 0)
        
        # Reset metrics and evaluate results
        t_acc = train_acc_metric.result()
        train_acc_metric.reset_states()
        us_loss = unsup_loss_metric.result()
        unsup_loss_metric.reset_states()
        s_loss = sup_loss_metric.result()
        sup_loss_metric.reset_states()
        t_loss = total_loss_metric.result()
        total_loss_metric.reset_states()
        stepVar = stepVar + 1
        
    # Update EMA now
    if(epoch == 0):
        yEMA = epLogits
    else:
        yEMA = EMA_fn(yEMA, epLogits, epochVar)
        
    # Print epoch final statistics
    if(step == num_batches-1):
        print("Loss: {:.4f}, US_Loss: {:.4f}, S_Loss: {:.4f}, Acc: {:.2f}".format(t_loss, us_loss,
                                                                                  s_loss, t_acc))
    epochVar = epochVar + 1
    stepVar = stepVar - stepVar # Assignment causes retracing???? WHY

Epoch 1
Loss: 2.0106, US_Loss: 0.0000, S_Loss: 2.0106, Acc: 0.32
Epoch 2
Loss: 1.2828, US_Loss: 0.0751, S_Loss: 1.2077, Acc: 0.60
Epoch 3
Loss: 1.1087, US_Loss: 0.1922, S_Loss: 0.9165, Acc: 0.71
Epoch 4
Loss: 1.0641, US_Loss: 0.3493, S_Loss: 0.7148, Acc: 0.80
Epoch 5
Loss: 1.1460, US_Loss: 0.3819, S_Loss: 0.7640, Acc: 0.85
Epoch 6
Loss: 1.2186, US_Loss: 0.4024, S_Loss: 0.8162, Acc: 0.86
Epoch 7
Loss: 1.2898, US_Loss: 0.3629, S_Loss: 0.9269, Acc: 0.88
Epoch 8
Loss: 1.3041, US_Loss: 0.3525, S_Loss: 0.9515, Acc: 0.90
Epoch 9
Loss: 1.3541, US_Loss: 0.3346, S_Loss: 1.0194, Acc: 0.92
Epoch 10
Loss: 1.4116, US_Loss: 0.3357, S_Loss: 1.0759, Acc: 0.92
Epoch 11
Loss: 1.3707, US_Loss: 0.3482, S_Loss: 1.0225, Acc: 0.92
Epoch 12
Loss: 1.3572, US_Loss: 0.3476, S_Loss: 1.0097, Acc: 0.95
Epoch 13
Loss: 1.3503, US_Loss: 0.3477, S_Loss: 1.0026, Acc: 0.95
Epoch 14
Loss: 1.3434, US_Loss: 0.3449, S_Loss: 0.9985, Acc: 0.95
Epoch 15
Loss: 1.3206, US_Loss: 0.3463, S_Loss: 0.9744, Acc: 0.96
Epoch 16
Loss: 1.32

The rest function in a similar manner except by including more networks to sample from / train at specific intervals. Keep in mind that while the EMA method did take a bit longer to reach the same accuracy from the previous example, it also requires much more hyperparameters that need adjustment in order to tune the network to get a specific response...

And now with the nightmare that was the EMA saved guesses saved... We can move on to the Mean Teacher model, which is admittedly much simpler than what it appears. The teacher model requires two models to be instantiated with the same set of weights, and they slowly begin to diverge as training progresses...

### Mean Teacher

In [5]:
# Implement Mean Teacher Training
teacherMod = LeNet()
studentMod = LeNet()
sup_loss_fn = keras.losses.SparseCategoricalCrossentropy()
optimizer = keras.optimizers.Adam(learning_rate=5e-3)
epochVar = tf.Variable(0)

# Metrics
train_acc_metric = keras.metrics.SparseCategoricalAccuracy()
unsup_loss_metric = keras.metrics.Mean()
sup_loss_metric = keras.metrics.Mean()
total_loss_metric = keras.metrics.Mean()

# Implement the values chosen for the weighting function
w_values = tf.constant(np.linspace(0, DISTANCE_WEIGHT_MAX, DISTANCE_MAX_WEIGHT_EPOCH), dtype='float32')

# Initialize teacher with student weights
teacherMod.set_weights(studentMod.get_weights())

# Set up EMA func for weights
# NOTE: THIS FUNCTION WILL BE RETRACED AT LEAST AS MANY TIMES AS THERE ARE LAYERS WITH DIFFERENTLY
#       SHAPED WEIGHT MATRICES!!!
EMA_RATE = tf.constant(0.5)
@tf.function
def EMA_fn(toUpdate, update):
    return EMA_RATE*toUpdate + (1-EMA_RATE)*update

# Now set up weighing function
@tf.function(experimental_relax_shapes=True)
def weigh_fn(epoch):
    return w_values[tf.math.minimum(DISTANCE_MAX_WEIGHT_EPOCH-1, epoch)]

@tf.function
def train_step(x_sup, x_unsup, y_sup, epoch):
    # Calculate gradients while performing operations
    with tf.GradientTape() as tape:
        # Process the necessary values for each batch
        us_slogits = studentMod(x_unsup, training=True)
        us_tlogits = teacherMod(x_unsup, training=True)
        l_logits = us_slogits[:x_sup.shape[0]]

        # Compute losses for each respective dataset
        sup_loss = sup_loss_fn(y_sup, l_logits)
        unsup_loss = weigh_fn(epoch) * tf.reduce_mean(keras.losses.MSE(us_slogits, us_tlogits))
        total_loss = sup_loss + unsup_loss
        
        # Update relevant metrics
        train_acc_metric.update_state(y_sup, l_logits)
        unsup_loss_metric.update_state(unsup_loss)
        sup_loss_metric.update_state(sup_loss)
        total_loss_metric.update_state(total_loss)
        
    # Use tape to propagate gradients back to weights
    grads = tape.gradient(total_loss, studentMod.trainable_weights)
    optimizer.apply_gradients(zip(grads, studentMod.trainable_weights))

for epoch in range(TRAINING_EPOCHS):
    print("Epoch {}".format(epoch+1))
    
    # iterate accross the batches
    for step, (labeled_batch, unlabeled_batch) in enumerate(zip(l_train_ds, ul_train_ds)):
        labeled_bx, labeled_by = labeled_batch
        unsup_samples = tf.concat([labeled_bx, unlabeled_batch], 0)
        
        # Progress through a step of training
        train_step(labeled_bx, unsup_samples, labeled_by, epochVar)
        
        # Apply the EMA for the teacher
        teacherMod.set_weights([EMA_fn(tWeight, sWeight) for (tWeight, sWeight) in zip(teacherMod.get_weights(),
                                                                                       studentMod.get_weights())])
        
        # Reset metrics and evaluate results
        t_acc = train_acc_metric.result()
        train_acc_metric.reset_states()
        us_loss = unsup_loss_metric.result()
        unsup_loss_metric.reset_states()
        s_loss = sup_loss_metric.result()
        sup_loss_metric.reset_states()
        t_loss = total_loss_metric.result()
        total_loss_metric.reset_states()
        
        # Print epoch final statistics
        if(step == num_batches-1):
            print("Loss: {:.4f}, US_Loss: {:.4f}, S_Loss: {:.4f}, Acc: {:.2f}".format(t_loss, us_loss,
                                                                                      s_loss, t_acc))
        epochVar = epochVar + 1

Epoch 1
Loss: 1.9789, US_Loss: 0.0457, S_Loss: 1.9332, Acc: 0.36
Epoch 2
Loss: 1.5585, US_Loss: 0.3317, S_Loss: 1.2268, Acc: 0.60
Epoch 3
Loss: 1.1895, US_Loss: 0.4101, S_Loss: 0.7793, Acc: 0.76
Epoch 4
Loss: 0.9023, US_Loss: 0.3643, S_Loss: 0.5380, Acc: 0.82
Epoch 5
Loss: 0.7144, US_Loss: 0.3016, S_Loss: 0.4129, Acc: 0.87
Epoch 6
Loss: 0.5800, US_Loss: 0.2569, S_Loss: 0.3231, Acc: 0.92
Epoch 7
Loss: 0.5180, US_Loss: 0.2316, S_Loss: 0.2865, Acc: 0.91
Epoch 8
Loss: 0.4137, US_Loss: 0.1972, S_Loss: 0.2165, Acc: 0.95
Epoch 9
Loss: 0.3810, US_Loss: 0.1753, S_Loss: 0.2056, Acc: 0.95
Epoch 10
Loss: 0.3388, US_Loss: 0.1587, S_Loss: 0.1801, Acc: 0.95
Epoch 11
Loss: 0.3026, US_Loss: 0.1481, S_Loss: 0.1545, Acc: 0.96
Epoch 12
Loss: 0.2797, US_Loss: 0.1410, S_Loss: 0.1388, Acc: 0.96
Epoch 13
Loss: 0.2613, US_Loss: 0.1292, S_Loss: 0.1321, Acc: 0.96
Epoch 14
Loss: 0.2309, US_Loss: 0.1207, S_Loss: 0.1102, Acc: 0.96
Epoch 15
Loss: 0.2052, US_Loss: 0.1146, S_Loss: 0.0906, Acc: 0.98
Epoch 16
Loss: 0.18

### Dual Students

At this point, it should be easy to follow up on how to implement the dual students model. It will be nearly exactly the same as above but with an additional student model. Which student model is updated will depend largely on the distance of each model from the teacher (it is possible to update both or just one of them), and the teacher will be updated in the same way it was before.

Note that there are obviously some tradeoffs with all of these models in terms of space. While the first is the most unstable, it's the most free to use in terms of space requirements. On the other hand, temporal ensembling will require __as much space as the expected output__ in order to train (which can be huge with things like segmentation maps!), and mean teacher will require as __double the amount of space to store weights__ to store two networks on the same GPU during training (which will obviously limit the total size of the network that will be used for training...). Keep that in mind when approaching these problems!!!!