In [6]:
import pandas as pd
import numpy as np
import h5py
import re
from pathlib import Path

In [7]:
base_dir = Path(Path.cwd()).parent / 'bes-edgeml-models/turbulence_regime_classification'
label_df = pd.read_excel(base_dir/'confinement_database.xlsx').fillna(0)
label_df

Unnamed: 0,shot,tstart (ms),tstop (ms),L-mode,H-mode,QH-mode,WP QH-mode,Notes
0,149992.0,2540.0,2635.0,1.0,0.0,0.0,0.0,0
1,149992.0,2638.0,3200.0,0.0,1.0,0.0,0.0,"ELM-free, then ELMy"
2,149992.0,4038.0,4125.0,1.0,0.0,0.0,0.0,0
3,149992.0,4136.0,4500.0,0.0,1.0,0.0,0.0,"ELM-free, then ELMy"
4,149993.0,1100.0,1900.0,1.0,0.0,0.0,0.0,long L-mode due to failed LH transition
5,149993.0,2540.0,2635.0,1.0,0.0,0.0,0.0,0
6,149993.0,2650.0,3400.0,0.0,1.0,0.0,0.0,"ELM-free, then ELMy"
7,149993.0,4050.0,4165.0,1.0,0.0,0.0,0.0,0
8,149993.0,4172.0,4975.0,0.0,1.0,0.0,0.0,"ELM-free, then ELMy"
9,149994.0,1340.0,1900.0,1.0,0.0,0.0,0.0,long L-mode due to failed LH transition


In [8]:
def make_labels(data, df):
    time = np.array(data['time'])
    labels = np.zeros_like(time)
    for i, row in df.iterrows():
        tstart = row['tstart (ms)']
        tstop = row['tstop (ms)']
        label = row[[col for col in row.index if 'mode' in col]].values.argmax() + 1
        labels[np.nonzero((time > tstart) & (time < tstop))] = label

    return labels.tolist()

In [11]:
# get labeled files if they exists. Probably a better way to do this.
files = [[],[],[]]
for file in (base_dir/'data').iterdir():
    try:
        shot_num, labeled = re.findall(r'_(\d+)_?(labeled)?.hdf5', str(file))[0]
    except IndexError:
        continue
    if shot_num in files[0]:
        i = files[0].index(shot_num)
        if files[2][i] != 'labeled':
            del files[0][i], files[1][i], files[2][i]
        else:
            continue
    files[0].append(shot_num)
    files[1].append(file)
    files[2].append(labeled)


files = [(sn, f) for sn, f, l in zip(*files)]

overwrite = False
signals, labels, time = [], [], []
for shot_num, file in files:
    shot_df = label_df.loc[label_df['shot'] == float(shot_num)]
    if len(shot_df) == 0:
        print(f'{shot_num} not in confinement database.')
        continue
    else:
        print(f'Processing shot {shot_num}')

    with h5py.File(file, 'a') as shot_data:
        signals = np.array(shot_data['signals']).transpose().tolist()
        time = np.array(shot_data['time']).tolist()
        try:
            labels = np.array(shot_data['labels']).tolist()
        except KeyError:
            labels = make_labels(shot_data, shot_df)

            if overwrite:
                shot_data.create_dataset('labels', data=labels)
            else:
                with h5py.File(f'{file.parent / (file.stem + "_labeled" + file.suffix)}', 'w') as sd:
                    for group in shot_data.keys():
                        shot_data.copy(group, sd)
                    sd.create_dataset('labels', data=labels)

    signals.extend(signals)
    labels.extend(labels)
    time.extend(time)


signals = np.array(signals)
labels = np.array(labels)
time = np.array(time)

print(f'Signals: {signals.shape}')
print(f'Labels: {labels.shape}')
print(f'Time: {time.shape}')

Processing shot 149995
Processing shot 149992
Processing shot 184815
Processing shot 175490
Processing shot 184829
Processing shot 149994
Processing shot 184826
Processing shot 184813
Signals: (12582912, 64)
Labels: (12582912,)
Time: (12582912,)


In [15]:
shot_data = h5py.File(files[-1][-1], 'r')
len(shot_data['labels'])

6291456