# MALLORN TDE Classification v1.3 - Feature Engineering

Giống như v1.2 - features của v1 + insights từ paper. Model ensemble sẽ có trong notebook 02.

In [1]:
import pandas as pd
import numpy as np
from pathlib import Path
from tqdm import tqdm
from scipy import stats
import warnings
warnings.filterwarnings('ignore')

DATA_DIR = Path('../mallorn-astronomical-classification-challenge')
np.random.seed(42)

BANDS = ['u', 'g', 'r', 'i', 'z', 'y']
N_SPLITS = 20
SNR_THRESHOLD = 5

In [2]:
train_log = pd.read_csv(DATA_DIR / 'train_log.csv')
test_log = pd.read_csv(DATA_DIR / 'test_log.csv')
print(f"Train: {len(train_log)}, Test: {len(test_log)}")

Train: 3043, Test: 7135


In [3]:
def extract_band_statistics(flux, flux_err, time):
    n = len(flux)
    if n == 0:
        return {}
    
    snr = flux / (flux_err + 1e-10)
    detections = snr > 3
    n_det = np.sum(detections)
    
    feats = {
        'n_obs': n, 'n_det': n_det, 'det_frac': n_det / n if n > 0 else 0,
        'flux_mean': np.mean(flux), 'flux_std': np.std(flux), 'flux_median': np.median(flux),
        'flux_max': np.max(flux), 'flux_min': np.min(flux), 'flux_range': np.max(flux) - np.min(flux),
        'flux_iqr': np.percentile(flux, 75) - np.percentile(flux, 25),
        'flux_skew': stats.skew(flux) if n > 2 else 0,
        'flux_kurtosis': stats.kurtosis(flux) if n > 3 else 0,
        'flux_p10': np.percentile(flux, 10), 'flux_p25': np.percentile(flux, 25),
        'flux_p75': np.percentile(flux, 75), 'flux_p90': np.percentile(flux, 90),
        'snr_mean': np.mean(snr), 'snr_max': np.max(snr), 'snr_median': np.median(snr), 'snr_std': np.std(snr),
        'err_mean': np.mean(flux_err), 'err_std': np.std(flux_err),
    }
    
    if n > 1:
        feats['time_span'] = time[-1] - time[0]
        feats['cadence_mean'] = np.mean(np.diff(time))
        feats['cadence_std'] = np.std(np.diff(time))
    else:
        feats['time_span'] = feats['cadence_mean'] = feats['cadence_std'] = 0
    
    if n_det > 0:
        det_flux = flux[detections]
        det_time = time[detections]
        feats['det_flux_mean'] = np.mean(det_flux)
        feats['det_flux_max'] = np.max(det_flux)
        feats['det_duration'] = det_time[-1] - det_time[0] if len(det_time) > 1 else 0
        peak_idx = np.argmax(det_flux)
        feats['peak_flux'] = det_flux[peak_idx]
        feats['peak_time_rel'] = (det_time[peak_idx] - det_time[0]) / (feats['det_duration'] + 1) if feats['det_duration'] > 0 else 0.5
        feats['_peak_time'] = det_time[peak_idx]
        if peak_idx > 0:
            rise_dt = det_time[peak_idx] - det_time[0]
            feats['rise_time'] = rise_dt
            feats['rise_rate'] = (det_flux[peak_idx] - det_flux[0]) / (rise_dt + 1e-10)
        else:
            feats['rise_time'] = feats['rise_rate'] = 0
        if peak_idx < len(det_flux) - 1:
            decay_dt = det_time[-1] - det_time[peak_idx]
            feats['decay_time'] = decay_dt
            feats['decay_rate'] = (det_flux[peak_idx] - det_flux[-1]) / (decay_dt + 1e-10)
        else:
            feats['decay_time'] = feats['decay_rate'] = 0
        feats['variability'] = np.std(det_flux) / (np.mean(det_flux) + 1e-10) if len(det_flux) > 1 else 0
        feats['rms'] = np.sqrt(np.mean(det_flux**2))
    else:
        for k in ['det_flux_mean','det_flux_max','det_duration','peak_flux','peak_time_rel','_peak_time',
                  'rise_time','rise_rate','decay_time','decay_rate','variability','rms']:
            feats[k] = 0
    
    feats['frac_above_mean'] = np.sum(flux > np.mean(flux)) / n
    if n >= 3:
        try:
            slope, _, r_value, _, _ = stats.linregress(time, flux)
            feats['trend_slope'] = slope
            feats['trend_r2'] = r_value**2
        except:
            feats['trend_slope'] = feats['trend_r2'] = 0
    else:
        feats['trend_slope'] = feats['trend_r2'] = 0
    
    return feats

