# **Log Data Processing**

### **Load Packages**

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.collections as mcoll  # For PolyCollection
import numpy as np
import os
import warnings
from scipy.signal import correlate, find_peaks
from scipy.ndimage import gaussian_filter1d

# Create interaction features
from sklearn.preprocessing import PolynomialFeatures
from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.feature_selection import SelectKBest, f_regression
from sklearn.ensemble import RandomForestRegressor, HistGradientBoostingRegressor
import xgboost as xgb
import lightgbm as lgb
from joblib import Parallel, delayed

## **Functions for Data Cleaning**

### Function for cleaning artifacts and noises

In [None]:
def compute_peak_valley_shift(signal_a, signal_b, common_depth):
    """
    Computes a candidate shift based on matching the 2 largest peaks.
    
    For peaks:
      - Finds all local maxima in each signal
      - Sorts them by amplitude (largest first), selects the top 2 peaks
      - Sorts these peaks by depth and computes the average depth difference
    
    Args:
        signal_a (ndarray): First signal (e.g., HRMS segment)
        signal_b (ndarray): Second signal (e.g., density, MS, or CT)
        common_depth (ndarray): Depth grid onto which the signals are interpolated
        
    Returns:
        float: The candidate shift (in depth units) based on peak matching
    """
    # Find peaks in each signal
    peaks_a, _ = find_peaks(signal_a)
    peaks_b, _ = find_peaks(signal_b)
    
    # If no peaks found, use global maximum
    if len(peaks_a) == 0:
        peaks_a = np.array([np.argmax(signal_a)])
    if len(peaks_b) == 0:
        peaks_b = np.array([np.argmax(signal_b)])
    
    # Sort peaks by amplitude (largest first) and take top 2
    sorted_peaks_a = sorted(peaks_a, key=lambda i: signal_a[i], reverse=True)[:2]
    sorted_peaks_b = sorted(peaks_b, key=lambda i: signal_b[i], reverse=True)[:2]
    
    # Sort selected peaks by depth
    top_peaks_a = sorted(sorted_peaks_a)
    top_peaks_b = sorted(sorted_peaks_b)
    
    # Calculate peak shifts
    n_peaks = min(len(top_peaks_a), len(top_peaks_b))
    if n_peaks > 0:
        peak_shifts = [common_depth[top_peaks_b[i]] - common_depth[top_peaks_a[i]] 
                      for i in range(n_peaks)]
        candidate_shift = np.mean(peak_shifts)
    else:
        candidate_shift = 0.0
        
    return candidate_shift

def compute_candidate_shift(signal_a, signal_b, common_depth, 
                            w_corr=0.2, w_peak=0.7):
    """
    Compute a candidate depth shift between two signals based on:
      - Cross-correlation, and
      - Matching top 2 peaks
    
    The final candidate shift is a weighted combination of these two methods.
    
    Args:
        signal_a (ndarray): First signal (e.g., HRMS segment).
        signal_b (ndarray): Second signal (e.g., density, MS, or CT).
        common_depth (ndarray): Depth grid onto which both signals are interpolated.
        w_corr (float): Weight for the cross-correlation candidate.
        w_peak (float): Weight for the peak/valley candidate.
        
    Returns:
        float: The weighted candidate shift (in depth units).
    """
    # --- Candidate 1: Cross-correlation shift ---
    # Apply smoothing using a Gaussian filter
    window = 3  # Window size for smoothing
    a_smoothed = gaussian_filter1d(signal_a, sigma=window)
    b_smoothed = gaussian_filter1d(signal_b, sigma=window)
    
    # Detrend the smoothed signals
    a_detrended = a_smoothed - np.mean(a_smoothed)
    b_detrended = b_smoothed - np.mean(b_smoothed)
    
    # Calculate cross-correlation
    corr = correlate(a_detrended, b_detrended, mode='full')
    lags = np.arange(-len(common_depth) + 1, len(common_depth))
    best_lag = lags[np.argmax(corr)]
    
    # Handle case where common_depth is too short
    try:
        depth_step = common_depth[1] - common_depth[0]
        cross_corr_shift = best_lag * depth_step
    except IndexError:
        cross_corr_shift = 0.0
        w_corr = 0.0  # Zero out the weight for cross-correlation

    # --- Candidate 2: Peak & Valley shift ---
    candidate_peak_valley = compute_peak_valley_shift(signal_a, signal_b, common_depth)

    # Return the weighted combination.
    return w_corr * cross_corr_shift + w_peak * candidate_peak_valley

