In [1]:
import sys
sys.path.insert(0, '..')

import torch
import os
import wandb
import random
import numpy as np
import torch
from torch import nn
from torch.optim import Adam
from ignite.handlers.param_scheduler import create_lr_scheduler_with_warmup
from torch.utils.data import DataLoader
from datetime import datetime
import matplotlib.pyplot as plt
import pandas as pd
from scipy import stats
from tqdm import tqdm
from matplotlib import cm
import seaborn as sns
import matplotlib.lines as mlines
from sklearn.decomposition import PCA
from openTSNE import TSNE
from PIL import Image
import umap
import torch.nn.functional as F
from scipy.spatial.distance import cdist
import umap.plot

from core.final.dataset import PSMDataset
from core.final.model import GalSpecNet, MetaModel, Informer, AstroModel
from core.final.trainer import Trainer

In [2]:
def get_model(config):
    if config['mode'] == 'photo':
        model = Informer(config)
    elif config['mode'] == 'spectra':
        model = GalSpecNet(config)
    elif config['mode'] == 'meta':
        model = MetaModel(config)
    else:
        model = AstroModel(config)

    return model

In [3]:
def get_embs(dataloader):
    all_p_emb, all_s_emb, all_m_emb = [], [], []
    all_labels = []
    
    for photometry, photometry_mask, spectra, metadata, labels in tqdm(dataloader):
        photometry, photometry_mask = photometry.to(device), photometry_mask.to(device)
        spectra, metadata = spectra.to(device), metadata.to(device)
    
        with torch.no_grad():
            p_emb, s_emb, m_emb = model.get_embeddings(photometry, photometry_mask, spectra, metadata)
    
            all_p_emb.append(p_emb.cpu())
            all_s_emb.append(s_emb.cpu())
            all_m_emb.append(m_emb.cpu())
            all_labels.append(labels)
    
    all_p_emb = torch.vstack(all_p_emb)
    all_s_emb = torch.vstack(all_s_emb)
    all_m_emb = torch.vstack(all_m_emb)
    all_labels = torch.hstack(all_labels)

    return all_p_emb, all_s_emb, all_m_emb, all_labels

In [4]:
def get_centers(p_emb, s_emb, m_emb, train_labels):
    p_centers, s_centers, m_centers, all_centers = [], [], [], []

    for i in range(10):
        ind = train_labels == i
    
        p_center = p_emb[ind].mean(axis=0)
        p_center = p_center / p_center.norm()
        p_centers.append(p_center)
    
        s_center = s_emb[ind].mean(axis=0)
        s_center = s_center / s_center.norm()
        s_centers.append(s_center)
    
        m_center = m_emb[ind].mean(axis=0)
        m_center = m_center / m_center.norm()
        m_centers.append(m_center)

        all_emb = (p_emb + s_emb + m_emb) / 3
        all_center = all_emb[ind].mean(axis=0)
        all_center = all_center / all_center.norm()
        all_centers.append(all_center)
    
    p_centers = torch.stack(p_centers)
    s_centers = torch.stack(s_centers)
    m_centers = torch.stack(m_centers)
    all_centers = torch.stack(all_centers)

    return p_centers, s_centers, m_centers, all_centers

In [5]:
def get_pred_center(p_emb, s_emb, m_emb, p_centers, s_centers, m_centers, all_centers):
    p_pred, s_pred, m_pred, all_pred = [], [], [], []
    all_emb = (p_emb + s_emb + m_emb) / 3
    
    for i in range(len(p_emb)):
        p_pred.append(torch.argmax(p_emb[i] @ p_centers.T))
        s_pred.append(torch.argmax(s_emb[i] @ s_centers.T))
        m_pred.append(torch.argmax(m_emb[i] @ m_centers.T))
        all_pred.append(torch.argmax(all_emb[i] @ all_centers.T))
    
    p_pred = torch.stack(p_pred)
    s_pred = torch.stack(s_pred)
    m_pred = torch.stack(m_pred)
    all_pred = torch.stack(all_pred)

    return p_pred, s_pred, m_pred, all_pred

In [6]:
def get_pred_closest(test_p_emb, train_p_emb, test_s_emb, train_s_emb, test_m_emb, train_m_emb, train_labels):
    indeces = torch.argmax(test_p_emb @ train_p_emb.T, axis=1)
    p_pred = train_labels[indeces]

    indeces = torch.argmax(test_s_emb @ train_s_emb.T, axis=1)
    s_pred = train_labels[indeces]
    
    indeces = torch.argmax(test_m_emb @ train_m_emb.T, axis=1)
    m_pred = train_labels[indeces]

    train_all_emb = (train_p_emb + train_s_emb + train_m_emb) / 3
    test_all_emb = (test_p_emb + test_s_emb + test_m_emb) / 3
    indeces = torch.argmax(test_all_emb @ train_all_emb.T, axis=1)
    all_pred = train_labels[indeces]

    return p_pred, s_pred, m_pred, all_pred

