In [14]:
from datetime import datetime
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

In [15]:
# RGB 强度轉換公式
# v1 - ['R', 'G', 'B', 'observe_type_eyes']
import numpy as np

# 把星等（Magnitude）轉換成相對光通量（Flux），轉換後數值越大代表越亮
def mag_to_flux(mag):
    return 10 ** (-0.4 * mag)
    
# 可調整threshold改變敏感值
# 實驗結果 threshold=0.4 準確率是最高的
def normalize_binary(r, g, b, threshold=0.4):
    """
    將 R/G/B 二值化為 0 或 1
    threshold 表示佔最大值多少才算「亮」
    """
    max_val = max(r, g, b)
    if max_val == 0:
        return (0, 0, 0)
    return (
        int(r >= threshold * max_val),
        int(g >= threshold * max_val),
        int(b >= threshold * max_val)
    )

def ugriz_to_rgb_binary(u=None, g=None, r=None, i=None, z=None, threshold=0.5):
    flux_u = mag_to_flux(u) if u is not None else 0
    flux_g = mag_to_flux(g) if g is not None else 0
    flux_r = mag_to_flux(r) if r is not None else 0
    flux_i = mag_to_flux(i) if i is not None else 0
    flux_z = mag_to_flux(z) if z is not None else 0

    red   = flux_i + 0.5 * flux_z
    green = flux_r
    blue  = flux_g + 0.3 * flux_u

    r_bin, g_bin, b_bin = normalize_binary(red, green, blue, threshold)

    return (r_bin, g_bin, b_bin)

In [16]:
# ‘肉眼可見’ 欄位條件: class == 'STAR'
# v1 - ['R', 'G', 'B', 'observe_type_eyes']
def only_stars(class_str):
    return 1 if class_str=='STAR' else 0

In [17]:
# 文件名稱都沒有（.csv/.excel），默認csv，不一樣的話需要手動修改
# 沒有路徑文件夾需要手動添加
version = 'v1'
today = datetime.today().strftime('%Y%m%d')

file_path = "data"
file_name = "star_classification"

output_path = 'output/csv'
output_file_name = f'{output_path}/{file_name}_{version}_{today}'

In [18]:
data = pd.read_csv(f"{file_path}/{file_name}.csv")
data.head()

Unnamed: 0,obj_ID,alpha,delta,u,g,r,i,z,run_ID,rerun_ID,cam_col,field_ID,spec_obj_ID,class,redshift,plate,MJD,fiber_ID
0,1.237661e+18,135.689107,32.494632,23.87882,22.2753,20.39501,19.16573,18.79371,3606,301,2,79,6.543777e+18,GALAXY,0.634794,5812,56354,171
1,1.237665e+18,144.826101,31.274185,24.77759,22.83188,22.58444,21.16812,21.61427,4518,301,5,119,1.176014e+19,GALAXY,0.779136,10445,58158,427
2,1.237661e+18,142.18879,35.582444,25.26307,22.66389,20.60976,19.34857,18.94827,3606,301,2,120,5.1522e+18,GALAXY,0.644195,4576,55592,299
3,1.237663e+18,338.741038,-0.402828,22.13682,23.77656,21.61162,20.50454,19.2501,4192,301,3,214,1.030107e+19,GALAXY,0.932346,9149,58039,775
4,1.23768e+18,345.282593,21.183866,19.43718,17.58028,16.49747,15.97711,15.54461,8102,301,3,137,6.891865e+18,GALAXY,0.116123,6121,56187,842


In [19]:
# 找到異常值index，刪除異常值 -9999
dropped_item_index = data[data['u'] == -9999].index
new_data = data.drop(dropped_item_index)
new_data.shape

(99999, 18)

In [20]:
# 波段强度轉化爲 RGB强度，加入 threahold敏感度 表現 RGB强度 關係
rgb_flux_data = new_data.copy()

color_columns = ['R', 'G', 'B']

rgb_flux_data[color_columns] = rgb_flux_data[['u', 'g', 'r', 'i', 'z']].apply(
    lambda row: pd.Series(ugriz_to_rgb_binary(row['u'], row['g'], row['r'], row['i'], row['z'])), axis=1)
rgb_flux_data[color_columns].head()

Unnamed: 0,R,G,B
0,1,0,0
1,1,0,0
2,1,0,0
3,1,0,0
4,1,0,0


In [21]:
# 加入'肉眼可見'欄位
# 欄位規則可替換，例如星等 < 6，要寫新的 def，未測試
observe_type_columns = 'observe_type_eyes'
rgb_flux_data[observe_type_columns] = rgb_flux_data['class'].apply(lambda star_class: only_stars(star_class))
rgb_flux_data[observe_type_columns]

0        0
1        0
2        0
3        0
4        0
        ..
99995    0
99996    0
99997    0
99998    0
99999    0
Name: observe_type_eyes, Length: 99999, dtype: int64

In [22]:
# 給欄位做個變量
# ['R', 'G', 'B', 'observe_type_eyes']
output_column = 'class'

rgb_flux_field = color_columns.copy()
rgb_flux_field.append(observe_type_columns)
rgb_flux_field.append(output_column)

output_data = rgb_flux_data[rgb_flux_field]
output_data

Unnamed: 0,R,G,B,observe_type_eyes,class
0,1,0,0,0,GALAXY
1,1,0,0,0,GALAXY
2,1,0,0,0,GALAXY
3,1,0,0,0,GALAXY
4,1,0,0,0,GALAXY
...,...,...,...,...,...
99995,1,0,0,0,GALAXY
99996,1,0,0,0,GALAXY
99997,1,0,0,0,GALAXY
99998,1,0,0,0,GALAXY


In [23]:
# data_version 是標記欄位轉換的版本，未來有可能改變欄位轉換規則
# v1 - ['R', 'G', 'B', 'observe_type_eyes']
output_data.to_csv(f'{output_file_name}.csv', index=False, encoding='utf-8-sig')

In [24]:
# 切割數據集
train_data, val_data = train_test_split(output_data, test_size=0.2, random_state=0)

# 測試用
train_file_name = f'{output_file_name}_train.csv'
train_data.to_csv(train_file_name, index=False, encoding='utf-8-sig')
# 驗證用
val_file_name = f'{output_file_name}_validation.csv'
val_data.to_csv(val_file_name, index=False, encoding='utf-8-sig')

In [25]:
print(f'tran shape: {train_data.shape}')
print(f'validation shape: {val_data.shape}')

tran shape: (79999, 5)
validation shape: (20000, 5)
