# Useful functions

Defines a couple of functions that are used in various notebooks for the ELM-Sawtooth correlations project for SUMTRAIC 2023.

In [1]:
# Author: Magne Lauritzen, mag.lauritzen@gmail.com
import numpy as np
from cdb_extras import xarray_support as cdbxr  # access to COMPASS Database (CDB)

%run ./sawtooths_extraction.ipynb

In [2]:
def ELM_phase(nshot: int, t: np.ndarray, relative_to_nearest: bool =False):
    """
    Calculates the ELM phase and ELM delay of the timestamps in t, as well as the ELM period. 
    If `relative_to_nearest` is True, the phase and delay is calculated with respect to the nearest ELM, not the previous.
    ELMs that had periods greater than 20ms are not counted.
    """
    shot = cdbxr.Shot(nshot) 
    
    # Get ELM start timestamps
    t_ELM_start = shot['t_ELM_start'].values
    
    # Allocate array to hold phase values
    ELM_phases = np.full(fill_value=np.nan, shape=t.shape)
    ELM_delays = np.full(fill_value=np.nan, shape=t.shape)
    ELM_period = np.full(fill_value=np.nan, shape=t.shape)
    
    if len(t_ELM_start) == 0:
        print(f"No ELMs in shot {nshot}. Cannot compute ELM phases.")
        return ELM_phases
    
    # Only calculate phases of times lying within first and last ELM timestamp
    elm_range_mask = np.logical_and(t_ELM_start[0] < t, t_ELM_start[-1] > t)
    t_masked = t[elm_range_mask]
        
    # Get ELMs preceding and following each time
    ELM_ind_following_t = np.searchsorted(t_ELM_start, t_masked)
    ELM_time_following_t = t_ELM_start[ELM_ind_following_t]
    ELM_time_preceding_t = t_ELM_start[ELM_ind_following_t-1]
    
    # Calculate duration of ELM
    elm_period = ELM_time_following_t - ELM_time_preceding_t
    
    # Calculate phase of t within each ELM
    t_delay = t_masked - ELM_time_preceding_t
    t_early = t_masked - ELM_time_following_t
    selection_mask = t_delay > np.abs(t_early)
    if relative_to_nearest:
        t_delay[selection_mask] = t_early[selection_mask] 
        
    phases = t_delay / elm_period
    
    # Mask out ELMs that lasted more than 20ms
    max_duration_mask = elm_period > 20
    phases[max_duration_mask] = np.nan
    t_delay[max_duration_mask] = np.nan
    ELM_phases[elm_range_mask] = phases
    ELM_delays[elm_range_mask] = t_delay
    ELM_period[elm_range_mask] = elm_period
    
    # Sanity check
    assert len(ELM_phases) == len(t)
    return ELM_phases, ELM_delays, ELM_period

In [3]:
def ST_time_and_phase(nshot, t, relative_to_nearest=False):
    shot = cdbxr.Shot(nshot)  # dict-like accessor to all signals measured in a given shot
    
    # Load positions of sawtooth crashes
    sawtooth_data_folder = Path('./sawtooth_data')
    sawtooth_data_folder.mkdir(exist_ok=True)
    filename = f"{nshot}_st.bin"
    filepath = sawtooth_data_folder / Path(filename)
    
    if not filepath.exists():
        ST_detector(nshot, save_path='./sawtooth_data')
        
    with open(filepath, 'rb') as fp:
        ST_data = pickle.load(fp)
        ST_times = ST_data.times
        ST_amplitudes = ST_data.amplitudes

    # Allocate empty arrays to hold sawtooth phase and sawtooth amplitudes
    ST_phases = np.full(fill_value=np.nan, shape=t.shape)
    ret_t_delays = np.full(fill_value=np.nan, shape=t.shape)
    ret_ST_amplitudes = np.full(fill_value=np.nan, shape=t.shape)
    
    if len(ST_times) == 0:
        print(f"No STs in shot {nshot}. Cannot compute ST phases.")
        return ST_phases, ST_amplitudes
    
    # Only calculate phases of times lying within first and last ST timestamp
    mask = np.logical_and(ST_times[0] < t, ST_times[-1] > t)
    t_masked = t[mask]
    
    # Get ST timestamps preceding and following each time
    ST_ind_following_t = np.searchsorted(ST_times, t_masked)
    ST_time_following_t = ST_times[ST_ind_following_t]
    ST_time_preceding_t = ST_times[ST_ind_following_t-1]
    
    # Get ST amplitudes preceding each time
    ST_amplitude_preceding_t = ST_amplitudes[ST_ind_following_t-1]
    
    # Calculate ST phase of each time
    ST_duration = ST_time_following_t - ST_time_preceding_t
    t_delay = t_masked - ST_time_preceding_t
    if relative_to_nearest:
        t_early = t_masked - ST_time_following_t
        selection_mask = t_delay > np.abs(t_early)
        t_delay[selection_mask] = t_early[selection_mask]
        
    phases = t_delay / ST_duration
    
    # Mask out ST phases that lasted more than 20ms
    max_duration_mask = ST_duration > 20
    phases[max_duration_mask] = np.nan
    ST_amplitude_preceding_t[max_duration_mask] = np.nan
    t_delay[max_duration_mask] = np.nan
    
    ST_phases[mask] = phases
    ret_ST_amplitudes[mask] = ST_amplitude_preceding_t
    ret_t_delays[mask] = t_delay
    
    # Sanity check
    assert len(ST_phases) == len(t)
    assert len(ret_t_delays) == len(t)
    assert len(ret_ST_amplitudes) == len(t)
    if not relative_to_nearest:
        assert np.all(ST_phases[~np.isnan(ST_phases)] >= 0)
        assert np.all(ret_t_delays[~np.isnan(ret_t_delays)] >= 0)
    assert np.all(ret_ST_amplitudes[~np.isnan(ret_ST_amplitudes)] >= 0)
    
    return ST_phases, ret_t_delays, ret_ST_amplitudes 

In [4]:
def load_ST_crash_time(nshot):
    
    shot = cdbxr.Shot(nshot)  # dict-like accessor to all signals measured in a given shot
    
    # Load positions of sawtooth crashes
    sawtooth_data_folder = Path('./sawtooth_data')
    sawtooth_data_folder.mkdir(exist_ok=True)
    filename = f"{nshot}_st.bin"
    filepath = sawtooth_data_folder / Path(filename)
    
    if not filepath.exists():
        ST_detector(nshot, save_path = './sawtooth_data')
        
    with open(filepath, 'rb') as fp:
        ST_data = pickle.load(fp)
        ST_times = ST_data.times
        
    return ST_times