In [None]:
import keras
import tensorflow as tf

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
drive  sample_data


In [7]:
!ls /content/drive/MyDrive/BAP/

BYOL.ipynb  UATD_TF_Datasets


In [None]:
# For EMA updates
class EMA:
  def __init__(self, model, decay=0.99):
    self.model = model
    self.decay = decay

    self.ema_model = keras.models.clone_model(model)
    self.ema_model.set_weights(model.get_weights())

  def update(self):
    for (ema_w, w) in zip(self.ema_model.weights, self.model.weights):
      ema_w.assign(self.decay * ema_w + (1 - self.decay) * w)

  def get_ema_model(self):
    return self.ema_model

In [None]:
# Define encoder network
def build_encoder():
  base_model = keras.applications.ResNet50(
      include_top=False,
      weights=None,
      input_shape=(224, 224, 3),
      pooling="avg"
  )
  return keras.Model(base_model.input, base_model.output, name="encoder")

In [None]:
# Define the Projection Head
def build_projection_head(input_dim, output_dim=256):
  model = keras.Sequential(name="projection_head")
  model.add(keras.layers.Dense(units=4096, activation="relu"))
  model.add(keras.layers.Dense(units=output_dim, activation=None))
  return model

In [None]:
# Define the Prediction Head
def build_prediction_head(input_dim, output_dim=256):
  model = keras.Sequential(name="prediction_head")
  model.add(keras.layers.Dense(units=4096, activation="relu"))
  model.add(keras.layers.Dense(units=output_dim, activation=None))
  return model

In [None]:
class BYOL(keras.Model):

  def __init__(self, input_shape=(224, 224, 3)):
    super(BYOL, self).__init__()

    self.encoder = build_encoder()
    self.projection_head = build_projection_head(input_dim=2048)
    self.prediction_head = build_prediction_head(input_dim=256)

    self.ema = EMA(self.encoder, decay=0.99) # Exponential moving average
    self.ema_projection = EMA(self.projection_head, decay=0.99)

  def call(self, inputs, training=True):
    view1, view2 = inputs # Two augmented views

    # Online network
    z1 = self.projection_head(self.encoder(view1, training=training))
    z2 = self.projection_head(self.encoder(view2, training=training))

    p1 = self.prediction_head(z1)
    p2 = self.prediction_head(z2)

    # Target network (EMA updates)
    with tf.stop_gradient():
      target_encoder = self.ema.get_ema_model()
      target_projection = self.ema_projection.get_ema_model()
      z1_target = target_projection(target_encoder(view1, training=False))
      z2_target = target_projection(target_encoder(view2, training=False))

    return p1, p2, z1_target, z2_target

In [None]:
def cosine_similarity_loss(x, y):
    x = tf.math.l2_normalize(x, axis=1)
    y = tf.math.l2_normalize(y, axis=1)
    return -tf.reduce_mean(tf.reduce_sum(x * y, axis=1))  # Negative cosine similarity

In [None]:
class BYOLLoss(keras.losses.Loss):
    def call(self, y_true, y_pred):
        p1, p2, z1_target, z2_target = y_pred
        return (cosine_similarity_loss(p1, z2_target) + cosine_similarity_loss(p2, z1_target)) / 2

In [None]:
# Create dataset of augmented image pairs
def get_dataset(batch_size=32):
    (x_train, _), (_, _) = keras.datasets.cifar10.load_data()
    x_train = tf.image.resize(x_train, (224, 224)) / 255.0  # Normalize
    dataset = tf.data.Dataset.from_tensor_slices(x_train)

    def augment(image):
        return tf.image.random_flip_left_right(tf.image.random_crop(image, (224, 224, 3)))

    dataset = dataset.map(lambda x: (augment(x), augment(x)))
    dataset = dataset.batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
    return dataset

byol = BYOL()
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
loss_fn = BYOLLoss()

@tf.function
def train_step(images):
    with tf.GradientTape() as tape:
        predictions = byol(images, training=True)
        loss = loss_fn(None, predictions)

    grads = tape.gradient(loss, byol.trainable_variables)
    optimizer.apply_gradients(zip(grads, byol.trainable_variables))

    # Update EMA model
    byol.ema.update()
    byol.ema_projection.update()

    return loss

dataset = get_dataset()
epochs = 10

for epoch in range(epochs):
    for batch in dataset:
        loss = train_step(batch)
    print(f"Epoch {epoch+1}, Loss: {loss.numpy():.4f}")
