# MALLORN TDE Classification v1.2 - Balanced Feature Engineering

**Chiến lược:** Giữ TẤT CẢ features của v1 + THÊM features theo hướng dẫn từ paper

v1.1 thất bại (0.3989) vì chúng ta giảm features quá mạnh.
v1.2 giữ mọi thứ từ v1 và thêm insights mới.

In [1]:
import pandas as pd
import numpy as np
from pathlib import Path
from tqdm import tqdm
from scipy import stats
import warnings
warnings.filterwarnings('ignore')

DATA_DIR = Path('../mallorn-astronomical-classification-challenge')
np.random.seed(42)

BANDS = ['u', 'g', 'r', 'i', 'z', 'y']
N_SPLITS = 20
SNR_THRESHOLD = 5

## Tải Metadata

In [2]:
train_log = pd.read_csv(DATA_DIR / 'train_log.csv')
test_log = pd.read_csv(DATA_DIR / 'test_log.csv')

print(f"Train objects: {len(train_log)}")
print(f"Test objects: {len(test_log)}")
print(f"TDE ratio in train: {train_log['target'].mean():.4f}")

Train objects: 3043
Test objects: 7135
TDE ratio in train: 0.0486


## Các Hàm Features V1 Gốc (GIỮ TẤT CẢ)

In [3]:
def extract_band_statistics(flux, flux_err, time):
    """Extract comprehensive statistics for a single band - FULL v1 version."""
    n = len(flux)
    if n == 0:
        return {}
    
    snr = flux / (flux_err + 1e-10)
    detections = snr > 3
    n_det = np.sum(detections)
    
    feats = {
        'n_obs': n,
        'n_det': n_det,
        'det_frac': n_det / n if n > 0 else 0,
        
        'flux_mean': np.mean(flux),
        'flux_std': np.std(flux),
        'flux_median': np.median(flux),
        'flux_max': np.max(flux),
        'flux_min': np.min(flux),
        'flux_range': np.max(flux) - np.min(flux),
        'flux_iqr': np.percentile(flux, 75) - np.percentile(flux, 25),
        'flux_skew': stats.skew(flux) if n > 2 else 0,
        'flux_kurtosis': stats.kurtosis(flux) if n > 3 else 0,
        
        'flux_p10': np.percentile(flux, 10),
        'flux_p25': np.percentile(flux, 25),
        'flux_p75': np.percentile(flux, 75),
        'flux_p90': np.percentile(flux, 90),
        
        'snr_mean': np.mean(snr),
        'snr_max': np.max(snr),
        'snr_median': np.median(snr),
        'snr_std': np.std(snr),
        
        'err_mean': np.mean(flux_err),
        'err_std': np.std(flux_err),
    }
    
    if n > 1:
        feats['time_span'] = time[-1] - time[0]
        feats['cadence_mean'] = np.mean(np.diff(time))
        feats['cadence_std'] = np.std(np.diff(time))
    else:
        feats['time_span'] = 0
        feats['cadence_mean'] = 0
        feats['cadence_std'] = 0
    
    if n_det > 0:
        det_flux = flux[detections]
        det_time = time[detections]
        
        feats['det_flux_mean'] = np.mean(det_flux)
        feats['det_flux_max'] = np.max(det_flux)
        feats['det_duration'] = det_time[-1] - det_time[0] if len(det_time) > 1 else 0
        
        peak_idx = np.argmax(det_flux)
        feats['peak_flux'] = det_flux[peak_idx]
        feats['peak_time_rel'] = (det_time[peak_idx] - det_time[0]) / (feats['det_duration'] + 1) if feats['det_duration'] > 0 else 0.5
        feats['_peak_time'] = det_time[peak_idx]  # For phase features
        
        if peak_idx > 0:
            rise_dt = det_time[peak_idx] - det_time[0]
            rise_df = det_flux[peak_idx] - det_flux[0]
            feats['rise_time'] = rise_dt
            feats['rise_rate'] = rise_df / (rise_dt + 1e-10)
        else:
            feats['rise_time'] = 0
            feats['rise_rate'] = 0
        
        if peak_idx < len(det_flux) - 1:
            decay_dt = det_time[-1] - det_time[peak_idx]
            decay_df = det_flux[peak_idx] - det_flux[-1]
            feats['decay_time'] = decay_dt
            feats['decay_rate'] = decay_df / (decay_dt + 1e-10)
        else:
            feats['decay_time'] = 0
            feats['decay_rate'] = 0
        
        if len(det_flux) > 1:
            feats['variability'] = np.std(det_flux) / (np.mean(det_flux) + 1e-10)
            feats['rms'] = np.sqrt(np.mean(det_flux**2))
        else:
            feats['variability'] = 0
            feats['rms'] = det_flux[0] if len(det_flux) > 0 else 0
    else:
        for key in ['det_flux_mean', 'det_flux_max', 'det_duration', 'peak_flux', 
                    'peak_time_rel', '_peak_time', 'rise_time', 'rise_rate', 'decay_time', 
                    'decay_rate', 'variability', 'rms']:
            feats[key] = 0
    
    above_mean = flux > np.mean(flux)
    feats['frac_above_mean'] = np.sum(above_mean) / n
    
    if n >= 3:
        try:
            slope, intercept, r_value, p_value, std_err = stats.linregress(time, flux)
            feats['trend_slope'] = slope
            feats['trend_r2'] = r_value**2
        except:
            feats['trend_slope'] = 0
            feats['trend_r2'] = 0
    else:
        feats['trend_slope'] = 0
        feats['trend_r2'] = 0
    
    return feats

