<a href="https://colab.research.google.com/github/Allen123321/Self-Supervised_Learning_Experiments_re-implement/blob/main/BYOL_TensorFlow_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

BYOL for self-supervised representation learning on the CIFAR-10 dataset.

## data-argumentation

In [1]:
import tensorflow as tf

def random_crop_flip_resize(image):
    # Random cropping
    h_crop = tf.cast(tf.random.uniform(shape=[], minval=13, maxval=33, dtype=tf.int32), tf.float32)
    w_crop = h_crop * tf.random.uniform(shape=[], minval=0.67, maxval=1.0)
    h_crop, w_crop = tf.cast(h_crop, tf.int32), tf.cast(w_crop, tf.int32)
    opposite_aspectratio = tf.random.uniform(shape=[])
    if opposite_aspectratio < 0.5:
        h_crop, w_crop = w_crop, h_crop
    image = tf.image.random_crop(image, size=[h_crop, w_crop, 3])

    # Horizontal flipping
    horizontal_flip = tf.random.uniform(shape=[])
    if horizontal_flip < 0.5:
        image = tf.image.random_flip_left_right(image)

    # Resizing to original size
    image = tf.image.resize(image, size=[32, 32])
    return image


def random_color_distortion(image):
    # Random color jittering (strength 0.5)
    color_jitter = tf.random.uniform(shape=[])
    if color_jitter < 0.8:
        image = tf.image.random_brightness(image, max_delta=0.4)
        image = tf.image.random_contrast(image, lower=0.6, upper=1.4)
        image = tf.image.random_saturation(image, lower=0.6, upper=1.4)
        image = tf.image.random_hue(image, max_delta=0.1)
        image = tf.clip_by_value(image, 0, 1)

    # Color dropping
    color_drop = tf.random.uniform(shape=[])
    if color_drop < 0.2:
        image = tf.image.rgb_to_grayscale(image)
        image = tf.tile(image, [1, 1, 3])

    return image


@tf.function
def augment_image_pretraining(image):
    image = random_crop_flip_resize(image)
    image = random_color_distortion(image)
    return image


@tf.function
def augment_image_finetuning(image):
    image = random_crop_flip_resize(image)
    return image

## datasets

In [2]:
class CIFAR10:

    def __init__(self):
        (self.x_train, self.y_train), (self.x_test, self.y_test) = tf.keras.datasets.cifar10.load_data()
        self.num_train_images, self.num_test_images = self.y_train.shape[0], self.y_test.shape[0]
        self.class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'sheep', 'truck']

        # Normalize training and testing images
        self.x_train = tf.cast(self.x_train / 255., tf.float32)
        self.x_test = tf.cast(self.x_test / 255., tf.float32)

        self.y_train = tf.cast(tf.squeeze(self.y_train), tf.int32)
        self.y_test = tf.cast(tf.squeeze(self.y_test), tf.int32)


    def get_batch_pretraining(self, batch_id, batch_size):
        augmented_images_1, augmented_images_2 = [], []
        for image_id in range(batch_id*batch_size, (batch_id+1)*batch_size):
            image = self.x_train[image_id]
            augmented_images_1.append(augment_image_pretraining(image))
            augmented_images_2.append(augment_image_pretraining(image))
        x_batch_1 = tf.stack(augmented_images_1)
        x_batch_2 = tf.stack(augmented_images_2)
        return x_batch_1, x_batch_2  # (bs, 32, 32, 3), (bs, 32, 32, 3)


    def get_batch_finetuning(self, batch_id, batch_size):
        augmented_images = []
        for image_id in range(batch_id*batch_size, (batch_id+1)*batch_size):
            image = self.x_train[image_id]
            augmented_images.append(augment_image_finetuning(image))
        x_batch = tf.stack(augmented_images)
        y_batch = tf.slice(self.y_train, [batch_id*batch_size], [batch_size])
        return x_batch, y_batch  # (bs, 32, 32, 3), (bs)


    def get_batch_testing(self, batch_id, batch_size):
        x_batch = tf.slice(self.x_test, [batch_id*batch_size, 0, 0, 0], [batch_size, -1, -1, -1])
        y_batch = tf.slice(self.y_test, [batch_id*batch_size], [batch_size])
        return x_batch, y_batch  # (bs, 32, 32, 3), (bs)


    def shuffle_training_data(self):
        random_ids = tf.random.shuffle(tf.range(self.num_train_images))
        self.x_train = tf.gather(self.x_train, random_ids)
        self.y_train = tf.gather(self.y_train, random_ids)

