In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
# -*- coding: utf-8 -*-

# https://teddykoker.com/2020/12/dataloader/

import sys

sys.path.append("/home/rkube/repos/frnn-loader")

from os.path import join
from pathlib import Path

import yaml

import torch

from frnn_loader.backends.fetchers import fetcher_d3d_v1
from frnn_loader.backends.backend_hdf5 import backend_hdf5
from frnn_loader.primitives.filters import filter_ip_thresh
from frnn_loader.primitives.resamplers import resampler_causal
from frnn_loader.primitives.signal import signal_0d
from frnn_loader.primitives.normalizers import mean_std_normalizer
from frnn_loader.loaders.frnn_dataset_disk import shot_dataset_disk

In [3]:
"""Construct a dataset for FRNN training.

Predictive machine learning models are trained on datasets. These dataset
consist of a suite of measurements taken on a set of shots.

Deep neural networks are trained on pre-processed and normalized data.
Pre-processing includes:
- Resampling of the measurements onto a common time-base
- Construction of target variables, such as time-to-disruption or time-to-ELM
- Signal clipping

Normalization means the transformation of signals into order unity quantities. Common ways
to do this is by a Z-score transformation (subtract mean, divide by std dev.), min/max normalizer,
etc.

"""

'Construct a dataset for FRNN training.\n\nPredictive machine learning models are trained on datasets. These dataset\nconsist of a suite of measurements taken on a set of shots.\n\nDeep neural networks are trained on pre-processed and normalized data.\nPre-processing includes:\n- Resampling of the measurements onto a common time-base\n- Construction of target variables, such as time-to-disruption or time-to-ELM\n- Signal clipping\n\nNormalization means the transformation of signals into order unity quantities. Common ways\nto do this is by a Z-score transformation (subtract mean, divide by std dev.), min/max normalizer,\netc.\n\n'

In [5]:
# Directory where all project data files are to be stored
proj_dir = "/projects/FRNN/frnn_loader"

# 1/ Describe the dataset
predictor_tags = [
    "q95",
    "efsli",
    "ipspr15V",
    "efsbetan",
    "efswmhd",
    "dusbradial",
    "dssdenest",
    "pradcore",
    "pradedge",
    "bmspinj",
    "bmstinj",
    "ipsiptargt",
    "ipeecoil",
]
predictor_list = [signal_0d(tag) for tag in predictor_tags]

# Contains a list of shots that are non-disruptive
shotlist_clear = "d3d_clear_100.txt"
# Contains a list of shots that are disruptive
shotlist_disrupt = "d3d_disrupt_100.txt"

In [6]:
# Instantiate the filter we use to crimp the shot times
ip_filter = filter_ip_thresh(0.2)
signal_ip = signal_0d("ipspr15V")
my_backend = backend_hdf5(proj_dir)
my_fetcher = fetcher_d3d_v1()

In [7]:
num_shots = 5
shotdict = {}

i = 0
with open(join(proj_dir, "..", "shot_lists", shotlist_clear), "r") as fp:
    for line in fp.readlines():
        # Convert shotnr to int and ttd to float
        shotnr, ttd = [trf(val) for trf, val in zip([int, float], line.split())]

        # Run the Ip filter over the current shot
        tb, data = my_backend.load(signal_ip.info, shotnr)
        tmin, tmax = ip_filter(tb, data)
        shotdict.update(
            {
                shotnr: {
                    "tmin": tmin,
                    "tmax": tmax,
                    "is_disruptive": False,
                    "t_disrupt": -1.0,
                }
            }
        )

        i += 1
        if i >= num_shots:
            break

In [8]:
i = 0
with open(join(proj_dir, "..", "shot_lists", shotlist_disrupt), "r") as fp:
    for line in fp.readlines():
        # Convert shotnr to int and ttd to float
        shotnr, ttd = [trf(val) for trf, val in zip([int, float], line.split())]
        # ttd is given in seconds in the text files. Convert it to milliseconds
        ttd = ttd * 1e3
        shotdict.update(
            {
                shotnr: {
                    "tmin": tmin,
                    "tmax": ttd,
                    "is_disruptive": True,
                    "t_disrupt": ttd,
                }
            }
        )

        i += 1
        if i >= num_shots:
            break

print("shotdict = ", shotdict)

