In [2]:
import GaSNet_MC_model as GaSNet_MC
import pandas as pd
import numpy as np
from astropy.table import Table, vstack

## load data

In [10]:
# read catalog
ctalog = pd.read_csv('SDSS_DR16_class_num20k.csv')

redshift_subsets = []
for i, row in ctalog.iterrows():
    if row['class'] == 'STAR':
        continue
    subset = {'class':row['class'], 'subclass': str(row['subclass'])}
    redshift_subsets.append(subset)
    
print('redshift subsets:',redshift_subsets)
# load data
tra = Table.read('train.fits')
tra['CLASS'] = tra['CLASS'].astype(str)
tra['CLASS'] = tra['CLASS'].data.flatten()
tra['SUBCLASS'] = tra['SUBCLASS'].astype(str)
tra['SUBCLASS'] = tra['SUBCLASS'].data.flatten()

val = Table.read('valid.fits')
val['CLASS'] = val['CLASS'].astype(str)
val['CLASS'] = val['CLASS'].data.flatten()
val['SUBCLASS'] = val['SUBCLASS'].astype(str)
val['SUBCLASS'] = val['SUBCLASS'].data.flatten()

tes = Table.read('test.fits')
tes['CLASS'] = tes['CLASS'].astype(str)
tes['CLASS'] = tes['CLASS'].data.flatten()
tes['SUBCLASS'] = tes['SUBCLASS'].astype(str)
tes['SUBCLASS'] = tes['SUBCLASS'].data.flatten()


# label the redshift only select the subclass as training data
def redshift_lable(clas, subclass):
    train = tra[ (tra['CLASS']==clas) & (tra['SUBCLASS']==subclass) ]
    valid = val[ (val['CLASS']==clas) & (val['SUBCLASS']==subclass) ]
    test =  tes[ (tes['CLASS']==clas) & (tes['SUBCLASS']==subclass) ]

    train = {'flux':train['int_flux'], 'label':train['Z']}
    valid = {'flux':valid['int_flux'], 'label':valid['Z']}
    test =  {'flux':test['int_flux'],  'label':test['Z'], 
             'SNR':test['SN_MEDIAN_ALL'].data.flatten(), 
             'CLASS':test['CLASS'], 'SUBCLASS':test['SUBCLASS'],
             'PLATE':test['PLATE'], 'MJD':test['MJD'], 'FIBERID':test['FIBERID']}
    
    train_data = {'train':train,'valid':valid}
    print('train:',len(train['flux']),'vaild:',len(valid['flux']),'test:',len(test['flux']))
    return train_data, test

# lable the subclass
def connect(data):
    connet = np.array([['_'],]*len(data))
    labels = np.char.add(data['CLASS'].data.reshape(-1,1),connet)
    labels = np.char.add(labels,data['SUBCLASS'].data.reshape(-1,1))
    return labels

def class_label():
    train = {'flux':tra['int_flux'],'label':connect(tra)}
    valid = {'flux':val['int_flux'],'label':connect(val)}
    test =  {'flux':tes['int_flux'],'label':connect(tes), 'SNR':tes['SN_MEDIAN_ALL'].data.flatten(), 
             'SNR':tes['SN_MEDIAN_ALL'].data.flatten(), 
             'CLASS':tes['CLASS'], 'SUBCLASS':tes['SUBCLASS'],
             'PLATE':tes['PLATE'], 'MJD':tes['MJD'], 'FIBERID':tes['FIBERID']}

    train_data = {'train':train,'valid':valid}
    print('train:',len(train['flux']),'vaild:',len(valid['flux']),'test:',len(test['flux']))
    return train_data, test

{'QSO_nan': ['QSO', 'nan'], 'GALAXY_STARFORMING': ['GALAXY', 'STARFORMING'], 'GALAXY_nan': ['GALAXY', 'nan'], 'GALAXY_STARBURST': ['GALAXY', 'STARBURST'], 'QSO_BROADLINE': ['QSO', 'BROADLINE'], 'GALAXY_AGN': ['GALAXY', 'AGN']}


## Training for redshift

In [None]:
import time

all_test_results = []

for subset in redshift_subsets:
    print('Now training:', subset )
    Network_name = '{}_{}-MC'.format(subset['class'], subset['subclass'])
    gasnet3 = GaSNet_MC.GaSNet3(Network_name, Network_name, task='regression', scale_factor=10)
    gasnet3.Plot_Model()
    # data
    train_data, test = redshift_lable(subset['class'], subset['subclass'])
    # training
    #gasnet3.Train_Model(train_data, epo=50)
    # test
    start_time = time.time()
    pred_hat, pred_std = gasnet3.Prodiction(test['flux'])
    end_time = time.time()
    total_time = end_time - start_time
    print(Network_name, 'Num:', len(test['flux']),'Time:', total_time,'each specturm time:',total_time/len(test['flux']))
    # saving the results
    test['pred_bar'] = pred_hat
    test['pred_std'] = pred_std
    test['class'] = [subset['class']] * len(pred_hat)
    test['subclass'] = [subset['subclass']] * len(pred_hat)
    del test['flux']
    test = Table(test)
    test.write('results/'+gasnet3.Network_name+'.fits',format='fits',overwrite=True)
    all_test_results.append(test)
    

all_test_results = vstack(all_test_results)
all_test_results.write('./results/test_results.fits',format='fits',overwrite=True)


## Training for classification

In [None]:
import time

classfy_label = {}
for i, row in ctalog.iterrows():
    key = '{}_{}'.format(row['class'],str(row['subclass']))
    classfy_label[key] = i
print('classfy label:',classfy_label)

Network_name = 'classify_model'
gasnet3 = GaSNet_MC.GaSNet3(Network_name, classfy_label, task='classification')
gasnet3.Plot_Model()
# data
train_data, test = class_label()
# training
#gasnet3.Train_Model(train_data, epo=50)
# test
start_time = time.time()
pred_label = gasnet3.Prodiction(test['flux'])
end_time = time.time()
total_time = end_time - start_time
print(Network_name, 'Num:', len(test['flux']),'Time:', total_time,'each specturm time:',total_time/len(test['flux']))
# saving the results
test['pred_label'] = pred_label
del test['flux']
test = Table(test)
test.write('results/'+gasnet3.Network_name+'.fits',format='fits',overwrite=True)