In [67]:
import sys
sys.path.insert(0, '..')

import torch
import os
import wandb
import random
import numpy as np
import torch
from torch import nn
from torch.optim import Adam
from ignite.handlers.param_scheduler import create_lr_scheduler_with_warmup
from torch.utils.data import DataLoader
from datetime import datetime
import matplotlib.pyplot as plt
import pandas as pd
from scipy import stats
from tqdm import tqdm
from matplotlib import cm
import seaborn as sns
import matplotlib.lines as mlines
from sklearn.decomposition import PCA
from openTSNE import TSNE
from PIL import Image
import umap
import torch.nn.functional as F
from scipy.spatial.distance import cdist
import umap.plot
import math

from core.final.dataset import PSMDataset
from core.final.model import GalSpecNet, MetaModel, Informer, AstroModel
from core.final.trainer import Trainer

In [2]:
def get_model(config):
    if config['mode'] == 'photo':
        model = Informer(config)
    elif config['mode'] == 'spectra':
        model = GalSpecNet(config)
    elif config['mode'] == 'meta':
        model = MetaModel(config)
    else:
        model = AstroModel(config)

    return model

def get_embs(dataloader):
    all_p_emb, all_s_emb, all_m_emb = [], [], []
    all_labels = []
    
    for photometry, photometry_mask, spectra, metadata, labels in tqdm(dataloader):
        photometry, photometry_mask = photometry.to(device), photometry_mask.to(device)
        spectra, metadata = spectra.to(device), metadata.to(device)
    
        with torch.no_grad():
            p_emb, s_emb, m_emb = model.get_embeddings(photometry, photometry_mask, spectra, metadata)
    
            all_p_emb.append(p_emb.cpu())
            all_s_emb.append(s_emb.cpu())
            all_m_emb.append(m_emb.cpu())
            all_labels.append(labels)
    
    all_p_emb = torch.vstack(all_p_emb)
    all_s_emb = torch.vstack(all_s_emb)
    all_m_emb = torch.vstack(all_m_emb)
    all_labels = torch.hstack(all_labels)

    return all_p_emb, all_s_emb, all_m_emb, all_labels

def get_centers(p_emb, s_emb, m_emb, train_labels):
    p_centers, s_centers, m_centers, all_centers = [], [], [], []

    for i in range(10):
        ind = train_labels == i
    
        p_center = p_emb[ind].mean(axis=0)
        p_center = p_center / p_center.norm()
        p_centers.append(p_center)
    
        s_center = s_emb[ind].mean(axis=0)
        s_center = s_center / s_center.norm()
        s_centers.append(s_center)
    
        m_center = m_emb[ind].mean(axis=0)
        m_center = m_center / m_center.norm()
        m_centers.append(m_center)

        all_emb = (p_emb + s_emb + m_emb) / 3
        all_center = all_emb[ind].mean(axis=0)
        all_center = all_center / all_center.norm()
        all_centers.append(all_center)
    
    p_centers = torch.stack(p_centers)
    s_centers = torch.stack(s_centers)
    m_centers = torch.stack(m_centers)
    all_centers = torch.stack(all_centers)

    return p_centers, s_centers, m_centers, all_centers

In [3]:
run_id = 'MeriDK/AstroCLIPResults3/2wz4ysvn'
api = wandb.Api()
run = api.run(run_id)
config = run.config
config['use_wandb'] = False

model = get_model(config)
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()

weights_path = os.path.join(config['weights_path'] + '-' + run_id.split('/')[-1], f'weights-best.pth')
model.load_state_dict(torch.load(weights_path, weights_only=False))

In [49]:
train_dataset = PSMDataset(config, split='train')
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=False)

test_dataset = PSMDataset(config, split='test')
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False)

train_p_emb, train_s_emb, train_m_emb, train_labels = get_embs(train_dataloader)
test_p_emb, test_s_emb, test_m_emb, test_labels = get_embs(test_dataloader)

p_centers, s_centers, m_centers, all_centers = get_centers(train_p_emb, train_s_emb, train_m_emb, train_labels)