In [None]:
def get_zero_shot_metrics(random_files):
    res_center = {'photometry': [], 'spectra': [], 'meta': [], 'all': []}
    res_closest = {'photometry': [], 'spectra': [], 'meta': [], 'all': []}
    
    for el in random_files:
        config['file'] = f'preprocessed_data/{el}/spectra_and_v'
        
        train_dataset = PSMDataset(config, split='train')
        train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'])
        
        test_dataset = PSMDataset(config, split='test')
        test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'])
        
        train_p_emb, train_s_emb, train_m_emb, train_labels = get_embs(train_dataloader)
        test_p_emb, test_s_emb, test_m_emb, test_labels = get_embs(test_dataloader)
        
        p_centers, s_centers, m_centers, all_centers = get_centers(train_p_emb, train_s_emb, train_m_emb, train_labels)
        p_pred, s_pred, m_pred, all_pred = get_pred_center(test_p_emb, test_s_emb, test_m_emb, p_centers, s_centers, m_centers, all_centers)
    
        p_acc = (p_pred == test_labels).sum().item() / len(test_labels)
        s_acc = (s_pred == test_labels).sum().item() / len(test_labels)
        m_acc = (m_pred == test_labels).sum().item() / len(test_labels)
        all_acc = (all_pred == test_labels).sum().item() / len(test_labels)
        
        print(f'Center {el}')
        res_center['photometry'].append(p_acc)
        res_center['spectra'].append(s_acc)
        res_center['meta'].append(m_acc)
        res_center['all'].append(all_acc)
        
        p_pred, s_pred, m_pred, all_pred = get_pred_closest(test_p_emb, train_p_emb, test_s_emb, train_s_emb, test_m_emb, train_m_emb, train_labels)
    
        p_acc = (p_pred == test_labels).sum().item() / len(test_labels)
        s_acc = (s_pred == test_labels).sum().item() / len(test_labels)
        m_acc = (m_pred == test_labels).sum().item() / len(test_labels)
        all_acc = (all_pred == test_labels).sum().item() / len(test_labels)
    
        print(f'Closest {el}')
        res_closest['photometry'].append(p_acc)
        res_closest['spectra'].append(s_acc)
        res_closest['meta'].append(m_acc)
        res_closest['all'].append(all_acc)

    return res_center, res_closest

In [8]:
def print_metrics(files, center, closest):
    for i in range(5):
        print(f"{files[i]}\tPhotometry\tSpectra\t\tMeta\t\tAll")
        print(f"Center:\t\t {round(center['photometry'][i], 4)}\t\t{round(center['spectra'][i], 4)}\t\t{round(center['meta'][i], 4)}\t\t{round(center['all'][i], 4)}")
        print(f"Closest:\t {round(closest['photometry'][i], 4)}\t\t{round(closest['spectra'][i], 4)}\t\t{round(closest['meta'][i], 4)}\t\t{round(closest['all'][i], 4)}\n")

In [9]:
def print_metrics_avg(files, center, closest):
    avg_center = {key: (np.mean(val) * 100, np.std(val) * 100) for key, val in center.items()}
    avg_closest = {key: (np.mean(val) * 100, np.std(val) * 100) for key, val in closest.items()}
    
    for key in avg_center:
        print(key)
        print(f'Center:\t\t {round(avg_center[key][0], 3)}\t ± {round(avg_center[key][1], 3)}')
        print(f'Closest:\t {round(avg_closest[key][0], 3)}\t ± {round(avg_closest[key][1], 3)}\n')

In [11]:
run_id = 'MeriDK/AstroCLIPResults3/2wz4ysvn'
api = wandb.Api()
run = api.run(run_id)
config = run.config
config['use_wandb'] = False

model = get_model(config)
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()

weights_path = os.path.join(config['weights_path'] + '-' + run_id.split('/')[-1], f'weights-best.pth')
model.load_state_dict(torch.load(weights_path, weights_only=False))

In [12]:
files_full = ['full_lb', 'full_lb0', 'full_lb12', 'full_lb123', 'full_lb66']
center_full, closest_full = get_zero_shot_metrics(files_full)

files50 = ['sub50_lb', 'sub50_lb0', 'sub50_lb12', 'sub50_lb123', 'sub50_lb66']
center50, closest50 = get_zero_shot_metrics(files50)

