In [1]:
# !pip install tensorflow-datasets
# https://keras.io/examples/vision/nnclr/ (aim) 

Collecting tensorflow-datasets
  Downloading tensorflow_datasets-4.5.2-py3-none-any.whl (4.2 MB)
Collecting tensorflow-metadata
  Downloading tensorflow_metadata-1.6.0-py3-none-any.whl (48 kB)
Collecting numpy
  Downloading numpy-1.22.1-cp39-cp39-win_amd64.whl (14.7 MB)
Collecting absl-py
  Downloading absl_py-1.0.0-py3-none-any.whl (126 kB)
Collecting protobuf>=3.12.2
  Downloading protobuf-3.19.4-cp39-cp39-win_amd64.whl (895 kB)
Collecting termcolor
  Using cached termcolor-1.1.0-py3-none-any.whl
Collecting promise
  Using cached promise-2.3-py3-none-any.whl
Collecting dill
  Using cached dill-0.3.4-py2.py3-none-any.whl (86 kB)
Collecting googleapis-common-protos<2,>=1.52.0
  Downloading googleapis_common_protos-1.54.0-py2.py3-none-any.whl (207 kB)
Collecting colorama
  Using cached colorama-0.4.4-py2.py3-none-any.whl (16 kB)
Installing collected packages: protobuf, googleapis-common-protos, colorama, absl-py, termcolor, tensorflow-metadata, promise, numpy, dill, tensorflow-datasets


ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
pylint 2.10.2 requires isort<6,>=4.2.5, which is not installed.
pylint 2.10.2 requires platformdirs>=2.2.0, which is not installed.
pylint 2.10.2 requires toml>=0.7.1, which is not installed.


In [1]:
import numpy as np 
import matplotlib.pyplot as plt 
import tensorflow as tf
import tensorflow_datasets as tfds

AUTOTUNE = tf.data.AUTOTUNE

In [12]:
N_SHUFFLE = 5000
N_EPOCH = 64


unlabeled_images = 100000
labeled_images = 5000

input_shape = (96, 96, 3)
width = 128

cont_aug = {
    'brightness' : 0.5, 
    'name' : 'cont_aug', 
    'scale' : (0.2, 1.0)
}

class_aug = {
    'brightness' : 0.2, 
    'name' : 'class_aug', 
    'scale' : (0.5, 1.0)
}


In [3]:
def make_dataset():
    unlabeled_batch_size = unlabeled_images // N_EPOCH
    labeled_batch_size = labeled_images // N_EPOCH
    
    batch_size = unlabeled_batch_size + labeled_batch_size
    
    unlabeled_train_dataset = (
        tfds.load(
            'stl10', split='unlabelled', as_supervised=True, shuffle_files=True 
        )
        .shuffle(N_SHUFFLE)
        .batch(unlabeled_batch_size, drop_remainder=True)
    )
    
    labeled_train_dataset = (
        tfds.load(
            'stl10', split='train', as_supervised=True, shuffle_files=True 
        )
        .shuffle(N_SHUFFLE)
        .batch(unlabeled_batch_size, drop_remainder=True)
    )
    
    test_dataset = (
        tfds.load(
            'stl10', split='test', as_supervised=True
        )
        .batch(unlabeled_batch_size, drop_remainder=True)
        .prefetch(AUTOTUNE)
    )
    
    train_dataset = tf.data.Dataset.zip(
        (unlabeled_train_dataset, labeled_train_dataset)
    ).prefetch(AUTOTUNE)
     
    return batch_size, train_dataset, labeled_train_dataset, test_dataset

batch_size, train_dataset, labeled_train_dataset, test_dataset = make_dataset() 

In [4]:
class RandomResizedCrop(tf.keras.layers.Layer):
    def __init__(self, scale, ratio):
        super(RandomResizedCrop, self).__init__()
        self.scale = scale 
        self.log_ratio = (tf.math.log(ratio[0], tf.math.log(ratio[1])))
        
    def call(self, images):
        batch_size = tf.shape(images)[0]
        height = tf.shape(images)[1]
        width = tf.shape(images)[2]
        
        random_scales = tf.random.uniform((batch_size, ), self.scale[0], self.scale[1])
        random_ratios = tf.exp(
            tf.random.uniform((batch_size, ), self.log_ratio[0], self.log_ratio[1])
        )
        
        # Tensor 값 범위 지정
        new_heights = tf.clip_by_value(tf.sqrt(random_scales / random_ratios), 0, 1)
        new_widths = tf.clip_by_value(tf.sqrt(random_scales * random_ratios), 0, 1)
        
        height_offsets = tf.random.uniform((batch_size, ), 0, 1 - new_heights)
        weight_offsets = tf.random.uniform((batch_size, ), 0, 1 - new_widths)
        
        bounding_boxes = tf.stack(
            [
                height_offsets, 
                weight_offsets, 
                height_offsets + new_heights, 
                weight_offsets + new_widths, 
            ],
            axis=1
        )
        
        images = tf.image.crop_and_resize(
            images, bounding_boxes, tf.range(batch_size), (height, width)
        )
        
        return images

