In [8]:
from astropy.table import Table
from astropy.io import fits
import pandas as pd
import time as t
import numpy as np
import os
import json

import umap
from sklearn.neighbors import radius_neighbors_graph
from sklearn.decomposition import PCA
from scipy.sparse.csgraph import connected_components

from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Sequence, Tuple

import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import rc
matplotlib.rcParams['figure.dpi'] = 360
matplotlib.rcParams['text.usetex'] = True
rc("animation", html = "jshtml")
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.ticker as ticker
import matplotlib.colors as colors
from matplotlib.lines import Line2D
# plt.style.use('default')

from pathlib import Path

In [3]:
!free -h

               total        used        free      shared  buff/cache   available
Mem:           251Gi        37Gi       202Gi       931Mi        30Gi       213Gi
Swap:             0B          0B          0B


In [4]:
import psutil
print("RAM gb:", psutil.virtual_memory().available/1e9)

RAM (GB): 229.756895232


### Save

In [2]:
base = Path('/pscratch/sd/v/vtorresg/desi-lenses/outlier_spectra_tiles/')

In [3]:
files = [str(x) for x in base.iterdir()]
len(files)

14199

In [5]:
OUTPUT_DIR = Path("/pscratch/sd/v/vtorresg/desi-lenses/outlier_spectra_tiles/")
BANDS: Tuple[str, ...] = ("B", "R", "Z")

In [6]:
@dataclass
class Metadata:
    fib_counts: List[int]
    union_waves: Dict[str, np.ndarray]
    band_offsets: Dict[str, int]
    total_rows: int
    waves_by_file: List[Dict[str, np.ndarray]]
    ids_by_file: List[np.ndarray]
    tiles_by_file: List[np.ndarray]
    nights_by_file: List[np.ndarray]
    petals_by_file: List[np.ndarray]
    fibers_by_file: List[np.ndarray]


def list_tile_fits(directory: Path) -> List[Path]:
    files = sorted(p for p in directory.glob("*.fits") if p.is_file())
    if not files:
        raise FileNotFoundError(f"No file in {directory}")
    return files


def gather_metadata(files: Sequence[Path], bands: Sequence[str]) -> Metadata:
    fib_counts: List[int] = []
    total_rows = 0
    waves_accum: Dict[str, List[np.ndarray]] = {b: [] for b in bands}
    waves_by_file: List[Dict[str, np.ndarray]] = []
    ids_by_file: List[np.ndarray] = []
    tiles_by_file: List[np.ndarray] = []
    nights_by_file: List[np.ndarray] = []
    petals_by_file: List[np.ndarray] = []
    fibers_by_file: List[np.ndarray] = []

    for jj, path in enumerate(files):
        print(f' {jj}/{len(files)}')

        with fits.open(path, memmap=True) as hdul:
            tab = hdul[1].data

            ids = np.asarray(tab["TARGETID"], dtype=np.int64)
            tiles = np.asarray(tab["TILEID"], dtype=np.int32)
            nights = np.asarray(tab["NIGHT"], dtype=np.int32)
            petals = np.asarray(tab["PETAL"], dtype=np.int16)
            fibers = np.asarray(tab["FIBER"], dtype=np.int32)

            n = ids.size
            fib_counts.append(n)
            total_rows += n

            ids_by_file.append(ids)
            tiles_by_file.append(tiles)
            nights_by_file.append(nights)
            petals_by_file.append(petals)
            fibers_by_file.append(fibers)

            file_waves: Dict[str, np.ndarray] = {}
            for band in bands:
                wave = np.asarray(tab[f"WAVE_{band}"][0], dtype=np.float64)
                file_waves[band] = wave
                waves_accum[band].append(wave)
            waves_by_file.append(file_waves)

    union_waves = {band: np.unique(np.concatenate(waves_accum[band]))
                   for band in bands}

    band_offsets: Dict[str, int] = {}
    offset = 0
    for band in bands:
        band_offsets[band] = offset
        offset += union_waves[band].size

    return Metadata(fib_counts=fib_counts,
                    union_waves=union_waves,
                band_offsets=band_offsets,
                total_rows=total_rows,
                waves_by_file=waves_by_file,
                ids_by_file=ids_by_file,
                tiles_by_file=tiles_by_file,
                nights_by_file=nights_by_file,
                petals_by_file=petals_by_file,
                fibers_by_file=fibers_by_file)