## loss

In [3]:

def byol_loss(p, z):
    p = tf.math.l2_normalize(p, axis=1)  # (2*bs, 128)
    z = tf.math.l2_normalize(z, axis=1)  # (2*bs, 128)

    similarities = tf.reduce_sum(tf.multiply(p, z), axis=1)
    return 2 - 2 * tf.reduce_mean(similarities)

## model

In [4]:
class BasicBlock(tf.keras.layers.Layer):

    def __init__(self, filters, strides):
        super(BasicBlock, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(filters=filters, kernel_size=(3, 3), strides=strides, padding='same')
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.conv2 = tf.keras.layers.Conv2D(filters=filters, kernel_size=(3, 3), strides=1, padding='same')
        self.bn2 = tf.keras.layers.BatchNormalization()

        if strides != 1:
            self.convdown = tf.keras.layers.Conv2D(filters=filters, kernel_size=(1, 1), strides=strides)
            self.bndown = tf.keras.layers.BatchNormalization()
        self.strides = strides

    def call(self, inp, training=False):
        x1 = self.conv1(inp)
        x1 = self.bn1(x1, training=training)
        x1 = tf.nn.relu(x1)
        x1 = self.conv2(x1)
        x1 = self.bn2(x1, training=training)

        if self.strides != 1:
            x2 = self.convdown(inp)
            x2 = self.bndown(x2, training=training)
        else:
            x2 = inp

        x = tf.keras.layers.add([x1, x2])
        x = tf.nn.relu(x)
        return x


# ResNet with BasicBlock (adapted to CIFAR-10)
class BasicResNet(tf.keras.Model):

    def __init__(self, layer_blocks):
        super(BasicResNet, self).__init__()

        self.conv1 = tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), strides=1, padding='same')
        self.bn1 = tf.keras.layers.BatchNormalization()

        self.blocks = []
        self.blocks.append(BasicBlock(filters=64, strides=1))
        for _ in range(layer_blocks[0] - 1):
            self.blocks.append(BasicBlock(filters=64, strides=1))
        
        self.blocks.append(BasicBlock(filters=128, strides=2))
        for _ in range(layer_blocks[1] - 1):
            self.blocks.append(BasicBlock(filters=128, strides=1))
        
        self.blocks.append(BasicBlock(filters=256, strides=2))
        for _ in range(layer_blocks[2] - 1):
            self.blocks.append(BasicBlock(filters=256, strides=1))
        
        self.blocks.append(BasicBlock(filters=512, strides=2))
        for _ in range(layer_blocks[3] - 1):
            self.blocks.append(BasicBlock(filters=512, strides=1))

        self.avgpool = tf.keras.layers.GlobalAveragePooling2D()

    def call(self, inp, training=False):
        x = self.conv1(inp)
        x = self.bn1(x, training=training)
        x = tf.nn.relu(x)
        for block in self.blocks:
            x = block(x, training=training)
        x = self.avgpool(x)
        return x


def ResNet18():
    return BasicResNet(layer_blocks=[2, 2, 2, 2])


def ResNet34():
    return BasicResNet(layer_blocks=[3, 4, 6, 3])