In [4]:
def extract_color_features(band_data):
    colors = {}
    band_fluxes = {}
    band_peak_fluxes = {}
    
    for band in BANDS:
        if band in band_data and len(band_data[band]['flux']) > 0:
            flux = band_data[band]['flux']
            snr = flux / (band_data[band]['flux_err'] + 1e-10)
            det_mask = snr > 3
            if np.sum(det_mask) > 0:
                band_fluxes[band] = np.mean(flux[det_mask])
                band_peak_fluxes[band] = np.max(flux[det_mask])
            else:
                band_fluxes[band] = np.mean(flux)
                band_peak_fluxes[band] = np.max(flux)
        else:
            band_fluxes[band] = band_peak_fluxes[band] = 0
    
    for b1, b2 in [('u','g'),('u','r'),('u','i'),('g','r'),('g','i'),('r','i'),('i','z'),('z','y')]:
        if band_fluxes[b1] > 0 and band_fluxes[b2] > 0:
            colors[f'color_{b1}_{b2}'] = -2.5 * np.log10(band_fluxes[b1] / band_fluxes[b2])
            colors[f'color_peak_{b1}_{b2}'] = -2.5 * np.log10((band_peak_fluxes[b1]+1e-10)/(band_peak_fluxes[b2]+1e-10))
        else:
            colors[f'color_{b1}_{b2}'] = colors[f'color_peak_{b1}_{b2}'] = 0
    
    blue_flux = band_fluxes['u'] + band_fluxes['g']
    red_flux = sum([band_fluxes[b] for b in ['r','i','z','y']])
    total = blue_flux + red_flux
    
    colors['blue_fraction'] = blue_flux / (total + 1e-10)
    colors['u_fraction'] = band_fluxes['u'] / (total + 1e-10)
    colors['g_fraction'] = band_fluxes['g'] / (total + 1e-10)
    colors['blue_red_ratio'] = blue_flux / (red_flux + 1e-10)
    colors['u_dominance'] = band_fluxes['u'] / (max(band_fluxes.values()) + 1e-10)
    colors['peak_band_is_u'] = 1 if band_peak_fluxes['u'] == max(band_peak_fluxes.values()) else 0
    colors['peak_band_is_g'] = 1 if band_peak_fluxes['g'] == max(band_peak_fluxes.values()) else 0
    colors['peak_band_is_blue'] = 1 if max(band_peak_fluxes['u'],band_peak_fluxes['g']) >= max(band_peak_fluxes['r'],band_peak_fluxes['i'],band_peak_fluxes['z'],band_peak_fluxes['y']) else 0
    
    return colors