files25 = ['sub25_lb', 'sub25_lb0', 'sub25_lb12', 'sub25_lb123', 'sub25_lb66']
center25, closest25 = get_zero_shot_metrics(files25)

files10 = ['sub10_lb', 'sub10_lb0', 'sub10_lb12', 'sub10_lb123', 'sub10_lb66']
center10, closest10 = get_zero_shot_metrics(files10)

In [13]:
print_metrics_avg(files_full, center_full, closest_full)
print_metrics_avg(files50, center50, closest50)
print_metrics_avg(files25, center25, closest25)
print_metrics_avg(files10, center10, closest10)

In [43]:
train_dataset = PSMDataset(config, split='train')
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)

test_dataset = PSMDataset(config, split='test')
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=True)

In [44]:
train_p_emb, train_s_emb, train_m_emb, train_labels = get_embs(train_dataloader)
test_p_emb, test_s_emb, test_m_emb, test_labels = get_embs(test_dataloader)

In [52]:
p_centers, s_centers, m_centers, all_centers = get_centers(train_p_emb, train_s_emb, train_m_emb, train_labels)
p_pred, s_pred, m_pred, all_pred = get_pred_center(test_p_emb, test_s_emb, test_m_emb, p_centers, s_centers, m_centers, all_centers)

print('Photometry Acc', (p_pred == test_labels).sum().item() / len(test_labels))
print('Spectra Acc   ', (s_pred == test_labels).sum().item() / len(test_labels))
print('Meta Acc      ', (m_pred == test_labels).sum().item() / len(test_labels))
print('All Acc       ', (all_pred == test_labels).sum().item() / len(test_labels))

In [150]:
"""
Data Type     No CLIP           CLIP
Photometry    84.642 ± 6.317    91.468 ± 0.446
Spectra       76.278 ± 0.931    77.396 ± 0.614
Metadata      85.623 ± 0.628    85.855 ± 0.856
All           94.065 ± 0.390    94.153 ± 0.577
"""

In [51]:
p_pred, s_pred, m_pred, all_pred = get_pred_closest(test_p_emb, train_p_emb, test_s_emb, train_s_emb, test_m_emb, train_m_emb, train_labels)

print('Photometry Acc', (p_pred == test_labels).sum().item() / len(test_labels))
print('Spectra Acc   ', (s_pred == test_labels).sum().item() / len(test_labels))
print('Meta Acc      ', (m_pred == test_labels).sum().item() / len(test_labels))
print('All Acc       ', (all_pred == test_labels).sum().item() / len(test_labels))

In [19]:
config['file'] = 'preprocessed_data/sub50_lb/spectra_and_v'

train_dataset = PSMDataset(config, split='train')
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)

test_dataset = PSMDataset(config, split='test')
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=True)

train_p_emb, train_s_emb, train_m_emb, train_labels = get_embs(train_dataloader)
test_p_emb, test_s_emb, test_m_emb, test_labels = get_embs(test_dataloader)

p_centers, s_centers, m_centers = get_centers(train_p_emb, train_s_emb, train_m_emb)
p_pred, s_pred, m_pred = get_pred_center(test_p_emb, test_s_emb, test_m_emb, p_centers, s_centers, m_centers)
pred, _ = torch.mode(torch.stack([p_pred, s_pred, m_pred]), axis=0)

print('Center 50% Data')
print('Photometry Acc', (p_pred == test_labels).sum().item() / len(test_labels))
print('Spectra Acc   ', (s_pred == test_labels).sum().item() / len(test_labels))
print('Meta Acc      ', (m_pred == test_labels).sum().item() / len(test_labels))
print('All Acc       ', (pred == test_labels).sum().item() / len(test_labels))

p_pred, s_pred, m_pred = get_pred_closest(test_p_emb, train_p_emb, test_s_emb, train_s_emb, test_m_emb, train_m_emb, train_labels)
pred, _ = torch.mode(torch.stack([p_pred, s_pred, m_pred]), axis=0)

print('Closest 50% Data')
print('Photometry Acc', (p_pred == test_labels).sum().item() / len(test_labels))
print('Spectra Acc   ', (s_pred == test_labels).sum().item() / len(test_labels))
print('Meta Acc      ', (m_pred == test_labels).sum().item() / len(test_labels))
print('All Acc       ', (pred == test_labels).sum().item() / len(test_labels))

In [20]:
config['file'] = 'preprocessed_data/sub25_lb/spectra_and_v'

train_dataset = PSMDataset(config, split='train')
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)

test_dataset = PSMDataset(config, split='test')
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=True)