# 512 (h) -> 256 -> 128 (z)
class ProjectionHead(tf.keras.Model):

    def __init__(self):
        super(ProjectionHead, self).__init__()
        self.fc1 = tf.keras.layers.Dense(units=256)
        self.bn = tf.keras.layers.BatchNormalization()
        self.fc2 = tf.keras.layers.Dense(units=128)

    def call(self, inp, training=False):
        x = self.fc1(inp)
        x = self.bn(x, training=training)
        x = tf.nn.relu(x)
        x = self.fc2(x)
        return x


# 512 (h) -> 10 (s)
class ClassificationHead(tf.keras.Model):

    def __init__(self):
        super(ClassificationHead, self).__init__()
        self.fc = tf.keras.layers.Dense(units=10)

    def call(self, inp):
        x = self.fc(inp)
        return x

### pre-train

In [5]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

#import argparse
import numpy as np
import tensorflow as tf

#from datasets import CIFAR10
#from models import ResNet18, ResNet34, ProjectionHead
#from losses import byol_loss



encoders = {'resnet18': ResNet18, 'resnet34': ResNet34}


def train(encoder,num_epochs,batch_size):

    # Load CIFAR-10 dataset
    data = CIFAR10()

    # Instantiate networks
    f_online = encoders[encoder]()
    g_online = ProjectionHead()
    q_online = ProjectionHead()

    f_target = encoders[encoder]()
    g_target = ProjectionHead()


    # Initialize the weights of the networks
    x = tf.random.normal((256, 32, 32, 3))
    h = f_online(x, training=False)
    print('Initializing online networks...')
    print('Shape of h:', h.shape)
    z = g_online(h, training=False)
    print('Shape of z:', z.shape)
    p = q_online(z, training=False)
    print('Shape of p:', p.shape)

    h = f_target(x, training=False)
    print('Initializing target networks...')
    print('Shape of h:', h.shape)
    z = g_target(h, training=False)
    print('Shape of z:', z.shape)
    
    num_params_f = tf.reduce_sum([tf.reduce_prod(var.shape) for var in f_online.trainable_variables])    
    print('The encoders have {} trainable parameters each.'.format(num_params_f))


    # Define optimizer
    lr = 1e-3 * batch_size / 512
    opt = tf.keras.optimizers.Adam(learning_rate=lr)
    print('Using Adam optimizer with learning rate {}.'.format(lr))


    @tf.function
    def train_step_pretraining(x1, x2):  # (bs, 32, 32, 3), (bs, 32, 32, 3)

        # Forward pass
        h_target_1 = f_target(x1, training=True)
        z_target_1 = g_target(h_target_1, training=True)

        h_target_2 = f_target(x2, training=True)
        z_target_2 = g_target(h_target_2, training=True)

        with tf.GradientTape(persistent=True) as tape:
            h_online_1 = f_online(x1, training=True)
            z_online_1 = g_online(h_online_1, training=True)
            p_online_1 = q_online(z_online_1, training=True)
            
            h_online_2 = f_online(x2, training=True)
            z_online_2 = g_online(h_online_2, training=True)
            p_online_2 = q_online(z_online_2, training=True)
            
            p_online = tf.concat([p_online_1, p_online_2], axis=0)
            z_target = tf.concat([z_target_2, z_target_1], axis=0)
            loss = byol_loss(p_online, z_target)

        # Backward pass (update online networks)
        grads = tape.gradient(loss, f_online.trainable_variables)
        opt.apply_gradients(zip(grads, f_online.trainable_variables))
        grads = tape.gradient(loss, g_online.trainable_variables)
        opt.apply_gradients(zip(grads, g_online.trainable_variables))
        grads = tape.gradient(loss, q_online.trainable_variables)
        opt.apply_gradients(zip(grads, q_online.trainable_variables))
        del tape

        return loss


    batches_per_epoch = data.num_train_images // batch_size
    log_every = 10  # batches
    save_every = 100  # epochs

    losses = []
    for epoch_id in range(num_epochs):
        data.shuffle_training_data()
        
        for batch_id in range(batches_per_epoch):
            x1, x2 = data.get_batch_pretraining(batch_id, batch_size)
            loss = train_step_pretraining(x1, x2)
            losses.append(float(loss))

            # Update target networks (exponential moving average of online networks)
            beta = 0.99

            f_target_weights = f_target.get_weights()
            f_online_weights = f_online.get_weights()
            for i in range(len(f_online_weights)):
                f_target_weights[i] = beta * f_target_weights[i] + (1 - beta) * f_online_weights[i]
            f_target.set_weights(f_target_weights)
            
            g_target_weights = g_target.get_weights()
            g_online_weights = g_online.get_weights()
            for i in range(len(g_online_weights)):
                g_target_weights[i] = beta * g_target_weights[i] + (1 - beta) * g_online_weights[i]
            g_target.set_weights(g_target_weights)

            if (batch_id + 1) % log_every == 0:
                print('[Epoch {}/{} Batch {}/{}] Loss={:.5f}.'.format(epoch_id+1, num_epochs, batch_id+1, batches_per_epoch, loss))

        if (epoch_id + 1) % save_every == 0:
            f_online.save_weights('/content/drive/MyDrive/dataset/f_online_{}.h5'.format(epoch_id + 1))
            print('Weights of f saved.')
    
    np.savetxt('losses.txt', tf.stack(losses).numpy())