In [4]:
def extract_color_features(band_data):
    """Extract color features - FULL v1 version."""
    colors = {}
    
    band_fluxes = {}
    band_peak_fluxes = {}
    
    for band in BANDS:
        if band in band_data and len(band_data[band]['flux']) > 0:
            flux = band_data[band]['flux']
            snr = flux / (band_data[band]['flux_err'] + 1e-10)
            det_mask = snr > 3
            
            if np.sum(det_mask) > 0:
                det_flux = flux[det_mask]
                band_fluxes[band] = np.mean(det_flux)
                band_peak_fluxes[band] = np.max(det_flux)
            else:
                band_fluxes[band] = np.mean(flux) if len(flux) > 0 else 0
                band_peak_fluxes[band] = np.max(flux) if len(flux) > 0 else 0
        else:
            band_fluxes[band] = 0
            band_peak_fluxes[band] = 0
    
    color_pairs = [('u', 'g'), ('u', 'r'), ('u', 'i'), ('g', 'r'), ('g', 'i'), ('r', 'i'), ('i', 'z'), ('z', 'y')]
    for b1, b2 in color_pairs:
        if band_fluxes[b1] > 0 and band_fluxes[b2] > 0:
            colors[f'color_{b1}_{b2}'] = -2.5 * np.log10(band_fluxes[b1] / band_fluxes[b2])
            colors[f'color_peak_{b1}_{b2}'] = -2.5 * np.log10(
                (band_peak_fluxes[b1] + 1e-10) / (band_peak_fluxes[b2] + 1e-10)
            )
        else:
            colors[f'color_{b1}_{b2}'] = 0
            colors[f'color_peak_{b1}_{b2}'] = 0
    
    blue_bands = ['u', 'g']
    red_bands = ['r', 'i', 'z', 'y']
    
    blue_flux = sum([band_fluxes[b] for b in blue_bands])
    red_flux = sum([band_fluxes[b] for b in red_bands])
    total_flux = blue_flux + red_flux
    
    colors['blue_fraction'] = blue_flux / (total_flux + 1e-10)
    colors['u_fraction'] = band_fluxes['u'] / (total_flux + 1e-10)
    colors['g_fraction'] = band_fluxes['g'] / (total_flux + 1e-10)
    colors['blue_red_ratio'] = blue_flux / (red_flux + 1e-10)
    
    colors['u_dominance'] = band_fluxes['u'] / (np.max(list(band_fluxes.values())) + 1e-10)
    colors['peak_band_is_u'] = 1 if band_peak_fluxes['u'] == max(band_peak_fluxes.values()) else 0
    colors['peak_band_is_g'] = 1 if band_peak_fluxes['g'] == max(band_peak_fluxes.values()) else 0
    colors['peak_band_is_blue'] = 1 if max(band_peak_fluxes['u'], band_peak_fluxes['g']) >= max(
        band_peak_fluxes['r'], band_peak_fluxes['i'], band_peak_fluxes['z'], band_peak_fluxes['y']
    ) else 0
    
    return colors

