In [1]:
import sys
sys.path.insert(0, '..')

import torch
import os
import wandb
import random
import numpy as np
import torch
from torch import nn
from torch.optim import Adam
from ignite.handlers.param_scheduler import create_lr_scheduler_with_warmup
from torch.utils.data import DataLoader
from datetime import datetime
import matplotlib.pyplot as plt
import pandas as pd
from scipy import stats
from tqdm import tqdm
from matplotlib import cm
import seaborn as sns
import matplotlib.lines as mlines
from sklearn.decomposition import PCA
from openTSNE import TSNE
from PIL import Image
import umap
import torch.nn.functional as F
from scipy.spatial.distance import cdist
import umap.plot

from core.final.dataset import PSMDataset
from core.final.model import GalSpecNet, MetaModel, Informer, AstroModel
from core.final.trainer import Trainer

  from torch.distributed.optim import ZeroRedundancyOptimizer


In [2]:
def get_model(config):
    if config['mode'] == 'photo':
        model = Informer(config)
    elif config['mode'] == 'spectra':
        model = GalSpecNet(config)
    elif config['mode'] == 'meta':
        model = MetaModel(config)
    else:
        model = AstroModel(config)

    return model

In [3]:
def get_embs(dataloader):
    all_p_emb, all_s_emb, all_m_emb = [], [], []
    all_labels = []
    
    for photometry, photometry_mask, spectra, metadata, labels in tqdm(dataloader):
        photometry, photometry_mask = photometry.to(device), photometry_mask.to(device)
        spectra, metadata = spectra.to(device), metadata.to(device)
    
        with torch.no_grad():
            p_emb, s_emb, m_emb = model.get_embeddings(photometry, photometry_mask, spectra, metadata)
    
            all_p_emb.append(p_emb.cpu())
            all_s_emb.append(s_emb.cpu())
            all_m_emb.append(m_emb.cpu())
            all_labels.append(labels)
    
    all_p_emb = torch.vstack(all_p_emb)
    all_s_emb = torch.vstack(all_s_emb)
    all_m_emb = torch.vstack(all_m_emb)
    all_labels = torch.hstack(all_labels)

    return all_p_emb, all_s_emb, all_m_emb, all_labels

In [4]:
def get_centers(p_emb, s_emb, m_emb, train_labels):
    p_centers, s_centers, m_centers, all_centers = [], [], [], []

    for i in range(10):
        ind = train_labels == i
    
        p_center = p_emb[ind].mean(axis=0)
        p_center = p_center / p_center.norm()
        p_centers.append(p_center)
    
        s_center = s_emb[ind].mean(axis=0)
        s_center = s_center / s_center.norm()
        s_centers.append(s_center)
    
        m_center = m_emb[ind].mean(axis=0)
        m_center = m_center / m_center.norm()
        m_centers.append(m_center)

        all_emb = (p_emb + s_emb + m_emb) / 3
        all_center = all_emb[ind].mean(axis=0)
        all_center = all_center / all_center.norm()
        all_centers.append(all_center)
    
    p_centers = torch.stack(p_centers)
    s_centers = torch.stack(s_centers)
    m_centers = torch.stack(m_centers)
    all_centers = torch.stack(all_centers)

    return p_centers, s_centers, m_centers, all_centers

In [5]:
def get_pred_center(p_emb, s_emb, m_emb, p_centers, s_centers, m_centers, all_centers):
    p_pred, s_pred, m_pred, all_pred = [], [], [], []
    all_emb = (p_emb + s_emb + m_emb) / 3
    
    for i in range(len(p_emb)):
        p_pred.append(torch.argmax(p_emb[i] @ p_centers.T))
        s_pred.append(torch.argmax(s_emb[i] @ s_centers.T))
        m_pred.append(torch.argmax(m_emb[i] @ m_centers.T))
        all_pred.append(torch.argmax(all_emb[i] @ all_centers.T))
    
    p_pred = torch.stack(p_pred)
    s_pred = torch.stack(s_pred)
    m_pred = torch.stack(m_pred)
    all_pred = torch.stack(all_pred)

    return p_pred, s_pred, m_pred, all_pred

In [6]:
def get_pred_closest(test_p_emb, train_p_emb, test_s_emb, train_s_emb, test_m_emb, train_m_emb, train_labels):
    indeces = torch.argmax(test_p_emb @ train_p_emb.T, axis=1)
    p_pred = train_labels[indeces]

    indeces = torch.argmax(test_s_emb @ train_s_emb.T, axis=1)
    s_pred = train_labels[indeces]
    
    indeces = torch.argmax(test_m_emb @ train_m_emb.T, axis=1)
    m_pred = train_labels[indeces]

    train_all_emb = (train_p_emb + train_s_emb + train_m_emb) / 3
    test_all_emb = (test_p_emb + test_s_emb + test_m_emb) / 3
    indeces = torch.argmax(test_all_emb @ train_all_emb.T, axis=1)
    all_pred = train_labels[indeces]

    return p_pred, s_pred, m_pred, all_pred

