In [None]:
SAMPLES_PER_CLASS = 200  # You can change this
img_size = (224, 224)

In [None]:
import numpy as np
from collections import defaultdict

# Storage
balanced_images = []
balanced_labels = []

# Count per class
class_counts = defaultdict(int)
num_classes = len(class_names)

# Iterate until all classes are filled
for batch_imgs, batch_labels in train_ds.unbatch():
    label = int(batch_labels.numpy())
    
    if class_counts[label] < SAMPLES_PER_CLASS:
        img = tf.image.resize(batch_imgs, img_size).numpy()
        balanced_images.append(img)
        balanced_labels.append(label)
        class_counts[label] += 1
        
    # Stop when all classes are filled
    if all(class_counts[c] >= SAMPLES_PER_CLASS for c in range(num_classes)):
        break

balanced_images = np.array(balanced_images)
balanced_labels = np.array(balanced_labels)

print("Balanced images:", balanced_images.shape)
print("Balanced labels:", balanced_labels.shape)


In [None]:
from tensorflow import keras

feature_extractor = keras.applications.EfficientNetB0(
    include_top=False,
    weights="imagenet",
    pooling="avg",
    input_shape=img_size + (3,)
)

In [None]:
# Preprocess
balanced_images_pp = keras.applications.efficientnet.preprocess_input(
    balanced_images
)

# Extract embeddings
embeddings = feature_extractor.predict(balanced_images_pp, batch_size=32)
print("Embeddings shape:", embeddings.shape)

In [None]:
!pip install umap-learn

import umap.umap_ as umap

reducer = umap.UMAP(
    n_components=2,
    n_neighbors=15,
    min_dist=0.1,
    metric="euclidean",
    random_state=42
)

emb_2d = reducer.fit_transform(embeddings)
print("UMAP shape:", emb_2d.shape)

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 9))
scatter = plt.scatter(
    emb_2d[:, 0],
    emb_2d[:, 1],
    c=balanced_labels,
    cmap="tab20",
    s=6
)

plt.colorbar(scatter, fraction=0.02)
plt.title("Balanced UMAP of EfficientNet-B0 Embeddings (Equal Samples Per Class)")
plt.xlabel("UMAP-1")
plt.ylabel("UMAP-2")
plt.show()