In [5]:
def extract_temporal_features(band_data):
    """Extract cross-band temporal features - FULL v1 version."""
    feats = {}
    
    all_times = []
    all_fluxes = []
    all_det_times = []
    peak_times = {}
    global_peak_time = None
    global_peak_flux = 0
    
    for band in BANDS:
        if band in band_data and len(band_data[band]['flux']) > 0:
            flux = band_data[band]['flux']
            time = band_data[band]['time']
            flux_err = band_data[band]['flux_err']
            
            all_times.extend(time)
            all_fluxes.extend(flux)
            
            snr = flux / (flux_err + 1e-10)
            det_mask = snr > 3
            if np.sum(det_mask) > 0:
                det_time = time[det_mask]
                det_flux = flux[det_mask]
                all_det_times.extend(det_time)
                peak_idx = np.argmax(det_flux)
                peak_times[band] = det_time[peak_idx]
                if det_flux[peak_idx] > global_peak_flux:
                    global_peak_flux = det_flux[peak_idx]
                    global_peak_time = det_time[peak_idx]
    
    if len(all_times) > 0:
        feats['total_time_span'] = max(all_times) - min(all_times)
        feats['total_observations'] = len(all_times)
    else:
        feats['total_time_span'] = 0
        feats['total_observations'] = 0
    
    if len(all_det_times) > 0:
        feats['detection_time_span'] = max(all_det_times) - min(all_det_times)
        feats['total_detections'] = len(all_det_times)
    else:
        feats['detection_time_span'] = 0
        feats['total_detections'] = 0
    
    if len(peak_times) >= 2:
        peak_time_values = list(peak_times.values())
        feats['peak_time_spread'] = max(peak_time_values) - min(peak_time_values)
        
        if 'u' in peak_times and 'r' in peak_times:
            feats['peak_delay_u_r'] = peak_times['u'] - peak_times['r']
        else:
            feats['peak_delay_u_r'] = 0
        
        if 'g' in peak_times and 'r' in peak_times:
            feats['peak_delay_g_r'] = peak_times['g'] - peak_times['r']
        else:
            feats['peak_delay_g_r'] = 0
    else:
        feats['peak_time_spread'] = 0
        feats['peak_delay_u_r'] = 0
        feats['peak_delay_g_r'] = 0
    
    n_bands_detected = sum([1 for b in BANDS if b in band_data and 
                           len(band_data[b]['flux']) > 0 and 
                           np.sum(band_data[b]['flux'] / (band_data[b]['flux_err'] + 1e-10) > 3) > 0])
    feats['n_bands_detected'] = n_bands_detected
    feats['_global_peak_time'] = global_peak_time
    
    return feats