In [None]:
def get_zero_shot_metrics(random_files):
    res_center = {'photometry': [], 'spectra': [], 'meta': [], 'all': []}
    res_closest = {'photometry': [], 'spectra': [], 'meta': [], 'all': []}
    
    for el in random_files:
        config['file'] = f'preprocessed_data/{el}/spectra_and_v'
        
        train_dataset = PSMDataset(config, split='train')
        train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'])
        
        test_dataset = PSMDataset(config, split='test')
        test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'])
        
        train_p_emb, train_s_emb, train_m_emb, train_labels = get_embs(train_dataloader)
        test_p_emb, test_s_emb, test_m_emb, test_labels = get_embs(test_dataloader)
        
        p_centers, s_centers, m_centers, all_centers = get_centers(train_p_emb, train_s_emb, train_m_emb, train_labels)
        p_pred, s_pred, m_pred, all_pred = get_pred_center(test_p_emb, test_s_emb, test_m_emb, p_centers, s_centers, m_centers, all_centers)
    
        p_acc = (p_pred == test_labels).sum().item() / len(test_labels)
        s_acc = (s_pred == test_labels).sum().item() / len(test_labels)
        m_acc = (m_pred == test_labels).sum().item() / len(test_labels)
        all_acc = (all_pred == test_labels).sum().item() / len(test_labels)
        
        print(f'Center {el}')
        res_center['photometry'].append(p_acc)
        res_center['spectra'].append(s_acc)
        res_center['meta'].append(m_acc)
        res_center['all'].append(all_acc)
        
        p_pred, s_pred, m_pred, all_pred = get_pred_closest(test_p_emb, train_p_emb, test_s_emb, train_s_emb, test_m_emb, train_m_emb, train_labels)
    
        p_acc = (p_pred == test_labels).sum().item() / len(test_labels)
        s_acc = (s_pred == test_labels).sum().item() / len(test_labels)
        m_acc = (m_pred == test_labels).sum().item() / len(test_labels)
        all_acc = (all_pred == test_labels).sum().item() / len(test_labels)
    
        print(f'Closest {el}')
        res_closest['photometry'].append(p_acc)
        res_closest['spectra'].append(s_acc)
        res_closest['meta'].append(m_acc)
        res_closest['all'].append(all_acc)

    return res_center, res_closest

In [8]:
def print_metrics(files, center, closest):
    for i in range(5):
        print(f"{files[i]}\tPhotometry\tSpectra\t\tMeta\t\tAll")
        print(f"Center:\t\t {round(center['photometry'][i], 4)}\t\t{round(center['spectra'][i], 4)}\t\t{round(center['meta'][i], 4)}\t\t{round(center['all'][i], 4)}")
        print(f"Closest:\t {round(closest['photometry'][i], 4)}\t\t{round(closest['spectra'][i], 4)}\t\t{round(closest['meta'][i], 4)}\t\t{round(closest['all'][i], 4)}\n")

In [9]:
def print_metrics_avg(files, center, closest):
    avg_center = {key: (np.mean(val) * 100, np.std(val) * 100) for key, val in center.items()}
    avg_closest = {key: (np.mean(val) * 100, np.std(val) * 100) for key, val in closest.items()}
    
    for key in avg_center:
        print(key)
        print(f'Center:\t\t {round(avg_center[key][0], 3)}\t ± {round(avg_center[key][1], 3)}')
        print(f'Closest:\t {round(avg_closest[key][0], 3)}\t ± {round(avg_closest[key][1], 3)}\n')

In [11]:
run_id = 'MeriDK/AstroCLIPResults3/2wz4ysvn'
api = wandb.Api()
run = api.run(run_id)
config = run.config
config['use_wandb'] = False

model = get_model(config)
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()

weights_path = os.path.join(config['weights_path'] + '-' + run_id.split('/')[-1], f'weights-best.pth')
model.load_state_dict(torch.load(weights_path, weights_only=False))

<All keys matched successfully>

In [12]:
files_full = ['full_lb', 'full_lb0', 'full_lb12', 'full_lb123', 'full_lb66']
center_full, closest_full = get_zero_shot_metrics(files_full)

