In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import cartopy.crs as ccrs
import cartopy.feature as cfeature
# from numpy.doc.constants import lines


In [None]:
PATH_ISC = "/home/rsafran/PycharmProjects/toolbox/data/ISC/ISC_Ocean_indien_2018.csv"
cnames = ['EVENTID','TYPE','AUTHOR','DATE','TIME','LAT','LON' ,'DEPTH','DEPFIX','DEPQUAL','AUTHOR_MG','TYPE_MG','MAG',"14",'15','16','17','18','19','20']
isc = pd.read_csv(PATH_ISC, comment='#',sep=',',header=None, names=cnames)
isc['datetime'] = pd.to_datetime(isc['DATE']+' '+ isc['TIME'])
isc.drop(['TYPE','AUTHOR','DATE','TIME',"14",'15','16','17','18','19','20'], axis=1, inplace=True)

In [None]:
catalog_path = '/media/rsafran/CORSAIR/Association/validated/refined_s_-60-5,35-120,350,0.8,0.6_final_filtered.npy'
catalogue = np.load(catalog_path,allow_pickle=True).item()



In [None]:
cat = pd.DataFrame.from_dict(catalogue['filtered_events'], orient='columns')
cat['lat']=cat['source_point'].apply(lambda x: x[1])
cat['lon']=cat['source_point'].apply(lambda x: x[0])
cat['uid']= pd.to_datetime( cat.uid.apply(lambda x: x.split('_')[0]))
cat.sort_values("uid", inplace=True, ignore_index=True)

In [None]:
ax1 = plt.axes(projection=ccrs.PlateCarree())
# These features will be drawn on top if the image is behind.
ax1.add_feature(cfeature.LAND, facecolor='lightgray', zorder=2)
ax1.add_feature(cfeature.COASTLINE, edgecolor='black', linewidth=1, zorder=3)
ax1.add_feature(cfeature.BORDERS, linestyle=':', edgecolor='black', zorder=3)
cat.plot('lat','lon',ax=plt.gca(),style='.' )
isc.plot('LON','LAT', ax=plt.gca(),style='.')


In [None]:
plt.subplot(2,1,1)
cat.uid.hist(bins=200)
isc.datetime.hist(bins=200)
plt.subplot(2,1,2)
cat.uid.hist(bins=100)

In [None]:
join_cat = pd.merge_asof(cat, isc, left_on='uid',right_on="datetime",tolerance=pd.Timedelta("10min"))

In [None]:
join_cat["time_error"] = join_cat["uid"]-join_cat["datetime"]

In [None]:
join_cat.dropna(inplace=True)
ax1 = plt.axes(projection=ccrs.PlateCarree())
# These features will be drawn on top if the image is behind.
ax1.add_feature(cfeature.LAND, facecolor='lightgray', zorder=2)
ax1.add_feature(cfeature.COASTLINE, edgecolor='black', linewidth=1, zorder=3)
ax1.add_feature(cfeature.BORDERS, linestyle=':', edgecolor='black', zorder=3)
join_cat.plot('lat','lon',ax=plt.gca(),style='.' )
join_cat.plot('LON','LAT', ax=plt.gca(),style='.')

In [None]:
np.mean(join_cat.time_error)

## Arrivals

In [None]:
from utils.data_reading.sound_data.station import StationsCatalog
CATALOG_PATH = "/media/rsafran/CORSAIR/OHASISBIO/recensement_stations_OHASISBIO_RS.csv"

DETECTIONS_DIR = "/home/rsafran/Bureau/tissnet/2018"
ASSOCIATION_OUTPUT_DIR = "../../../data/detection/association"
STATIONS = StationsCatalog(CATALOG_PATH).filter_out_undated().filter_out_unlocated()
STATIONS = STATIONS.by_dataset('2018')
for st in STATIONS :
    print(st.name)
    print(st.get_pos(include_depth=False))

In [None]:
from utils.physics.sound_model import ISAS_grid as isg
from pyproj import Geod
from multiprocessing import Manager

ISAS_PATH = "/media/rsafran/CORSAIR/ISAS/86442/field/2018"

GRID_LAT_BOUNDS = [-60, 5]
GRID_LON_BOUNDS = [35, 120]
DEPTH = 1250               # meters
SOUND_SPEED = 1480
PICKING_ERROR_BASE = 2
geod = Geod(ellps="WGS84")


