In [None]:
import os, sys

sys.path.append("../..")

import torch
import numpy as np
from tqdm import tqdm

from astroclip.models.astroclip import AstroClipModel
from astroclip.data.datamodule import AstroClipDataloader, AstroClipCollator


ASTROCLIP_ROOT = "/mnt/ceph/users/polymathic/astroclip"
file_path = (
    f"{ASTROCLIP_ROOT}/outputs/astroclip-alignment/l1uwsr42/checkpoints/last.ckpt"
)

# Load the model
AstroClip = AstroClipModel.load_from_checkpoint(file_path)

In [None]:
# Get the data loader
loader = AstroClipDataloader(
    path="/mnt/ceph/users/polymathic/astroclip/datasets/astroclip_file/",
    batch_size=256,
    num_workers=0,
    collate_fn=AstroClipCollator(),
    columns=["image", "spectrum"],
)
loader.setup("fit")
val_loader = loader.val_dataloader()

In [None]:
# Get the embeddings over the dataset
im_embeddings, sp_embeddings = [], []
images, spectra = [], []
with torch.no_grad():
    for batch_test in tqdm(val_loader):
        # Get the image and spectrum from the batch
        im = batch_test["image"]
        sp = batch_test["spectrum"]

        # Append
        images.append(im)
        spectra.append(sp)
        im_embeddings.append(
            AstroClip(im.cuda(), input_type="image").detach().cpu().numpy()
        )
        sp_embeddings.append(
            AstroClip(sp.cuda(), input_type="spectrum").detach().cpu().numpy()
        )

# Concatenate the lists
images = np.concatenate(images)
spectra = np.concatenate(spectra)
im_embeddings = np.concatenate(im_embeddings)
sp_embeddings = np.concatenate(sp_embeddings)

In [None]:
# Normalize the embeddings
image_features_normed = im_embeddings / np.linalg.norm(
    im_embeddings, axis=-1, keepdims=True
)
spectrum_features_normed = sp_embeddings / np.linalg.norm(
    sp_embeddings, axis=-1, keepdims=True
)

In [None]:
from matplotlib import pyplot as plt


def plot_image_subset(images, n: int = 10):
    plt.figure(figsize=[10, 10])
    for i in range(n):
        for j in range(n):
            plt.subplot(n, n, i * n + j + 1)
            plt.imshow(images[i * n + j].T)
            plt.title(f"Image {i*n+j}")
            plt.axis("off")
    plt.tight_layout()
    plt.show()

In [None]:
%pylab inline
figure(figsize=[10, 10])
for i in range(10):
    for j in range(10):
        subplot(10, 10, i * 10 + j + 1)
        imshow(images[i * 10 + j].T)
        title(i * 10 + j)
        axis("off")
# plt.subplots_adjust(wspace=0.01, hspace=0.01)

In [None]:
# find index of target ids
in_query = [7, 30, 31, 48]

ims1 = []
ims2 = []
ims3 = []
ims4 = []
sps1 = []
sps2 = []
sps3 = []
sps4 = []

for ind in ind_query:
    sp_sim = spectrum_features_normed[ind] @ spectrum_features_normed.T
    im_sim = image_features_normed[ind] @ image_features_normed.T
    x_im_sim = image_features_normed[ind] @ spectrum_features_normed.T
    x_sp_sim = spectrum_features_normed[ind] @ image_features_normed.T

    ims1.append([images[i] for i in argsort(sp_sim)[::-1][:8]])
    ims2.append([images[i] for i in argsort(im_sim)[::-1][:8]])
    ims3.append([images[i] for i in argsort(x_im_sim)[::-1][:8]])
    ims4.append([images[i] for i in argsort(x_sp_sim)[::-1][:8]])

    sps1.append([spectra[i] for i in argsort(sp_sim)[::-1][:8]])
    sps2.append([spectra[i] for i in argsort(im_sim)[::-1][:8]])
    sps3.append([spectra[i] for i in argsort(x_im_sim)[::-1][:8]])
    sps4.append([spectra[i] for i in argsort(x_sp_sim)[::-1][:8]])

In [None]:
%pylab inline
figure(figsize=[19.4, 6.1])
for n, i in enumerate(ind_query):
    subplot(4, 13, n * 13 + 1)
    imshow(images[i].T)
    axis("off")

    # Image similarity
    for j in range(3):
        subplot(4, 13, n * 13 + j + 1 + 1)
        imshow(ims2[n][j].T)
        axis("off")

    # Spectra similarity
    for j in range(3):
        subplot(4, 13, n * 13 + j + 1 + 3 + 1)
        imshow(ims1[n][j].T)
        axis("off")

    # cross im similarity
    for j in range(3):
        subplot(4, 13, n * 13 + j + 1 + 6 + 1)
        imshow(ims3[n][j].T)
        axis("off")

    for j in range(3):
        subplot(4, 13, n * 13 + j + 1 + 9 + 1)
        imshow(ims4[n][j].T)
        axis("off")