train_p_emb, train_s_emb, train_m_emb, train_labels = get_embs(train_dataloader)
test_p_emb, test_s_emb, test_m_emb, test_labels = get_embs(test_dataloader)

p_centers, s_centers, m_centers = get_centers(train_p_emb, train_s_emb, train_m_emb)
p_pred, s_pred, m_pred = get_pred_center(test_p_emb, test_s_emb, test_m_emb, p_centers, s_centers, m_centers)
pred, _ = torch.mode(torch.stack([p_pred, s_pred, m_pred]), axis=0)

print('Center 25% Data')
print('Photometry Acc', (p_pred == test_labels).sum().item() / len(test_labels))
print('Spectra Acc   ', (s_pred == test_labels).sum().item() / len(test_labels))
print('Meta Acc      ', (m_pred == test_labels).sum().item() / len(test_labels))
print('All Acc       ', (pred == test_labels).sum().item() / len(test_labels))

p_pred, s_pred, m_pred = get_pred_closest(test_p_emb, train_p_emb, test_s_emb, train_s_emb, test_m_emb, train_m_emb, train_labels)
pred, _ = torch.mode(torch.stack([p_pred, s_pred, m_pred]), axis=0)

print('Closest 25% Data')
print('Photometry Acc', (p_pred == test_labels).sum().item() / len(test_labels))
print('Spectra Acc   ', (s_pred == test_labels).sum().item() / len(test_labels))
print('Meta Acc      ', (m_pred == test_labels).sum().item() / len(test_labels))
print('All Acc       ', (pred == test_labels).sum().item() / len(test_labels))

In [21]:
config['file'] = 'preprocessed_data/sub10_lb/spectra_and_v'

train_dataset = PSMDataset(config, split='train')
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)

test_dataset = PSMDataset(config, split='test')
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=True)

train_p_emb, train_s_emb, train_m_emb, train_labels = get_embs(train_dataloader)
test_p_emb, test_s_emb, test_m_emb, test_labels = get_embs(test_dataloader)

p_centers, s_centers, m_centers = get_centers(train_p_emb, train_s_emb, train_m_emb)
p_pred, s_pred, m_pred = get_pred_center(test_p_emb, test_s_emb, test_m_emb, p_centers, s_centers, m_centers)
pred, _ = torch.mode(torch.stack([p_pred, s_pred, m_pred]), axis=0)

print('Center 10% Data')
print('Photometry Acc', (p_pred == test_labels).sum().item() / len(test_labels))
print('Spectra Acc   ', (s_pred == test_labels).sum().item() / len(test_labels))
print('Meta Acc      ', (m_pred == test_labels).sum().item() / len(test_labels))
print('All Acc       ', (pred == test_labels).sum().item() / len(test_labels))

p_pred, s_pred, m_pred = get_pred_closest(test_p_emb, train_p_emb, test_s_emb, train_s_emb, test_m_emb, train_m_emb, train_labels)
pred, _ = torch.mode(torch.stack([p_pred, s_pred, m_pred]), axis=0)

print('Closest 10% Data')
print('Photometry Acc', (p_pred == test_labels).sum().item() / len(test_labels))
print('Spectra Acc   ', (s_pred == test_labels).sum().item() / len(test_labels))
print('Meta Acc      ', (m_pred == test_labels).sum().item() / len(test_labels))
print('All Acc       ', (pred == test_labels).sum().item() / len(test_labels))

## Results Across Diff Random Seeds

In [None]:
files_full = ['full_lb', 'full_lb0', 'full_lb12', 'full_lb123', 'full_lb66']
center_full, closest_full = get_zero_shot_metrics(files_full)

In [None]:
files50 = ['sub50_lb', 'sub50_lb0', 'sub50_lb12', 'sub50_lb123', 'sub50_lb66']
center50, closest50 = get_zero_shot_metrics(files50)

In [None]:
files25 = ['sub25_lb', 'sub25_lb0', 'sub25_lb12', 'sub25_lb123', 'sub25_lb66']
center25, closest25 = get_zero_shot_metrics(files25)

In [None]:
files10 = ['sub10_lb', 'sub10_lb0', 'sub10_lb12', 'sub10_lb123', 'sub10_lb66']
center10, closest10 = get_zero_shot_metrics(files10)

In [None]:
print_metrics_avg(files10, center10, closest10)

In [100]:
print_metrics_avg(files10, center10, closest10)

In [101]:
files10 = ['sub10_lb', 'sub10_lb0', 'sub10_lb12', 'sub10_lb123', 'sub10_lb66']
center10, closest10 = get_zero_shot_metrics(files10)
print_metrics_avg(files10, center10, closest10)