def get_isas_data(month):
    """Load ISAS data if not already loaded in this process"""
    global process_local_isas_cache
    if month not in process_local_isas_cache:
        process_local_isas_cache[month] = isg.load_ISAS_TS(
            ISAS_PATH, month, GRID_LAT_BOUNDS, GRID_LON_BOUNDS, fast=False
        )
    return process_local_isas_cache[month]

def compute_travel_time(lat, lon, station_lat, station_lon, month, velocity_dict=None):
    """Travel time calculation using ISAS grid, loading data as needed"""
    ds = velocity_dict[month]
    # Error modeling with multiple components
    picking_err = PICKING_ERROR_BASE  # Base error in picking arrival times
    try:
        tt, total_err, dist_m= isg.compute_travel_time(
            lat, lon, station_lat, station_lon,
            DEPTH, ds,
            resolution=30,
            verbose=False,
            interpolate_missing=True
        )
    except ValueError:
        print(f"Error in ISAS calculation for lat={lat}, lon={lon}, station_lat={station_lat}, station_lon={station_lon}")
        _, _, dist_m = geod.inv(lon, lat, station_lon, station_lat)
        tt = dist_m / SOUND_SPEED
        total_err = tt * 0.1
    total_err = np.sqrt(picking_err**2 + total_err**2)

    return tt, total_err, dist_m

# Global cache for ISAS data per process
process_local_isas_cache = {}
# Create a manager for sharing data between processes
manager = Manager()
shared_velocity_grid = manager.dict()
# Load ISAS data into the shared dictionary
for m in range(1, 13):
    print(f"Loading ISAS data for month {m}...")
    shared_velocity_grid[m] = get_isas_data(m)


In [None]:
for st in STATIONS :
    print(st.name)
    name = st.name
    lat, lon = st.get_pos(include_depth=False)
    isc[f"travel_time_{name}"]=isc.apply(lambda x: compute_travel_time(x["LAT"], x["LON"], lat, lon, x.datetime.month, velocity_dict=shared_velocity_grid), axis=1)

In [None]:
isc.columns

## Detections

In [None]:
from utils.detection.association import load_detections
import glob2
import datetime
from pathlib import Path

# Detections loading parameters
RELOAD_DETECTIONS = True # if False, load files called "detections.npy" and "detections_merged.npy" containing everything instead of the raw detection output. Leave at True by default
MIN_P_TISSNET_PRIMARY = 0.4  # min probability of browsed detections
MIN_P_TISSNET_SECONDARY = 0.1  # min probability of detections that can be associated with the browsed one
MERGE_DELTA_S = 10 # threshold below which we consider two events should be merged
MERGE_DELTA = datetime.timedelta(seconds=MERGE_DELTA_S)

if RELOAD_DETECTIONS:
    det_files = [f for f in glob2.glob(DETECTIONS_DIR + "/*") if Path(f).is_file()]
    DETECTIONS, DETECTIONS_MERGED = load_detections(det_files, STATIONS, DETECTIONS_DIR, MIN_P_TISSNET_PRIMARY, MIN_P_TISSNET_SECONDARY, MERGE_DELTA)
else:
    DETECTIONS = np.load(f"{DETECTIONS_DIR}/cache/detections.npy", allow_pickle=True).item()
    # DETECTIONS_MERGED = np.load(f"{DETECTIONS_DIR}/cache/detections_merged.npy", allow_pickle=True)
    DETECTIONS_MERGED = np.load(f"{DETECTIONS_DIR}/cache/refined_detections_merged.npy", allow_pickle=True)

In [None]:
from datetime import timedelta


def extract_times(detections, station_name):
    # Extract the station mapping
    station_mapping = {station_obj.name: station_obj for station_obj in detections.keys()}

    if station_name not in station_mapping:
        print(f"Station {station_name} not found. Available: {list(station_mapping.keys())}")
        return None

    station_obj = station_mapping[station_name]
    # Extract the detection times
    times = [row[0] for row in detections[station_obj]]
    return times