files50 = ['sub50_lb', 'sub50_lb0', 'sub50_lb12', 'sub50_lb123', 'sub50_lb66']
center50, closest50 = get_zero_shot_metrics(files50)

files25 = ['sub25_lb', 'sub25_lb0', 'sub25_lb12', 'sub25_lb123', 'sub25_lb66']
center25, closest25 = get_zero_shot_metrics(files25)

files10 = ['sub10_lb', 'sub10_lb0', 'sub10_lb12', 'sub10_lb123', 'sub10_lb66']
center10, closest10 = get_zero_shot_metrics(files10)

100%|██████████████████████████████████████████████████████████████████████████████████| 34/34 [04:17<00:00,  7.59s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:32<00:00,  6.59s/it]


Center full_lb
Closest full_lb


100%|██████████████████████████████████████████████████████████████████████████████████| 34/34 [04:17<00:00,  7.57s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:35<00:00,  7.08s/it]


Center full_lb0
Closest full_lb0


100%|██████████████████████████████████████████████████████████████████████████████████| 34/34 [04:20<00:00,  7.66s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:33<00:00,  6.72s/it]


Center full_lb12
Closest full_lb12


100%|██████████████████████████████████████████████████████████████████████████████████| 34/34 [04:19<00:00,  7.62s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:34<00:00,  6.81s/it]


Center full_lb123
Closest full_lb123


100%|██████████████████████████████████████████████████████████████████████████████████| 34/34 [04:19<00:00,  7.64s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:34<00:00,  6.97s/it]


Center full_lb66
Closest full_lb66


100%|██████████████████████████████████████████████████████████████████████████████████| 17/17 [02:06<00:00,  7.45s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:16<00:00,  5.52s/it]


Center sub50_lb
Closest sub50_lb


100%|██████████████████████████████████████████████████████████████████████████████████| 17/17 [02:04<00:00,  7.33s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:17<00:00,  5.95s/it]


Center sub50_lb0
Closest sub50_lb0


100%|██████████████████████████████████████████████████████████████████████████████████| 17/17 [02:08<00:00,  7.54s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:17<00:00,  5.69s/it]


Center sub50_lb12
Closest sub50_lb12


100%|██████████████████████████████████████████████████████████████████████████████████| 17/17 [02:07<00:00,  7.50s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:15<00:00,  5.19s/it]


Center sub50_lb123
Closest sub50_lb123


100%|██████████████████████████████████████████████████████████████████████████████████| 17/17 [02:07<00:00,  7.50s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:16<00:00,  5.51s/it]


Center sub50_lb66
Closest sub50_lb66


100%|████████████████████████████████████████████████████████████████████████████████████| 9/9 [01:02<00:00,  6.90s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:07<00:00,  3.90s/it]


Center sub25_lb
Closest sub25_lb


100%|████████████████████████████████████████████████████████████████████████████████████| 9/9 [01:03<00:00,  7.05s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:10<00:00,  5.31s/it]


Center sub25_lb0
Closest sub25_lb0


100%|████████████████████████████████████████████████████████████████████████████████████| 9/9 [01:11<00:00,  7.96s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:09<00:00,  4.86s/it]


Center sub25_lb12
Closest sub25_lb12


100%|████████████████████████████████████████████████████████████████████████████████████| 9/9 [01:13<00:00,  8.13s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:09<00:00,  4.68s/it]


Center sub25_lb123
Closest sub25_lb123


100%|████████████████████████████████████████████████████████████████████████████████████| 9/9 [01:11<00:00,  7.96s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:09<00:00,  4.74s/it]


Center sub25_lb66
Closest sub25_lb66


100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:28<00:00,  7.06s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.43s/it]


Center sub10_lb
Closest sub10_lb


100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:26<00:00,  6.61s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.49s/it]


Center sub10_lb0
Closest sub10_lb0


100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:26<00:00,  6.65s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.66s/it]


Center sub10_lb12
Closest sub10_lb12


100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:26<00:00,  6.67s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.55s/it]


Center sub10_lb123
Closest sub10_lb123


100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:26<00:00,  6.67s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.65s/it]


Center sub10_lb66
Closest sub10_lb66


In [13]:
print_metrics_avg(files_full, center_full, closest_full)
print_metrics_avg(files50, center50, closest50)
print_metrics_avg(files25, center25, closest25)
print_metrics_avg(files10, center10, closest10)

photometry
Center:		 81.86	 ± 0.865
Closest:	 87.018	 ± 0.346

spectra
Center:		 67.439	 ± 0.688
Closest:	 72.631	 ± 0.526