def allocate_matrices(md: Metadata, bands: Sequence[str]):
    total_cols = sum(md.union_waves[band].size for band in bands)
    flux = np.zeros((md.total_rows, total_cols), dtype=np.float32)
    wave_grid = np.concatenate([md.union_waves[band] for band in bands])

    target_ids = np.empty(md.total_rows, dtype=np.int64)
    tile_ids = np.empty(md.total_rows, dtype=np.int32)
    nights = np.empty(md.total_rows, dtype=np.int32)
    petals = np.empty(md.total_rows, dtype=np.int16)
    fibers = np.empty(md.total_rows, dtype=np.int32)

    return wave_grid, flux, target_ids, tile_ids, nights, petals, fibers


def fill_flux_matrix(files: Sequence[Path], bands: Sequence[str], md: Metadata, flux: np.ndarray,
                     target_ids: np.ndarray, tile_ids: np.ndarray, nights: np.ndarray, petals: np.ndarray,
                     fibers: np.ndarray) -> None:
    row = 0
    for i, (path, n_rows) in enumerate(zip(files, md.fib_counts)):
        print(f'--- {i}/{len(files)}')
        with fits.open(path, memmap=True) as hdul:
            tab = hdul[1].data

            target_ids[row:row + n_rows] = md.ids_by_file[i]
            tile_ids[row:row + n_rows] = md.tiles_by_file[i]
            nights[row:row + n_rows] = md.nights_by_file[i]
            petals[row:row + n_rows] = md.petals_by_file[i]
            fibers[row:row + n_rows] = md.fibers_by_file[i]

            dest_rows = np.arange(n_rows)[:, None]
            for band in bands:
                wave = md.waves_by_file[i][band]
                idx = np.searchsorted(md.union_waves[band], wave)
                dest = flux[row:row + n_rows,
                            md.band_offsets[band]: md.band_offsets[band] + md.union_waves[band].size]
                band_flux = [np.asarray(arr, dtype=np.float32) for arr in tab[f"FLUX_{band}"]]
                dest[dest_rows, idx] = np.vstack(band_flux)

        row += n_rows


def build_flux_cube(output_dir: Path, bands: Sequence[str]):
    files = list_tile_fits(output_dir)
    metadata = gather_metadata(files, bands)
    (wave_grid, flux, target_ids, tile_ids, nights, petals, fibers) = allocate_matrices(metadata, bands)
    fill_flux_matrix(files, bands, metadata, flux, target_ids, tile_ids, nights, petals, fibers)
    return {"wave_grid": wave_grid,
            "flux": flux,
            "target_id": target_ids,
            "tile_id": tile_ids,
            "night": nights,
            "petal": petals,
            "fiber": fibers,
            "band_offsets": metadata.band_offsets,
            "union_waves": metadata.union_waves}

In [7]:
# result = build_flux_cube(OUTPUT_DIR, BANDS)
# print("Flux matrix shape:", result["flux"].shape, "| wave grid length:", result["wave_grid"].size)

In [13]:
# np.savez('/pscratch/sd/v/vtorresg/desi-lenses/matrix_outliers.npz',
#         flux=result["flux"], wave=result["wave_grid"], targets=result['target_id'],
#         tiles=result["tile_id"], nights=result['night'], petals=result['petal'],
#         fibers=res∫ult['fiber'], band_off=result['band_offsets'], union_w=result['union_waves']
#        )

In [8]:
reducer = umap.UMAP(n_neighbors=100, min_dist=1.0, n_components=2,
                        metric='cosine', n_jobs=-1)

### Read