In [6]:
def extract_variability_features(band_data):
    """Extract variability features - FULL v1 version."""
    feats = {}
    
    all_variabilities = []
    
    for band in BANDS:
        if band in band_data and len(band_data[band]['flux']) > 2:
            flux = band_data[band]['flux']
            time = band_data[band]['time']
            
            diffs = np.diff(flux)
            feats[f'{band}_flux_diff_std'] = np.std(diffs)
            feats[f'{band}_flux_diff_mean'] = np.mean(np.abs(diffs))
            
            if len(flux) > 3:
                sorted_idx = np.argsort(time)
                sorted_flux = flux[sorted_idx]
                
                try:
                    coeffs = np.polyfit(range(len(sorted_flux)), sorted_flux, 2)
                    poly_fit = np.polyval(coeffs, range(len(sorted_flux)))
                    residuals = sorted_flux - poly_fit
                    feats[f'{band}_residual_std'] = np.std(residuals)
                except:
                    feats[f'{band}_residual_std'] = 0
            else:
                feats[f'{band}_residual_std'] = 0
            
            cv = np.std(flux) / (np.mean(np.abs(flux)) + 1e-10)
            all_variabilities.append(cv)
            feats[f'{band}_cv'] = cv
        else:
            feats[f'{band}_flux_diff_std'] = 0
            feats[f'{band}_flux_diff_mean'] = 0
            feats[f'{band}_residual_std'] = 0
            feats[f'{band}_cv'] = 0
    
    if len(all_variabilities) > 0:
        feats['mean_variability'] = np.mean(all_variabilities)
        feats['max_variability'] = np.max(all_variabilities)
    else:
        feats['mean_variability'] = 0
        feats['max_variability'] = 0
    
    return feats

## Features Mới theo Hướng dẫn Paper (THÊM LÊN TRÊN)

In [7]:
def extract_paper_features(band_data, global_peak_time, redshift=0):
    """NEW paper-guided features - added on top of v1."""
    feats = {}
    
    # 1. Detection phase features (some_color metrics)
    if global_peak_time is not None and global_peak_time > 0:
        n_det_pre = 0
        n_det_near = 0
        n_det_post = 0
        bands_near_peak = set()
        has_u_near = False
        
        for band in BANDS:
            if band not in band_data:
                continue
            flux = band_data[band]['flux']
            flux_err = band_data[band]['flux_err']
            time = band_data[band]['time']
            snr = flux / (flux_err + 1e-10)
            det_mask = snr > SNR_THRESHOLD
            if np.sum(det_mask) == 0:
                continue
            det_time = time[det_mask]
            for t in det_time:
                rel_t = t - global_peak_time
                if rel_t < -10:
                    n_det_pre += 1
                elif -10 <= rel_t <= 10:
                    n_det_near += 1
                    bands_near_peak.add(band)
                    if band == 'u':
                        has_u_near = True
                elif 10 < rel_t <= 30:
                    n_det_post += 1
        
        feats['n_det_pre_peak'] = n_det_pre
        feats['n_det_near_peak'] = n_det_near
        feats['n_det_post_peak'] = n_det_post
        feats['n_bands_near_peak'] = len(bands_near_peak)
        feats['has_u_near_peak'] = 1 if has_u_near else 0
        
        # some_color score
        feats['some_color_score'] = (
            (1 if n_det_pre >= 1 else 0) +
            (1 if len(bands_near_peak) >= 3 else 0) +
            (1 if n_det_post >= 2 else 0)
        ) / 3.0
    else:
        feats['n_det_pre_peak'] = 0
        feats['n_det_near_peak'] = 0
        feats['n_det_post_peak'] = 0
        feats['n_bands_near_peak'] = 0
        feats['has_u_near_peak'] = 0
        feats['some_color_score'] = 0
    
    # 2. Redshift-corrected features
    z_factor = 1 + redshift if redshift > 0 else 1
    
    # Get u-r color for z-correction
    u_flux = 0
    r_flux = 0
    if 'u' in band_data and len(band_data['u']['flux']) > 0:
        u_flux = np.mean(band_data['u']['flux'])
    if 'r' in band_data and len(band_data['r']['flux']) > 0:
        r_flux = np.mean(band_data['r']['flux'])
    
    if u_flux > 0 and r_flux > 0:
        color_u_r = -2.5 * np.log10(u_flux / r_flux)
        feats['color_u_r_z_norm'] = color_u_r / z_factor
    else:
        feats['color_u_r_z_norm'] = 0
    
    # 3. Duration classification
    det_span = 0
    for band in BANDS:
        if band in band_data:
            flux = band_data[band]['flux']
            flux_err = band_data[band]['flux_err']
            time = band_data[band]['time']
            snr = flux / (flux_err + 1e-10)
            det_mask = snr > 3
            if np.sum(det_mask) > 1:
                det_time = time[det_mask]
                det_span = max(det_span, det_time[-1] - det_time[0])
    
    # TDEs last ~400 days, SNe ~100-150 days
    if det_span > 300:
        feats['duration_class'] = 2  # Long (TDE-like)
    elif det_span > 150:
        feats['duration_class'] = 1  # Medium
    else:
        feats['duration_class'] = 0  # Short (SN-like)
    
    # 4. Smoothness features (AGN vs TDE)
    for band in ['g', 'r']:
        if band in band_data and len(band_data[band]['flux']) > 4:
            flux = band_data[band]['flux']
            time = band_data[band]['time']
            sorted_idx = np.argsort(time)
            flux_sorted = flux[sorted_idx]
            
            # Autocorrelation
            try:
                autocorr = np.corrcoef(flux_sorted[:-1], flux_sorted[1:])[0, 1]
                feats[f'{band}_autocorr'] = autocorr if not np.isnan(autocorr) else 0
            except:
                feats[f'{band}_autocorr'] = 0
            
            # Sign changes
            if len(flux_sorted) > 2:
                diffs = np.diff(flux_sorted)
                sign_changes = np.sum(np.diff(np.sign(diffs)) != 0)
                feats[f'{band}_sign_change_rate'] = sign_changes / (len(diffs) - 1 + 1e-10)
            else:
                feats[f'{band}_sign_change_rate'] = 0
        else:
            feats[f'{band}_autocorr'] = 0
            feats[f'{band}_sign_change_rate'] = 0
    
    return feats

