In [None]:

import numpy as np
import pandas as pd
from scipy import stats
from multiprocessing import Pool, cpu_count
from functools import partial
from tqdm import tqdm
import os
import math
import time
import matplotlib
matplotlib.use('Agg')  # Non-interactive backend
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

# ---------------- USER SETTINGS ----------------
CSV_FILE = "../nz_1980_2024_mc.csv"  
OUT_BASE = "final-data"                    
MC = 3.0                             # Completeness magnitude
WINDOW_SIZE = 100                     # Number of events per sliding window
STEP = 10                             # Step size for sliding window
N_R = 30                              # Number of points in log(r) for C(r)
BOOTSTRAP_REPS = 0
NUM_WORKERS = max(1, min(4, cpu_count()-1))  # Avoid CPU overload
BATCH_SIZE = 1000                     # Number of windows per batch
MAX_WINDOWS = 5000                    # Maximum number of windows to process
MAX_EVENTS_FOR_CR = 2000              # Max events per C(r) calculation
# ------------------------------------------------

R_EARTH = 6371.0  # Earth radius in km

# ----------------  Functions ----------------
def find_column(df, choices):
    """Find first matching column name from a list of possible choices."""
    for c in choices:
        if c in df.columns:
            return c
    return None

def safe_great_circle_matrix(lat_deg, lon_deg, max_size=2000):
    """Compute great-circle distance matrix safely with memory check."""
    if len(lat_deg) > max_size:
        indices = np.random.choice(len(lat_deg), max_size, replace=False)
        lat_deg = lat_deg[indices]
        lon_deg = lon_deg[indices]
    
    try:
        lat = np.radians(lat_deg)
        lon = np.radians(lon_deg)
        # Use float32 to save memory
        cos_ang = (np.cos(lat[:, None]).astype(np.float32) * np.cos(lat[None, :]).astype(np.float32) +
                   np.sin(lat[:, None]).astype(np.float32) * np.sin(lat[None, :]).astype(np.float32) *
                   np.cos((lon[:, None] - lon[None, :]).astype(np.float32)))
        cos_ang = np.clip(cos_ang, -1.0, 1.0)
        ang = np.arccos(cos_ang)
        return R_EARTH * ang
    except MemoryError:
        print(f"Memory error with {len(lat_deg)} points, subsampling further...")
        indices = np.random.choice(len(lat_deg), max_size//2, replace=False)
        return safe_great_circle_matrix(lat_deg[indices], lon_deg[indices], max_size//2)

def correlation_integral_from_dists(dists, N, r_vals):
    """Compute correlation integral C(r) from distances."""
    if len(dists) == 0 or N == 0:
        return np.zeros_like(r_vals)
    denom = max(1, N*(N-1))
    C = np.zeros_like(r_vals)
    for i, r in enumerate(r_vals):
        try:
            C[i] = 2.0 * np.sum(dists < r) / denom
        except:
            C[i] = 0.0
    return C

def chunks(lst, n):
    """Yield successive n-sized chunks from list."""
    for i in range(0, len(lst), n):
        yield lst[i:i+n]

# ---------------- Fractal Analysis ----------------
def estimate_D2_from_positions_safe(lat, lon, n_r=N_R):
    """Estimate correlation dimension D2 safely from positions."""
    N = len(lat)
    if N < 3:
        return dict(D2=np.nan, D2_err=np.nan, r_min=np.nan, r_max=np.nan)
    try:
        Dmat = safe_great_circle_matrix(lat, lon)
        iu = np.triu_indices(len(Dmat), k=1)
        dists = Dmat[iu]
        dists_pos = dists[dists > 0]
        if dists_pos.size == 0:
            return dict(D2=np.nan, D2_err=np.nan, r_min=np.nan, r_max=np.nan)
        r_min = np.min(dists_pos) * 1.2
        r_max = np.max(dists_pos) / 2.0
        if r_min <= 0 or r_min >= r_max:
            r_min = np.min(dists_pos)
            r_max = np.max(dists_pos)
        if r_min >= r_max:
            return dict(D2=np.nan, D2_err=np.nan, r_min=r_min, r_max=r_max)
        r_vals = np.logspace(math.log10(r_min), math.log10(r_max), n_r)
        C = correlation_integral_from_dists(dists, len(Dmat), r_vals)
        mask = C > 0
        if mask.sum() < 6:
            return dict(D2=np.nan, D2_err=np.nan, r_min=r_min, r_max=r_max)
        logr = np.log10(r_vals[mask])
        logC = np.log10(C[mask])
        i0 = len(logr)//4
        i1 = 3*len(logr)//4
        if i1 - i0 < 3:
            return dict(D2=np.nan, D2_err=np.nan, r_min=r_min, r_max=r_max)
        slope, _, _, _, se = stats.linregress(logr[i0:i1], logC[i0:i1])
        D2, D2_err = slope, se
        fit_rmin, fit_rmax = 10**logr[i0], 10**logr[i1-1]
        return dict(D2=D2, D2_err=D2_err, r_min=fit_rmin, r_max=fit_rmax)
    except Exception as e:
        print(f"Error in D2 estimation: {e}")
        return dict(D2=np.nan, D2_err=np.nan, r_min=np.nan, r_max=np.nan)

# ---------------- B-Value ----------------
def b_value_mle(mags, Mc):
    """Maximum likelihood estimate of b-value."""
    try:
        mags = np.asarray(mags)
        mags = mags[mags >= Mc]
        if mags.size == 0: 
            return np.nan
        Mbar = mags.mean()
        if Mbar <= Mc: 
            return np.nan
        return 0.4342944819 / (Mbar - Mc)
    except:
        return np.nan

# ---------------- Window Processing ----------------
def process_window_safe(idx_start, df, lat_col, lon_col, mag_col, win, Mc):
    """Process one sliding window safely."""
    try:
        end_idx = min(idx_start + win, len(df))
        subset = df.iloc[idx_start:end_idx]
        if len(subset) < 3:
            return None
        lat = subset[lat_col].values
        lon = subset[lon_col].values
        mags = subset[mag_col].values
        if np.any(np.isnan(lat)) or np.any(np.isnan(lon)):
            return None
        Dres = estimate_D2_from_positions_safe(lat, lon, n_r=N_R)
        b = b_value_mle(mags, Mc)
        D_pred = 2.3 - 0.73*b if (not np.isnan(b)) else np.nan

        # Include original data for this window
        subset_data = subset[['latitude','longitude','time']] if 'time' in subset.columns else subset[['latitude','longitude']]

        return {
            "start_idx": int(idx_start),
            "end_idx": int(end_idx),
            "n_events": int(len(subset)),
            "D2": float(Dres["D2"]) if not np.isnan(Dres["D2"]) else np.nan,
            "D2_err": float(Dres["D2_err"]) if not np.isnan(Dres["D2_err"]) else np.nan,
            "r_min_km": float(Dres["r_min"]) if not np.isnan(Dres["r_min"]) else np.nan,
            "r_max_km": float(Dres["r_max"]) if not np.isnan(Dres["r_max"]) else np.nan,
            "b": float(b) if not np.isnan(b) else np.nan,
            "D_pred": float(D_pred) if not np.isnan(D_pred) else np.nan,
            "events_data": subset_data.to_dict(orient='records')  # Keep original lat/lon/time
        }
    except Exception as e:
        print(f"Error processing window {idx_start}: {e}")
        return None

# ---------------- C(r) Analysis for Batch ----------------
def analyze_Cr_for_batch_safe(batch_results, df, lat_col, lon_col, batch_idx):
    """Safe C(r) analysis for one batch."""
    try:
        valid_results = [r for r in batch_results if r is not None]
        if not valid_results:
            return None
        start_indices = [r['start_idx'] for r in valid_results]
        end_indices = [r['end_idx'] for r in valid_results]
        min_idx = min(start_indices)
        max_idx = max(end_indices)
        batch_events = df.iloc[min_idx:max_idx]
        if len(batch_events) > MAX_EVENTS_FOR_CR:
            indices = np.random.choice(len(batch_events), MAX_EVENTS_FOR_CR, replace=False)
            batch_events = batch_events.iloc[indices]
        if len(batch_events) < 10:
            return None
        lat = batch_events[lat_col].values
        lon = batch_events[lon_col].values
        if np.any(np.isnan(lat)) or np.any(np.isnan(lon)):
            return None
        Dmat = safe_great_circle_matrix(lat, lon)
        iu = np.triu_indices(len(Dmat), k=1)
        dists_km = Dmat[iu]
        dists_deg = dists_km / 111.0
        dists_deg = dists_deg[dists_deg > 0]
        if len(dists_deg) < 10:
            return None
        r_min_deg = np.min(dists_deg) * 1.2
        r_max_deg = np.max(dists_deg) / 2.0
        if r_min_deg >= r_max_deg:
            r_min_deg = np.min(dists_deg)
            r_max_deg = np.max(dists_deg)
        r_vals_deg = np.logspace(np.log10(r_min_deg), np.log10(r_max_deg), min(25, N_R))
        N = len(lat)
        C_r = correlation_integral_from_dists(dists_deg, N, r_vals_deg)
        mask = C_r > 0
        if mask.sum() < 6:
            return None
        r_vals_filtered = r_vals_deg[mask]
        C_r_filtered = C_r[mask]
        log_r = np.log10(r_vals_filtered)
        log_C = np.log10(C_r_filtered)
        start_idx = max(1, len(log_r)//4)
        end_idx = min(len(log_r)-1, 3*len(log_r)//4)
        if end_idx - start_idx < 3:
            return None
        slope, intercept, r_value, p_value, std_err = stats.linregress(
            log_r[start_idx:end_idx], log_C[start_idx:end_idx]
        )
        return {
            'batch_idx': batch_idx,
            'n_events': len(batch_events),
            'r_vals_deg': r_vals_filtered,
            'C_r': C_r_filtered,
            'slope': slope,
            'intercept': intercept,
            'r_value': r_value,
            'p_value': p_value,
            'std_err': std_err,
            'fit_range': (10**log_r[start_idx], 10**log_r[end_idx-1])
        }
    except Exception as e:
        print(f"Error in C(r) analysis for batch {batch_idx}: {e}")
        return None

# ---------------- Safe Plotting Functions ----------------
def plot_batch_Cr_safe(Cr_results, save_individual=True):
    """Plot C(r) and slope for all batches."""
    if not Cr_results:
        print("No C(r) results to plot")
        return
    try:
        if save_individual:
            for result in Cr_results[:10]:
                plt.figure(figsize=(8, 6))
                plt.loglog(result['r_vals_deg'], result['C_r'], 'bo-', markersize=4, linewidth=1, alpha=0.7)
                log_r = np.log10(result['r_vals_deg'])
                fit_line = 10**(result['slope'] * log_r + result['intercept'])
                plt.loglog(result['r_vals_deg'], fit_line, 'r--', linewidth=2,
                           label=f'Slope = {result["slope"]:.3f} ± {result["std_err"]:.3f}')
                plt.xlabel('Distance r (degrees)')
                plt.ylabel('Correlation Integral C(r)')
                plt.title(f'Batch {result["batch_idx"]}: C(r) vs r (N={result["n_events"]})')
                plt.grid(True, which="both", ls="--", alpha=0.7)
                plt.legend()
                plt.savefig(f'Cr_plot_batch_{result["batch_idx"]:03d}.png', dpi=150, bbox_inches='tight')
                plt.close()
        # Comparison plot
        plt.figure(figsize=(10, 6))
        colors = plt.cm.viridis(np.linspace(0,1,len(Cr_results)))
        for i, result in enumerate(Cr_results):
            plt.loglog(result['r_vals_deg'], result['C_r'], 'o-', color=colors[i], markersize=2, linewidth=1, alpha=0.7,
                       label=f'Batch {result["batch_idx"]} (slope={result["slope"]:.3f})')
        plt.xlabel('Distance r (degrees)')
        plt.ylabel('C(r)')
        plt.title('C(r) Comparison All Batches')
        plt.grid(True, which="both", ls="--", alpha=0.5)
        if len(Cr_results) <= 10:
            plt.legend(bbox_to_anchor=(1.05,1), loc='upper left')
        plt.tight_layout()
        plt.savefig('all_batches.png', dpi=150, bbox_inches='tight')
        plt.close()
        # Slope comparison
        plt.figure(figsize=(10,6))
        batch_indices = [r['batch_idx'] for r in Cr_results]
        slopes = [r['slope'] for r in Cr_results]
        errors = [r['std_err'] for r in Cr_results]
        plt.errorbar(batch_indices, slopes, yerr=errors, fmt='bo-', capsize=3, markersize=4)
        plt.xlabel('Batch Number')
        plt.ylabel('C(r) Slope')
        plt.title('Slope Comparison Across Batches')
        plt.grid(True, alpha=0.7)
        plt.axhline(np.mean(slopes), color='r', linestyle='--', label=f'Mean slope = {np.mean(slopes):.3f}')
        plt.legend()
        plt.tight_layout()
        plt.savefig('slope_comparison_batches.png', dpi=150, bbox_inches='tight')
        plt.close()
        print("All plots saved successfully")
    except Exception as e:
        print(f"Error in plotting: {e}")

# ---------------- Main ----------------
def main():
    try:
        print("Reading CSV...")
        df = pd.read_csv(CSV_FILE)
        print(f"Loaded {len(df)} rows")
        
        # Detect columns
        lat_col = find_column(df, ["latitude","lat","LAT","Latitude"])
        lon_col = find_column(df, ["longitude","lon","LON","Longitude","LONGITUDE"])
        mag_col = find_column(df, ["magnitude","mag","MAG","magnitude"])
        time_col = find_column(df, ["time","date","datetime","origin_time"])
        if lat_col is None or lon_col is None or mag_col is None:
            raise RuntimeError("CSV missing latitude/longitude/magnitude columns.")
        print(f"Using columns: {lat_col}, {lon_col}, {mag_col}, time={time_col}")

        # Optional depth filter
        depth_col = find_column(df, ["depth","DEPTH","depth_km"])
        if depth_col:
            initial_len = len(df)
            df = df[df[depth_col]<=60.0].reset_index(drop=True)
            print(f"Filtered by depth: {initial_len} -> {len(df)} events")
        
        # Sort by time
        if time_col:
            df = df.sort_values(time_col).reset_index(drop=True)
            print("Sorted by time")

        # Filter by magnitude
        initial_len = len(df)
        df = df[df[mag_col] >= MC].reset_index(drop=True)
        n_total = len(df)
        print(f"Filtered by Mc={MC}: {initial_len} -> {n_total} events")
        if n_total < 10:
            print("Too few events after filtering")
            return

        # Set sliding window parameters
        win = min(WINDOW_SIZE, max(10, n_total//10)) if n_total >= WINDOW_SIZE else max(10, n_total//3)
        step = min(STEP, max(1, win//5)) if n_total >= WINDOW_SIZE else max(1, win//5)
        start_indices = list(range(0, n_total - win + 1, step))
        if MAX_WINDOWS:
            start_indices = start_indices[:MAX_WINDOWS]
        print(f"Window size {win}, step {step}, total windows {len(start_indices)}")
        if len(start_indices)==0:
            print("No windows to process")
            return

        start_time = time.time()
        worker = partial(process_window_safe, df=df, lat_col=lat_col, lon_col=lon_col, mag_col=mag_col, win=win, Mc=MC)

        all_batches = list(chunks(start_indices, BATCH_SIZE))
        batch_idx = 0
        Cr_results = []

        for batch in all_batches:
            batch_idx += 1
            print(f"\nProcessing batch {batch_idx}/{len(all_batches)} (size={len(batch)})")
            batch_results = []
            try:
                with Pool(processes=NUM_WORKERS) as pool:
                    for res in tqdm(pool.imap_unordered(worker, batch), total=len(batch)):
                        if res is not None:
                            batch_results.append(res)
            except Exception as e:
                print(f"Error in multiprocessing for batch {batch_idx}: {e}")
                continue

            if not batch_results:
                print(f"No valid results for batch {batch_idx}")
                continue

            # Save batch results including actual lat/lon/time
            try:
                # Flatten events_data for CSV
                expanded_rows = []
                for r in batch_results:
                    for event in r['events_data']:
                        row = r.copy()
                        row.pop('events_data')
                        row.update(event)
                        expanded_rows.append(row)
                batch_df = pd.DataFrame(expanded_rows)
                batch_filename = f"{OUT_BASE}_{batch_idx:03d}.csv"
                batch_df.to_csv(batch_filename, index=False)
                print(f"Saved {len(batch_df)} rows to {batch_filename}")
            except Exception as e:
                print(f"Error saving batch {batch_idx}: {e}")
                continue

            # C(r) analysis
            print(f"Analyzing C(r) for batch {batch_idx}...")
            Cr_result = analyze_Cr_for_batch_safe(batch_results, df, lat_col, lon_col, batch_idx)
            if Cr_result:
                Cr_results.append(Cr_result)
                print(f"Batch {batch_idx}: C(r) slope = {Cr_result['slope']:.4f} ± {Cr_result['std_err']:.4f}")
            else:
                print(f"Batch {batch_idx}: Could not calculate C(r) slope")

        elapsed = time.time() - start_time
        print(f"\nCompleted processing. Elapsed time: {elapsed:.2f} seconds")

        if Cr_results:
            print(f"\nCreating C(r) plots for {len(Cr_results)} batches...")
            plot_batch_Cr_safe(Cr_results)
            # Save summary CSV
            try:
                Cr_summary = pd.DataFrame([{
                    'batch_idx': r['batch_idx'],
                    'n_events': r['n_events'],
                    'slope': r['slope'],
                    'intercept': r['intercept'],
                    'r_value': r['r_value'],
                    'p_value': r['p_value'],
                    'std_err': r['std_err'],
                    'fit_r_min_deg': r['fit_range'][0],
                    'fit_r_max_deg': r['fit_range'][1]
                } for r in Cr_results])
                Cr_summary.to_csv('Cr_slopes_summary.csv', index=False)
                print("Saved C(r) slope analysis to 'Cr_slopes_summary.csv'")
                slopes = Cr_summary['slope'].values
                print(f"\nC(r) Slope Statistics:\nNumber of batches: {len(slopes)}\nMean: {np.mean(slopes):.4f}\nStd:  {np.std(slopes):.4f}\nMin:  {np.min(slopes):.4f}\nMax:  {np.max(slopes):.4f}")
            except Exception as e:
                print(f"Error saving summary: {e}")
        else:
            print("No valid C(r) results obtained.")

    except Exception as e:
        print(f"Fatal error in main: {e}")
        import traceback
        traceback.print_exc()

# ---------------- Entry ----------------
if __name__ == "__main__":
    main()


In [1]:
# script.py
import numpy as np
import pandas as pd
from scipy import stats
from multiprocessing import Pool, cpu_count
from functools import partial
import math
import warnings
warnings.filterwarnings('ignore')

# ---------------- USER SETTINGS ----------------
CSV_FILE = "data.csv"        # Input CSV (assumes columns: date,time,latitude,longitude,depth,magnitude)
OUT_BASE = "final-data"      # Output base name
MC = 3.0                     # Completeness magnitude (filter)
WINDOW_SIZE = 100            # events per sliding window (like Hirata's N=100)
STEP = 10                    # sliding step
BATCH_SIZE = 1000            # how many windows per output CSV
MAX_WINDOWS = 5000           # limit number of windows
NUM_WORKERS = max(1, min(4, cpu_count()-1))
N_R = 30                     # points in r for C(r) (logspace)
MAX_DIST_SUBSAMPLE = 2000    # safe subsample if too many events for distance matrix
R_EARTH = 6371.0             # km
# ---------------- end settings ----------------

def b_value_mle(mags, Mc):
    mags = np.asarray(mags)
    mags = mags[mags >= Mc]
    if mags.size == 0:
        return np.nan
    Mbar = mags.mean()
    if Mbar <= Mc:
        return np.nan
    return 0.4342944819 / (Mbar - Mc)  # ln10 / (Mbar - Mc)

def safe_great_circle_matrix(lat_deg, lon_deg, max_size=MAX_DIST_SUBSAMPLE):
    # Subsample if necessary to avoid huge memory
    n = len(lat_deg)
    if n > max_size:
        idx = np.random.choice(n, max_size, replace=False)
        lat_deg = lat_deg[idx]
        lon_deg = lon_deg[idx]
        n = max_size
    lat = np.radians(lat_deg)
    lon = np.radians(lon_deg)
    # vectorized great-circle using cos law
    cos_ang = (np.cos(lat)[:,None] * np.cos(lat)[None,:] +
               np.sin(lat)[:,None] * np.sin(lat)[None,:] * np.cos((lon[:,None] - lon[None,:])))
    cos_ang = np.clip(cos_ang, -1.0, 1.0)
    ang = np.arccos(cos_ang)
    return R_EARTH * ang  # distances in km

def correlation_integral_from_dists(dists, N, r_vals):
    # dists: 1D array of pairwise distances (same units as r_vals)
    if len(dists)==0 or N < 2:
        return np.zeros_like(r_vals)
    denom = max(1, N*(N-1))
    C = np.zeros_like(r_vals, dtype=float)
    for i, r in enumerate(r_vals):
        C[i] = 2.0 * np.sum(dists < r) / denom
    return C

def estimate_D2_from_positions(lat, lon, n_r=N_R):
    """
    Compute D2 (correlation dimension) from lat/lon arrays following Hirata (1989).
    Returns dict with D2, D2_err, r_min (deg), r_max (deg) where fit was done.
    """
    try:
        N = len(lat)
        if N < 3:
            return dict(D2=np.nan, D2_err=np.nan, r_min=np.nan, r_max=np.nan)

        # compute pairwise distances (km), safe subsampling inside function
        Dmat_km = safe_great_circle_matrix(lat, lon)
        iu = np.triu_indices(len(Dmat_km), k=1)
        dists_km = Dmat_km[iu]
        # convert to degrees roughly (1 deg ~ 111 km)
        dists_deg = dists_km / 111.0
        dists_deg = dists_deg[dists_deg > 0]
        if dists_deg.size < 10:
            return dict(D2=np.nan, D2_err=np.nan, r_min=np.nan, r_max=np.nan)

        # r_min and r_max as Hirata: r_min = min(dists)*1.2, r_max = max(dists)/2
        r_min = np.min(dists_deg) * 1.2
        r_max = np.max(dists_deg) / 2.0
        if r_min <= 0 or r_min >= r_max:
            r_min = np.min(dists_deg)
            r_max = np.max(dists_deg)
        if r_min >= r_max:
            return dict(D2=np.nan, D2_err=np.nan, r_min=r_min, r_max=r_max)

        r_vals = np.logspace(math.log10(r_min), math.log10(r_max), n_r)
        C = correlation_integral_from_dists(dists_deg, len(Dmat_km), r_vals)
        mask = C > 0
        if mask.sum() < 6:
            return dict(D2=np.nan, D2_err=np.nan, r_min=r_min, r_max=r_max)

        logr = np.log10(r_vals[mask])
        logC = np.log10(C[mask])

        # use middle 50% of points (i0..i1) as Hirata
        L = len(logr)
        i0 = L // 4
        i1 = 3 * L // 4
        if i1 - i0 < 3:
            return dict(D2=np.nan, D2_err=np.nan, r_min=r_min, r_max=r_max)

        slope, intercept, r_value, p_value, std_err = stats.linregress(logr[i0:i1], logC[i0:i1])
        fit_rmin = 10**(logr[i0])
        fit_rmax = 10**(logr[i1-1])
        return dict(D2=float(slope), D2_err=float(std_err), r_min=fit_rmin, r_max=fit_rmax)
    except Exception as e:
        # on any error, return NaNs
        return dict(D2=np.nan, D2_err=np.nan, r_min=np.nan, r_max=np.nan)

def process_window(idx_start, df, win, Mc):
    end_idx = min(idx_start + win, len(df))
    subset = df.iloc[idx_start:end_idx]
    if len(subset) < 3:
        return None

    mags = subset["magnitude"].values
    b = b_value_mle(mags, Mc)

    lat = subset["latitude"].values
    lon = subset["longitude"].values
    Dres = estimate_D2_from_positions(lat, lon, n_r=N_R)

    result = {
        "start_idx": int(idx_start),
        "end_idx": int(end_idx),
        "n_events": int(len(subset)),
        "b_value": float(b) if not np.isnan(b) else np.nan,
        "D2": float(Dres["D2"]) if not np.isnan(Dres["D2"]) else np.nan,
        "D2_err": float(Dres["D2_err"]) if not np.isnan(Dres["D2_err"]) else np.nan,
        "r_fit_min_deg": float(Dres["r_min"]) if not np.isnan(Dres["r_min"]) else np.nan,
        "r_fit_max_deg": float(Dres["r_max"]) if not np.isnan(Dres["r_max"]) else np.nan,
        "events_data": subset[["date","time","latitude","longitude","magnitude"]].to_dict(orient='records')
    }
    return result

def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i+n]

def main():
    df = pd.read_csv(CSV_FILE)
    print(f"Loaded {len(df)} rows from {CSV_FILE}")

    # filter by magnitude (Mc)
    df = df[df["magnitude"] >= MC].reset_index(drop=True)
    n_total = len(df)
    print(f"Events after Mc={MC} filter: {n_total}")
    if n_total < 10:
        print("Too few events after filtering; exiting.")
        return

    # sliding windows
    start_indices = list(range(0, n_total - WINDOW_SIZE + 1, STEP))
    start_indices = start_indices[:MAX_WINDOWS]
    print(f"Total windows to process: {len(start_indices)} (window size {WINDOW_SIZE}, step {STEP})")
    if len(start_indices) == 0:
        print("No windows to process; exit.")
        return

    worker = partial(process_window, df=df, win=WINDOW_SIZE, Mc=MC)
    all_batches = list(chunks(start_indices, BATCH_SIZE))

    batch_idx = 0
    Cr_results = []
    for batch in all_batches:
        batch_idx += 1
        print(f"\nProcessing batch {batch_idx}/{len(all_batches)} (windows in batch: {len(batch)})")
        batch_results = []
        try:
            with Pool(processes=NUM_WORKERS) as pool:
                for res in pool.imap_unordered(worker, batch):
                    if res is not None:
                        batch_results.append(res)
        except Exception as e:
            print(f"Multiprocessing error: {e}")
            continue

        if not batch_results:
            print("No valid windows in this batch.")
            continue

        # Flatten and save batch CSV
        expanded_rows = []
        for r in batch_results:
            base = {
                "start_idx": r["start_idx"],
                "end_idx": r["end_idx"],
                "n_events": r["n_events"],
                "b_value": r["b_value"],
                "D2": r["D2"],
                "D2_err": r["D2_err"],
                "r_fit_min_deg": r["r_fit_min_deg"],
                "r_fit_max_deg": r["r_fit_max_deg"]
            }
            for ev in r["events_data"]:
                row = base.copy()
                row.update(ev)  # adds date,time,latitude,longitude,magnitude
                expanded_rows.append(row)

        batch_df = pd.DataFrame(expanded_rows)
        out_file = f"{OUT_BASE}_{batch_idx:03d}.csv"
        batch_df.to_csv(out_file, index=False)
        print(f"Saved batch {batch_idx} -> {out_file} ({len(batch_df)} rows / {len(batch_results)} windows)")

        # collect C(r)/D2 summary for later overall summary
        for r in batch_results:
            Cr_results.append({
                "batch_idx": batch_idx,
                "start_idx": r["start_idx"],
                "end_idx": r["end_idx"],
                "n_events": r["n_events"],
                "b_value": r["b_value"],
                "D2": r["D2"],
                "D2_err": r["D2_err"],
                "r_fit_min_deg": r["r_fit_min_deg"],
                "r_fit_max_deg": r["r_fit_max_deg"]
            })

    # save summary if we have any results
    if Cr_results:
        summary_df = pd.DataFrame(Cr_results)
        summary_file = "Cr_slopes_summary.csv"
        summary_df.to_csv(summary_file, index=False)
        print(f"\nSaved summary to {summary_file} (rows: {len(summary_df)})")
        # print basic stats
        slopes = summary_df['D2'].dropna().values
        if slopes.size > 0:
            print(f"D2 stats -> count: {len(slopes)}, mean: {np.mean(slopes):.4f}, std: {np.std(slopes):.4f}, min: {np.min(slopes):.4f}, max: {np.max(slopes):.4f}")
    else:
        print("No valid D2/b-value results computed.")

if __name__ == "__main__":
    main()


Loaded 396271 rows from data.csv
Events after Mc=3.0 filter: 127061
Total windows to process: 5000 (window size 100, step 10)

Processing batch 1/5 (windows in batch: 1000)
Saved batch 1 -> final-data_001.csv (100000 rows / 1000 windows)

Processing batch 2/5 (windows in batch: 1000)
Saved batch 2 -> final-data_002.csv (100000 rows / 1000 windows)

Processing batch 3/5 (windows in batch: 1000)
Saved batch 3 -> final-data_003.csv (100000 rows / 1000 windows)

Processing batch 4/5 (windows in batch: 1000)
Saved batch 4 -> final-data_004.csv (100000 rows / 1000 windows)

Processing batch 5/5 (windows in batch: 1000)
Saved batch 5 -> final-data_005.csv (100000 rows / 1000 windows)

Saved summary to Cr_slopes_summary.csv (rows: 5000)
D2 stats -> count: 5000, mean: 1.2444, std: 0.2919, min: 0.2732, max: 1.9951