In [72]:
print_metrics_avg(files25, center25, closest25)

In [73]:
print_metrics_avg(files50, center50, closest50)

In [74]:
print_metrics_avg(files_full, center_full, closest_full)

# Single example prediction

In [76]:
config['file'] = 'preprocessed_data/sub10_lb/spectra_and_v'
train_dataset = PSMDataset(config, split='train')
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
train_p_emb, train_s_emb, train_m_emb, train_labels = get_embs(train_dataloader)

In [105]:
one_p_emb = []
one_s_emb = []
one_m_emb = []

for i in range(10):
    one_p_emb.append(train_p_emb[train_labels == i][2])
    one_s_emb.append(train_s_emb[train_labels == i][2])
    one_m_emb.append(train_m_emb[train_labels == i][2])

one_p_emb = torch.vstack(one_p_emb)
one_s_emb = torch.vstack(one_s_emb)
one_m_emb = torch.vstack(one_m_emb)
one_all_emb = (one_p_emb + one_s_emb + one_m_emb) / 3

In [106]:
test_dataset = PSMDataset(config, split='test')
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=True)
test_p_emb, test_s_emb, test_m_emb, test_labels = get_embs(test_dataloader)

In [107]:
p_pred, s_pred, m_pred, all_pred = get_pred_center(test_p_emb, test_s_emb, test_m_emb, one_p_emb, one_s_emb, one_m_emb, one_all_emb)

In [108]:
p_acc = (p_pred == test_labels).sum().item() / len(test_labels)
s_acc = (s_pred == test_labels).sum().item() / len(test_labels)
m_acc = (m_pred == test_labels).sum().item() / len(test_labels)
all_acc = (all_pred == test_labels).sum().item() / len(test_labels)

print(f'One \t photo {p_acc} spectra {s_acc} meta {m_acc} all {all_acc}')

In [95]:
config['file'] = 'preprocessed_data/sub25_lb/spectra_and_v'
test_dataset = PSMDataset(config, split='test')
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=True)

test_p_emb, test_s_emb, test_m_emb, test_labels = get_embs(test_dataloader)
p_pred, s_pred, m_pred, all_pred = get_pred_center(test_p_emb, test_s_emb, test_m_emb, one_p_emb, one_s_emb, one_m_emb, one_all_emb)

p_acc = (p_pred == test_labels).sum().item() / len(test_labels)
s_acc = (s_pred == test_labels).sum().item() / len(test_labels)
m_acc = (m_pred == test_labels).sum().item() / len(test_labels)
all_acc = (all_pred == test_labels).sum().item() / len(test_labels)

print(f'One \t photo {p_acc} spectra {s_acc} meta {m_acc} all {all_acc}')

In [96]:
config['file'] = 'preprocessed_data/sub50_lb/spectra_and_v'
test_dataset = PSMDataset(config, split='test')
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=True)

test_p_emb, test_s_emb, test_m_emb, test_labels = get_embs(test_dataloader)
p_pred, s_pred, m_pred, all_pred = get_pred_center(test_p_emb, test_s_emb, test_m_emb, one_p_emb, one_s_emb, one_m_emb, one_all_emb)

p_acc = (p_pred == test_labels).sum().item() / len(test_labels)
s_acc = (s_pred == test_labels).sum().item() / len(test_labels)
m_acc = (m_pred == test_labels).sum().item() / len(test_labels)
all_acc = (all_pred == test_labels).sum().item() / len(test_labels)

print(f'One \t photo {p_acc} spectra {s_acc} meta {m_acc} all {all_acc}')

In [97]:
config['file'] = 'preprocessed_data/full_lb/spectra_and_v'
test_dataset = PSMDataset(config, split='test')
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=True)

test_p_emb, test_s_emb, test_m_emb, test_labels = get_embs(test_dataloader)
p_pred, s_pred, m_pred, all_pred = get_pred_center(test_p_emb, test_s_emb, test_m_emb, one_p_emb, one_s_emb, one_m_emb, one_all_emb)

p_acc = (p_pred == test_labels).sum().item() / len(test_labels)
s_acc = (s_pred == test_labels).sum().item() / len(test_labels)
m_acc = (m_pred == test_labels).sum().item() / len(test_labels)
all_acc = (all_pred == test_labels).sum().item() / len(test_labels)

print(f'One \t photo {p_acc} spectra {s_acc} meta {m_acc} all {all_acc}')

In [27]:
val_m_emb[0] @ m_centers.T

In [12]:

for i in range(10):
    print(i, val_m_emb[0] @ m_centers[i].cpu())

print(val_labels[0])

In [10]:
m_centers