In [None]:
import os
import yaml
import numpy as np
import torch
import matplotlib.pyplot as plt
from postprocessing.plot_knn import plot_knn_examples
from postprocessing.plot_umap import plot_umap_projection
from postprocessing.plot_grid import plot_images_on_grid
from models.simclr import SimCLRModel

In [None]:
# Load Dataset Path from Configuration
with open("configs/dataset_config.yaml", "r") as file:
    dataset_config = yaml.safe_load(file)
DATASET_PATH = dataset_config["dataset"]["data_path"]

# Load Model
model_path = "checkpoints/final_model.pth"
if not os.path.exists(model_path):
    raise FileNotFoundError("Model checkpoint not found. Please run `main.sh` first.")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimCLRModel()
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
print("Model loaded successfully.")


In [None]:
# Load Embeddings
embeddings_path = "embeddings.npy"
embeddings_2d_path = "embeddings_2d.npy"

if os.path.exists(embeddings_path) and os.path.exists(embeddings_2d_path):
    embeddings = np.load(embeddings_path)
    embeddings_2d = np.load(embeddings_2d_path)
    print("Embeddings loaded successfully.")
else:
    raise FileNotFoundError("Embeddings not found. Please run `main.sh` first.")

In [None]:
# Retrieve Filenames Directly from Dataset Path
image_extensions = (".jpg", ".png", ".jpeg")
filenames = sorted([os.path.join(DATASET_PATH, f) for f in os.listdir(DATASET_PATH) if f.lower().endswith(image_extensions)])

if len(filenames) != len(embeddings):
    raise ValueError(f"Mismatch: {len(embeddings)} embeddings vs. {len(filenames)} filenames!")

In [None]:
# Plot 1: Visualize Nearest Neighbors
plot_knn_examples(embeddings, filenames, n_neighbors=5, num_examples=6)

In [None]:
# Plot 2: UMAP Projection
plot_umap_projection(embeddings_2d)

In [None]:
# Plot 3: Plot Images on Grid
plot_images_on_grid(embeddings_2d, filenames, grid_size=20, cell_size=128, step=50)