def preprocess_core_data(data_config, shift_limit_multiplier=3.0):
    """
    Preprocess core data by cleaning and scaling depth values using direct file paths.
    """
    # Validate threshold conditions
    valid_conditions = ['>', '<', '<=', '>=']
    for param, threshold in data_config.get('thresholds', {}).items():
        if threshold[0] not in valid_conditions:
            raise ValueError(f"Invalid condition '{threshold[0]}' for {param}. Must be one of: {valid_conditions}")
    
    # Create output directory
    os.makedirs(data_config['mother_dir'] + data_config['clean_output_folder'], exist_ok=True)
    
    # Initialize data variables
    ct_data = None
    rgb_data = None 
    mst_data = None
    hrms_data = None

    # Try to read each data file if it exists
    ct_path = data_config['mother_dir'] + data_config['data_folder'] + f"{data_config['core_name']}_CT.csv"
    if os.path.exists(ct_path):
        ct_data = pd.read_csv(ct_path).astype('float64')
        
        # Process CT data
        if ct_data is not None and 'ct' in data_config['thresholds']:
            condition, threshold_value, buffer_size = data_config['thresholds']['ct']
            invalid_ct = eval(f"ct_data['CT'] {condition} {threshold_value}")
            buffer_indices_ct = []
            for i in range(len(ct_data)):
                if invalid_ct[i]:
                    buffer_indices_ct.extend(range(max(0, i-buffer_size), min(len(ct_data), i+buffer_size+1)))
            ct_data.loc[buffer_indices_ct, ['CT', 'CT_std']] = np.nan
        
        if ct_data is not None:
            ct_depth_scale = data_config['core_length'] / ct_data['SB_DEPTH_cm'].max()
            ct_data['SB_DEPTH_cm'] = ct_data['SB_DEPTH_cm'] * ct_depth_scale
            # Use direct file path
            ct_output_path = data_config['mother_dir'] + data_config['clean_file_paths']['ct']
            ct_data.to_csv(ct_output_path, index=False)

    rgb_path = data_config['mother_dir'] + data_config['data_folder'] + f"{data_config['core_name']}_RGB.csv"
    if os.path.exists(rgb_path):
        rgb_data = pd.read_csv(rgb_path).astype('float64')
        
        # Process RGB data
        if rgb_data is not None:
            rgb_columns = ['R', 'G', 'B', 'Lumin']
            buffer_indices_rgb = []
            
            for col in rgb_columns:
                if col.lower() in data_config['thresholds']:
                    condition, threshold_value, buffer_size = data_config['thresholds'][col.lower()]
                    invalid_values = eval(f"rgb_data['{col}'] {condition} {threshold_value}")
                    for i in range(len(rgb_data)):
                        if invalid_values[i]:
                            buffer_indices_rgb.extend(range(max(0, i-buffer_size), min(len(rgb_data), i+buffer_size+1)))
            
            if buffer_indices_rgb:
                rgb_data.loc[buffer_indices_rgb, rgb_columns + [f'{col}_std' for col in rgb_columns]] = np.nan
            
            rgb_depth_scale = data_config['core_length'] / rgb_data['SB_DEPTH_cm'].max()
            rgb_data['SB_DEPTH_cm'] = rgb_data['SB_DEPTH_cm'] * rgb_depth_scale
            # Use direct file path
            rgb_output_path = data_config['mother_dir'] + data_config['clean_file_paths']['rgb']
            rgb_data.to_csv(rgb_output_path, index=False)

    # Determine subfolder paths based on core name
    if data_config['core_name'].startswith('M99'):
        mst_subfolder = "OSU orignal dataset/R-V_Melville99/Calibrated_MST/"
        hrms_subfolder = "OSU orignal dataset/R-V_Melville99/M9907_point_mag/"
    elif data_config['core_name'].startswith('RR02'):
        mst_subfolder = "OSU orignal dataset/R-V_Revelle02/Calibrated_MST/"
        hrms_subfolder = "OSU orignal dataset/R-V_Revelle02/RR0207_point_mag/"
    else:
        mst_subfolder = "OSU orignal dataset/R-V_Melville99/Calibrated_MST/"
        hrms_subfolder = "OSU orignal dataset/R-V_Melville99/M9907_point_mag/"

    mst_path = data_config['mother_dir'] + mst_subfolder + f"{data_config['core_name']}_MST.csv"
    if os.path.exists(mst_path):
        mst_data = pd.read_csv(mst_path).astype('float64')
        
        # Process MST data
        if mst_data is not None:
            mst_columns = {
                'MS': 'ms',
                'PWVel_m/s': 'pwvel',
                'PWAmp': 'pwamp',
                'Den_gm/cc': 'den', 
                'ElecRes_ohmm': 'elecres'
            }

            density_extreme_indices = []
            if 'Den_gm/cc' in mst_data.columns and 'den' in data_config['thresholds']:
                condition, threshold_value, buffer_size = data_config['thresholds']['den']
                density_extreme = eval(f"mst_data['Den_gm/cc'] {condition} {threshold_value}")
                for i in range(len(mst_data)):
                    if density_extreme[i]:
                        density_extreme_indices.extend(range(max(0, i-buffer_size), min(len(mst_data), i+buffer_size+1)))

            for column, param_key in mst_columns.items():
                if column in mst_data.columns and param_key in data_config['thresholds']:
                    condition, threshold_value, buffer_size = data_config['thresholds'][param_key]
                    extreme_values = eval(f"mst_data[column] {condition} {threshold_value}")
                    
                    buffer_indices = []
                    for i in range(len(mst_data)):
                        if extreme_values[i]:
                            buffer_indices.extend(range(max(0, i-buffer_size), min(len(mst_data), i+buffer_size+1)))
                    
                    if column == 'MS':
                        buffer_indices.extend(density_extreme_indices)
                    
                    mst_data.loc[buffer_indices, column] = np.nan

            if not mst_data.drop('SB_DEPTH_cm', axis=1).isna().all().all():
                mst_depth_scale = data_config['core_length'] / mst_data['SB_DEPTH_cm'].max()
                mst_data['SB_DEPTH_cm'] = mst_data['SB_DEPTH_cm'] * mst_depth_scale
                # Use direct file path
                mst_output_path = data_config['mother_dir'] + data_config['clean_file_paths']['mst']
                mst_data.to_csv(mst_output_path, index=False)

    hrms_path = data_config['mother_dir'] + hrms_subfolder + f"{data_config['core_name']}_ptMS.csv"
    if os.path.exists(hrms_path):
        hrms_data = pd.read_csv(hrms_path).astype('float64')
        
        # Process Hi-res MS data only if both HRMS data and at least one reference curve exist
        if hrms_data is not None and (ct_data is not None or (mst_data is not None and 'Den_gm/cc' in mst_data.columns)):
            # [Rest of HRMS processing logic remains the same]
            hrms_depth = hrms_data['SB_DEPTH_cm'].values
            
            # Resample density data if available
            density_resampled = None
            if mst_data is not None and 'Den_gm/cc' in mst_data.columns:
                density_resampled = np.interp(hrms_depth, 
                                            mst_data['SB_DEPTH_cm'].values,
                                            mst_data['Den_gm/cc'].values)
            else:
                density_resampled = np.full_like(hrms_depth, np.nan)
                
            # Resample CT data if available
            ct_resampled = None
            if ct_data is not None:
                ct_resampled = np.interp(hrms_depth,
                                       ct_data['SB_DEPTH_cm'].values,
                                       ct_data['CT'].values)
            else:
                ct_resampled = np.full_like(hrms_depth, np.nan)

            # Apply thresholds to HRMS data
            if 'hiresms' in data_config['thresholds']:
                condition, threshold_value, buffer_size = data_config['thresholds']['hiresms']
                extreme_values = eval(f"hrms_data['hiresMS'] {condition} {threshold_value}")
                buffer_indices = []
                for i in range(len(hrms_data)):
                    if extreme_values[i]:
                        buffer_indices.extend(range(max(0, i - buffer_size), min(len(hrms_data), i + buffer_size + 1)))
                hrms_data.loc[buffer_indices, 'hiresMS'] = np.nan

            # Identify continuous segments in HRMS data
            valid_indices = hrms_data.index[hrms_data['hiresMS'].notna()].tolist()
            segments = []
            if valid_indices:
                current_segment = [valid_indices[0]]
                for idx in valid_indices[1:]:
                    if idx == current_segment[-1] + 1:
                        current_segment.append(idx)
                    else:
                        segments.append(current_segment)
                        current_segment = [idx]
                segments.append(current_segment)

            # Process each HRMS segment
            for seg in segments:
                seg_depth = hrms_data.loc[seg, 'SB_DEPTH_cm']
                seg_values = hrms_data.loc[seg, 'hiresMS']
                if seg_values.empty:
                    continue

                # Get resampled data for this segment
                seg_density = density_resampled[seg] if density_resampled is not None else None
                seg_ct = ct_resampled[seg] if ct_resampled is not None else None

                # Check if at least one reference curve has valid data
                if (seg_density is None or np.all(np.isnan(seg_density))) and (seg_ct is None or np.all(np.isnan(seg_ct))):
                    print(f"Warning: No valid reference data for segment at depth {seg_depth.iloc[0]:.2f}")
                    continue

                # ---- Candidate Shift from Density Curve ----
                candidate_shift_density = 0.0
                corr_density = 0.0
                if seg_density is not None and not np.all(np.isnan(seg_density)):
                    candidate_shift_density = compute_candidate_shift(seg_values.values, 
                                                                   seg_density,
                                                                   seg_depth.values)
                    corr_density = np.abs(np.corrcoef(seg_values.values, seg_density)[0,1])
                    if np.isnan(corr_density): corr_density = 0.0

                # ---- Candidate Shift from CT Curve ----
                candidate_shift_ct = 0.0
                corr_ct = 0.0
                if seg_ct is not None and not np.all(np.isnan(seg_ct)):
                    candidate_shift_ct = compute_candidate_shift(seg_values.values,
                                                              seg_ct, 
                                                              seg_depth.values)
                    corr_ct = np.abs(np.corrcoef(seg_values.values, seg_ct)[0,1])
                    if np.isnan(corr_ct): corr_ct = 0.0

                # Determine the maximum allowed shift based on neighboring gaps
                if seg[0] > 0:
                    gap_before = hrms_data.at[seg[0], 'SB_DEPTH_cm'] - hrms_data.at[seg[0]-1, 'SB_DEPTH_cm']
                else:
                    gap_before = np.inf
                if seg[-1] < len(hrms_data) - 1:
                    gap_after = hrms_data.at[seg[-1]+1, 'SB_DEPTH_cm'] - hrms_data.at[seg[-1], 'SB_DEPTH_cm']
                else:
                    gap_after = np.inf

                # Calculate gap-based shift limit
                gap_based_shift = min(gap_before, gap_after) * shift_limit_multiplier

                # ---- Consensus Shift ----
                # Set correlation to 0 if candidate shift exceeds gap-based shift or is negative
                if abs(candidate_shift_density) > gap_based_shift or corr_density < 0:
                    corr_density = 0.0
                if abs(candidate_shift_ct) > gap_based_shift or corr_ct < 0:
                    corr_ct = 0.0

                # Calculate weights based on correlations
                total_corr = corr_density + corr_ct
                if total_corr > 0:
                    w_density = corr_density / total_corr
                    w_ct = corr_ct / total_corr
                else:
                    w_density, w_ct = 0.4, 0.6

                consensus_shift = (w_density * candidate_shift_density + 
                                 w_ct * candidate_shift_ct)

                # Apply the consensus shift by moving hiresMS values to new positions
                if abs(consensus_shift) <= gap_based_shift:
                    # Calculate target indices based on depth shift
                    target_indices = []
                    for idx in seg:
                        current_depth = hrms_data.at[idx, 'SB_DEPTH_cm']
                        target_depth = current_depth + consensus_shift
                        # Find closest depth position
                        target_idx = (hrms_data['SB_DEPTH_cm'] - target_depth).abs().idxmin()
                        target_indices.append(target_idx)
                    
                    # Store original values
                    original_values = hrms_data.loc[seg, 'hiresMS'].copy()
                    
                    # Clear original positions
                    hrms_data.loc[seg, 'hiresMS'] = np.nan
                    
                    # Move values to new positions
                    for orig_val, target_idx in zip(original_values, target_indices):
                        hrms_data.at[target_idx, 'hiresMS'] = orig_val
                else:
                    print(f"Warning: Computed shift ({consensus_shift:.2f}) exceeds gap-based shift limit ({gap_based_shift:.2f}) "
                          f"for segment at depths {seg_depth.iloc[0]:.2f}-{seg_depth.iloc[-1]:.2f}")

            # Rescale the depth after shifting
            depth_scale_factor = data_config['core_length'] / hrms_data['SB_DEPTH_cm'].max()
            hrms_data['SB_DEPTH_cm'] = hrms_data['SB_DEPTH_cm'] * depth_scale_factor
            hrms_output_path = data_config['mother_dir'] + data_config['clean_file_paths']['hrms']
            hrms_data.to_csv(hrms_output_path, index=False)