In [5]:
data = np.load('/pscratch/sd/v/vtorresg/desi-lenses/matrix_outliers.npz')
flux = data['flux']

In [6]:
flux.shape

(1143255, 7958)

In [7]:
X = flux.astype(np.float32, copy=False)

In [None]:
pca = PCA(n_components=200, svd_solver="randomized", random_state=42)
Xp = pca.fit_transform(flux.astype(np.float32))

In [8]:
reducer = umap.UMAP(n_neighbors=100, min_dist=1.0, n_components=2,
                    metric='cosine', n_jobs=-1)

In [None]:
reducer.fit(X)

In [None]:
N = X.shape[0]
B = min(N, 500_000)
Y = np.empty((N, 2), dtype=np.float32)

In [None]:
for i in tqdm(range(0, N, B), desc="UMAP transform", unit="rows"):
    Y[i:i+B] = reducer.transform(X[i:i+B])

In [None]:
npz_path = "/pscratch/sd/v/vtorresg/desi-lenses/matrix_outliers.npz"
d = np.load(npz_path, allow_pickle=True)

flux = d["flux"]
wavegrid = d["wave"]
targets = d["targets"]
tiles = d["tiles"]
nights = d["nights"]
petals = d["petals"]
fibers = d["fibers"]
band_off = d["band_off"].item()
union_w = d["union_w"].item()

In [None]:
umap1 = embedding[:, 0].astype(np.float32, copy=False)
umap2 = embedding[:, 1].astype(np.float32, copy=False)

In [None]:
N = flux.shape[0]

In [None]:
col_tid = fits.Column(name="TARGETID", format="K", array=targets.astype(np.int64, copy=False))

col_tile = fits.Column(name="TILEID",  format="J", array=tiles.astype(np.int32, copy=False))
col_fiber= fits.Column(name="FIBER",   format="J", array=fibers.astype(np.int32, copy=False))
col_petal= fits.Column(name="PETAL",   format="J", array=petals.astype(np.int32, copy=False))

In [None]:
if np.issubdtype(nights.dtype, np.integer):
    col_night = fits.Column(name="NIGHT", format="K", array=nights.astype(np.int64, copy=False))
else:
    nights_str = np.asarray(nights).astype(str)
    width = max(5, min(32, int(np.max([len(s) for s in nights_str]))))
    col_night = fits.Column(name="NIGHT", format=f"{width}A",
                            array=np.char.encode(nights_str, encoding="ascii", errors="replace"))

In [None]:
col_u1 = fits.Column(name="UMAP1", format="E", array=umap1)
col_u2 = fits.Column(name="UMAP2", format="E", array=umap2)

meta_hdu = fits.BinTableHDU.from_columns([col_tid, col_tile, col_night, col_petal, col_fiber, col_u1, col_u2])
meta_hdu.header["EXTNAME"] = "META"

In [None]:
flux_hdu = fits.ImageHDU(data=flux, name="FLUX")
flux_hdu.header["NOBJ"] = N
flux_hdu.header["NLAMBDA"] = flux.shape[1]

for b, off in band_off.items():
    flux_hdu.header[f"OFF_{b.upper()}"] = int(off)

In [None]:
wave_hdu = fits.ImageHDU(data=wavegrid, name="WAVEGRID")

In [None]:
band_hdus = []
bands_order = "".join([b.upper() for b in ["B","R","Z"] if b.upper() in union_w])
for b, arr in union_w.items():
    band_hdus.append(fits.ImageHDU(data=np.asarray(arr), name=f"WAVE_{b.upper()}"))

In [None]:
primary_hdu = fits.PrimaryHDU()
hdr = primary_hdu.header

In [None]:
hdul = fits.HDUList([primary_hdu, flux_hdu, wave_hdu, meta_hdu] + band_hdus)

In [None]:
out_fits = "/pscratch/sd/v/vtorresg/desi-lenses/matrix_outliers.fits.gz"
hdul.writeto(out_fits, overwrite=True)