shotdict =  {167475: {'tmin': 37.75, 'tmax': 6233.25, 'is_disruptive': False, 't_disrupt': -1.0}, 167481: {'tmin': 30.5, 'tmax': 6455.25, 'is_disruptive': False, 't_disrupt': -1.0}, 167482: {'tmin': 34.0, 'tmax': 6445.25, 'is_disruptive': False, 't_disrupt': -1.0}, 167483: {'tmin': 37.0, 'tmax': 6842.25, 'is_disruptive': False, 't_disrupt': -1.0}, 167484: {'tmin': 37.5, 'tmax': 6245.0, 'is_disruptive': False, 't_disrupt': -1.0}, 167480: {'tmin': 37.5, 'tmax': 5073.5, 'is_disruptive': True, 't_disrupt': 5073.5}, 167487: {'tmin': 37.5, 'tmax': 2159.0, 'is_disruptive': True, 't_disrupt': 2159.0}, 167488: {'tmin': 37.5, 'tmax': 5269.0, 'is_disruptive': True, 't_disrupt': 5269.0}, 167492: {'tmin': 37.5, 'tmax': 7088.0, 'is_disruptive': True, 't_disrupt': 7088.0}, 167494: {'tmin': 37.5, 'tmax': 3478.5, 'is_disruptive': True, 't_disrupt': 3478.5}}


In [9]:
#########################################################################################################
#
# Next we create a list of datasets for all shots.
# The shots are cut to the time intervals defined by tmin and tmax
# No transformation has been defined,

dset_list = []
for shotnr in shotdict.keys():
    print(shotnr)

    # Resample all signals over the valid intervals
    my_resampler = resampler_causal(0.0, shotdict[shotnr]["tmax"], 1.0)

    ds = shot_dataset_disk(
        shotnr,
        predictors=predictor_list,
        resampler=my_resampler,
        backend_file=my_backend,
        fetcher=my_fetcher,
        root=proj_dir,
        download=True,
        dtype=torch.float32,
    )

    dset_list.append(ds)

167475
__init__: root = /projects/FRNN/frnn_loader
167481
__init__: root = /projects/FRNN/frnn_loader
167482
__init__: root = /projects/FRNN/frnn_loader
167483
__init__: root = /projects/FRNN/frnn_loader
167484
__init__: root = /projects/FRNN/frnn_loader
167480
__init__: root = /projects/FRNN/frnn_loader
167487
__init__: root = /projects/FRNN/frnn_loader
167488
__init__: root = /projects/FRNN/frnn_loader
167492
__init__: root = /projects/FRNN/frnn_loader
167494
__init__: root = /projects/FRNN/frnn_loader


In [10]:
#########################################################################################################
#
# With all datasets cropped to the correct time in place we continue by calculating the normalization.
# Do this using multi-processing
my_normalizer = mean_std_normalizer()
my_normalizer.fit(dset_list)

In [12]:
my_normalizer.mean

In [13]:
my_normalizer.mean_all

tensor([ 6.2881e+00,  1.0096e+00,  1.7226e+00,  1.6998e+00,  5.6891e+05,
         9.0436e-01,  3.2198e+00,  1.5518e-03,  1.3113e-03,  5.0626e+03,
         4.0367e+00,  1.7478e+00, -2.8650e-02])

In [14]:
my_normalizer.std_all

tensor([1.8086e+00, 3.4762e-01, 4.5560e-01, 1.0277e+00, 3.6443e+05, 1.0497e+00,
        1.3093e+00, 1.9447e-03, 1.0103e-03, 3.3218e+03, 2.7402e+00, 4.6416e-01,
        1.1235e-01])

In [15]:
# Now we can add the normalizer as a transform. The __getitem__ method of the dataset
# will apply the transform.

# Verify that the returned data is about zero mean and order unity std deviation
for ds in dset_list:
    ds.transform = my_normalizer

    print(ds[:].shape, ds[:].mean(axis=0), ds[:].std(axis=0))


torch.Size([6234, 13]) tensor([-0.0267, -0.0952,  0.0347,  0.1136,  0.1621, -0.0045, -0.1391,  0.1717,
         0.0775,  0.1645,  0.1270,  0.0469, -0.0595]) tensor([0.7978, 1.0978, 0.9519, 0.9931, 1.0204, 0.9509, 0.8209, 0.8511, 0.9136,
        1.0144, 0.9674, 0.8824, 0.7272])