### Function for plotting cleanned core images and logs

In [None]:
def plot_core_logs(data_config, file_type='clean', title=None):
    """Plot core logs using direct file paths"""
    # Get file paths based on type
    if file_type == 'clean':
        data_paths = data_config.get('clean_file_paths', {})
    else:
        data_paths = data_config.get('filled_file_paths', {})
    
    # Get available column configs
    available_columns = data_config.get('column_configs', {})
    
    # Only process data types that have both file path and column config
    valid_data_types = set(data_paths.keys()) & set(available_columns.keys())
    
    # Build full file paths
    full_paths = {}
    for data_type in valid_data_types:
        if file_type == 'clean':
            full_paths[data_type] = data_config['mother_dir'] + data_config['clean_output_folder'] + data_paths[data_type]
        else:
            full_paths[data_type] = data_config['mother_dir'] + data_config['filled_output_folder'] + data_paths[data_type]
    

    # Load images
    ct_img_path = data_config['mother_dir'] + data_config['ct_image_path']
    rgb_img_path = data_config['mother_dir'] + data_config['rgb_image_path']
    
    ct_img = plt.imread(ct_img_path) if os.path.exists(ct_img_path) else None
    rgb_img = plt.imread(rgb_img_path) if os.path.exists(rgb_img_path) else None
    
    # Load Core Length and Name
    core_length = data_config['core_length']
    core_name = data_config['core_name']
    
    if title is None:
        file_type_title = 'Cleaned' if file_type == 'clean' else 'ML-Filled'
        title = f'{core_name} {file_type_title} Logs'
    
    # Load available data
    data = {}
    for key, path in full_paths.items():
        if os.path.exists(path):
            loaded_data = pd.read_csv(path)
            if 'SB_DEPTH_cm' in loaded_data.columns:
                data[key] = loaded_data
    
    if not data:
        raise ValueError("No valid data files found to plot")
    
    # [Rest of plotting logic remains the same as before]
    # Calculate number of plots based on available data
    n_plots = 0
    
    if ct_img is not None and 'ct' in data:
        n_plots += 2
    if rgb_img is not None and 'rgb' in data:
        n_plots += 2
        
    # MS panel
    has_ms = False
    if 'mst' in data and 'mst' in available_columns:
        ms_col = available_columns['mst']['ms']['data_col']
        if ms_col in data['mst'].columns and not data['mst'][ms_col].isna().all():
            has_ms = True
    if 'hrms' in data and 'hrms' in available_columns:
        hrms_col = available_columns['hrms']['data_col']
        if not data['hrms'][hrms_col].isna().all():
            has_ms = True
    if has_ms:
        n_plots += 1
        
    # Other MST logs
    if 'mst' in data and 'mst' in available_columns:
        for log_type, config in available_columns['mst'].items():
            if log_type != 'ms' and config['data_col'] in data['mst'].columns and not data['mst'][config['data_col']].isna().all():
                n_plots += 1

    if n_plots == 0:
        raise ValueError("No valid data to plot")
    
    # Create figure and axes
    fig, axes = plt.subplots(1, n_plots, figsize=(10, 16), sharey=True)
    if n_plots == 1:
        axes = [axes]
    fig.suptitle(title, fontweight='bold')
    
    current_ax = 0
    
    # Plot CT if available
    if ct_img is not None and 'ct' in data:
        # CT image
        axes[current_ax].imshow(ct_img, aspect='auto', extent=[0, 0.5, core_length, 0])
        axes[current_ax].set_ylabel('Depth (cm)')
        axes[current_ax].set_xticks([])
        axes[current_ax].set_xlabel('Sediment\nCore\nCT Scan', fontweight='bold', fontsize='small')
        current_ax += 1
        
        # CT data
        ct_config = available_columns['ct']
        ct_col = ct_config['data_col']
        ct_std = ct_config['std_col']
        ct_depth = data['ct'][ct_config['depth_col']].astype(np.float64)
        
        axes[current_ax].plot(data['ct'][ct_col].astype(np.float64), ct_depth, color='black', linewidth=0.7)
        
        if ct_std in data['ct'].columns:
            axes[current_ax].fill_betweenx(
                ct_depth,
                data['ct'][ct_col].astype(np.float64) - data['ct'][ct_std].astype(np.float64),
                data['ct'][ct_col].astype(np.float64) + data['ct'][ct_std].astype(np.float64),
                color='black', alpha=0.2, linewidth=0
            )
        
        # Color-coded CT values
        ct_values = data['ct'][ct_col].astype(np.float64).values
        depths = ct_depth.values
        norm = plt.Normalize(300, 1600)
        cmap = plt.cm.jet
        
        ct_polys = []
        ct_facecolors = []
        for i in range(len(depths) - 1):
            if not (np.isnan(ct_values[i]) or np.isnan(ct_values[i+1])):
                poly = [(0, depths[i]), (ct_values[i], depths[i]), (ct_values[i+1], depths[i+1]), (0, depths[i+1])]
                ct_polys.append(poly)
                avg_val = (ct_values[i] + ct_values[i+1]) / 2
                ct_facecolors.append(cmap(norm(avg_val)))
                
        if ct_polys:
            pc_ct = mcoll.PolyCollection(ct_polys, facecolors=ct_facecolors, edgecolors='none', alpha=0.95)
            axes[current_ax].add_collection(pc_ct)
        
        axes[current_ax].set_xlabel('CT#\nBrightness', fontweight='bold', fontsize='small')
        axes[current_ax].grid(True)
        axes[current_ax].set_xlim(300, None)
        axes[current_ax].tick_params(axis='x', labelsize='x-small')
        current_ax += 1
    
    # Plot RGB if available
    if rgb_img is not None and 'rgb' in data:
        # RGB image
        axes[current_ax].imshow(rgb_img, aspect='auto', extent=[0, 0.5, core_length, 0])
        axes[current_ax].set_xticks([])
        axes[current_ax].set_xlabel('Sediment\nCore\nPhoto', fontweight='bold', fontsize='small')
        current_ax += 1
        
        # RGB data
        rgb_config = available_columns['rgb']
        rgb_cols = rgb_config['data_cols']
        rgb_stds = rgb_config['std_cols']
        rgb_depth = data['rgb'][rgb_config['depth_col']].astype(np.float64)
        colors = ['red', 'green', 'blue']
        
        for col, std, color in zip(rgb_cols[:3], rgb_stds[:3], colors):
            if col in data['rgb'].columns:
                axes[current_ax].plot(data['rgb'][col].astype(np.float64), rgb_depth, color=color, linewidth=0.7)
                if std in data['rgb'].columns:
                    axes[current_ax].fill_betweenx(
                        rgb_depth,
                        data['rgb'][col].astype(np.float64) - data['rgb'][std].astype(np.float64),
                        data['rgb'][col].astype(np.float64) + data['rgb'][std].astype(np.float64),
                        color=color, alpha=0.2, linewidth=0
                    )
        
        # Luminance plot
        if 'Lumin' in data['rgb'].columns:
            lumin_values = data['rgb']['Lumin'].astype(np.float64).values
            lumin_depths = rgb_depth.values
            
            valid_lumin = lumin_values[~np.isnan(lumin_values)]
            if len(valid_lumin) > 0:
                vmin, vmax = valid_lumin.min(), valid_lumin.max()
                if not np.isclose(vmin, vmax):
                    lumin_norm = plt.Normalize(vmin, vmax)
                    cmap_inferno = plt.cm.inferno
                    
                    lumin_polys = []
                    lumin_facecolors = []
                    for i in range(len(lumin_depths) - 1):
                        if not (np.isnan(lumin_values[i]) or np.isnan(lumin_values[i+1])):
                            poly = [(0, lumin_depths[i]), (lumin_values[i], lumin_depths[i]), 
                                   (lumin_values[i+1], lumin_depths[i+1]), (0, lumin_depths[i+1])]
                            lumin_polys.append(poly)
                            avg_val = (lumin_values[i] + lumin_values[i+1]) / 2
                            lumin_facecolors.append(cmap_inferno(lumin_norm(avg_val)))
                    if lumin_polys:
                        pc_lumin = mcoll.PolyCollection(lumin_polys, facecolors=lumin_facecolors, edgecolors='none', alpha=0.95)
                        axes[current_ax].add_collection(pc_lumin)
        
        axes[current_ax].set_xlabel('RGB\nLuminance', fontweight='bold', fontsize='small')
        axes[current_ax].grid(True)
        axes[current_ax].tick_params(axis='x', labelsize='x-small')
        current_ax += 1
    
    # Plot MS data if available
    if has_ms:
        has_mst_ms = False
        has_hrms = False
        
        if 'mst' in data and 'mst' in available_columns:
            ms_col = available_columns['mst']['ms']['data_col']
            if ms_col in data['mst'].columns and not data['mst'][ms_col].isna().all():
                has_mst_ms = True
                
        if 'hrms' in data and 'hrms' in available_columns:
            if not data['hrms'][available_columns['hrms']['data_col']].isna().all():
                has_hrms = True
        
        if has_mst_ms:
            axes[current_ax].plot(
                data['mst'][ms_col].astype(np.float64), 
                data['mst'][available_columns['mst']['ms']['depth_col']].astype(np.float64),
                color='darkgray', label='Lo-res', linewidth=0.7
            )
        if has_hrms:
            axes[current_ax].plot(
                data['hrms'][available_columns['hrms']['data_col']].astype(np.float64), 
                data['hrms'][available_columns['hrms']['depth_col']].astype(np.float64),
                color='black', label='Hi-res', linewidth=0.7
            )
        axes[current_ax].tick_params(axis='x', labelsize='x-small')
        axes[current_ax].set_xlabel('Magnetic\nSusceptibility', fontweight='bold', fontsize='small')
        axes[current_ax].grid(True)
        current_ax += 1

    # Plot other MST logs if available
    if 'mst' in data and 'mst' in available_columns:
        mst_labels = {
            'density': 'Density\n(g/cc)',
            'pwvel': 'P-wave\nVelocity\n(m/s)',
            'pwamp': 'P-wave\nAmplitude',
            'elecres': 'Electrical\nResistivity\n(ohm-m)'
        }
        
        mst_colors = {
            'density': 'orange',
            'pwvel': 'purple',
            'pwamp': 'purple',
            'elecres': 'brown'
        }

        for log_type, config in available_columns['mst'].items():
            if log_type != 'ms' and config['data_col'] in data['mst'].columns and not data['mst'][config['data_col']].isna().all():
                axes[current_ax].plot(
                    data['mst'][config['data_col']].astype(np.float64), 
                    data['mst'][config['depth_col']].astype(np.float64), 
                    color=mst_colors.get(log_type, 'black'), 
                    linewidth=0.7
                )
                axes[current_ax].set_xlabel(mst_labels[log_type], fontweight='bold', fontsize='small')
                axes[current_ax].tick_params(axis='x', labelsize='x-small')
                axes[current_ax].grid(True)
                if log_type == 'density':
                    axes[current_ax].set_xlim(1, 2)
                current_ax += 1
    
    # Set common y-axis properties
    for ax in axes:
        ax.invert_yaxis()
        ax.set_ylim(core_length, 0)
    
    plt.tight_layout()
    return fig, axes

