In [20]:
import h5py
import numpy as np
import pandas as pd
import os
import pandas as pd
import seaborn as sn
import torch
import torchvision.models as models
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from IPython.core.display import display
from collections.abc import Mapping
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split, TensorDataset
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import CSVLogger
import torchmetrics
from torchvision.transforms import ToTensor
import astropy.units as u
import astropy.coordinates as coord

import matplotlib.pyplot as plt
plt.style.use('seaborn-colorblind')

#User Input
sim = 'DR3_lsr012' #input("DR2 or DR3: ")
dim = '6D_cyl' #input("Input how many dimensions are needed: ")
galaxy = 'Gaia' #input("Use m12i or m12f data: ")
transfer = True #bool(input("Transfer learning (True or False): "))
if transfer == True:
    transfer_galaxy = 'm12i' #i nput("Which galaxy parameters for transfer learning: ")

# Training data
if dim == '4D':
    x_keys = ['ra', 'dec', 'pmra', 'pmdec']
elif dim == '5D':
    x_keys = ['ra', 'dec', 'pmra', 'pmdec', 'parallax']
elif dim == '6D':
    x_keys = ['ra', 'dec', 'pmra', 'pmdec', 'parallax', 'radial_velocity']
elif dim == '7D':
    x_keys = ['ra', 'dec', 'pmra', 'pmdec', 'parallax', 'radial_velocity', 'feh']
elif dim == '9D':
    x_keys = ['ra', 'dec', 'pmra', 'pmdec', 'parallax', 'radial_velocity', 'Jr', 'Jphi', 'Jz']
elif dim == '10D':
    x_keys = ['ra', 'dec', 'pmra', 'pmdec', 'parallax', 'radial_velocity', 'Jr', 'Jphi', 'Jz', 'feh']
elif dim == '6D_cyl':
    x_keys = ['ra', 'dec', 'pmra', 'pmdec', 'parallax', 'radial_velocity']
elif dim == '6D_gal':
    x_keys = ['ra', 'dec', 'pmra', 'pmdec', 'parallax', 'radial_velocity']
    
y_key = 'is_accreted'

# Directories
path = '/ocean/projects/phy210068p/hsu1/Ananke_datasets_training/AnankeDR3_data_reduced_m12f_lsr012.hdf5'
            
data = []
f = h5py.File(path, 'r')

for i in x_keys:
    data.append(f[i][:])


# Getting rid of nan values
x = []
if 'Jr' in x_keys:
    Jr = f['Jr'][:]
    mask = (~np.isnan(Jr))
    for i in range(len(x_keys)):
        new = data[i][:][mask]
        x.append(new)
elif 'radial_velocity' in x_keys:
    rv = f['radial_velocity'][:]
    mask = (~np.isnan(rv))
    for i in range(len(x_keys)):
        new = data[i][:][mask]
        x.append(new)
else:
    x = data

ra = x[0] * u.deg
dec = x[1] * u.deg
pmra = x[2] * u.mas / u.yr
pmdec = x[3] * u.mas / u.yr
parallax = x[4] * u.mas
rv = x[5] * u.km / u.s

dist = coord.Distance(parallax=parallax, allow_negative=True)

# Coord transformation
icrs = coord.ICRS(
    ra=ra, dec=dec, distance=dist, pm_ra_cosdec=pmra, pm_dec=pmdec, radial_velocity=rv)
icrs.representation_type = 'cylindrical'

rho_cyl = icrs.rho.to_value(u.pc)
phi_cyl = icrs.phi.to_value(u.deg)
z_cyl = icrs.z.to_value(u.pc)
vrho_cyl = icrs.d_rho.to_value(u.mas * u.pc / (u.rad * u.yr))
vphi_cyl = icrs.d_phi.to_value(u.mas / u.yr)
vz_cyl = icrs.d_z.to_value(u.mas * u.pc / (u.rad * u.yr))

  from IPython.core.display import display


In [21]:
f.close()

In [22]:
with h5py.File(path, 'r') as f:
    print(list(f.keys()))
    length = len(f['ra'][:])

new_rho = np.empty(length)
new_rho[:] = np.nan
new_phi = np.empty(length)
new_phi[:] = np.nan
new_z = np.empty(length)
new_z[:] = np.nan
new_vrho = np.empty(length)
new_vrho[:] = np.nan
new_vphi = np.empty(length)
new_vphi[:] = np.nan
new_vz = np.empty(length)
new_vz[:] = np.nan

new_rho[mask] = rho_cyl
new_phi[mask] = phi_cyl
new_z[mask] = z_cyl
new_vrho[mask] = vrho_cyl
new_vphi[mask] = vphi_cyl
new_vz[mask] = vz_cyl

# Save to hdf5 file
with h5py.File(path, 'a') as f:
    f.create_dataset('rho', data=new_rho)
    f.create_dataset('phi', data=new_phi)
    f.create_dataset('z', data=new_z)
    f.create_dataset('vrho', data=new_vrho)
    f.create_dataset('vphi', data=new_vphi)
    f.create_dataset('vz', data=new_vz)



['b', 'dec', 'feh', 'l', 'parallax', 'parentid', 'pmdec', 'pmra', 'ra', 'radial_velocity', 'radial_velocity_error', 'vx', 'vy', 'x', 'y']


In [23]:
with h5py.File(path, 'r') as f:
    print(list(f.keys()))

['b', 'dec', 'feh', 'l', 'parallax', 'parentid', 'phi', 'pmdec', 'pmra', 'ra', 'radial_velocity', 'radial_velocity_error', 'rho', 'vphi', 'vrho', 'vx', 'vy', 'vz', 'x', 'y', 'z']