## Trích Xuất Features Chủ

In [8]:
def extract_features_for_object(obj_id, lc_df, redshift=0):
    """Extract ALL features: v1 original + paper additions."""
    features = {'object_id': obj_id}
    
    obj_data = lc_df[lc_df['object_id'] == obj_id]
    if len(obj_data) == 0:
        return None
    
    band_data = {}
    for band in BANDS:
        band_df = obj_data[obj_data['Filter'] == band].sort_values('Time (MJD)')
        if len(band_df) > 0:
            band_data[band] = {
                'flux': band_df['Flux'].values,
                'flux_err': band_df['Flux_err'].values,
                'time': band_df['Time (MJD)'].values
            }
    
    # V1 ORIGINAL FEATURES (ALL)
    for band in BANDS:
        if band in band_data:
            band_feats = extract_band_statistics(
                band_data[band]['flux'],
                band_data[band]['flux_err'],
                band_data[band]['time']
            )
            for key, value in band_feats.items():
                if not key.startswith('_'):  # Skip internal keys
                    features[f'{band}_{key}'] = value
        else:
            for key in ['n_obs', 'n_det', 'det_frac', 'flux_mean', 'flux_std', 'flux_median',
                       'flux_max', 'flux_min', 'flux_range', 'flux_iqr', 'flux_skew', 'flux_kurtosis',
                       'flux_p10', 'flux_p25', 'flux_p75', 'flux_p90', 'snr_mean', 'snr_max',
                       'snr_median', 'snr_std', 'err_mean', 'err_std', 'time_span', 'cadence_mean',
                       'cadence_std', 'det_flux_mean', 'det_flux_max', 'det_duration', 'peak_flux',
                       'peak_time_rel', 'rise_time', 'rise_rate', 'decay_time', 'decay_rate',
                       'variability', 'rms', 'frac_above_mean', 'trend_slope', 'trend_r2']:
                features[f'{band}_{key}'] = 0
    
    # V1 color features
    color_feats = extract_color_features(band_data)
    features.update(color_feats)
    
    # V1 temporal features
    temporal_feats = extract_temporal_features(band_data)
    global_peak_time = temporal_feats.pop('_global_peak_time', None)
    features.update(temporal_feats)
    
    # V1 variability features
    var_feats = extract_variability_features(band_data)
    features.update(var_feats)
    
    # NEW: Paper-guided features (ADDITIONS)
    paper_feats = extract_paper_features(band_data, global_peak_time, redshift)
    features.update(paper_feats)
    
    return features