## **Functions for Machine Learning to fill data gaps**

### Function for plotting filled data

In [None]:
def plot_filled_data(target_log, original_data, filled_data, core_length, core_name, ML_type = 'ML'):
    """
    Plot original and ML-filled data for a given log.
    
    Args:
        target_log (str): Name of the log to plot
        original_data (pd.DataFrame): Original data containing the log
        filled_data (pd.DataFrame): Data with ML-filled gaps
        core_length (int): Length of the core in cm
        core_name (str): Name of the core for plot title
    """
    # Check if there are any gaps
    has_gaps = original_data[target_log].isna().any()
    
    # Create figure
    fig, ax = plt.subplots(figsize=(15, 3))
    title_suffix = f'Use {ML_type} for Data Gap Filling' if has_gaps else "(No Data Gap to be filled by ML)"
    fig.suptitle(f'{core_name} {target_log} Values {title_suffix}', fontweight='bold')

    # Plot data with ML-predicted gaps only if gaps exist
    if has_gaps:
        ax.plot(filled_data['SB_DEPTH_cm'], filled_data[target_log], 
                color='red', label=f'ML Predicted {target_log}', linewidth=0.7, alpha=0.7)

    # Plot original data
    ax.plot(original_data['SB_DEPTH_cm'], original_data[target_log], 
            color='black', label=f'Original {target_log}', linewidth=0.7)

    # Add uncertainty shade if std column exists
    std_col = f'{target_log}_std'
    if std_col in original_data.columns:
        ax.fill_between(original_data['SB_DEPTH_cm'],
                       original_data[target_log] - original_data[std_col],
                       original_data[target_log] + original_data[std_col],
                       color='black', alpha=0.2, linewidth=0)

    # Customize plot
    ax.set_ylabel(f'{target_log}\nBrightness', fontweight='bold', fontsize='small')
    ax.set_xlabel('Depth (cm)')
    ax.grid(True)
    ax.invert_xaxis()
    ax.set_xlim(0, core_length)
    ax.tick_params(axis='y', labelsize='x-small')
    ax.legend()

    plt.tight_layout()
    plt.show()

### Functions for Machine Learning Data Gap filling

In [None]:
# Helper Functions for fill_gaps_with_ml