def check_detection(detections_df, catalogue_arrival_times, time_tolerance_seconds=15):
    """
    Check each detection time against the catalogue's arrival times and return a boolean column.

    Args:
        detections_df (pd.DataFrame): The dataframe with detection times.
        catalogue_arrival_times (list): List of datetime objects representing theoretical arrival times.
        time_tolerance_seconds (int): Time tolerance (in seconds) for matching arrival times.

    Returns:
        pd.DataFrame: The updated dataframe with the 'is_teleseismic' column.
    """
    # Convert the catalogue arrival times to a pandas Series
    catalogue_series = pd.to_datetime(catalogue_arrival_times)

    # Expand the catalogue arrival times into a DataFrame column for comparison
    detections_df['is_in_isc'] = detections_df['detection_time'].apply(
        lambda detection_time: any(
            abs(detection_time - catalogue_series) <= timedelta(seconds=time_tolerance_seconds)
        )
    )
    return detections_df

In [None]:
time_tol= [5,10,15,20,25,30,50]
res = []
for time_tolerance_seconds in time_tol:
    for st in STATIONS:
        name = st.name
        print(name)

        isc[f"arrival_time_{name}"] = isc.apply(lambda x: x.datetime + pd.Timedelta(seconds=x[f'travel_time_{name}'][0]), axis=1)
        catalogue = isc[[f"arrival_time_{name}","EVENTID"]]
        detection_times = extract_times(DETECTIONS, name)
        detections_df = pd.DataFrame(detection_times, columns=['detection_time'])
        detections_df['detection_time'] = pd.to_datetime(detections_df['detection_time'])  # Ensure correct datetime format

        updated_df = check_detection(detections_df, catalogue[f'arrival_time_{name}'], time_tolerance_seconds)
        # updated_df['is_in_isc'] = updated_df['is_in_isc']
        print(updated_df['is_in_isc'].value_counts())
        res.append({"tol":time_tolerance_seconds, name : updated_df['is_in_isc'].value_counts()[1] })

In [None]:
stations = ['ELAN', 'MADE', 'MADW', 'NEAMS', 'RTJ', 'SSEIR', 'SSWIR', 'SWAMSbot', 'WKER2']


fig, axs = plt.subplots(nrows=3, ncols=3, figsize=(9, 6),
                        subplot_kw={'xticks': [5,10,15,20,25,30,50]})
for j in range(len(res)):
    for ax, interp_method in zip(axs.flat, stations):
        name =['ELAN', 'MADE', 'MADW', 'NEAMS', 'RTJ', 'SSEIR', 'SSWIR', 'SWAMSbot', 'WKER2'][j%9]
        ax.plot(res[j]['tol'], res[j][name],'o')
        ax.set_title(str(interp_method))


plt.tight_layout()
plt.show()

## Arrival times catalogue

In [None]:
isc.rename({"EVENTID":"id","LAT":'latitude',"LON":'longitude',"DEPTH":'depth',
            'datetime': 'time', 'MAG':'mag','TYPE_MG':"magType"}, axis = 'columns', inplace=True)

In [None]:
def isc_to_csv_cat(isc, name):
    catalogue = isc[['time', 'latitude', 'longitude', 'depth', 'mag', 'magType']]
    catalogue['phase'] = "T"
    catalogue['travel_time'] = isc[f'travel_time_{name}'][0][0]
    catalogue['arrival_time']= isc[f'arrival_time_{name}']
    catalogue['distance_deg']= isc[f'travel_time_{name}'][0][-1]
    return catalogue

In [None]:
catalogue = isc_to_csv_cat(isc, 'ELAN')

In [None]:
from scipy import signal


# Function to downsample audio data
def downsample_audio(data, original_fs, target_fs):
    """Downsample audio data to the target frequency"""
    # print(f"Downsampling from {original_fs}Hz to {target_fs}Hz")
    # Calculate downsampling factor
    factor = int(original_fs / target_fs)
    # Apply anti-aliasing filter before downsampling
    b, a = signal.butter(5, target_fs/2, fs=original_fs, btype='low')
    filtered_data = signal.filtfilt(b, a, data)
    # Downsample by taking every 'factor' sample
    downsampled_data = filtered_data[::factor]
    return downsampled_data, target_fs

