Importing libraries

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
import os
import random
from tensorflow.keras.preprocessing import image_dataset_from_directory


Dataset Loading (Unlabeled for Pretraining)

In [None]:
import os
import tensorflow as tf

IMAGE_SIZE = 224
BATCH_SIZE = 32

train_dirs = [f"ssl_dataset/train.X{i}" for i in range(1, 5)]

def load_mae_dataset():
    all_datasets = []
    for path in train_dirs:
        ds = tf.keras.preprocessing.image_dataset_from_directory(
            path,
            labels=None,
            label_mode=None,
            image_size=(IMAGE_SIZE, IMAGE_SIZE),
            batch_size=BATCH_SIZE,
            shuffle=True
        )
        ds = ds.map(lambda x: x / 255.0)  # Normalize
        all_datasets.append(ds)

    # Combine into one
    train_ds = all_datasets[0]
    for ds in all_datasets[1:]:
        train_ds = train_ds.concatenate(ds)

    return train_ds

train_dataset = load_mae_dataset()



Patch + Masking Utility

In [None]:
def patchify(images, patch_size=16):
    batch_size = tf.shape(images)[0]
    patches = tf.image.extract_patches(
        images=images,
        sizes=[1, patch_size, patch_size, 1],
        strides=[1, patch_size, patch_size, 1],
        rates=[1, 1, 1, 1],
        padding='VALID'
    )
    patch_dims = patches.shape[-1]
    patches = tf.reshape(patches, [batch_size, -1, patch_dims])
    return patches

def random_masking(patches, mask_ratio=0.75):
    batch, num_patches, _ = patches.shape
    num_mask = int(mask_ratio * num_patches)

    mask_indices = np.array([
        np.random.choice(num_patches, num_mask, replace=False) for _ in range(batch)
    ])

    mask = np.ones((batch, num_patches), dtype=np.float32)
    for i in range(batch):
        mask[i, mask_indices[i]] = 0  # 0 = keep, 1 = mask

    return tf.convert_to_tensor(mask), mask_indices


MAE Encoder + Decoder

In [None]:
def create_encoder(input_shape, num_patches, embed_dim):
    inputs = layers.Input(shape=input_shape)
    x = layers.Dense(embed_dim)(inputs)
    x = layers.LayerNormalization()(x)
    for _ in range(4):  # Use more layers for deeper encoder
        x1 = layers.LayerNormalization()(x)
        x1 = layers.MultiHeadAttention(num_heads=4, key_dim=embed_dim)(x1, x1)
        x = layers.Add()([x, x1])
    outputs = layers.LayerNormalization()(x)
    return models.Model(inputs, outputs, name="encoder")

def create_decoder(embed_dim, patch_dim):
    inputs = layers.Input(shape=(None, embed_dim))
    x = layers.Dense(embed_dim)(inputs)
    x = layers.LayerNormalization()(x)
    x = layers.Dense(patch_dim)(x)
    return models.Model(inputs, x, name="decoder")


MAE Model Wrapper

In [None]:
class MAE(tf.keras.Model):
    def __init__(self, encoder, decoder, num_patches, patch_dim):
        super(MAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.num_patches = num_patches
        self.patch_dim = patch_dim

    def call(self, images):
        patches = patchify(images)  # [B, N, D]
        mask, _ = random_masking(patches)
        visible_patches = patches * tf.expand_dims(1 - mask, -1)
        latent = self.encoder(visible_patches)
        reconstructed = self.decoder(latent)
        return reconstructed, patches, mask


Training Loop

In [None]:
PATCH_SIZE = 16
NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2
PATCH_DIM = PATCH_SIZE * PATCH_SIZE * 3
EMBED_DIM = 128

encoder = create_encoder((NUM_PATCHES, PATCH_DIM), NUM_PATCHES, EMBED_DIM)
decoder = create_decoder(EMBED_DIM, PATCH_DIM)
mae = MAE(encoder, decoder, NUM_PATCHES, PATCH_DIM)

loss_fn = tf.keras.losses.MeanSquaredError()
optimizer = tf.keras.optimizers.Adam()

@tf.function
def train_step(images):
    with tf.GradientTape() as tape:
        reconstructed, original, mask = mae(images)
        loss = loss_fn(original * tf.expand_dims(mask, -1), reconstructed * tf.expand_dims(mask, -1))
    gradients = tape.gradient(loss, mae.trainable_variables)
    optimizer.apply_gradients(zip(gradients, mae.trainable_variables))
    return loss

EPOCHS = 3
for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")
    for batch in train_dataset:
        loss = train_step(batch)
    print(f"Loss: {loss.numpy():.4f}")


Save Encoder for Linear Probing

In [None]:
encoder.save_weights("mae_encoder_tf.h5")