def prepare_feature_data(target_log, All_logs, merge_tolerance):
    """Prepare merged feature data for ML training."""
    # Get target data from All_logs
    target_data = None
    for df, cols in All_logs.values():
        if target_log in cols:
            target_data = df.copy()
            break
    
    if target_data is None:
        raise ValueError(f"Target log '{target_log}' not found in any dataset")

    # Convert SB_DEPTH_cm to float64 in target data
    target_data['SB_DEPTH_cm'] = target_data['SB_DEPTH_cm'].astype('float64')

    # Prepare training data by merging all available logs
    merged_data = target_data[['SB_DEPTH_cm', target_log]].copy()
    features = []
    
    # Merge feature dataframes one by one, using their own SB_DEPTH_cm column
    for df_name, (df, cols) in All_logs.items():
        if target_log not in cols:  # Skip the target dataframe
            df = df.copy()
            df['SB_DEPTH_cm'] = df['SB_DEPTH_cm'].astype('float64')
            # Rename SB_DEPTH_cm temporarily to avoid conflicts during merging
            df = df.rename(columns={'SB_DEPTH_cm': f'SB_DEPTH_cm_{df_name}'})
            # Convert all numeric columns to float64
            for col in cols:
                if col != 'SB_DEPTH_cm' and df[col].dtype.kind in 'biufc':
                    df[col] = df[col].astype('float64')
            # Rename feature columns for merging
            df_renamed = df.rename(columns={col: f'{df_name}_{col}' for col in cols if col != 'SB_DEPTH_cm'})
            df_renamed = df_renamed.sort_values(f'SB_DEPTH_cm_{df_name}')
            
            # Perform merge_asof with tolerance for data alignment
            merged_data = pd.merge_asof(
                merged_data.sort_values('SB_DEPTH_cm'),
                df_renamed,
                left_on='SB_DEPTH_cm',
                right_on=f'SB_DEPTH_cm_{df_name}',
                direction='nearest',
                tolerance=merge_tolerance
            )
            
            # Check for unmatched rows due to the tolerance constraint
            unmatched = merged_data[f'SB_DEPTH_cm_{df_name}'].isna().sum()
            if unmatched > 0:
                warnings.warn(f"{unmatched} rows did not have a matching depth within tolerance for log '{df_name}'.")
            
            # Add renamed feature columns to features list
            features.extend([f'{df_name}_{col}' for col in cols if col != 'SB_DEPTH_cm'])
            # Drop the temporary depth column used for merging
            merged_data = merged_data.drop(columns=[f'SB_DEPTH_cm_{df_name}'])
    
    # Add SB_DEPTH_cm as a feature
    features.append('SB_DEPTH_cm')
    
    return target_data, merged_data, features

def apply_feature_weights(X, method):
    """Apply feature weights for XGBoost methods."""
    if method == 'xgb':
        feature_weights = {
            'RGB': {
                'R': 0.1,
                'G': 0.1,
                'B': 0.1,
                'Lumin': 0.32
            },
            'MS': {
                'hiresMS': 3.0,
                'MS': 0.5
            },
            'Physical': {
                'PWVel_m/s': 0.05,
                'PWAmp': 0.05,
                'Den_gm/cc': 3.0
            }
        }
        
        X_weighted = X.copy()
        for group, weights in feature_weights.items():
            for feature, weight in weights.items():
                if feature in X_weighted.columns:
                    X_weighted[feature] = (X_weighted[feature] * weight).astype('float32')
        return X_weighted
    
    elif method == 'xgblgbm':
        feature_weights = {
            'RGB': {
                'R': 0.3,
                'G': 0.3,
                'B': 0.3,
                'Lumin': 0.3
            },
            'MS': {
                'hiresMS': 3.0,
                'MS': 0.05
            },
            'Physical': {
                'PWVel_m/s': 0.01,
                'PWAmp': 0.01,
                'Den_gm/cc': 0.5
            }
        }
        
        X_weighted = X.copy()
        for group, weights in feature_weights.items():
            for feature, weight in weights.items():
                matching_cols = [col for col in X_weighted.columns if feature in col]
                for col in matching_cols:
                    X_weighted[col] = (X_weighted[col] * weight).astype('float32')
        return X_weighted
    
    return X


def adjust_gap_predictions(df, gap_mask, ml_preds, target_log):
    """
    Adjust ML predictions for gap rows in 'df' so that for each contiguous gap
    segment (with both left and right boundaries available) the predictions are
    blended with the linear interpolation between the boundary values.
    """
    # Get the integer positions (row numbers) of missing values
    gap_positions = np.where(gap_mask.values)[0]
    # Create a Series for easier handling; index = positions in df
    preds_series = pd.Series(ml_preds, index=gap_positions)
    
    # Identify contiguous segments in the gap positions
    segments = np.split(gap_positions, np.where(np.diff(gap_positions) != 1)[0] + 1)
    
    adjusted = preds_series.copy()
    for seg in segments:
        # seg is an array of row positions (in df) for a contiguous gap segment.
        start_pos = seg[0]
        end_pos = seg[-1]
        
        # Enforce trend constraints only if both boundaries exist.
        if start_pos == 0 or end_pos == len(df) - 1:
            continue  # Skip segments at the very beginning or end.
        
        # Retrieve boundary (observed) values and depths
        left_value = df.iloc[start_pos - 1][target_log]
        right_value = df.iloc[end_pos + 1][target_log]
        # Skip if boundaries are missing (should not happen if gap_mask is correct)
        if pd.isna(left_value) or pd.isna(right_value):
            continue
        left_depth = df.iloc[start_pos - 1]['SB_DEPTH_cm']
        right_depth = df.iloc[end_pos + 1]['SB_DEPTH_cm']
        
        # For each gap row in the segment, blend the ML prediction with linear interpolation
        for pos in seg:
            current_depth = df.iloc[pos]['SB_DEPTH_cm']
            # Normalize the depth position (x in [0, 1])
            if right_depth == left_depth:
                x = 0.5
            else:
                x = (current_depth - left_depth) / (right_depth - left_depth)
            # Compute the linear interpolation value at this depth
            interp_val = left_value + (right_value - left_value) * x
            # Define a weight that is 0 at the boundaries and 1 at the middle.
            # Here we use: weight = 1 - 2*|x - 0.5|
            weight = 1 - 2 * abs(x - 0.5)
            weight = max(0, min(weight, 1))  # Ensure weight is between 0 and 1
            # Blend: final = interpolation + weight*(ML_prediction - interpolation)
            adjusted[pos] = interp_val + weight * (preds_series.loc[pos] - interp_val)
    return adjusted.values


def train_model(model):
    """Helper function for parallel model training."""
    def train_wrapper(X_train, y_train, X_pred):
        model.fit(X_train, y_train)
        return model.predict(X_pred)
    return train_wrapper




In [None]:
def fill_gaps_with_ml(target_log=None, 
                      All_logs=None, 
                      output_csv=False, 
                      output_dir=None, 
                      core_name=None, 
                      merge_tolerance=3.0,
                      ml_method='xgblgbm'):
    """
    Fill gaps in target data using specified ML method.
    
    Args:
        target_log (str): Name of the target column to fill gaps in.
        All_logs (dict): Dictionary of dataframes containing feature data and target data.
                         Format: {'df_name': (dataframe, [column_names])}
        output_csv (bool): Whether to output filled data to CSV file.
        output_dir (str): Directory to save output CSV file.
        core_name (str): Name of the core for CSV filename.
        merge_tolerance (float): Maximum allowed difference in depth (SB_DEPTH_cm) for merging
                                 rows from different logs.
        ml_method (str): ML method to use - 'rf', 'rftc', 'xgb', 'xgblgbm' (default)
        
    Returns:
        tuple: (target_data_filled, gap_mask)
    """
    # Input validation
    if target_log is None or All_logs is None:
        raise ValueError("Both target_log and All_logs must be provided")
        
    if output_csv and (output_dir is None or core_name is None):
        raise ValueError("output_dir and core_name must be provided when output_csv is True")
    
    if ml_method not in ['rf', 'rftc', 'xgb', 'xgblgbm']:
        raise ValueError("ml_method must be one of: 'rf', 'rftc', 'xgb', 'xgblgbm'")
    
    # Prepare feature data
    target_data, merged_data, features = prepare_feature_data(target_log, All_logs, merge_tolerance)
    
    # Create a copy of the original data to hold the interpolated results
    target_data_filled = target_data.copy()

    # Identify gaps in target data
    gap_mask = target_data[target_log].isna()
    
    # If no gaps exist, save to CSV if requested and return original data
    if not gap_mask.any():
        if output_csv:
            output_path = os.path.join(output_dir, f'{core_name}_{target_log.split("_")[0]}_MLfilled.csv')
            target_data_filled.to_csv(output_path, index=False)
        return target_data_filled, gap_mask

    # Prepare features and target for ML
    X = merged_data[features].copy()
    y = merged_data[target_log].copy()

    # Convert all features to float64
    for col in X.columns:
        if X[col].dtype.kind in 'biufc':
            X[col] = X[col].astype('float64')
    y = y.astype('float64')

    # Split into training (non-gap) and prediction (gap) sets
    X_train = X[~gap_mask]
    y_train = y[~gap_mask]
    X_pred = X[gap_mask]

    # Apply specific ML method
    if ml_method == 'rf':
        predictions = _apply_random_forest(X_train, y_train, X_pred)
    elif ml_method == 'rftc':
        predictions = _apply_random_forest_with_trend_constraints(X_train, y_train, X_pred, merged_data, gap_mask, target_log)
    elif ml_method == 'xgb':
        predictions = _apply_xgboost(X_train, y_train, X_pred)
    elif ml_method == 'xgblgbm':
        predictions = _apply_xgboost_lightgbm(X_train, y_train, X_pred)

    # Fill gaps with predictions
    target_data_filled.loc[gap_mask, target_log] = predictions
    
    # Save to CSV if requested
    if output_csv:
        output_path = os.path.join(output_dir, f"{core_name}_{target_log.split('_')[0]}_MLfilled.csv")
        target_data_filled.to_csv(output_path, index=False)

    return target_data_filled, gap_mask


