Linear Probing with Pretrained MAE Encoder - TensorFlow

In [None]:

import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
from tensorflow.keras.preprocessing import image_dataset_from_directory


 Load Labeled Dataset for Linear Probing

In [None]:
IMAGE_SIZE= 224
BATCH_SIZE= 32
NUM_CLASSES= 100
import json

# Load class label mapping
with open("ssl_dataset/Labels.json", "r") as f:
    label_map = json.load(f)
class_names = sorted(label_map, key=lambda k: label_map[k])

train_dirs = ["ssl_dataset/train.X1", "ssl_dataset/train.X2", "ssl_dataset/train.X3", "ssl_dataset/train.X4"]
val_dir = "ssl_dataset/val.X"

# Load and normalize training dataset
all_train_ds = []
for path in train_dirs:
    ds = image_dataset_from_directory(
        path,
        labels='inferred',
        label_mode='int',
        class_names=class_names,
        image_size=(IMAGE_SIZE, IMAGE_SIZE),
        batch_size=BATCH_SIZE,
        shuffle=True
    )
    ds = ds.map(lambda x, y: (x / 255.0, y))
    all_train_ds.append(ds)

train_dataset = all_train_ds[0]
for ds in all_train_ds[1:]:
    train_dataset = train_dataset.concatenate(ds)

# Load and normalize validation dataset
val_dataset = image_dataset_from_directory(
    val_dir,
    labels='inferred',
    label_mode='int',
    class_names=class_names,
    image_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size=BATCH_SIZE,
    shuffle=False
)
val_dataset = val_dataset.map(lambda x, y: (x / 255.0, y))


Patchify Function for Embedding Extraction

In [None]:

def patchify(images, patch_size=16):
    batch_size = tf.shape(images)[0]
    patches = tf.image.extract_patches(
        images=images,
        sizes=[1, patch_size, patch_size, 1],
        strides=[1, patch_size, patch_size, 1],
        rates=[1, 1, 1, 1],
        padding='VALID'
    )
    patch_dims = patches.shape[-1]
    patches = tf.reshape(patches, [batch_size, -1, patch_dims])
    return patches


Load Pretrained Encoder

In [None]:

def create_encoder(input_shape, num_patches, embed_dim):
    inputs = layers.Input(shape=input_shape)
    x = layers.Dense(embed_dim)(inputs)
    x = layers.LayerNormalization()(x)
    for _ in range(4):
        x1 = layers.LayerNormalization()(x)
        x1 = layers.MultiHeadAttention(num_heads=4, key_dim=embed_dim)(x1, x1)
        x = layers.Add()([x, x1])
    outputs = layers.LayerNormalization()(x)
    return models.Model(inputs, outputs, name="encoder")

PATCH_SIZE = 16
NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2
PATCH_DIM = PATCH_SIZE * PATCH_SIZE * 3
EMBED_DIM = 128

encoder = create_encoder((NUM_PATCHES, PATCH_DIM), NUM_PATCHES, EMBED_DIM)
encoder.load_weights("mae_encoder_tf.h5")
encoder.trainable = False  # Freeze the encoder


'Linear Classifier on Top of Frozen Encoder

In [None]:

inputs = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3))
patches = patchify(inputs)
features = encoder(patches)
features = layers.GlobalAveragePooling1D()(features)
outputs = layers.Dense(NUM_CLASSES, activation='softmax')(features)

model = models.Model(inputs, outputs)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])


 Train Linear Classifier

In [None]:

history = model.fit(train_dataset, validation_data=val_dataset, epochs=3)


 Evaluation: Accuracy and F1 Score

In [None]:

from sklearn.metrics import classification_report
import numpy as np

# Get true and predicted labels
y_true, y_pred = [], []
for x_batch, y_batch in val_dataset:
    preds = model.predict(x_batch)
    y_true.extend(y_batch.numpy())
    y_pred.extend(np.argmax(preds, axis=1))

print(classification_report(y_true, y_pred, digits=4))


In [None]:
plt.plot(history.history['loss'], label='train_loss')
plt.plot(history.history['val_loss'], label='val_loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('MAE Linear Probing Loss')
plt.show()