# Function to apply dehazing (using spectral subtraction)
def dehaze_audio(data, fs, frame_size=1024, overlap=0.8):
    """Apply spectral subtraction for dehazing"""
    # print("Applying dehazing using spectral subtraction")
    hop_size = int(frame_size * (1 - overlap))
    # Estimate noise profile from first few frames
    num_noise_frames = 5
    noise_estimate = np.zeros(frame_size // 2 + 1)

    frames = []
    for i in range(0, len(data) - frame_size, hop_size):
        frame = data[i:i+frame_size]
        if len(frame) < frame_size:
            frame = np.pad(frame, (0, frame_size - len(frame)))
        frames.append(frame)

    # Estimate noise from first few frames
    for i in range(min(num_noise_frames, len(frames))):
        noise_frame = frames[i]
        noise_spectrum = np.abs(np.fft.rfft(noise_frame * np.hanning(frame_size)))
        noise_estimate += noise_spectrum / num_noise_frames

    # Apply spectral subtraction
    result = np.zeros(len(data))
    window = np.hanning(frame_size)

    for i, frame in enumerate(frames):
        windowed_frame = frame * window
        spectrum = np.fft.rfft(windowed_frame)
        magnitude = np.abs(spectrum)
        phase = np.angle(spectrum)

        # Subtract noise and ensure no negative values
        magnitude = np.maximum(magnitude - noise_estimate * 1.5, 0.01 * magnitude)

        # Reconstruct frame
        enhanced_spectrum = magnitude * np.exp(1j * phase)
        enhanced_frame = np.fft.irfft(enhanced_spectrum)

        # Overlap-add
        start = i * hop_size
        end = start + frame_size
        result[start:end] += enhanced_frame

    # Normalize
    result = result / np.max(np.abs(result))
    return result

# Function to apply Butterworth bandpass filter
def apply_butter_bandpass(data, fs, lowcut, highcut, order=5):
    """Apply Butterworth bandpass filter"""
    # print(f"Applying bandpass filter: {lowcut}-{highcut}Hz, order {order}")
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = signal.butter(order, [low, high], btype='band')
    filtered_data = signal.filtfilt(b, a, data)
    return filtered_data

# Function to create spectrogram
def create_spectrogram(data, fs, nperseg=256, noverlap=128, cmap='viridis'):
    """Create and return spectrogram of the data"""
    # print("Creating spectrogram")
    f, t, Sxx = signal.spectrogram(data, fs=fs, nperseg=nperseg, noverlap=noverlap)
    return f, t, Sxx

# Function to detect potential seismic events using energy
def detect_seismic_events(data, fs, window_size=5.0, threshold_factor=5.0):
    """Detect potential seismic events based on energy threshold"""
    print("Detecting potential seismic events")
    window_samples = int(window_size * fs)
    energy = []

    # Calculate energy in sliding windows
    for i in range(0, len(data) - window_samples, window_samples // 2):
        window = data[i:i+window_samples]
        window_energy = np.sum(window**2) / len(window)
        energy.append(window_energy)

    # Set threshold as a factor of the median energy
    energy = np.array(energy)
    threshold = np.median(energy) * threshold_factor

    # Find events that exceed threshold
    events = []
    in_event = False
    event_start = 0

    for i, e in enumerate(energy):
        if e > threshold and not in_event:
            in_event = True
            event_start = i * (window_samples // 2) / fs
        elif e <= threshold and in_event:
            in_event = False
            event_end = i * (window_samples // 2) / fs
            events.append((event_start, event_end))

    # Handle if we're still in an event at the end
    if in_event:
        event_end = len(data) / fs
        events.append((event_start, event_end))

    return events, energy, threshold

# Main processing function
def process_underwater_recording(data,df,date_time,original_fs=240., target_fs=50, low_pass=1.5, high_pass=0.6):
    """Process underwater recording to visualize seismic events"""
    # Load data

    print(f"Original sampling rate: {original_fs}Hz")
    print(f"Original data length: {len(data)} samples ({len(data)/original_fs:.2f} seconds)")

    # Downsample to 50Hz
    downsampled_data, new_fs = downsample_audio(data, original_fs, target_fs)
    print(f"Downsampled data length: {len(downsampled_data)} samples ({len(downsampled_data)/new_fs:.2f} seconds)")

    # Apply dehazing
    dehazed_data = dehaze_audio(downsampled_data, new_fs)

    # Apply Butterworth bandpass filter
    filtered_data = apply_butter_bandpass(dehazed_data, new_fs, high_pass, low_pass)

    # Detect potential seismic events
    events, energy, threshold = detect_seismic_events(filtered_data, new_fs)
    time = timedelta(minutes=10)
    for event_start, _ in events:
        if np.abs(time - event_start) < 10 :
            catalogue.loc[catalogue['time'] == date_time,"candidate"] = True
            break
    if True :
        # Create a separate figure for event detection results
        plt.figure(figsize=(10, 4))
        window_size = 5.0  # seconds
        for arrival, phase in zip(df['arrival_time'],df['phase']):
            print(arrival, phase)
            time = arrival - df['arrival_time'].iloc[0] + timedelta(minutes=10)
            time = time.total_seconds()
            plt.axvline(time, color='g', linestyle='--', label=phase)
        time_axis = np.arange(len(energy)) * (window_size/2)
        plt.plot(time_axis, energy)
        plt.axhline(threshold, color='r', linestyle='--', label='Threshold')
        plt.title('Signal Energy for Event Detection')
        plt.xlabel('Time (s)')
        plt.ylabel('Energy')
        plt.legend()
        plt.show()
    return filtered_data, events

In [None]:
from utils.data_reading.sound_data.sound_file_manager import DatFilesManager
# Configuration
from tqdm.notebook import tqdm
# PATH = f"F:/OHASISBIO/2018/{name}"
name = "ELAN"
PATH = f"/media/rsafran/CORSAIR/OHASISBIO/2018/{name}"
ORIGINAL_FS = 240.0
TARGET_FS = 60
LOW_PASS = 1.5
HIGH_PASS = 0.6
TIME_TOLERANCE = 10  # seconds

# Initialize candidate column
catalogue['candidate'] = False
catalogue['file_number'] = -1
manager = DatFilesManager(PATH,kwargs='raw')

def find_candidates(manager, catalogue):
    ORIGINAL_FS = 240.0
    TARGET_FS = 60
    LOW_PASS = 1.5
    HIGH_PASS = 0.6
    TIME_TOLERANCE = 10
    # Process each unique event in the catalogue
    for date_time in tqdm(catalogue['time'].unique()):

        event_df = catalogue[catalogue['time'] == date_time]
        first_arrival = event_df['arrival_time'].min()

        # Define time window: 10 minutes before and after first arrival
        start = (first_arrival - timedelta(minutes=10)).replace(tzinfo=None)
        end = (first_arrival + timedelta(minutes=10)).replace(tzinfo=None)

        try:
            # Load and preprocess seismic data
            data = manager.get_segment(start, end)
            file_number = manager.find_file_name(start)
            downsampled_data, new_fs = downsample_audio(data, ORIGINAL_FS, TARGET_FS)
            dehazed_data = dehaze_audio(downsampled_data, new_fs)
            filtered_data = apply_butter_bandpass(dehazed_data, new_fs, HIGH_PASS, LOW_PASS)

            # Detect seismic events (adjust threshold as needed)
            events, energy, threshold = detect_seismic_events(filtered_data, new_fs, threshold_factor=8.0)

            # # Check if any detected event aligns with expected arrival time
            expected_time_sec = 600  # 10 minutes into the segment
            # for event_start, _ in events:
            #     if abs(event_start - expected_time_sec) < TIME_TOLERANCE:
            #         catalogue.loc[catalogue['time'] == datetime, 'candidate'] = True
            #         break  # Stop checking once a match is found
            event_starts = np.array([e[0] for e in events])
            if True :#np.any(np.abs(event_starts - 600) < TIME_TOLERANCE):
                catalogue.loc[catalogue['time'] == date_time, 'candidate'] = True
                catalogue.loc[catalogue['time'] == date_time, 'file_number'] = file_number

            # Optional: Plot detection results for debugging
            if False:  # Set to True to enable
                plt.figure(figsize=(10, 4))
                time_axis = np.arange(len(energy)) * (5.0 / 2)  # Assuming 5s window
                plt.plot(time_axis, energy)
                plt.axhline(threshold, color='r', linestyle='--', label='Threshold')
                plt.axvline(expected_time_sec, color='g', linestyle='--', label='Expected Arrival')
                plt.xlabel('Time (s)')
                plt.ylabel('Energy')
                plt.legend()
                plt.show()

        except Exception as e:
            print(f"Error processing {date_time}: {str(e)}")
            continue  # Skip to next event on failure
    return data, catalogue

data, catalogue = find_candidates(manager, catalogue)

#first arrival is in the middle
#second arrival plot will be
#dt = sec - first + timedelta(minute=10) and ect.

In [None]:
catalogue.to_csv(f'../../data/{name}_2018_T.csv',index=False)