# 🔍 Linear Probing with Pretrained MAE Encoder - TensorFlow

In [None]:

import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
from tensorflow.keras.preprocessing import image_dataset_from_directory


## 📦 Load Labeled Dataset for Linear Probing

In [None]:

IMAGE_SIZE = 224
BATCH_SIZE = 64
NUM_CLASSES = 100  # Adjust to match ImageNet-100

train_dataset = image_dataset_from_directory(
    "path_to/train",
    labels='inferred',
    label_mode='int',
    shuffle=True,
    batch_size=BATCH_SIZE,
    image_size=(IMAGE_SIZE, IMAGE_SIZE)
).map(lambda x, y: (x / 255.0, y))

val_dataset = image_dataset_from_directory(
    "path_to/val",
    labels='inferred',
    label_mode='int',
    shuffle=False,
    batch_size=BATCH_SIZE,
    image_size=(IMAGE_SIZE, IMAGE_SIZE)
).map(lambda x, y: (x / 255.0, y))


## 🔲 Patchify Function for Embedding Extraction

In [None]:

def patchify(images, patch_size=16):
    batch_size = tf.shape(images)[0]
    patches = tf.image.extract_patches(
        images=images,
        sizes=[1, patch_size, patch_size, 1],
        strides=[1, patch_size, patch_size, 1],
        rates=[1, 1, 1, 1],
        padding='VALID'
    )
    patch_dims = patches.shape[-1]
    patches = tf.reshape(patches, [batch_size, -1, patch_dims])
    return patches


## 🔁 Load Pretrained Encoder

In [None]:

def create_encoder(input_shape, num_patches, embed_dim):
    inputs = layers.Input(shape=input_shape)
    x = layers.Dense(embed_dim)(inputs)
    x = layers.LayerNormalization()(x)
    for _ in range(4):
        x1 = layers.LayerNormalization()(x)
        x1 = layers.MultiHeadAttention(num_heads=4, key_dim=embed_dim)(x1, x1)
        x = layers.Add()([x, x1])
    outputs = layers.LayerNormalization()(x)
    return models.Model(inputs, outputs, name="encoder")

PATCH_SIZE = 16
NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2
PATCH_DIM = PATCH_SIZE * PATCH_SIZE * 3
EMBED_DIM = 128

encoder = create_encoder((NUM_PATCHES, PATCH_DIM), NUM_PATCHES, EMBED_DIM)
encoder.load_weights("mae_encoder_tf.h5")
encoder.trainable = False  # Freeze the encoder


## 🧠 Linear Classifier on Top of Frozen Encoder

In [None]:

inputs = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3))
patches = patchify(inputs)
features = encoder(patches)
features = layers.GlobalAveragePooling1D()(features)
outputs = layers.Dense(NUM_CLASSES, activation='softmax')(features)

model = models.Model(inputs, outputs)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])


## 🚀 Train Linear Classifier

In [None]:

model.fit(train_dataset, validation_data=val_dataset, epochs=10)


## 📊 Evaluation: Accuracy and F1 Score

In [None]:

from sklearn.metrics import classification_report
import numpy as np

# Get true and predicted labels
y_true, y_pred = [], []
for x_batch, y_batch in val_dataset:
    preds = model.predict(x_batch)
    y_true.extend(y_batch.numpy())
    y_pred.extend(np.argmax(preds, axis=1))

print(classification_report(y_true, y_pred, digits=4))