## Xử lý Dữ liệu Training

In [9]:
print("="*60)
print("Processing TRAINING data from ALL 20 splits...")
print("="*60)

train_z_lookup = dict(zip(train_log['object_id'], train_log['Z']))
train_features_list = []
train_object_ids = set(train_log['object_id'].values)

for split_num in range(1, N_SPLITS + 1):
    split_name = f'split_{split_num:02d}'
    lc_file = DATA_DIR / split_name / 'train_full_lightcurves.csv'
    
    if not lc_file.exists():
        print(f"  {split_name}: train file not found, skipping")
        continue
    
    print(f"  Processing {split_name}...")
    lc_df = pd.read_csv(lc_file)
    
    object_ids = lc_df['object_id'].unique()
    object_ids = [oid for oid in object_ids if oid in train_object_ids]
    
    for obj_id in tqdm(object_ids, desc=f"  {split_name}", leave=False):
        z = train_z_lookup.get(obj_id, 0)
        feats = extract_features_for_object(obj_id, lc_df, redshift=z)
        if feats is not None:
            train_features_list.append(feats)
    
    print(f"    Processed {len(object_ids)} objects, total: {len(train_features_list)}")

train_features = pd.DataFrame(train_features_list)
print(f"\nTotal training features: {len(train_features)}")

Processing TRAINING data from ALL 20 splits...
  Processing split_01...


                                                             

    Processed 155 objects, total: 155
  Processing split_02...


                                                             

    Processed 170 objects, total: 325
  Processing split_03...


                                                             

    Processed 138 objects, total: 463
  Processing split_04...


                                                             

    Processed 145 objects, total: 608
  Processing split_05...


                                                             

    Processed 165 objects, total: 773
  Processing split_06...


                                                             

    Processed 155 objects, total: 928
  Processing split_07...


                                                             

    Processed 165 objects, total: 1093
  Processing split_08...


                                                             

    Processed 162 objects, total: 1255
  Processing split_09...


                                                             

    Processed 128 objects, total: 1383
  Processing split_10...


                                                             

    Processed 144 objects, total: 1527
  Processing split_11...


                                                             

    Processed 146 objects, total: 1673
  Processing split_12...


                                                             

    Processed 155 objects, total: 1828
  Processing split_13...


                                                             

    Processed 143 objects, total: 1971
  Processing split_14...


                                                             

    Processed 154 objects, total: 2125
  Processing split_15...


                                                             

    Processed 158 objects, total: 2283
  Processing split_16...


                                                             

    Processed 155 objects, total: 2438
  Processing split_17...


                                                             

    Processed 153 objects, total: 2591
  Processing split_18...


                                                             

    Processed 152 objects, total: 2743
  Processing split_19...


                                                             

    Processed 147 objects, total: 2890
  Processing split_20...


                                                             

    Processed 153 objects, total: 3043

Total training features: 3043


## Xử lý Dữ liệu Test

In [10]:
print("\n" + "="*60)
print("Processing TEST data from ALL 20 splits...")
print("="*60)