# Photometry

In [144]:
distances_p = torch.zeros(test_labels.shape)

for i in range(10):
    ind = test_labels == i
    dist = 1 - test_p_emb[ind] @ p_centers[i]
    distances_p[ind] = dist

In [145]:
plt.hist(distances_p)

In [151]:
topk_values_p, topk_indices_p = torch.topk(distances_p, k=20, largest=True)
pred_labels_p = torch.argmax(test_p_emb[topk_indices_p] @ p_centers.T, axis=1)
true_labels_p = test_labels[topk_indices_p]

topk_values_p, topk_indices_p, pred_labels_p, true_labels_p

# ADD PERIOD

In [165]:
test_dataset.phased = True

# Define number of rows and columns (3 columns)
num_cols = 3
num_rows = math.ceil(len(topk_indices_p) / num_cols)

# Create subplots with the calculated number of rows and 3 columns
fig, axs = plt.subplots(num_rows, num_cols, figsize=(15, 5 * num_rows), constrained_layout=True)

# Flatten the axs array if there are multiple rows
axs = axs.flatten()

# Loop over top-k indices to plot each object in its own subplot
for i, idx in enumerate(topk_indices_p):
    # Fetch the photometry data for the object from the test dataset
    photometry, photometry_mask, spectra, metadata, label = test_dataset[idx.item()]
    
    # Plot the photometry data for this object in its respective subplot
    axs[i].plot(photometry[:, 0], photometry[:, 1], '.')
    
    # Add title and labels to each subplot
    axs[i].set_title(f'Label: {test_dataset.id2target[true_labels_p[i].item()]}\nPred: {test_dataset.id2target[pred_labels_p[i].item()]}', fontsize=10)

# Remove any unused subplots
for j in range(i + 1, len(axs)):
    fig.delaxes(axs[j])

# Display the plot
plt.show()

# Spectra

In [74]:
distances_s = torch.zeros(test_labels.shape)

for i in range(10):
    ind = test_labels == i
    dist = 1 - test_s_emb[ind] @ s_centers[i]
    distances_s[ind] = dist

In [75]:
plt.hist(distances_s)

In [77]:
topk_values_s, topk_indices_s = torch.topk(distances_s, k=20, largest=True)
pred_labels_s = torch.argmax(test_s_emb[topk_indices_s] @ s_centers.T, axis=1)
true_labels_s = test_labels[topk_indices_s]

topk_values_s, topk_indices_s, pred_labels_s, true_labels_s

In [86]:
# Define number of rows and columns (3 columns)
num_cols = 3
num_rows = math.ceil(len(topk_indices_p) / num_cols)

# Create subplots with the calculated number of rows and 3 columns
fig, axs = plt.subplots(num_rows, num_cols, figsize=(15, 5 * num_rows), constrained_layout=True)

# Flatten the axs array if there are multiple rows
axs = axs.flatten()

# Loop over top-k indices to plot each object in its own subplot
for i, idx in enumerate(topk_indices_s):
    # Fetch the photometry data for the object from the test dataset
    photometry, photometry_mask, spectra, metadata, label = test_dataset[idx.item()]
    
    # Plot the photometry data for this object in its respective subplot
    axs[i].plot(spectra[0, :])
    
    # Add title and labels to each subplot
    axs[i].set_title(f'Label: {test_dataset.id2target[true_labels_s[i].item()]}\nPred: {test_dataset.id2target[pred_labels_s[i].item()]}', fontsize=10)

# Remove any unused subplots
for j in range(i + 1, len(axs)):
    fig.delaxes(axs[j])

# Display the plot
plt.show()

# Metadata

In [87]:
distances_m = torch.zeros(test_labels.shape)

for i in range(10):
    ind = test_labels == i
    dist = 1 - test_m_emb[ind] @ m_centers[i]
    distances_m[ind] = dist

plt.hist(distances_m)

In [88]:
topk_values_m, topk_indices_m = torch.topk(distances_m, k=20, largest=True)
pred_labels_m = torch.argmax(test_m_emb[topk_indices_m] @ m_centers.T, axis=1)
true_labels_m = test_labels[topk_indices_m]