meta
Center:		 74.136	 ± 0.712
Closest:	 81.459	 ± 0.613

all
Center:		 80.035	 ± 0.532
Closest:	 88.791	 ± 0.254

photometry
Center:		 83.793	 ± 0.792
Closest:	 84.53	 ± 0.512

spectra
Center:		 65.595	 ± 0.677
Closest:	 67.57	 ± 1.351

meta
Center:		 74.211	 ± 0.576
Closest:	 78.445	 ± 1.561

all
Center:		 80.779	 ± 0.564
Closest:	 86.791	 ± 0.543

photometry
Center:		 83.605	 ± 1.339
Closest:	 84.508	 ± 0.937

spectra
Center:		 65.989	 ± 1.727
Closest:	 65.131	 ± 1.339

meta
Center:		 74.672	 ± 1.559
Closest:	 76.944	 ± 1.066

all
Center:		 80.399	 ± 1.563
Closest:	 87.03	 ± 0.64

photometry
Center:		 84.895	 ± 1.334
Closest:	 86.442	 ± 1.054

spectra
Center:		 64.237	 ± 1.655
Closest:	 62.604	 ± 1.388

meta
Center:		 75.702	 ± 1.895
Closest:	 76.978	 ± 0.93

all
Center:		 80.616	 ± 2.461
Closest:	 85.804	 ± 1.597



In [43]:
train_dataset = PSMDataset(config, split='train')
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)

test_dataset = PSMDataset(config, split='test')
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=True)

In [44]:
train_p_emb, train_s_emb, train_m_emb, train_labels = get_embs(train_dataloader)
test_p_emb, test_s_emb, test_m_emb, test_labels = get_embs(test_dataloader)

100%|██████████████████████████████████████████████████████████████████████████████████| 34/34 [04:12<00:00,  7.42s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:32<00:00,  6.49s/it]


In [52]:
p_centers, s_centers, m_centers, all_centers = get_centers(train_p_emb, train_s_emb, train_m_emb, train_labels)
p_pred, s_pred, m_pred, all_pred = get_pred_center(test_p_emb, test_s_emb, test_m_emb, p_centers, s_centers, m_centers, all_centers)

print('Photometry Acc', (p_pred == test_labels).sum().item() / len(test_labels))
print('Spectra Acc   ', (s_pred == test_labels).sum().item() / len(test_labels))
print('Meta Acc      ', (m_pred == test_labels).sum().item() / len(test_labels))
print('All Acc       ', (all_pred == test_labels).sum().item() / len(test_labels))

Photometry Acc 0.8085393258426966
Spectra Acc    0.6719101123595506
Meta Acc       0.7492134831460674
All Acc        0.762247191011236


In [150]:
"""
Data Type     No CLIP           CLIP
Photometry    84.642 ± 6.317    91.468 ± 0.446
Spectra       76.278 ± 0.931    77.396 ± 0.614
Metadata      85.623 ± 0.628    85.855 ± 0.856
All           94.065 ± 0.390    94.153 ± 0.577
"""

'\nData Type     No CLIP           CLIP\nPhotometry    84.642 ± 6.317    91.468 ± 0.446\nSpectra       76.278 ± 0.931    77.396 ± 0.614\nMetadata      85.623 ± 0.628    85.855 ± 0.856\nAll           94.065 ± 0.390    94.153 ± 0.577\n'

In [51]:
p_pred, s_pred, m_pred, all_pred = get_pred_closest(test_p_emb, train_p_emb, test_s_emb, train_s_emb, test_m_emb, train_m_emb, train_labels)

print('Photometry Acc', (p_pred == test_labels).sum().item() / len(test_labels))
print('Spectra Acc   ', (s_pred == test_labels).sum().item() / len(test_labels))
print('Meta Acc      ', (m_pred == test_labels).sum().item() / len(test_labels))
print('All Acc       ', (all_pred == test_labels).sum().item() / len(test_labels))

Photometry Acc 0.8692134831460674
Spectra Acc    0.7164044943820225
Meta Acc       0.8116853932584269
All Acc        0.8831460674157303


In [19]:
config['file'] = 'preprocessed_data/sub50_lb/spectra_and_v'

train_dataset = PSMDataset(config, split='train')
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)

test_dataset = PSMDataset(config, split='test')
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=True)

train_p_emb, train_s_emb, train_m_emb, train_labels = get_embs(train_dataloader)
test_p_emb, test_s_emb, test_m_emb, test_labels = get_embs(test_dataloader)