In [5]:
def extract_temporal_features(band_data):
    feats = {}
    all_times, all_det_times = [], []
    peak_times = {}
    global_peak_time, global_peak_flux = None, 0
    
    for band in BANDS:
        if band in band_data and len(band_data[band]['flux']) > 0:
            flux, time, flux_err = band_data[band]['flux'], band_data[band]['time'], band_data[band]['flux_err']
            all_times.extend(time)
            snr = flux / (flux_err + 1e-10)
            det_mask = snr > 3
            if np.sum(det_mask) > 0:
                det_time, det_flux = time[det_mask], flux[det_mask]
                all_det_times.extend(det_time)
                peak_idx = np.argmax(det_flux)
                peak_times[band] = det_time[peak_idx]
                if det_flux[peak_idx] > global_peak_flux:
                    global_peak_flux = det_flux[peak_idx]
                    global_peak_time = det_time[peak_idx]
    
    feats['total_time_span'] = max(all_times) - min(all_times) if all_times else 0
    feats['total_observations'] = len(all_times)
    feats['detection_time_span'] = max(all_det_times) - min(all_det_times) if all_det_times else 0
    feats['total_detections'] = len(all_det_times)
    
    if len(peak_times) >= 2:
        feats['peak_time_spread'] = max(peak_times.values()) - min(peak_times.values())
        feats['peak_delay_u_r'] = peak_times.get('u',0) - peak_times.get('r',0) if 'u' in peak_times and 'r' in peak_times else 0
        feats['peak_delay_g_r'] = peak_times.get('g',0) - peak_times.get('r',0) if 'g' in peak_times and 'r' in peak_times else 0
    else:
        feats['peak_time_spread'] = feats['peak_delay_u_r'] = feats['peak_delay_g_r'] = 0
    
    feats['n_bands_detected'] = sum([1 for b in BANDS if b in band_data and len(band_data[b]['flux'])>0 and np.sum(band_data[b]['flux']/(band_data[b]['flux_err']+1e-10)>3)>0])
    feats['_global_peak_time'] = global_peak_time
    return feats

In [6]:
def extract_variability_features(band_data):
    feats = {}
    all_var = []
    for band in BANDS:
        if band in band_data and len(band_data[band]['flux']) > 2:
            flux, time = band_data[band]['flux'], band_data[band]['time']
            diffs = np.diff(flux)
            feats[f'{band}_flux_diff_std'] = np.std(diffs)
            feats[f'{band}_flux_diff_mean'] = np.mean(np.abs(diffs))
            if len(flux) > 3:
                try:
                    coeffs = np.polyfit(range(len(flux)), flux[np.argsort(time)], 2)
                    residuals = flux[np.argsort(time)] - np.polyval(coeffs, range(len(flux)))
                    feats[f'{band}_residual_std'] = np.std(residuals)
                except:
                    feats[f'{band}_residual_std'] = 0
            else:
                feats[f'{band}_residual_std'] = 0
            cv = np.std(flux) / (np.mean(np.abs(flux)) + 1e-10)
            all_var.append(cv)
            feats[f'{band}_cv'] = cv
        else:
            feats[f'{band}_flux_diff_std'] = feats[f'{band}_flux_diff_mean'] = feats[f'{band}_residual_std'] = feats[f'{band}_cv'] = 0
    feats['mean_variability'] = np.mean(all_var) if all_var else 0
    feats['max_variability'] = np.max(all_var) if all_var else 0
    return feats

