In [None]:
import altair as alt
import numpy as np
import pandas as pd
import torch

from collections import defaultdict
from pathlib import Path
from re import split
from tqdm import tqdm

# [1] **Line Count Heatmap**

In [None]:
line_stats = defaultdict(list)
for file in tqdm(Path('../../experiments/08_blase3D_HPC_test/emulator_states').glob('*0.0.pt'), total=657):
    state_dict = torch.load(file, 'cuda')
    tokens = split('[TGZ]', file.stem)
    line_stats['teff'].append(int(tokens[1]))
    line_stats['logg'].append(float(tokens[2]))
    line_stats['Z'].append(float(tokens[3]))
    line_stats['center'].append(state_dict['pre_line_centers'].cpu().numpy())
    line_stats['shift_center'].append(state_dict['lam_centers'].cpu().numpy())
    line_stats['amp'].append(state_dict['amplitudes'].cpu().numpy())
    line_stats['sigma'].append(state_dict['sigma_widths'].cpu().numpy())
    line_stats['gamma'].append(state_dict['gamma_widths'].cpu().numpy())
df = pd.DataFrame(line_stats).explode(['center', 'amp', 'sigma', 'gamma', 'shift_center'])
df['jitter'] = df.shift_center - df.center

In [None]:
alt.Chart(df.groupby(['teff', 'logg']).size().reset_index(name='n_lines')).mark_rect()\
    .encode(x=alt.X('teff:O', title='Effective Temperature [K]'), y=alt.Y('logg:O', title='Surface Gravity'), color=alt.Color('n_lines:Q', title='Line Count', scale=alt.Scale(type='log')))\
    .properties(width=1000, height=400)\
    .configure_axis(labelFontSize=15, titleFontSize=24)\
    .configure_legend(labelFontSize=15, titleFontSize=15)

# [2] **PHOENIX Subset Discrete Manifold**

In [None]:
first_line = df.value_counts('center').index[1]
alt.Chart(df.query('center == @first_line')).mark_rect()\
    .encode(x=alt.X('teff:O', title='Effective Temperature [K]'), y=alt.Y('logg:O', title='Surface Gravity'), color=alt.Color(f'amp:Q', title='Log-Amplitude'))\
    .properties(width=1000, height=400, title=f'Spectral Line at {first_line} Å')\
    .configure_axis(labelFontSize=15, titleFontSize=24)\
    .configure_legend(labelFontSize=15, titleFontSize=15)\
    .configure_title(fontSize=25)

In [None]:
np.where(df.center.unique() == 11617.66)
df.center.unique()[10051]