In [1]:
import sys
sys.path.insert(0, '..')

import torch
import os
import wandb
import random
import numpy as np
import torch
from torch import nn
from torch.optim import Adam
from ignite.handlers.param_scheduler import create_lr_scheduler_with_warmup
from torch.utils.data import DataLoader
from datetime import datetime
import matplotlib.pyplot as plt
import pandas as pd
from scipy import stats

from core.final.dataset import PSMDataset
from core.final.model import GalSpecNet, MetaModel, Informer, AstroModel
from core.final.trainer import Trainer

In [9]:
plt.ioff()

In [10]:
def get_model(config):
    if config['mode'] == 'photo':
        model = Informer(config)
    elif config['mode'] == 'spectra':
        model = GalSpecNet(config)
    elif config['mode'] == 'meta':
        model = MetaModel(config)
    else:
        model = AstroModel(config)

    return model

In [11]:
def calc_results(run_id, last_epoch):
    api = wandb.Api()
    run = api.run(run_id)
    config = run.config
    
    config['use_wandb'] = False
    config['file'] = 'preprocessed_data/full_lb/spectra_and_v'
    
    test_dataset = PSMDataset(config, split='test')
    test_dataloader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    model = get_model(config)
    model = model.to(device)
    
    weights_path = os.path.join(config['weights_path'] + '-' + run_id.split('/')[-1], f'weights-{last_epoch}.pth')
    model.load_state_dict(torch.load(weights_path, weights_only=False))

    trainer = Trainer(model=model, optimizer=None, scheduler=None, warmup_scheduler=None, criterion=None, device=device, config=config)
    conf_matrix = trainer.evaluate(test_dataloader, test_dataset.id2target)
    conf_matrix_percent = 100 * conf_matrix / conf_matrix.sum(axis=1)[:, np.newaxis]
    
    acc = np.trace(conf_matrix) / np.sum(conf_matrix)
    acc_percent = np.trace(conf_matrix_percent) / np.sum(conf_matrix_percent)
    
    return acc, acc_percent

In [12]:
sub50 = {
    'photo': ['uwi549ko', 'lasio95e', 'ssjgsf4o', '8pjmji36', 'ey93jptd'],
    'spectra': ['9w2iy7if', 'uxdgxzzg', '20jf5fc3', 'apy0hwze', 'l8r1v87z'],
    'meta': ['dnql8zxb', 'w4z8mc7g', '872t6r9u', 'pdk8lhmr', 'tiorlnpa'],
    'all': ['awnlvghm', 'u7lhipcv', 'ibj38qtl', 'v268pfls', '10k7n8f8'],
    'photoclip': ['6cfb67n2', 'u2xrn2es', '7ds9f4j0', 'yoel2uju', 'klhhmrgz'],
    'spectraclip': ['jft4o7ll', '2ys1pu74', '0u83wkoo', 'ubyfjbl6', '6mhe95eb'],
    'metaclip': ['6afyvxnq', 'f4lt4n3m', '1hlifayz', '30m8j9kv', 'a2pps0vg'],
    'allclip': ['52fk0181', 'f4hvk7s9', 'xm2gkm0g', 'sf3f1fn7', '80s6j269']
}

sub25 = {
    'photo': ['i16e091u', 'yq6wrwn8', 'khf2rzs1', '32po6rjv', '3oz9g756'],
    'spectra': ['85x94bpe', 'k4pywvjv', 'dy7uliaw', 's66ifk69', 'rn3x84v9'],
    'meta': ['5n21zp2m', 'e2jkak5c', '02zpe2gy', 'dado4tth', 'etn9tbcy'],
    'all': ['wit0n3w3', 'w3d0xbe2', '96fwt2gx', 'du5zsfz6', '0k73x5qd'],
    'photoclip': ['1xgyaljv', '35idhf4w', 'o8n5dti7', 'fhpmo60y', '4at8i8ji'],
    'spectraclip': ['x1ec0yff', 'd5c3rhwz', 'kggejkom', '52xgy4wt', 'r5ctmuri'],
    'metaclip': ['osxq73a0', 'p2ozjutp', 'sce7zk10', 'bw5lzst5', '8hhqrs6d'],
    'allclip': ['a5peiq96', '05o27w72', 'xnbiygd1', '71qz3y5q', 'j8bnyk18']
}

sub10 = {
    'photo': ['nq1bhy0w', 'ogd2l28v', 'yagx8pk7', 'eh9qmriw', 'mk9ckjb4'],
    'spectra': ['vvljqb9o', 'bjnix390', 'wwrv3nut', 'pagk0jo5', '2dx7j4sy'],
    'meta': ['qqjg8qk3', 'n38nru7q', 'xftkn3gw', 'dtcx3xqk', 'r5mqe2rp'],
    'all': ['8hl6cpj7', 'h01wkom5', '5ivkbdrq', '5l9h32ak', 'thf3hzty'],
    'photoclip': ['v9hs5mvj', 'vc0tj3ee', '451pr1oc', 'lit1jza0', 'xp3lxwkn'],
    'spectraclip': ['x27wg2o5', 'gbp5cbc4', '0t8y55t8', 'hu285ze4', '3xe7g7ac'],
    'metaclip': ['0sjxdzrr', 'fm8wkcdn', 'vsku583d', '68n5fsm7', 'plpa1d4x'],
    'allclip': ['h4bq7juh', '7l1h5ak1', '6apkn9rt', 'hqwxn9df', 'eeur7zbf']
}

In [13]:
meta_sub_acc = {
    'meta': {
        'sub50': [],
        'sub25': [],
        'sub10': []
    },
    'metaclip': {
        'sub50': [],
        'sub25': [],
        'sub10': []
    }
}