In [7]:
def extract_paper_features(band_data, global_peak_time, redshift=0):
    feats = {}
    
    if global_peak_time is not None and global_peak_time > 0:
        n_det_pre, n_det_near, n_det_post = 0, 0, 0
        bands_near_peak = set()
        has_u_near = False
        
        for band in BANDS:
            if band not in band_data: continue
            flux, flux_err, time = band_data[band]['flux'], band_data[band]['flux_err'], band_data[band]['time']
            det_mask = (flux / (flux_err + 1e-10)) > SNR_THRESHOLD
            if np.sum(det_mask) == 0: continue
            for t in time[det_mask]:
                rel_t = t - global_peak_time
                if rel_t < -10: n_det_pre += 1
                elif -10 <= rel_t <= 10:
                    n_det_near += 1
                    bands_near_peak.add(band)
                    if band == 'u': has_u_near = True
                elif 10 < rel_t <= 30: n_det_post += 1
        
        feats['n_det_pre_peak'] = n_det_pre
        feats['n_det_near_peak'] = n_det_near
        feats['n_det_post_peak'] = n_det_post
        feats['n_bands_near_peak'] = len(bands_near_peak)
        feats['has_u_near_peak'] = 1 if has_u_near else 0
        feats['some_color_score'] = ((1 if n_det_pre>=1 else 0) + (1 if len(bands_near_peak)>=3 else 0) + (1 if n_det_post>=2 else 0)) / 3.0
    else:
        for k in ['n_det_pre_peak','n_det_near_peak','n_det_post_peak','n_bands_near_peak','has_u_near_peak','some_color_score']:
            feats[k] = 0
    
    z_factor = 1 + redshift if redshift > 0 else 1
    u_flux = np.mean(band_data['u']['flux']) if 'u' in band_data and len(band_data['u']['flux'])>0 else 0
    r_flux = np.mean(band_data['r']['flux']) if 'r' in band_data and len(band_data['r']['flux'])>0 else 0
    feats['color_u_r_z_norm'] = (-2.5 * np.log10(u_flux / r_flux)) / z_factor if u_flux > 0 and r_flux > 0 else 0
    
    det_span = 0
    for band in BANDS:
        if band in band_data:
            flux, flux_err, time = band_data[band]['flux'], band_data[band]['flux_err'], band_data[band]['time']
            det_mask = (flux / (flux_err + 1e-10)) > 3
            if np.sum(det_mask) > 1:
                det_span = max(det_span, time[det_mask][-1] - time[det_mask][0])
    feats['duration_class'] = 2 if det_span > 300 else (1 if det_span > 150 else 0)
    
    for band in ['g', 'r']:
        if band in band_data and len(band_data[band]['flux']) > 4:
            flux = band_data[band]['flux'][np.argsort(band_data[band]['time'])]
            try:
                feats[f'{band}_autocorr'] = np.corrcoef(flux[:-1], flux[1:])[0,1]
                if np.isnan(feats[f'{band}_autocorr']): feats[f'{band}_autocorr'] = 0
            except:
                feats[f'{band}_autocorr'] = 0
            diffs = np.diff(flux)
            feats[f'{band}_sign_change_rate'] = np.sum(np.diff(np.sign(diffs))!=0)/(len(diffs)-1+1e-10) if len(flux)>2 else 0
        else:
            feats[f'{band}_autocorr'] = feats[f'{band}_sign_change_rate'] = 0
    
    return feats

In [8]:
def extract_features_for_object(obj_id, lc_df, redshift=0):
    features = {'object_id': obj_id}
    obj_data = lc_df[lc_df['object_id'] == obj_id]
    if len(obj_data) == 0: return None
    
    band_data = {}
    for band in BANDS:
        band_df = obj_data[obj_data['Filter'] == band].sort_values('Time (MJD)')
        if len(band_df) > 0:
            band_data[band] = {'flux': band_df['Flux'].values, 'flux_err': band_df['Flux_err'].values, 'time': band_df['Time (MJD)'].values}
    
    for band in BANDS:
        if band in band_data:
            band_feats = extract_band_statistics(band_data[band]['flux'], band_data[band]['flux_err'], band_data[band]['time'])
            for k, v in band_feats.items():
                if not k.startswith('_'): features[f'{band}_{k}'] = v
        else:
            for k in ['n_obs','n_det','det_frac','flux_mean','flux_std','flux_median','flux_max','flux_min','flux_range',
                     'flux_iqr','flux_skew','flux_kurtosis','flux_p10','flux_p25','flux_p75','flux_p90','snr_mean','snr_max',
                     'snr_median','snr_std','err_mean','err_std','time_span','cadence_mean','cadence_std','det_flux_mean',
                     'det_flux_max','det_duration','peak_flux','peak_time_rel','rise_time','rise_rate','decay_time','decay_rate',
                     'variability','rms','frac_above_mean','trend_slope','trend_r2']:
                features[f'{band}_{k}'] = 0
    
    features.update(extract_color_features(band_data))
    temporal = extract_temporal_features(band_data)
    gpt = temporal.pop('_global_peak_time', None)
    features.update(temporal)
    features.update(extract_variability_features(band_data))
    features.update(extract_paper_features(band_data, gpt, redshift))
    
    return features