预训练

In [None]:
train("resnet18",200,512)

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
Initializing online networks...
Shape of h: (256, 512)
Shape of z: (256, 128)
Shape of p: (256, 128)
Initializing target networks...
Shape of h: (256, 512)
Shape of z: (256, 128)
The encoders have 11173632 trainable parameters each.
Using Adam optimizer with learning rate 0.001.
[Epoch 1/200 Batch 10/97] Loss=0.81563.
[Epoch 1/200 Batch 20/97] Loss=0.67043.
[Epoch 1/200 Batch 30/97] Loss=0.57101.
[Epoch 1/200 Batch 40/97] Loss=0.54568.
[Epoch 1/200 Batch 50/97] Loss=0.47895.
[Epoch 1/200 Batch 60/97] Loss=0.40210.
[Epoch 1/200 Batch 70/97] Loss=0.24860.
[Epoch 1/200 Batch 80/97] Loss=0.24036.
[Epoch 1/200 Batch 90/97] Loss=0.14926.
[Epoch 2/200 Batch 10/97] Loss=0.13301.
[Epoch 2/200 Batch 20/97] Loss=0.14967.
[Epoch 2/200 Batch 30/97] Loss=0.14494.
[Epoch 2/200 Batch 40/97] Loss=0.13782.
[Epoch 2/200 Batch 50/97] Loss=0.13844.
[Epoch 2/200 Batch 60/97] Loss=0.13403.
[Epoch 2/200 Batch 70/97] Loss=0.14207.
[E

### linearevaluation

In [10]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import argparse
import tensorflow as tf

#from datasets import CIFAR10
#from models import ResNet18, ResNet34, ClassificationHead



encoders = {'resnet18': ResNet18, 'resnet34': ResNet34}


def compute_test_accuracy(data, f_net, c_net):
    batch_size = 500
    num_batches = data.num_test_images // batch_size

    num_correct_predictions = 0
    for batch_id in range(num_batches):
        x, y = data.get_batch_testing(batch_id, batch_size)
        h = f_net(x, training=False)
        y_pred_logits = c_net(h)
        y_pred_labels = tf.argmax(y_pred_logits, axis=1, output_type=tf.int32)

        num_correct_predictions += tf.reduce_sum(tf.cast(tf.equal(y_pred_labels, y), tf.int32))

    return tf.cast(num_correct_predictions / data.num_test_images, tf.float32)


def main(encoder,encoder_weights):

    # Load CIFAR-10 dataset
    data = CIFAR10()

    # Define hyperparameters
    num_epochs = 50
    batch_size = 512

    # Instantiate networks f and c
    f_net = encoders[encoder]()
    c_net = ClassificationHead()

    # Initialize the weights of f and c
    x, y = data.get_batch_finetuning(batch_id=0, batch_size=batch_size)
    h = f_net(x, training=False)
    print('Shape of h:', h.shape)
    s = c_net(h)
    print('Shape of s:', s.shape)

    # Load the weights of f from pretraining
    f_net.load_weights(encoder_weights)
    print('Weights of f loaded.')


    # Define optimizer
    batches_per_epoch = data.num_train_images // batch_size
    total_update_steps = num_epochs * batches_per_epoch
    lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(5e-2, total_update_steps, 5e-4, power=2)
    opt = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

    
    @tf.function
    def train_step_evaluation(x, y):  # (bs, 32, 32, 3), (bs)

        # Forward pass
        with tf.GradientTape() as tape:
            h = f_net(x, training=False)  # (bs, 512)
            y_pred_logits = c_net(h)  # (bs, 10)
            loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=y_pred_logits))
        
        # Backward pass
        grads = tape.gradient(loss, c_net.trainable_variables)
        opt.apply_gradients(zip(grads, c_net.trainable_variables))

        return loss


    log_every = 10  # batches
    for epoch_id in range(num_epochs):
        data.shuffle_training_data()
        
        for batch_id in range(batches_per_epoch):
            x, y = data.get_batch_finetuning(batch_id, batch_size)
            loss = train_step_evaluation(x, y)
            if (batch_id + 1) % log_every == 0:
                print('[Epoch {}/{} Batch {}/{}] Loss: {:.4f}'.format(epoch_id+1, num_epochs, batch_id+1, batches_per_epoch, loss))
    
    # Compute classification accuracy on test set
    test_accuracy = compute_test_accuracy(data, f_net, c_net)
    print('Test Accuracy: {:.4f}'.format(test_accuracy))
    