def _apply_random_forest(X_train, y_train, X_pred):
    """Apply Random Forest method."""
    # Handle outliers using IQR method
    quantile_cutoff = 0.025
    Q1 = y_train.quantile(quantile_cutoff)
    Q3 = y_train.quantile(1 - quantile_cutoff)
    IQR = Q3 - Q1
    outlier_mask = (y_train >= Q1 - 1.5 * IQR) & (y_train <= Q3 + 1.5 * IQR)
    X_train = X_train[outlier_mask]
    y_train = y_train[outlier_mask]

    def train_model(model):
        model.fit(X_train, y_train)
        return model.predict(X_pred)

    # Initialize two ensemble models
    models = [
        RandomForestRegressor(n_estimators=1000,
                              max_depth=30,
                              min_samples_split=5,
                              min_samples_leaf=5,
                              max_features='sqrt',
                              bootstrap=True,
                              random_state=42,
                              n_jobs=-1),
        HistGradientBoostingRegressor(max_iter=800,
                                      learning_rate=0.05,
                                      max_depth=5,
                                      min_samples_leaf=50,
                                      l2_regularization=1.0,
                                      random_state=42,
                                      verbose=0)
    ]

    # Train models in parallel
    predictions = Parallel(n_jobs=-1)(delayed(train_model)(model) for model in models)

    # Ensemble predictions by averaging
    ensemble_predictions = np.mean(predictions, axis=0)
    
    return ensemble_predictions


def _apply_random_forest_with_trend_constraints(X_train, y_train, X_pred, merged_data, gap_mask, target_log):
    """Apply Random Forest with trend constraints method."""
    # Handle outliers using IQR method
    quantile_cutoff = 0.15
    Q1 = y_train.quantile(quantile_cutoff)
    Q3 = y_train.quantile(1 - quantile_cutoff)
    IQR = Q3 - Q1
    outlier_mask = (y_train >= Q1 - 1.5 * IQR) & (y_train <= Q3 + 1.5 * IQR)
    X_train = X_train[outlier_mask]
    y_train = y_train[outlier_mask]
    
    def train_model(model):
        model.fit(X_train, y_train)
        return model.predict(X_pred)
    
    # Initialize two ensemble models
    models = [
        RandomForestRegressor(n_estimators=1000,
                              max_depth=30,
                              min_samples_split=5,
                              min_samples_leaf=5,
                              max_features='sqrt',
                              bootstrap=True,
                              random_state=42,
                              n_jobs=-1),
        HistGradientBoostingRegressor(max_iter=800,
                                      learning_rate=0.05,
                                      max_depth=5,
                                      min_samples_leaf=50,
                                      l2_regularization=1.0,
                                      random_state=42,
                                      verbose=-1)
    ]
    
    # Train models in parallel and average their predictions
    predictions = Parallel(n_jobs=-1)(delayed(train_model)(model) for model in models)
    ensemble_predictions = np.mean(predictions, axis=0)
    
    # Apply trend constraints using the helper function from original
    adjusted_predictions = _adjust_gap_predictions(merged_data, gap_mask, ensemble_predictions, target_log)
    
    return adjusted_predictions


def _adjust_gap_predictions(df, gap_mask, ml_preds, target_log):
    """
    Adjust ML predictions for gap rows in 'df' so that for each contiguous gap
    segment (with both left and right boundaries available) the predictions are
    blended with the linear interpolation between the boundary values.
    """
    # Get the integer positions (row numbers) of missing values
    gap_positions = np.where(gap_mask.values)[0]
    # Create a Series for easier handling; index = positions in df
    preds_series = pd.Series(ml_preds, index=gap_positions)
    
    # Identify contiguous segments in the gap positions
    segments = np.split(gap_positions, np.where(np.diff(gap_positions) != 1)[0] + 1)
    
    adjusted = preds_series.copy()
    for seg in segments:
        # seg is an array of row positions (in df) for a contiguous gap segment.
        start_pos = seg[0]
        end_pos = seg[-1]
        
        # Enforce trend constraints only if both boundaries exist.
        if start_pos == 0 or end_pos == len(df) - 1:
            continue  # Skip segments at the very beginning or end.
        
        # Retrieve boundary (observed) values and depths
        left_value = df.iloc[start_pos - 1][target_log]
        right_value = df.iloc[end_pos + 1][target_log]
        # Skip if boundaries are missing (should not happen if gap_mask is correct)
        if pd.isna(left_value) or pd.isna(right_value):
            continue
        left_depth = df.iloc[start_pos - 1]['SB_DEPTH_cm']
        right_depth = df.iloc[end_pos + 1]['SB_DEPTH_cm']
        
        # For each gap row in the segment, blend the ML prediction with linear interpolation
        for pos in seg:
            current_depth = df.iloc[pos]['SB_DEPTH_cm']
            # Normalize the depth position (x in [0, 1])
            if right_depth == left_depth:
                x = 0.5
            else:
                x = (current_depth - left_depth) / (right_depth - left_depth)
            # Compute the linear interpolation value at this depth
            interp_val = left_value + (right_value - left_value) * x
            # Define a weight that is 0 at the boundaries and 1 at the middle.
            # Here we use: weight = 1 - 2*|x - 0.5|
            weight = 1 - 2 * abs(x - 0.5)
            weight = max(0, min(weight, 1))  # Ensure weight is between 0 and 1
            # Blend: final = interpolation + weight*(ML_prediction - interpolation)
            adjusted[pos] = interp_val + weight * (preds_series.loc[pos] - interp_val)
    return adjusted.values


def _apply_xgboost(X_train, y_train, X_pred):
    """Apply XGBoost method."""
    # Handle outliers using IQR method
    quantile_cutoff = 0.025
    Q1 = y_train.quantile(quantile_cutoff)
    Q3 = y_train.quantile(1 - quantile_cutoff)
    IQR = Q3 - Q1
    outlier_mask = (y_train >= Q1 - 1.5 * IQR) & (y_train <= Q3 + 1.5 * IQR)
    X_train = X_train[outlier_mask]
    y_train = y_train[outlier_mask]

    # Create feature pipeline
    feature_pipeline = Pipeline([
        ('imputer', SimpleImputer(strategy='median')),
        ('scaler', StandardScaler()),
        ('poly', PolynomialFeatures(degree=2, interaction_only=True, include_bias=True)),
        ('selector', SelectKBest(score_func=f_regression, k='all'))
    ])

    # Process features
    X_train_processed = feature_pipeline.fit_transform(X_train, y_train)
    X_pred_processed = feature_pipeline.transform(X_pred)

    # Convert processed arrays to float32
    X_train_processed = X_train_processed.astype('float32')
    X_pred_processed = X_pred_processed.astype('float32')
    y_train = y_train.astype('float32')

    # Initialize and train XGBoost model
    model = xgb.XGBRegressor(
        n_estimators=5000,
        learning_rate=0.003,
        max_depth=10,
        min_child_weight=5,
        subsample=0.75,
        colsample_bytree=0.75,
        gamma=0.2,
        reg_alpha=0.3,
        reg_lambda=3.0,
        random_state=42,
        n_jobs=-1,
    )
    
    model.fit(X_train_processed, y_train)
    predictions = model.predict(X_pred_processed).astype('float32')
    
    return predictions