p_centers, s_centers, m_centers = get_centers(train_p_emb, train_s_emb, train_m_emb)
p_pred, s_pred, m_pred = get_pred_center(test_p_emb, test_s_emb, test_m_emb, p_centers, s_centers, m_centers)
pred, _ = torch.mode(torch.stack([p_pred, s_pred, m_pred]), axis=0)

print('Center 50% Data')
print('Photometry Acc', (p_pred == test_labels).sum().item() / len(test_labels))
print('Spectra Acc   ', (s_pred == test_labels).sum().item() / len(test_labels))
print('Meta Acc      ', (m_pred == test_labels).sum().item() / len(test_labels))
print('All Acc       ', (pred == test_labels).sum().item() / len(test_labels))

p_pred, s_pred, m_pred = get_pred_closest(test_p_emb, train_p_emb, test_s_emb, train_s_emb, test_m_emb, train_m_emb, train_labels)
pred, _ = torch.mode(torch.stack([p_pred, s_pred, m_pred]), axis=0)

print('Closest 50% Data')
print('Photometry Acc', (p_pred == test_labels).sum().item() / len(test_labels))
print('Spectra Acc   ', (s_pred == test_labels).sum().item() / len(test_labels))
print('Meta Acc      ', (m_pred == test_labels).sum().item() / len(test_labels))
print('All Acc       ', (pred == test_labels).sum().item() / len(test_labels))

100%|██████████████████████████████████████████████████████████████████████████████████| 17/17 [02:03<00:00,  7.28s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:16<00:00,  5.47s/it]


Center 50% Data
Photometry Acc 0.828082808280828
Spectra Acc    0.648064806480648
Meta Acc       0.7407740774077408
All Acc        0.7911791179117912
Closest 50% Data
Photometry Acc 0.8568856885688569
Spectra Acc    0.666966696669667
Meta Acc       0.7704770477047704
All Acc        0.8487848784878488


In [20]:
config['file'] = 'preprocessed_data/sub25_lb/spectra_and_v'

train_dataset = PSMDataset(config, split='train')
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)

test_dataset = PSMDataset(config, split='test')
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=True)

train_p_emb, train_s_emb, train_m_emb, train_labels = get_embs(train_dataloader)
test_p_emb, test_s_emb, test_m_emb, test_labels = get_embs(test_dataloader)

p_centers, s_centers, m_centers = get_centers(train_p_emb, train_s_emb, train_m_emb)
p_pred, s_pred, m_pred = get_pred_center(test_p_emb, test_s_emb, test_m_emb, p_centers, s_centers, m_centers)
pred, _ = torch.mode(torch.stack([p_pred, s_pred, m_pred]), axis=0)

print('Center 25% Data')
print('Photometry Acc', (p_pred == test_labels).sum().item() / len(test_labels))
print('Spectra Acc   ', (s_pred == test_labels).sum().item() / len(test_labels))
print('Meta Acc      ', (m_pred == test_labels).sum().item() / len(test_labels))
print('All Acc       ', (pred == test_labels).sum().item() / len(test_labels))

p_pred, s_pred, m_pred = get_pred_closest(test_p_emb, train_p_emb, test_s_emb, train_s_emb, test_m_emb, train_m_emb, train_labels)
pred, _ = torch.mode(torch.stack([p_pred, s_pred, m_pred]), axis=0)

print('Closest 25% Data')
print('Photometry Acc', (p_pred == test_labels).sum().item() / len(test_labels))
print('Spectra Acc   ', (s_pred == test_labels).sum().item() / len(test_labels))
print('Meta Acc      ', (m_pred == test_labels).sum().item() / len(test_labels))
print('All Acc       ', (pred == test_labels).sum().item() / len(test_labels))

100%|████████████████████████████████████████████████████████████████████████████████████| 9/9 [01:01<00:00,  6.83s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:07<00:00,  3.91s/it]

Center 25% Data
Photometry Acc 0.827027027027027
Spectra Acc    0.6558558558558558
Meta Acc       0.7279279279279279
All Acc        0.790990990990991
Closest 25% Data
Photometry Acc 0.8450450450450451
Spectra Acc    0.6324324324324324
Meta Acc       0.7621621621621621
All Acc        0.8324324324324325





In [21]:
config['file'] = 'preprocessed_data/sub10_lb/spectra_and_v'

train_dataset = PSMDataset(config, split='train')
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)

test_dataset = PSMDataset(config, split='test')
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=True)

train_p_emb, train_s_emb, train_m_emb, train_labels = get_embs(train_dataloader)
test_p_emb, test_s_emb, test_m_emb, test_labels = get_embs(test_dataloader)