test_z_lookup = dict(zip(test_log['object_id'], test_log['Z']))
test_features_list = []
test_object_ids = set(test_log['object_id'].values)

for split_num in range(1, N_SPLITS + 1):
    split_name = f'split_{split_num:02d}'
    lc_file = DATA_DIR / split_name / 'test_full_lightcurves.csv'
    
    if not lc_file.exists():
        print(f"  {split_name}: test file not found, skipping")
        continue
    
    print(f"  Processing {split_name}...")
    lc_df = pd.read_csv(lc_file)
    
    object_ids = lc_df['object_id'].unique()
    object_ids = [oid for oid in object_ids if oid in test_object_ids]
    
    for obj_id in tqdm(object_ids, desc=f"  {split_name}", leave=False):
        z = test_z_lookup.get(obj_id, 0)
        feats = extract_features_for_object(obj_id, lc_df, redshift=z)
        if feats is not None:
            test_features_list.append(feats)
    
    print(f"    Processed {len(object_ids)} objects, total: {len(test_features_list)}")

test_features = pd.DataFrame(test_features_list)
print(f"\nTotal test features: {len(test_features)}")


Processing TEST data from ALL 20 splits...
  Processing split_01...


                                                             

    Processed 364 objects, total: 364
  Processing split_02...


                                                             

    Processed 414 objects, total: 778
  Processing split_03...


                                                             

    Processed 338 objects, total: 1116
  Processing split_04...


                                                             

    Processed 332 objects, total: 1448
  Processing split_05...


                                                             

    Processed 375 objects, total: 1823
  Processing split_06...


                                                             

    Processed 374 objects, total: 2197
  Processing split_07...


                                                             

    Processed 398 objects, total: 2595
  Processing split_08...


                                                             

    Processed 387 objects, total: 2982
  Processing split_09...


                                                             

    Processed 289 objects, total: 3271
  Processing split_10...


                                                             

    Processed 331 objects, total: 3602
  Processing split_11...


                                                             

    Processed 325 objects, total: 3927
  Processing split_12...


                                                             

    Processed 353 objects, total: 4280
  Processing split_13...


                                                             

    Processed 379 objects, total: 4659
  Processing split_14...


                                                             

    Processed 351 objects, total: 5010
  Processing split_15...


                                                             

    Processed 342 objects, total: 5352
  Processing split_16...


                                                             

    Processed 354 objects, total: 5706
  Processing split_17...


                                                             

    Processed 351 objects, total: 6057
  Processing split_18...


                                                             

    Processed 345 objects, total: 6402
  Processing split_19...


                                                             

    Processed 375 objects, total: 6777
  Processing split_20...


                                                             

    Processed 358 objects, total: 7135

Total test features: 7135


## Gộp & Làm sạch

In [11]:
train_features = train_features.merge(
    train_log[['object_id', 'Z', 'EBV', 'target']], 
    on='object_id', how='left'
)
test_features = test_features.merge(
    test_log[['object_id', 'Z', 'Z_err', 'EBV']], 
    on='object_id', how='left'
)

train_features = train_features.fillna(0).replace([np.inf, -np.inf], 0)
test_features = test_features.fillna(0).replace([np.inf, -np.inf], 0)

print(f"Train shape: {train_features.shape}")
print(f"Test shape: {test_features.shape}")

Train shape: (3043, 308)
Test shape: (7135, 308)


In [12]:
train_features.to_csv('train_features.csv', index=False)
test_features.to_csv('test_features.csv', index=False)

print("\n" + "="*60)
print("v1.2 FEATURE ENGINEERING COMPLETE")
print("="*60)
print(f"Features: {len(train_features.columns) - 2}")
print(f"  = v1 original (~280) + paper additions (~15)")
print("\nNext: Run 02_model_training.ipynb")


v1.2 FEATURE ENGINEERING COMPLETE
Features: 306
  = v1 original (~280) + paper additions (~15)

Next: Run 02_model_training.ipynb
