In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import tensorflow as tf
from tensorflow.keras.preprocessing import image

def load_and_preprocess_image(img_path, target_size=(224,224)):
    img = image.load_img(img_path, target_size=target_size)
    arr = image.img_to_array(img) / 255.0
    return np.expand_dims(arr, axis=0)

embed_model = tf.keras.Model(medinet_xg.input, medinet_xg.get_layer("embedding").output)

root_dir = input().strip()
max_per_class = int(input().strip()) if True else 30

samples = []
for cls in sorted(os.listdir(root_dir)):
    cls_path = os.path.join(root_dir, cls)
    if not os.path.isdir(cls_path):
        continue
    count = 0
    for fn in os.listdir(cls_path):
        if fn.lower().endswith((".jpg", ".jpeg", ".png")):
            samples.append((os.path.join(cls_path, fn), cls))
            count += 1
            if count >= max_per_class:
                break

X = []
y = []
for p, lab in samples:
    v = embed_model.predict(load_and_preprocess_image(p, (img_size, img_size)), verbose=0).squeeze()
    X.append(v)
    y.append(lab)

X = np.array(X)
y = np.array(y)

Z = TSNE(n_components=2, perplexity=30, init="pca", learning_rate="auto", random_state=42).fit_transform(X)

plt.figure(figsize=(9,7), dpi=300)
for cls in np.unique(y):
    idx = (y == cls)
    plt.scatter(Z[idx,0], Z[idx,1], s=10, label=cls)
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left", fontsize=7)
plt.title("t-SNE of MediNet_XG Embeddings")
plt.tight_layout()
out_path = f"{OUTPUT_DIR}/tsne_MediNet_XG.png"
plt.savefig(out_path, dpi=300, bbox_inches="tight")
plt.show()
plt.close()