p_centers, s_centers, m_centers = get_centers(train_p_emb, train_s_emb, train_m_emb)
p_pred, s_pred, m_pred = get_pred_center(test_p_emb, test_s_emb, test_m_emb, p_centers, s_centers, m_centers)
pred, _ = torch.mode(torch.stack([p_pred, s_pred, m_pred]), axis=0)

print('Center 10% Data')
print('Photometry Acc', (p_pred == test_labels).sum().item() / len(test_labels))
print('Spectra Acc   ', (s_pred == test_labels).sum().item() / len(test_labels))
print('Meta Acc      ', (m_pred == test_labels).sum().item() / len(test_labels))
print('All Acc       ', (pred == test_labels).sum().item() / len(test_labels))

p_pred, s_pred, m_pred = get_pred_closest(test_p_emb, train_p_emb, test_s_emb, train_s_emb, test_m_emb, train_m_emb, train_labels)
pred, _ = torch.mode(torch.stack([p_pred, s_pred, m_pred]), axis=0)

print('Closest 10% Data')
print('Photometry Acc', (p_pred == test_labels).sum().item() / len(test_labels))
print('Spectra Acc   ', (s_pred == test_labels).sum().item() / len(test_labels))
print('Meta Acc      ', (m_pred == test_labels).sum().item() / len(test_labels))
print('All Acc       ', (pred == test_labels).sum().item() / len(test_labels))

100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:24<00:00,  6.10s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.95s/it]

Center 10% Data
Photometry Acc 0.85
Spectra Acc    0.6409090909090909
Meta Acc       0.7772727272727272
All Acc        0.8227272727272728
Closest 10% Data
Photometry Acc 0.8590909090909091
Spectra Acc    0.6
Meta Acc       0.7818181818181819
All Acc        0.8318181818181818





## Results Across Diff Random Seeds

In [None]:
files_full = ['full_lb', 'full_lb0', 'full_lb12', 'full_lb123', 'full_lb66']
center_full, closest_full = get_zero_shot_metrics(files_full)

100%|██████████████████████████████████████████████████████████████████████████████████| 34/34 [04:54<00:00,  8.67s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:38<00:00,  7.67s/it]


Center full_lb
Closest full_lb


100%|██████████████████████████████████████████████████████████████████████████████████| 34/34 [04:56<00:00,  8.71s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:39<00:00,  7.93s/it]


Center full_lb0
Closest full_lb0


100%|██████████████████████████████████████████████████████████████████████████████████| 34/34 [04:57<00:00,  8.74s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:38<00:00,  7.60s/it]


Center full_lb12
Closest full_lb12


100%|██████████████████████████████████████████████████████████████████████████████████| 34/34 [04:51<00:00,  8.58s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:39<00:00,  7.83s/it]


Center full_lb123
Closest full_lb123


 76%|██████████████████████████████████████████████████████████████▋                   | 26/34 [03:50<01:11,  8.88s/it]

In [None]:
files50 = ['sub50_lb', 'sub50_lb0', 'sub50_lb12', 'sub50_lb123', 'sub50_lb66']
center50, closest50 = get_zero_shot_metrics(files50)

In [None]:
files25 = ['sub25_lb', 'sub25_lb0', 'sub25_lb12', 'sub25_lb123', 'sub25_lb66']
center25, closest25 = get_zero_shot_metrics(files25)

In [None]:
files10 = ['sub10_lb', 'sub10_lb0', 'sub10_lb12', 'sub10_lb123', 'sub10_lb66']
center10, closest10 = get_zero_shot_metrics(files10)

In [None]:
print_metrics_avg(files10, center10, closest10)

In [100]:
print_metrics_avg(files10, center10, closest10)

photometry
Center:		 84.895	 ± 1.334
Closest:	 86.442	 ± 1.054

spectra
Center:		 64.237	 ± 1.655
Closest:	 62.604	 ± 1.388

meta
Center:		 75.702	 ± 1.895
Closest:	 76.978	 ± 0.93

all
Center:		 80.524	 ± 2.593
Closest:	 85.804	 ± 1.597



In [101]:
files10 = ['sub10_lb', 'sub10_lb0', 'sub10_lb12', 'sub10_lb123', 'sub10_lb66']
center10, closest10 = get_zero_shot_metrics(files10)
print_metrics_avg(files10, center10, closest10)

100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:28<00:00,  7.11s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.90s/it]


Center sub10_lb
Closest sub10_lb


100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:26<00:00,  6.68s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.43s/it]


Center sub10_lb0
Closest sub10_lb0


100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:25<00:00,  6.50s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.71s/it]


Center sub10_lb12
Closest sub10_lb12


100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:25<00:00,  6.44s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.49s/it]