torch.Size([6456, 13]) tensor([ 0.0234, -0.1454, -0.0376,  0.0219,  0.0647,  0.1862,  0.2429,  0.3585,
        -0.1759,  0.3616,  0.3748, -0.0572,  0.0820]) tensor([0.7518, 0.9331, 1.0060, 1.0907, 1.0695, 0.9223, 1.1449, 1.1887, 0.7864,
        1.3123, 1.3134, 1.0297, 0.3068])
torch.Size([6446, 13]) tensor([-0.0311, -0.0795, -0.0192,  0.0953,  0.1292,  0.1281,  0.0173,  0.2036,
        -0.0305,  0.3278,  0.3510, -0.0523,  0.1358]) tensor([0.7493, 1.1101, 0.9920, 1.1733, 1.1274, 0.9080, 0.9154, 1.1008, 0.8759,
        1.2848, 1.2854, 1.0226, 0.2710])
torch.Size([6843, 13]) tensor([ 0.0513, -0.0568, -0.1926,  0.1774,  0.2087,  0.2226, -0.1236, -0.2423,
         0.3224, -0.1173, -0.1194, -0.2677,  0.3227]) tensor([0

In [11]:
my_conf = parse_config("conf_test_preprocessor.yaml")

Selected signals (determines which signals are used for training):
None


# Paths

The basis path where all data is searched for is given by conf["paths"]["fs_path"]. 

Directories that are searched for signals are given by conf["paths"]["signal_prepath"]





In [16]:
my_conf

{'callbacks': {'list': ['earlystop'],
  'metrics': ['val_loss', 'val_roc', 'train_loss'],
  'mode': 'min',
  'monitor': 'val_loss',
  'patience': 50,
  'write_grads': False},
 'data': {'T_max': 1000.0,
  'T_min_warn': 30,
  'augment_during_training': False,
  'augmentation_mode': 'none',
  'bleed_in': 0,
  'bleed_in_remove_from_test': True,
  'current_end_thresh': 10000,
  'current_index': 0,
  'current_thresh': 750000,
  'cut_shot_ends': False,
  'dt': 0.001,
  'equalize_classes': False,
  'floatx': 'float32',
  'normalizer': 'var',
  'plotting': False,
  'positive_example_penalty': 16.0,
  'recompute': False,
  'recompute_normalization': False,
  'signal_to_augment': 'None',
  'use_shots': 200000,
  'window_decay': 2,
  'window_size': 10,
  'target': plasma.models.targets.FLATTarget},
 'env': {'name': 'torch-env', 'type': 'anaconda3'},
 'model': {'PCS': True,
  'backend': 'tensorflow',
  'cell_order': 4,
  'cell_rank': 11,
  'cell_steps': 5,
  'clipnorm': 10.0,
  'dense_layers_1d': 1

In [17]:
my_conf["paths"]

{'data': 'd3d_data_ped_spec',
 'executable': 'torch_learn.py',
 'shallow_executable': 'learn.py',
 'base_path': '/tigress',
 'shot_list_dir': '/tigress/FRNN/shot_lists/',
 'signal_prepath': ['/tigress/FRNN/signal_data_ipsip/',
  '/tigress/FRNN/signal_data_new_nov2019/',
  '/tigress/FRNN/signal_data_new_2020/',
  '/tigress/FRNN/signal_data_new_2021/',
  '/tigress/FRNN/signal_data_new_REAL_TIME/',
  '/tigress/FRNN/signal_data/',
  '/tigress/FRNN/signal_data_efit/'],
 'global_normalizer_path': '/normalization/normalization_signal_group_184694441437167251751036554375425130447.npz',
 'normalizer_path': './normalization/normalization_signal_group_184694441437167251751036554375425130447.npz',
 'model_save_path': 'model_checkpoints',
 'csvlog_save_path': 'csv_logs',
 'results_prepath': 'results',
 'saved_shotlist_path': '/tigress/processed_shotlists_torch',
 'processed_prepath': '/tigress/../FRNN/rkube-temp/processed_shots_torch/signal_group_{h}',
 'all_signals_dict': {'qmin': Minimum safety f