topk_values_m, topk_indices_m, pred_labels_m, true_labels_m

In [89]:
topk_indices_p, topk_indices_s, topk_indices_m

In [100]:
ps = topk_indices_p[torch.isin(topk_indices_p, topk_indices_s)]
ps

In [101]:
sm = topk_indices_s[torch.isin(topk_indices_s, topk_indices_m)]
sm

In [102]:
mp = topk_indices_m[torch.isin(topk_indices_m, topk_indices_p)]
mp

In [106]:
w = ps[torch.isin(ps, sm)]
w

In [108]:
pred_labels_p = torch.argmax(test_p_emb[w] @ p_centers.T, axis=1)
pred_labels_s = torch.argmax(test_s_emb[w] @ s_centers.T, axis=1)
pred_labels_m = torch.argmax(test_m_emb[w] @ m_centers.T, axis=1)
true_labels = test_labels[w]

pred_labels_p, pred_labels_s, pred_labels_m, true_labels

# Plot misclassifications on UMAP

In [116]:
class_freq = test_dataset.df['target'].value_counts()
sorted_classes = class_freq.index

id2target = test_dataset.id2target
target2id = test_dataset.target2id

palette = sns.color_palette("tab20", len(id2target))
marker_size = 12

In [136]:
def plot_one_embs_outliers(embeddings, targets, outliers=None):
    plt.figure(figsize=(10, 8))
    
    for class_name in sorted_classes:
        class_id = target2id[class_name]
        class_mask = targets == class_id
        plt.scatter(embeddings[class_mask, 0], embeddings[class_mask, 1], 
                    color=palette[class_id], label=class_name, alpha=0.7, s=marker_size)

    if outliers is not None:
        plt.scatter(embeddings[outliers, 0], embeddings[outliers, 1], 
                    facecolor='none', edgecolor='black', s=marker_size, label='Outliers')

    handles_colors = [mlines.Line2D([], [], color=palette[target2id[class_name]], marker='o', linestyle='None', markersize=8, 
                      label=class_name) for class_name in sorted_classes]
    plt.legend(handles=handles_colors, loc='upper right', bbox_to_anchor=(1.15, 1), fontsize=10, title="Classes")
    plt.xlabel("UMAP Dimension 1")
    plt.ylabel("UMAP Dimension 2")
    plt.show()

In [138]:
umap_model = umap.UMAP(n_neighbors=22, n_components=2, metric='cosine', n_jobs=1, random_state=42)
p_emb_umap = umap_model.fit_transform(test_p_emb)
plot_one_embs_outliers(p_emb_umap, test_labels, topk_indices_p)

In [139]:
umap_model = umap.UMAP(n_neighbors=22, n_components=2, metric='cosine', n_jobs=1, random_state=42)
s_emb_umap = umap_model.fit_transform(test_s_emb)
plot_one_embs_outliers(s_emb_umap, test_labels, topk_indices_s)

In [140]:
umap_model = umap.UMAP(n_neighbors=22, n_components=2, metric='cosine', n_jobs=1, random_state=42)
m_emb_umap = umap_model.fit_transform(test_m_emb)
plot_one_embs_outliers(m_emb_umap, test_labels, topk_indices_m)

# k = 50

In [153]:
topk_values_p, topk_indices_p = torch.topk(distances_p, k=50, largest=True)
pred_labels_p = torch.argmax(test_p_emb[topk_indices_p] @ p_centers.T, axis=1)
true_labels_p = test_labels[topk_indices_p]

topk_values_p, topk_indices_p, pred_labels_p, true_labels_p

In [156]:
topk_indices_p[pred_labels_p == true_labels_p]

In [163]:
test_p_emb[530] @ p_centers.T, test_p_emb[713] @ p_centers.T, test_p_emb[962] @ p_centers.T, test_p_emb[426] @ p_centers.T

In [164]:
test_labels[topk_indices_p[pred_labels_p == true_labels_p]]