Center sub10_lb123
Closest sub10_lb123


100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:26<00:00,  6.59s/it]
100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.57s/it]


Center sub10_lb66
Closest sub10_lb66
photometry
Center:		 84.895	 ± 1.334
Closest:	 86.442	 ± 1.054

spectra
Center:		 64.237	 ± 1.655
Closest:	 62.604	 ± 1.388

meta
Center:		 75.702	 ± 1.895
Closest:	 76.978	 ± 0.93

all
Center:		 80.524	 ± 2.593
Closest:	 85.804	 ± 1.597



In [72]:
print_metrics_avg(files25, center25, closest25)

photometry
Center:		 83.605	 ± 1.339
Closest:	 84.327	 ± 1.787

spectra
Center:		 65.989	 ± 1.727
Closest:	 65.131	 ± 1.339

meta
Center:		 74.672	 ± 1.559
Closest:	 76.944	 ± 1.066

all
Center:		 80.399	 ± 1.563
Closest:	 86.957	 ± 0.728



In [73]:
print_metrics_avg(files50, center50, closest50)

photometry
Center:		 83.722	 ± 0.748
Closest:	 84.8	 ± 0.545

spectra
Center:		 65.595	 ± 0.677
Closest:	 67.57	 ± 1.351

meta
Center:		 74.211	 ± 0.576
Closest:	 78.445	 ± 1.561

all
Center:		 80.761	 ± 0.589
Closest:	 86.737	 ± 0.712



In [74]:
print_metrics_avg(files_full, center_full, closest_full)

photometry
Center:		 81.851	 ± 0.867
Closest:	 86.732	 ± 0.485

spectra
Center:		 67.439	 ± 0.688
Closest:	 72.64	 ± 0.531

meta
Center:		 74.136	 ± 0.712
Closest:	 81.459	 ± 0.613

all
Center:		 80.018	 ± 0.544
Closest:	 88.603	 ± 0.358



# Single example prediction

In [76]:
config['file'] = 'preprocessed_data/sub10_lb/spectra_and_v'
train_dataset = PSMDataset(config, split='train')
train_dataloader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
train_p_emb, train_s_emb, train_m_emb, train_labels = get_embs(train_dataloader)

100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:28<00:00,  7.20s/it]


In [105]:
one_p_emb = []
one_s_emb = []
one_m_emb = []

for i in range(10):
    one_p_emb.append(train_p_emb[train_labels == i][2])
    one_s_emb.append(train_s_emb[train_labels == i][2])
    one_m_emb.append(train_m_emb[train_labels == i][2])

one_p_emb = torch.vstack(one_p_emb)
one_s_emb = torch.vstack(one_s_emb)
one_m_emb = torch.vstack(one_m_emb)
one_all_emb = (one_p_emb + one_s_emb + one_m_emb) / 3

In [106]:
test_dataset = PSMDataset(config, split='test')
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=True)
test_p_emb, test_s_emb, test_m_emb, test_labels = get_embs(test_dataloader)

100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:04<00:00,  4.26s/it]


In [107]:
p_pred, s_pred, m_pred, all_pred = get_pred_center(test_p_emb, test_s_emb, test_m_emb, one_p_emb, one_s_emb, one_m_emb, one_all_emb)

In [108]:
p_acc = (p_pred == test_labels).sum().item() / len(test_labels)
s_acc = (s_pred == test_labels).sum().item() / len(test_labels)
m_acc = (m_pred == test_labels).sum().item() / len(test_labels)
all_acc = (all_pred == test_labels).sum().item() / len(test_labels)

print(f'One \t photo {p_acc} spectra {s_acc} meta {m_acc} all {all_acc}')

One 	 photo 0.4954545454545455 spectra 0.4590909090909091 meta 0.41818181818181815 all 0.4863636363636364


In [95]:
config['file'] = 'preprocessed_data/sub25_lb/spectra_and_v'
test_dataset = PSMDataset(config, split='test')
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=True)

test_p_emb, test_s_emb, test_m_emb, test_labels = get_embs(test_dataloader)
p_pred, s_pred, m_pred, all_pred = get_pred_center(test_p_emb, test_s_emb, test_m_emb, one_p_emb, one_s_emb, one_m_emb, one_all_emb)

p_acc = (p_pred == test_labels).sum().item() / len(test_labels)
s_acc = (s_pred == test_labels).sum().item() / len(test_labels)
m_acc = (m_pred == test_labels).sum().item() / len(test_labels)
all_acc = (all_pred == test_labels).sum().item() / len(test_labels)

