In [None]:
import math
import itertools
from functools import partial

import torch
import torch.nn.functional as F

In [None]:
import os
import sys

sys.path.insert(0, os.path.abspath("../dinov2/"))

In [None]:
from dinov2.utils.config import setup
from dinov2.models import build_model_from_cfg
from dinov2.fsdp import FSDPCheckpointer
from dinov2.train.ssl_meta_arch import SSLMetaArch
from dinov2.eval.setup import setup_and_build_model

In [None]:
class config:
    output_dir = (
        "/mnt/ceph/users/polymathic/astroclip/outputs/astroclip_image/u6lwxdfu/"
    )
    config_file = "./config.yaml"
    pretrained_weights = "/mnt/ceph/users/polymathic/astroclip/outputs/astroclip_image/u6lwxdfu/eval/training_99999/teacher_checkpoint.pth"
    opts = []

In [None]:
model, dtype = setup_and_build_model(config())

In [None]:
from data.augmentations import ToRGB

In [None]:
# We should also look at how the images from the dataset look like

sys.path.insert(0, os.path.abspath("../"))

from astrodino.data.loaders import make_dataset

In [None]:
from torchvision.transforms import CenterCrop, Compose

test_dataset = make_dataset(
    dataset_str='LegacySurvey:split=test:root=/mnt/ceph/users/polymathic/external_data/astro/DECALS_Stein_et_al/:extra=""',
    transform=Compose([CenterCrop(144), ToRGB()]),
)

In [None]:
from tqdm import tqdm
import numpy as np

embeddings = []
for i in tqdm(range(10_000 // 64)):
    images = torch.tensor([test_dataset[j + i * 64][0].T for j in range(64)]).cuda()
    embeddings.append(model(images).detach().cpu())
    del images
embeddings = np.concatenate(embeddings)

In [None]:
embeddings = np.concatenate(embeddings)

In [None]:
%pylab inline

figure(figsize=[20, 20])
for i in range(8):
    for j in range(8):
        subplot(8, 8, i * 8 + j + 1)
        imshow((test_dataset[(i + 400) * 8 + j][0]))
        title(f"{(i+400)*8+j}", fontsize=10)
        axis("off")

plt.subplots_adjust(wspace=0.01, hspace=0.01)

In [None]:
test_im = np.array(test_dataset[844][0]).T
test_embed = model(torch.from_numpy(test_im[np.newaxis]).to("cuda")).detach().cpu()
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
s = cos(test_embed.reshape(1, -1), torch.from_numpy(embeddings)).cpu().numpy()
inds = argsort(s)[::-1]
figure(figsize=[20, 20])
for i in range(8):
    for j in range(8):
        subplot(8, 8, i * 8 + j + 1)
        imshow(clip((test_dataset[inds[i * 8 + j]][0]), 0, 1))
        axis("off")
plt.subplots_adjust(wspace=0.01, hspace=0.01)

In [None]:
test_im = np.array(test_dataset[800][0]).T
test_embed = model(torch.from_numpy(test_im[np.newaxis]).to("cuda")).detach().cpu()
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
s = cos(test_embed.reshape(1, -1), torch.from_numpy(embeddings)).cpu().numpy()
inds = argsort(s)[::-1]
figure(figsize=[20, 20])
for i in range(8):
    for j in range(8):
        subplot(8, 8, i * 8 + j + 1)
        imshow(clip((test_dataset[inds[i * 8 + j]][0]), 0, 1))
        axis("off")
plt.subplots_adjust(wspace=0.01, hspace=0.01)

In [None]:
test_im = np.array(test_dataset[818][0]).T
test_embed = model(torch.from_numpy(test_im[np.newaxis]).to("cuda")).detach().cpu()
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
s = cos(test_embed.reshape(1, -1), torch.from_numpy(embeddings)).cpu().numpy()
inds = argsort(s)[::-1]
figure(figsize=[20, 20])
for i in range(8):
    for j in range(8):
        subplot(8, 8, i * 8 + j + 1)
        imshow(clip((test_dataset[inds[i * 8 + j]][0]), 0, 1))
        axis("off")
plt.subplots_adjust(wspace=0.01, hspace=0.01)

In [None]:
from astropy.io import fits

dr2_rgb = ToRGB()

d = fits.getdata("test_image.fits")
test_im = dr2_rgb(d[:, 56:-56, 56:-56])
imshow(test_im)

In [None]:
test_embed = (
    model(torch.from_numpy(test_im.T[np.newaxis].astype("float32")).to("cuda"))
    .detach()
    .cpu()
)
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
s = cos(test_embed.reshape(1, -1), torch.from_numpy(embeddings)).cpu().numpy()
inds = argsort(s)[::-1]
figure(figsize=[20, 20])
for i in range(8):
    for j in range(8):
        subplot(8, 8, i * 8 + j + 1)
        imshow((test_dataset[inds[i * 8 + j]][0]))
        axis("off")
plt.subplots_adjust(wspace=0.01, hspace=0.01)