In [5]:
class RandomBrightness(tf.keras.layers.Layer):
    def __init__(self, brightness):
        super(RandomBrightness, self).__init__()
        self.brightness = brightness
        
    def blend(self, images_1, images_2, ratios):
        return tf.clip_by_value(ratios * images_1 + (1.0 - ratios) * images_2, 0, 1)
    
    def random_brightness(self, images):
        return self.blend(
            images, 
            0,
            tf.random.uniform(
                (tf.shape(images)[0], 1, 1, 1), 1 - self.brightness, 1 + self.brightness
            )
        )
        
    def call(self, images):
        images = self.random_brightness(images)
        return images

In [7]:
def aug(brightness, name, scale):
    return tf.keras.Sequential(
        [
            tf.keras.layers.Input(shape=input_shape), 
            tf.keras.layers.Rescaling(1 / 255), 
            tf.keras.layers.RandomFlip('horizontal'), 
            RandomResizedCrop(scale=scale, ratio=(3 / 4, 4 / 3)), 
            RandomBrightness(brightness=brightness)
            
        ]
    )

In [9]:
def encoder():
    return tf.keras.Sequentail([
        tf.keras.layers.Conv2D(width, kernel_size=3, strides=2, activation='relu'),
        tf.keras.layers.Conv2D(width, kernel_size=3, strides=2, activation='relu'),
        tf.keras.layers.Conv2D(width, kernel_size=3, strides=2, activation='relu'),
        tf.keras.layers.Conv2D(width, kernel_size=3, strides=2, activation='relu'),
        tf.keras.layers.Flatten(), 
        tf.keras.layers.Dense(width, activation='relu')
    ], name='encoder'
    )

In [None]:
class NNCLR(tf.keras.Model):
    def __init__(self, temp, queue_size):
        super(NNCLR, self).__init__()
        self.prob_accuracy = tf.keras.metrics.SparseCategoricalAccuracy() 
        self.corr_accuracy = tf.keras.metrics.SparseCategoricalAccuracy() 
        self.cont_accuracy = tf.keras.metrics.SparseCategoricalAccuracy() 
        
        self.prob_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        
        self.cont_aug = aug(**cont_aug)
        self.class_aug = aug(**class_aug)
        self.encoder = encoder()
        self.prediction_head = tf.keras.Sequential([
            tf.keras.layers.Input(shape=(width, )), 
            tf.keras.layers.Dense(width, activation='relu'), 
            tf.keras.layers.Dense(width)
        ], name='projection_head')
        self.linear_probe = tf.keras.Sequenatial([
            tf.keras.layers.Input(shape=(width, )), 
            tf.keras.layers.Dense(10)
        ], name='linear_probe')
        self.temp = temp 
        feature_dims = self.encoder.output_shape[1]
        self.featrue_queue = tf.Variable(
            tf.math.l2_normalize(
                tf.random.normal(shape=(queue_size, feature_dims)), axis=1
            ), 
            trainable=False 
        )
        
    def compile(self, cont_optimizer, prob_optimizer, **kwargs):
        super(NNCLR, self).compile(**kwargs)
        self.cont_optimizer  = cont_optimizer
        self.prob_optimizer = prob_optimizer
        
    def nearest_neighbour(self, projections): 
        support_similarties = tf.matmul(
            projections, 
            self.featrue_queue,
            transpose_b = True
        )
        nn_projections = tf.gather(
            self.featrue_queue, 
            tf.argmax(support_similarties, axis=1), axis=0
        )
        return projections + tf.stop_gradient(nn_projections - projections) 
    
    def update_cont_accuracy(self, feature1, feature2):
        feature1 = tf.math.l2_normalize(feature1, axis=1) 
        feature2 = tf.math.l2_normalize(feature2, axis=1) 
        
        
        
        