def _apply_xgboost_lightgbm(X_train, y_train, X_pred):
    """Apply XGBoost + LightGBM ensemble method."""
    # Create feature pipeline
    feature_pipeline = Pipeline([
        ('imputer', SimpleImputer(strategy='median')),
        ('scaler', StandardScaler()),
        ('poly', PolynomialFeatures(degree=2, interaction_only=True, include_bias=True))
    ])

    # Process features without selector first to get actual feature count
    X_train_processed = feature_pipeline.fit_transform(X_train, y_train)
    
    # Now add selector with correct feature count
    max_features = min(50, X_train.shape[0]//10, X_train_processed.shape[1])
    selector = SelectKBest(score_func=f_regression, k=max_features)
    X_train_processed = selector.fit_transform(X_train_processed, y_train)
    X_pred_processed = feature_pipeline.transform(X_pred)
    X_pred_processed = selector.transform(X_pred_processed)

    # Convert processed arrays to float32
    X_train_processed = X_train_processed.astype('float32')
    X_pred_processed = X_pred_processed.astype('float32')
    y_train = y_train.astype('float32')

    # Initialize models
    xgb_model = xgb.XGBRegressor(
        n_estimators=3000,
        learning_rate=0.003,
        max_depth=10,
        min_child_weight=5,
        subsample=0.75,
        colsample_bytree=0.75,
        gamma=0.2,
        reg_alpha=0.3,
        reg_lambda=3.0,
        random_state=42,
        n_jobs=-1,
    )

    lgb_model = lgb.LGBMRegressor(
        n_estimators=3000,
        learning_rate=0.003,
        max_depth=6,
        num_leaves=20,
        min_child_samples=50,
        subsample=0.9,
        colsample_bytree=0.9,
        reg_alpha=0.3,
        reg_lambda=3.0,
        random_state=42,
        n_jobs=-1,
        force_col_wise=True,
        verbose=-1
    )

    # Train both models with warnings suppressed
    import warnings
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        xgb_model.fit(X_train_processed, y_train)
        lgb_model.fit(X_train_processed, y_train, feature_name='auto')

    # Make predictions with both models
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        xgb_predictions = xgb_model.predict(X_pred_processed).astype('float32')
        lgb_predictions = lgb_model.predict(X_pred_processed).astype('float32')

    # Ensemble predictions (simple average)
    predictions = (xgb_predictions + lgb_predictions) / 2

    return predictions

### Function to process and fill logs with chosen ML methods

In [None]:
def process_and_fill_logs(data_config, ml_method='xgblgbm'):
    """Process and fill gaps using direct file paths"""
    mother_dir = data_config['mother_dir']
    core_name = data_config['core_name']
    core_length = data_config['core_length']
    clean_output_folder = data_config['clean_output_folder']
    filled_output_folder = data_config['filled_output_folder']
    
    os.makedirs(mother_dir + filled_output_folder, exist_ok=True)
    
    clean_paths = data_config.get('clean_file_paths', {})
    available_columns = data_config.get('column_configs', {})
    valid_data_types = set(clean_paths.keys()) & set(available_columns.keys())
    
    # Load data with correct path construction
    data_dict = {}
    for data_type in valid_data_types:
        full_path = mother_dir + clean_output_folder + clean_paths[data_type]
        if os.path.exists(full_path):
            data = pd.read_csv(full_path)
            if not data.empty:
                data_dict[data_type] = data

    if not data_dict:
        print("No valid data files found for processing")
        return

    # Create feature data dictionary
    feature_data = {}
    
    if 'ct' in data_dict and 'ct' in available_columns:
        ct_col = available_columns['ct']['data_col']
        if ct_col in data_dict['ct'].columns:
            feature_data['ct'] = (data_dict['ct'], ['SB_DEPTH_cm', ct_col])
    
    if 'rgb' in data_dict and 'rgb' in available_columns:
        valid_rgb_cols = ['SB_DEPTH_cm'] + [col for col in available_columns['rgb']['data_cols'] 
                                           if col in data_dict['rgb'].columns]
        if len(valid_rgb_cols) > 1:
            feature_data['rgb'] = (data_dict['rgb'], valid_rgb_cols)
    
    if 'mst' in data_dict and 'mst' in available_columns:
        mst_cols = ['SB_DEPTH_cm']
        for log_type, config in available_columns['mst'].items():
            if config['data_col'] in data_dict['mst'].columns:
                mst_cols.append(config['data_col'])
        if len(mst_cols) > 1:
            feature_data['mst'] = (data_dict['mst'], mst_cols)
    
    if 'hrms' in data_dict and 'hrms' in available_columns:
        hrms_col = available_columns['hrms']['data_col']
        if hrms_col in data_dict['hrms'].columns:
            feature_data['hrms'] = (data_dict['hrms'], ['SB_DEPTH_cm', hrms_col])

    if not feature_data:
        print("No valid feature data found for ML processing")
        return

    # Define logs to process
    logs_to_process = []
    
    if 'ct' in data_dict and 'ct' in available_columns:
        ct_col = available_columns['ct']['data_col']
        if ct_col in data_dict['ct'].columns:
            logs_to_process.append((ct_col, ct_col, data_dict['ct']))
    
    if 'rgb' in data_dict and 'rgb' in available_columns:
        for col in available_columns['rgb']['data_cols']:
            if col in data_dict['rgb'].columns and not data_dict['rgb'][col].empty:
                logs_to_process.append((col, col, data_dict['rgb']))
    
    if 'mst' in data_dict and 'mst' in available_columns:
        for log_type, config in available_columns['mst'].items():
            col = config['data_col']
            if col in data_dict['mst'].columns and not data_dict['mst'][col].empty:
                logs_to_process.append((col, col, data_dict['mst']))
    
    if 'hrms' in data_dict and 'hrms' in available_columns:
        hrms_col = available_columns['hrms']['data_col']
        if hrms_col in data_dict['hrms'].columns:
            logs_to_process.append((hrms_col, hrms_col, data_dict['hrms']))

    if not logs_to_process:
        print("No valid logs found for processing")
        return

    ml_names = {
        'rf': 'Random Forest',
        'rftc': 'Random Forest + Trend Constraint', 
        'xgb': 'XGBoost',
        'xgblgbm': 'XGBoost + LightGBM'
    }
    
    if ml_method not in ml_names:
        raise ValueError("ml_method must be one of: 'rf', 'rftc', 'xgb', 'xgblgbm'")

    # Process each log
    for target_log, plot_name, data in logs_to_process:
        if target_log in ['R', 'G', 'B', 'Lumin']:
            filtered_features = {}
            priority_features = ['hrms', 'ct', 'rgb']
            
            for key in priority_features:
                if key in feature_data:
                    if key == 'rgb':
                        df, cols = feature_data[key]
                        valid_cols = ['SB_DEPTH_cm'] + [c for c in ['R','G','B','Lumin'] if c in cols and c in df.columns]
                        if len(valid_cols) > 1:
                            filtered_features[key] = (df, valid_cols)
                    else:
                        filtered_features[key] = feature_data[key]
            
            if 'mst' in feature_data:
                df, cols = feature_data['mst']
                if 'Den_gm/cc' in cols and 'Den_gm/cc' in df.columns:
                    filtered_features['mst'] = (df, ['SB_DEPTH_cm', 'Den_gm/cc'])
                    
            filled_data, gap_mask = fill_gaps_with_ml(
                target_log=target_log,
                All_logs=filtered_features,
                output_csv=True,
                output_dir=mother_dir + filled_output_folder,
                core_name=core_name,
                ml_method=ml_method
            )
        else:
            filled_data, gap_mask = fill_gaps_with_ml(
                target_log=target_log,
                All_logs=feature_data,
                output_csv=True,
                output_dir=mother_dir + filled_output_folder,
                core_name=core_name,
                ml_method=ml_method
            )
            
        plot_filled_data(plot_name, data, filled_data, core_length, core_name, ML_type=ml_names[ml_method])

    # Consolidate RGB data
    if 'rgb' in data_dict and 'rgb' in available_columns:
        rgb_data = data_dict['rgb'].copy()
        rgb_columns = available_columns['rgb']['data_cols']
        updated = False
        
        for col in rgb_columns:
            if col in rgb_data.columns:
                filled_file = mother_dir + filled_output_folder + f'{core_name}_{col}_MLfilled.csv'
                if os.path.exists(filled_file):
                    filled_data = pd.read_csv(filled_file)
                    if col in filled_data.columns:
                        rgb_data[col] = filled_data[col]
                        updated = True
                        print(f"Updated {col} column with ML-filled data")
        
        if updated:
            rgb_data.to_csv(mother_dir + filled_output_folder + f'{core_name}_RGB_MLfilled.csv', index=False)
            for col in rgb_columns:
                filled_file = mother_dir + filled_output_folder + f'{core_name}_{col}_MLfilled.csv'
                if os.path.exists(filled_file):
                    os.remove(filled_file)

    # Consolidate MST data
    if 'mst' in data_dict and 'mst' in available_columns:
        mst_data = data_dict['mst'].copy()
        updated = False
        
        for log_type, config in available_columns['mst'].items():
            col = config['data_col']
            if col in mst_data.columns:
                col_name = col.split('_')[0] if '_' in col else col
                filled_file = mother_dir + filled_output_folder + f'{core_name}_{col_name}_MLfilled.csv'
                if os.path.exists(filled_file):
                    filled_data = pd.read_csv(filled_file)
                    if col in filled_data.columns:
                        mst_data[col] = filled_data[col]
                        updated = True
                        print(f"Updated {col} column with ML-filled data")
        
        if updated:
            mst_data.to_csv(mother_dir + filled_output_folder + f'{core_name}_MST_MLfilled.csv', index=False)
            for log_type, config in available_columns['mst'].items():
                col = config['data_col']
                col_name = col.split('_')[0] if '_' in col else col
                filled_file = mother_dir + filled_output_folder + f'{core_name}_{col_name}_MLfilled.csv'
                if os.path.exists(filled_file):
                    os.remove(filled_file)

<hr>

### **Define data structure**

#### Define core name and core length

In [None]:
core_name = "M9907-11PC"  # Core name
total_length_cm = 439     # Core length in cm

# core_name = "M9907-12PC"  # Core name
# total_length_cm = 488     # Core length in cm

# core_name = "M9907-14TC"  # Core name
# total_length_cm = 199     # Core length in cm

# core_name = "M9907-22PC"  # Core name
# total_length_cm = 501     # Core length in cm

# core_name = "M9907-22TC"  # Core name
# total_length_cm = 173     # Core length in cm

# core_name = "M9907-23PC"  # Core name
# total_length_cm = 783     # Core length in cm

# core_name = "M9907-25PC"  # Core name
# total_length_cm = 797     # Core length in cm

# core_name = "RR0207-56PC"  # Core name
# total_length_cm = 794     # Core length in cm

# core_name = "M9907-30PC"  # Core name
# total_length_cm = 781     # Core length in cm

# core_name = "M9907-31PC"  # Core name
# total_length_cm = 767     # Core length in cm

#### Define file path, data configuration, and outliner cut-off thresholds for ML data processing

In [None]:
# Data configuration for ML data imputation

data_config = {
    'mother_dir': '/Users/larryslai/Library/CloudStorage/Dropbox/My Documents/University of Texas Austin/(Project) NWP turbidites/Cascadia_core_data/OSU_dataset/',
    'core_name': core_name,
    'core_length': total_length_cm,
    'data_folder': f'_compiled_logs/{core_name}/',
    'clean_output_folder': f'_compiled_logs/{core_name}/ML_clean/',
    'filled_output_folder': f'_compiled_logs/{core_name}/ML_filled/',
    
    # filenames
    'clean_file_paths': {
        'ct': f'{core_name}_CT_clean.csv',
        'rgb': f'{core_name}_RGB_clean.csv',
        'mst': f'{core_name}_MST_clean.csv',
        'hrms': f'{core_name}_hiresMS_clean.csv'
    },
    
    'filled_file_paths': {
        'ct': f'{core_name}_CT_MLfilled.csv',
        'rgb': f'{core_name}_RGB_MLfilled.csv',
        'mst': f'{core_name}_MST_MLfilled.csv',
        'hrms': f'{core_name}_hiresMS_MLfilled.csv'
    },
    
    'ct_image_path': f'_compiled_logs/{core_name}/{core_name}_CT.tiff',
    'rgb_image_path': f'_compiled_logs/{core_name}/{core_name}_RGB.tiff',
    
    'column_configs': {
        'ct': {'data_col': 'CT', 'std_col': 'CT_std', 'depth_col': 'SB_DEPTH_cm'},
        'rgb': {
            'data_cols': ['R', 'G', 'B', 'Lumin'],
            'std_cols': ['R_std', 'G_std', 'B_std', 'Lumin_std'],
            'depth_col': 'SB_DEPTH_cm'
        },
        'mst': {
            'density': {'data_col': 'Den_gm/cc', 'depth_col': 'SB_DEPTH_cm'},
            'pwvel': {'data_col': 'PWVel_m/s', 'depth_col': 'SB_DEPTH_cm'},
            'pwamp': {'data_col': 'PWAmp', 'depth_col': 'SB_DEPTH_cm'},
            'elecres': {'data_col': 'ElecRes_ohmm', 'depth_col': 'SB_DEPTH_cm'},
            'ms': {'data_col': 'MS', 'depth_col': 'SB_DEPTH_cm'}
        },
        'hrms': {'data_col': 'hiresMS', 'depth_col': 'SB_DEPTH_cm'}
    },

    # Thresholds for data cleaning
    'thresholds': {
        'ms': ['>', 180, 1],
        'pwvel': ['>=', 1077, 1], 
        'pwamp': ['>=', 30, 1],
        'den': ['<', 1.14, 1],
        'elecres': ['<', 0, 1],
        'hiresms': ['<=', 19, 1]
    }
}

### Data cleaning

In [None]:
# Run data cleaning function
print("Starting data cleaning...")
preprocess_core_data(data_config, shift_limit_multiplier=3.0)
print("Data cleaning completed.")

# Plot processed logs using new function signature
fig, axes = plot_core_logs(
    data_config,                           # Data configuration containing all parameters
    file_type='clean',                     # Type of data files to plot ('clean' or 'filled')
    title=f'{core_name} Cleaned Logs'      # Title for the plot figure
)
plt.show()

### ML-based data gap filling

In [None]:
process_and_fill_logs(data_config,              # Data configuration containing all parameters
                      ml_method='xgblgbm')      # Available ml_method options: 'rf', 'rftc', 'xgb', 'xgblgbm'
                                                # - 'rf': Random Forest ML
                                                # - 'rftc': Random Forest ML with trend constraints
                                                # - 'xgb': XGBoost ML
                                                # - 'xgblgbm': XGBoost + LightGBM ML         

#### Plot ML-based gap-filled log diagram

In [None]:
# Plot ML-based gap-filled log diagram
fig, axes = plot_core_logs(
    data_config,                                              # Data configuration containing all parameters
    file_type='filled',                                       # Type of data files to plot ('filled' for gap-filled data)
    title=f'{core_name} XGBoost + LightGBM ML-Filled Logs'    # Title for the plot figure
)
plt.show()