In [11]:

# parser = argparse.ArgumentParser()

# parser.add_argument('--encoder', type=str, required=True, choices=['resnet18', 'resnet34'], help='Encoder architecture')
# parser.add_argument('--encoder_weights', type=str, help='Encoder weights')

# args = parser.parse_args()
main('resnet18','/content/drive/MyDrive/dataset/f_online_100.h5')

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
Shape of h: (512, 512)
Shape of s: (512, 10)
Weights of f loaded.
[Epoch 1/50 Batch 10/97] Loss: 3.3132
[Epoch 1/50 Batch 20/97] Loss: 2.0262
[Epoch 1/50 Batch 30/97] Loss: 1.4943
[Epoch 1/50 Batch 40/97] Loss: 1.4179
[Epoch 1/50 Batch 50/97] Loss: 1.0761
[Epoch 1/50 Batch 60/97] Loss: 1.0093
[Epoch 1/50 Batch 70/97] Loss: 1.0558
[Epoch 1/50 Batch 80/97] Loss: 0.7877
[Epoch 1/50 Batch 90/97] Loss: 0.8801
[Epoch 2/50 Batch 10/97] Loss: 0.7555
[Epoch 2/50 Batch 20/97] Loss: 0.8941
[Epoch 2/50 Batch 30/97] Loss: 0.8735
[Epoch 2/50 Batch 40/97] Loss: 0.7674
[Epoch 2/50 Batch 50/97] Loss: 0.7736
[Epoch 2/50 Batch 60/97] Loss: 0.8255
[Epoch 2/50 Batch 70/97] Loss: 0.8255
[Epoch 2/50 Batch 80/97] Loss: 0.8494
[Epoch 2/50 Batch 90/97] Loss: 0.7836
[Epoch 3/50 Batch 10/97] Loss: 0.8952
[Epoch 3/50 Batch 20/97] Loss: 0.8055
[Epoch 3/50 Batch 30/97] Loss: 0.7151
[Epoch 3/50 Batch 40/97] Loss: 0.7777
[Epoch 3/50 Batch 50