Importing Libraries

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
from sklearn.metrics import classification_report
import numpy as np
import os


Load class label mapping from Labels.json

In [None]:
import json

# Load class label mapping
with open("ssl_dataset/Labels.json", "r") as f:
    label_map = json.load(f)

# Sort class names by label index
class_names = sorted(label_map, key=lambda k: label_map[k])


Load Labeled Dataset

In [None]:
train_dirs = ["ssl_dataset/train.X1", "ssl_dataset/train.X2", "ssl_dataset/train.X3", "ssl_dataset/train.X4"]
val_dir = "ssl_dataset/val.X"

batch_size = 64
img_size = (224, 224)
normalization_layer = tf.keras.layers.Rescaling(1./255)

# Load and normalize training datasets from all four folders
all_ds = []
for path in train_dirs:
    ds = tf.keras.preprocessing.image_dataset_from_directory(
        path,
        label_mode='int',
        image_size=img_size,
        batch_size=batch_size,
        class_names=class_names,
        shuffle=True
    )
    ds = ds.map(lambda x, y: (normalization_layer(x), y))
    all_ds.append(ds)

# Concatenate all datasets into one train_ds
train_ds = all_ds[0]
for ds in all_ds[1:]:
    train_ds = train_ds.concatenate(ds)

# Load and normalize validation dataset
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    val_dir,
    label_mode='int',
    image_size=img_size,
    batch_size=batch_size,
    class_names=class_names,
    shuffle=False
)
val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y))


Load Frozen Encoder

In [None]:
encoder = tf.keras.models.load_model('simclr_encoder.h5')
encoder.trainable = False  # Freeze it!


Attach Linear Classifier

In [None]:
num_classes = 100  # For ImageNet-100

model = tf.keras.Sequential([
    encoder,
    layers.Dense(num_classes, activation='softmax')
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])


Train Linear Classifier


In [None]:
history = model.fit(train_ds, validation_data=val_ds, epochs=3)


Evaluate Performance

In [None]:
# Get predictions and true labels
y_true = np.concatenate([y.numpy() for x, y in val_ds], axis=0)
y_pred_probs = model.predict(val_ds)
y_pred = np.argmax(y_pred_probs, axis=1)

# Classification report
print(classification_report(y_true, y_pred))


Plot Accuracy & Loss Curves

In [None]:
import matplotlib.pyplot as plt

plt.plot(history.history['accuracy'], label='Train Acc')
plt.plot(history.history['val_accuracy'], label='Val Acc')
plt.title('Accuracy Curve')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Val Loss')
plt.title('Loss Curve')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()