for mode in meta_sub_acc:
    for s in meta_sub_acc[mode]:
        print(mode, s)
        
        if s == 'sub50':
            runs = sub50[mode]
        elif s == 'sub25':
            runs = sub25[mode]
        elif s == 'sub10':
            runs = sub10[mode]

        for r in runs:
            acc, acc_percent = calc_results('MeriDK/AstroCLIPResults3/' + r, 'best')
            meta_sub_acc[mode][s].append((acc, acc_percent))

In [14]:
photo_sub_acc = {
    'photo': {
        'sub50': [],
        'sub25': [],
        'sub10': []
    },
    'photoclip': {
        'sub50': [],
        'sub25': [],
        'sub10': []
    }
}

for mode in photo_sub_acc:
    for s in photo_sub_acc[mode]:
        print(mode, s)
        
        if s == 'sub50':
            runs = sub50[mode]
        elif s == 'sub25':
            runs = sub25[mode]
        elif s == 'sub10':
            runs = sub10[mode]

        for r in runs:
            acc, acc_percent = calc_results('MeriDK/AstroCLIPResults3/' + r, 'best')
            photo_sub_acc[mode][s].append((acc, acc_percent))

In [21]:
spectra_sub_acc = {
    'spectra': {
        'sub50': [],
        'sub25': [],
        'sub10': []
    },
    'spectraclip': {
        'sub50': [],
        'sub25': [],
        'sub10': []
    }
}

for mode in spectra_sub_acc:
    for s in spectra_sub_acc[mode]:
        print(mode, s)
        
        if s == 'sub50':
            runs = sub50[mode]
        elif s == 'sub25':
            runs = sub25[mode]
        elif s == 'sub10':
            runs = sub10[mode]

        for r in runs:
            acc, acc_percent = calc_results('MeriDK/AstroCLIPResults3/' + r, 'best')
            spectra_sub_acc[mode][s].append((acc, acc_percent))

In [25]:
all_sub_acc = {
    'all': {
        'sub50': [],
        'sub25': [],
        'sub10': []
    },
    'allclip': {
        'sub50': [],
        'sub25': [],
        'sub10': []
    }
}

for mode in all_sub_acc:
    for s in all_sub_acc[mode]:
        print(mode, s)
        
        if s == 'sub50':
            runs = sub50[mode]
        elif s == 'sub25':
            runs = sub25[mode]
        elif s == 'sub10':
            runs = sub10[mode]

        for r in runs:
            acc, acc_percent = calc_results('MeriDK/AstroCLIPResults3/' + r, 'best')
            all_sub_acc[mode][s].append((acc, acc_percent))

In [16]:
for mode in meta_sub_acc:
    for s in meta_sub_acc[mode]:   
        acc = meta_sub_acc[mode][s][0]
        mean = np.mean(acc)
        std = np.std(acc)
        print(mode, s, 'mean', round(mean * 100, 3), 'std', round(std * 100, 3), 'min', round((mean - std) * 100, 3), 'max', round((mean + std) * 100, 3))

In [17]:
for mode in meta_sub_acc:
    for s in meta_sub_acc[mode]:   
        acc = meta_sub_acc[mode][s][1]
        mean = np.mean(acc)
        std = np.std(acc)
        print(mode, 'weighted', s, 'mean', round(mean * 100, 3), 'std', round(std * 100, 3), 'min', round((mean - std) * 100, 3), 'max', round((mean + std) * 100, 3))

In [18]:
for mode in photo_sub_acc:
    for s in photo_sub_acc[mode]:
        acc = photo_sub_acc[mode][s][0]
        mean = np.mean(acc)
        std = np.std(acc)
        print(mode, s, 'mean', round(mean * 100, 3), 'std', round(std * 100, 3), 'min', round((mean - std) * 100, 3), 'max', round((mean + std) * 100, 3))

In [20]:
for mode in photo_sub_acc:
    for s in photo_sub_acc[mode]:
        acc = photo_sub_acc[mode][s][1]
        mean = np.mean(acc)
        std = np.std(acc)
        print(mode, 'weighted', s, 'mean', round(mean * 100, 3), 'std', round(std * 100, 3), 'min', round((mean - std) * 100, 3), 'max', round((mean + std) * 100, 3))

In [23]:
for mode in spectra_sub_acc:
    for s in spectra_sub_acc[mode]:
        acc = spectra_sub_acc[mode][s][0]
        mean = np.mean(acc)
        std = np.std(acc)
        print(mode, s, 'mean', round(mean * 100, 3), 'std', round(std * 100, 3), 'min', round((mean - std) * 100, 3), 'max', round((mean + std) * 100, 3))

In [24]:
for mode in spectra_sub_acc:
    for s in spectra_sub_acc[mode]:
        acc = spectra_sub_acc[mode][s][1]
        mean = np.mean(acc)
        std = np.std(acc)
        print(mode, 'weighted', s, 'mean', round(mean * 100, 3), 'std', round(std * 100, 3), 'min', round((mean - std) * 100, 3), 'max', round((mean + std) * 100, 3))

In [26]:
for mode in all_sub_acc:
    for s in all_sub_acc[mode]:
        acc = all_sub_acc[mode][s][0]
        mean = np.mean(acc)
        std = np.std(acc)
        print(mode, s, 'mean', round(mean * 100, 3), 'std', round(std * 100, 3), 'min', round((mean - std) * 100, 3), 'max', round((mean + std) * 100, 3))

In [27]:
for mode in all_sub_acc:
    for s in all_sub_acc[mode]:
        acc = all_sub_acc[mode][s][1]
        mean = np.mean(acc)
        std = np.std(acc)
        print(mode, 'weighted', s, 'mean', round(mean * 100, 3), 'std', round(std * 100, 3), 'min', round((mean - std) * 100, 3), 'max', round((mean + std) * 100, 3))