In [9]:
print("Processing TRAINING data...")
train_z = dict(zip(train_log['object_id'], train_log['Z']))
train_ids = set(train_log['object_id'])
train_features_list = []

for split_num in range(1, N_SPLITS + 1):
    lc_file = DATA_DIR / f'split_{split_num:02d}' / 'train_full_lightcurves.csv'
    if not lc_file.exists(): continue
    lc_df = pd.read_csv(lc_file)
    obj_ids = [o for o in lc_df['object_id'].unique() if o in train_ids]
    for obj_id in tqdm(obj_ids, desc=f"split_{split_num:02d}", leave=False):
        feats = extract_features_for_object(obj_id, lc_df, train_z.get(obj_id, 0))
        if feats: train_features_list.append(feats)
    print(f"  split_{split_num:02d}: total {len(train_features_list)}")

train_features = pd.DataFrame(train_features_list)
print(f"Train features: {len(train_features)}")

Processing TRAINING data...


                                                           

  split_01: total 155


                                                           

  split_02: total 325


                                                           

  split_03: total 463


                                                           

  split_04: total 608


                                                           

  split_05: total 773


                                                           

  split_06: total 928


                                                           

  split_07: total 1093


                                                           

  split_08: total 1255


                                                           

  split_09: total 1383


                                                           

  split_10: total 1527


                                                           

  split_11: total 1673


                                                           

  split_12: total 1828


                                                           

  split_13: total 1971


                                                           

  split_14: total 2125


                                                           

  split_15: total 2283


                                                           

  split_16: total 2438


                                                           

  split_17: total 2591


                                                           

  split_18: total 2743


                                                           

  split_19: total 2890


                                                           

  split_20: total 3043
Train features: 3043


In [10]:
print("\nProcessing TEST data...")
test_z = dict(zip(test_log['object_id'], test_log['Z']))
test_ids = set(test_log['object_id'])
test_features_list = []

for split_num in range(1, N_SPLITS + 1):
    lc_file = DATA_DIR / f'split_{split_num:02d}' / 'test_full_lightcurves.csv'
    if not lc_file.exists(): continue
    lc_df = pd.read_csv(lc_file)
    obj_ids = [o for o in lc_df['object_id'].unique() if o in test_ids]
    for obj_id in tqdm(obj_ids, desc=f"split_{split_num:02d}", leave=False):
        feats = extract_features_for_object(obj_id, lc_df, test_z.get(obj_id, 0))
        if feats: test_features_list.append(feats)
    print(f"  split_{split_num:02d}: total {len(test_features_list)}")

test_features = pd.DataFrame(test_features_list)
print(f"Test features: {len(test_features)}")


Processing TEST data...


                                                           

  split_01: total 364


                                                           

  split_02: total 778


                                                           

  split_03: total 1116


                                                           

  split_04: total 1448


                                                           

  split_05: total 1823


                                                           

  split_06: total 2197


                                                           

  split_07: total 2595


                                                           

  split_08: total 2982


                                                           

  split_09: total 3271


                                                           

  split_10: total 3602


                                                           

  split_11: total 3927


                                                           

  split_12: total 4280


                                                           

  split_13: total 4659


                                                           

  split_14: total 5010


                                                           

  split_15: total 5352


                                                           

  split_16: total 5706


                                                           

  split_17: total 6057


                                                           

  split_18: total 6402


                                                           

  split_19: total 6777


                                                           

  split_20: total 7135
Test features: 7135


In [None]:
print(f"\nĐã lưu: train {train_features.shape}, test {test_features.shape}")
print("Tiếp theo: 02_model_training.ipynb (ENSEMBLE)")


Saved: train (3043, 308), test (7135, 308)
Next: 02_model_training.ipynb (ENSEMBLE)