print(f'One \t photo {p_acc} spectra {s_acc} meta {m_acc} all {all_acc}')

100%|████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:09<00:00,  4.61s/it]

One 	 photo 0.5369369369369369 spectra 0.4972972972972973 meta 0.4828828828828829 all 0.5513513513513514





In [96]:
config['file'] = 'preprocessed_data/sub50_lb/spectra_and_v'
test_dataset = PSMDataset(config, split='test')
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=True)

test_p_emb, test_s_emb, test_m_emb, test_labels = get_embs(test_dataloader)
p_pred, s_pred, m_pred, all_pred = get_pred_center(test_p_emb, test_s_emb, test_m_emb, one_p_emb, one_s_emb, one_m_emb, one_all_emb)

p_acc = (p_pred == test_labels).sum().item() / len(test_labels)
s_acc = (s_pred == test_labels).sum().item() / len(test_labels)
m_acc = (m_pred == test_labels).sum().item() / len(test_labels)
all_acc = (all_pred == test_labels).sum().item() / len(test_labels)

print(f'One \t photo {p_acc} spectra {s_acc} meta {m_acc} all {all_acc}')

100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:17<00:00,  5.97s/it]

One 	 photo 0.5427542754275427 spectra 0.5184518451845185 meta 0.49954995499549953 all 0.585958595859586





In [97]:
config['file'] = 'preprocessed_data/full_lb/spectra_and_v'
test_dataset = PSMDataset(config, split='test')
test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=True)

test_p_emb, test_s_emb, test_m_emb, test_labels = get_embs(test_dataloader)
p_pred, s_pred, m_pred, all_pred = get_pred_center(test_p_emb, test_s_emb, test_m_emb, one_p_emb, one_s_emb, one_m_emb, one_all_emb)

p_acc = (p_pred == test_labels).sum().item() / len(test_labels)
s_acc = (s_pred == test_labels).sum().item() / len(test_labels)
m_acc = (m_pred == test_labels).sum().item() / len(test_labels)
all_acc = (all_pred == test_labels).sum().item() / len(test_labels)

print(f'One \t photo {p_acc} spectra {s_acc} meta {m_acc} all {all_acc}')

100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:33<00:00,  6.79s/it]


One 	 photo 0.5177528089887641 spectra 0.529438202247191 meta 0.4732584269662921 all 0.5716853932584269


In [27]:
val_m_emb[0] @ m_centers.T

tensor([-0.1157,  0.0946,  0.2835,  0.3743,  0.2619,  0.3575, -0.0509,  0.8764,
         0.7402,  0.1351])

In [12]:

for i in range(10):
    print(i, val_m_emb[0] @ m_centers[i].cpu())

print(val_labels[0])

0 tensor(-0.1157)
1 tensor(0.0946)
2 tensor(0.2835)
3 tensor(0.3743)
4 tensor(0.2619)
5 tensor(0.3575)
6 tensor(-0.0509)
7 tensor(0.8764)
8 tensor(0.7402)
9 tensor(0.1351)
tensor(7)


In [10]:
m_centers

[tensor([ 2.1247e-02,  3.7952e-03, -1.4793e-02, -1.0881e-01, -2.7797e-03,
          2.1860e-02,  1.1879e-02, -1.1445e-02, -3.9842e-03,  2.9460e-02,
         -1.6177e-02,  1.1165e-02,  3.9058e-03, -2.7822e-02, -2.4768e-02,
         -9.7808e-03,  2.7336e-02, -3.7336e-03,  2.4106e-04, -1.9938e-02,
         -1.0377e-02,  6.5937e-03, -2.3949e-02,  4.3492e-02,  1.0484e-02,
          1.3267e-02, -8.8545e-03,  1.0652e-02,  9.4356e-03,  8.6364e-03,
         -5.6013e-02, -1.1076e-02,  2.0148e-02, -7.8863e-03, -1.9276e-04,
          3.9877e-02,  1.5867e-01,  6.6636e-03, -1.5462e-02, -1.0152e-02,
          9.9161e-03, -1.3483e-02,  8.0748e-03,  2.6081e-03,  3.0227e-03,
         -3.6729e-05,  7.0379e-03, -1.8633e-02, -9.6590e-03, -1.2463e-02,
         -2.1684e-02, -2.3899e-02,  1.8698e-01, -2.6397e-02, -1.3501e-02,
          2.0383e-03, -1.6449e-02, -9.4272e-03,  2.2986e-03, -9.4092e-03,
          4.2974e-03,  2.0360e-02,  2.7541e-02, -2.3124e-02,  3.8132e-02,
          1.5180e-02, -1.5430e-05, -6.