plt.subplots_adjust(wspace=0.01, hspace=0.0)
plt.subplots_adjust(wspace=0.00, hspace=0.01)

In [None]:
from scipy.ndimage import gaussian_filter1d

l = np.linspace(3586.7408577, 10372.89543574, spectra.shape[1])

figure = plt.figure(figsize=[15, 5])

# First subplot
ax1 = figure.add_subplot(121)
ax1.plot(
    l,
    gaussian_filter1d(spectra[ind_query[0]][:, 0], 5),
    color="r",
    lw=1,
    label="spectrum of query image",
)
for j in range(3):
    if j == 0:
        ax1.plot(
            l,
            gaussian_filter1d(sps3[0][j + 1][:, 0], 5),
            alpha=0.5,
            lw=1,
            color="gray",
            label="retrieved spectra",
        )
    else:
        ax1.plot(
            l, gaussian_filter1d(sps3[0][j + 1][:, 0], 5), alpha=0.5, lw=1, color="gray"
        )
ax1.set_xlabel(r"$\lambda$")
ax1.set_ylabel("flux")
ax1.legend()

# Add inset image to the first subplot
axins1 = ax1.inset_axes([0, 0.55, 0.4, 0.4])
image_data = images[ind_query[0]]
axins1.imshow(image_data.T)
axins1.axis("off")  # Turn off axis

# Second subplot
ax2 = figure.add_subplot(122)
ax2.plot(l, gaussian_filter1d(spectra[ind_query[2]][:, 0], 5), color="b", lw=1)
for j in range(3):
    if j == 0:
        ax2.plot(
            l,
            gaussian_filter1d(sps3[2][j + 1][:, 0], 5),
            alpha=0.5,
            lw=1,
            color="gray",
            label="retrieved spectra",
        )
    else:
        ax2.plot(
            l, gaussian_filter1d(sps3[2][j + 1][:, 0], 5), alpha=0.5, lw=1, color="gray"
        )
ax2.set_xlabel(r"$\lambda$")
ax2.set_ylabel("flux")
ax2.set_ylim(0, 15)

# Add inset image to the second subplot (if needed)
axins2 = ax2.inset_axes([0.6, 0.55, 0.4, 0.4])
# axins2 = ax2.inset_axes(width="30%", height="30%", loc='upper right')
axins2.imshow(images[ind_query[2]].T)
axins2.axis("off")  # Turn off axis

In [None]:
figure = plt.figure(figsize=[15, 5])

# First subplot
ax1 = figure.add_subplot(121)
ax1.plot(
    l,
    gaussian_filter1d(spectra[ind_query[0]][:, 0], 5),
    color="r",
    lw=1,
    label="query spectrum",
)
for j in range(3):
    if j == 0:
        ax1.plot(
            l,
            gaussian_filter1d(sps1[0][j + 1][:, 0], 5),
            alpha=0.5,
            lw=1,
            color="gray",
            label="retrieved spectra",
        )
    else:
        ax1.plot(
            l, gaussian_filter1d(sps1[0][j + 1][:, 0], 5), alpha=0.5, lw=1, color="gray"
        )
ax1.set_xlabel(r"$\lambda$")
ax1.set_ylabel("flux")
ax1.legend(loc=1)

# Add inset image to the first subplot
axins1 = ax1.inset_axes([0, 0.55, 0.4, 0.4])
image_data = images[ind_query[0]]
axins1.imshow(image_data.T)
axins1.axis("off")  # Turn off axis

# Second subplot
ax2 = figure.add_subplot(122)
ax2.plot(l, gaussian_filter1d(spectra[ind_query[2]][:, 0], 5), color="b", lw=1)
for j in range(3):
    if j == 0:
        ax2.plot(
            l,
            gaussian_filter1d(sps1[2][j + 1][:, 0], 5),
            alpha=0.5,
            lw=1,
            color="gray",
            label="retrieved spectra",
        )
    else:
        ax2.plot(
            l, gaussian_filter1d(sps1[2][j + 1][:, 0], 5), alpha=0.5, lw=1, color="gray"
        )
ax2.set_xlabel(r"$\lambda$")
ax2.set_ylabel("flux")
ax2.set_ylim(0, 15)


# Add inset image to the second subplot (if needed)
axins2 = ax2.inset_axes([0.6, 0.55, 0.4, 0.4])
# axins2 = ax2.inset_axes(width="30%", height="30%", loc='upper right')
axins2.imshow(images[ind_query[2]].T)
axins2.axis("off")  # Turn off axis

# plt.savefig('spectra_retrieval_spectrum.pdf', bbox_inches = 'tight', pad_inches = 0 )