In [1]:
################ import package

import os
import time
import datetime as dt
import h5py
from multiprocessing import cpu_count, Pool
from ipdb import set_trace as st
from tqdm import tqdm

import numpy as np
import numpy.linalg as LA
import linecache as lc
import pandas as pd

from scipy import io as sio
from sklearn.metrics import normalized_mutual_info_score
from scipy.signal import find_peaks

import matplotlib.pyplot as plt

from funs import SoHO_read, Read_solar_image, \
    dt2date, Prob_train


import sklearn


In [6]:
###################### Config ######################

time_stamp = 'Data/time_data_19873_train_valid_test.h5'
channels = ['MDI','EIT','LASCO'] # , ,'LASCO_diff','MDI_diff',

dst_peak = -100
delay = 24
time_res = 2
Peak_width = 5
Peak_dis = 120
F107_thres = [0, 500]
SDO_flag = 0
storm_idx = [15]
var_idx = [0, 1, 2]

delay_range = 12*time_res
delay_hour_clu = range(delay, delay+delay_range, time_res)
SoHO_file = 'Res/Solar_data_19873.h5'
Omni_data = 'Data/1999-2012.pkl'
hmi_file = 'Data/hmi_halloween2021.h5'

Res_name = 'Res/Dst_'+\
    str(delay)+'-'+\
    str(delay+delay_range) + '-'+\
    str(time_res) + '--'+\
    str(dst_peak)+'.h5'

# filename_Y = 'Results/Bz_GSE_0-48.h5'
callname = 'Res/params_'+\
    str(np.array(var_idx))+'_'+ \
    str(delay)+'-' +\
    str(delay+12*time_res) +'--'+\
    str(time_res)+'-' +\
    str(dst_peak)+'-'+\
    str(storm_idx[0])+'.pt'

callname_opt = 'Res/params_opt_'+\
    str(np.array(var_idx))+'_'+ \
    str(delay)+'-' +\
    str(delay+12*time_res) +'--'+\
    str(time_res)+'-' +\
    str(dst_peak)+'-'+\
    str(storm_idx[0])+'.pt'


In [3]:
################### global variables #################

df = pd.read_pickle(Omni_data)
omni_data = df['DST']
omni_date = df.index

# Fill missing values
print(f'Missing value count \
    {omni_data.isna().sum()}/{len(omni_data)}')
omni_data.interpolate(inplace=True)
omni_data.dropna(inplace=True)

Missing value count     0/113963


In [4]:
###################### SoHO data (X, run it once) ######################

with h5py.File(time_stamp,'r') as f:
    print(f.keys())
    train_date = np.array(f['train_15645_dates'])
    valid_date = np.array(f['valid_2301_dates'])
    test_date = np.array(f['test_1927_dates'])
    f.close()

all_date = np.vstack([valid_date, train_date, test_date])
X_all, Y_all = SoHO_read(channels, win_size=1)

with h5py.File(SoHO_file, 'w') as f:

    f.create_dataset('X', data=X_all)
    # f.create_dataset('Y', data=Y_all)
    f.create_dataset('date', data=all_date)
    f.close()

<KeysViewHDF5 ['test_1927_dates', 'train_15645_dates', 'valid_2301_dates']>


In [5]:
###################### Dst data (Y, run it once) ######################

with h5py.File(SoHO_file, 'r+') as f:

    for v in ['Y']:
        if v in f:
            del f[v]
    
    all_date = np.array(f['date'])

    out = np.zeros([len(all_date), delay_range//time_res])

    for idx in tqdm(range(len(all_date))):
        out[idx] = Read_solar_image(idx, omni_data, 
                                    omni_date, all_date,
                                    time_res, delay_range, delay)

    f.create_dataset('Y', data=out)
    f.close()

100%|██████████| 19873/19873 [02:53<00:00, 114.34it/s]


In [7]:
##########dst storm time selection(run it once) ###################

dates = []
sample_storm = []
n = 1

with h5py.File(SoHO_file, 'r') as f:
    X = np.array(f['X'])
    Y = np.array(f['Y'])
    all_date = np.array(f['date'])
    f.close()

if SDO_flag:
    with h5py.File(hmi_file, 'r') as f:

        X_ex = np.array(f['data'])
        Y_reg_ex = np.array(f['y'])
        Y_ex = np.zeros(Y_reg_ex.shape)
        Y_ex[Y_reg_ex<dst_peak] = 1
        date_ex = np.array(f['date'])
        f.close()

for ind in tqdm(range(len(all_date))):
    date = dt2date(all_date[ind], time_res)
    dates.append(date)

peaks, _ = find_peaks(omni_data*-1,
                    # height=np.abs(args.Dst_sel),
                    distance=Peak_dis,
                    width=Peak_width)
idx = np.where(omni_data[peaks] <= dst_peak+20)[0]

idx_clu = np.zeros([len(idx), 2])

for i, idx_t in tqdm(enumerate(idx)):

    print('peak {}:'.format(i), omni_date[peaks[idx_t]])

    idx_clu[i, 0] = np.where(
        omni_data[:peaks[idx_t]] >= 0)[0][-1]-delay_range
    idx_clu[i, 1] = np.where(
        omni_data[peaks[idx_t]:] >= 0)[0][0]+delay_range+peaks[idx_t]

idx_clu = idx_clu.astype(int)

for i, idx in tqdm(enumerate(idx_clu)):
    date_end = omni_date[int(idx[1])]
    date_beg = omni_date[int(idx[0])]

    index_image = [j for j in range(len(dates)) if
                    ((date_beg <= dates[j]+dt.timedelta(hours=delay))
                    & (date_end >= dates[j]+dt.timedelta(hours=delay)))
                    ]

    if len(index_image) >= 28:
        gap = np.zeros(len(index_image) - 1)
        for k, idx_image in enumerate(index_image[:-1]):
            gap_t = dates[index_image[k+1]] - dates[index_image[k]]
            gap[k] = gap_t.seconds//3600

        if ((gap > time_res*3).sum() <= 15) & (len(index_image) > 30):

            sample_storm.append(index_image)
            print('size of {}th storm should be {}/{}'\
                .format(n, (idx[1]-idx[0])//time_res, 
                len(index_image)))
            
            # date_plot = [dates[j] for j in index_image]
            # plt.plot(date_plot, \
            #     Y[index_image, 0], 
            #     'r.')
            # plt.xticks(rotation='vertical')

            # plt.savefig('Figs/sample_'+str(n)+'.jpg', dpi=300)
            # plt.close()
            print('start/end time {}/{}'.format(date_beg, date_end))
            # print('end time', date_end)
            
            n += 1

            if n == 51:
                break

with h5py.File(Res_name, 'w') as f:

    for v in ['X_ex', 'Y_ex', 'Y_reg_ex', \
        'time_ex', 'storm_num']:
        if v in f:
            del f[v]
    # import ipdb;ipdb.set_trace()
    for i in tqdm(range(len(sample_storm))):

        idx = sample_storm[i]
        for v in ['X_'+str(i), 'Y_'+str(i),
                    'Y_reg_'+str(i), 'time_'+str(i)]:
            if v in f:
                del f[v]

        f.create_dataset('X_'+str(i), data=X[idx])
        f.create_dataset('Y_reg_'+str(i), data=Y[idx])
        f.create_dataset('Y_'+str(i), data=Y[idx]<dst_peak)
        f.create_dataset('time_'+str(i), data=all_date[idx])
    f.create_dataset('X_ex', data=X_ex)
    f.create_dataset('time_ex', data=date_ex)
    f.create_dataset('Y_ex', data=Y_ex)
    f.create_dataset('Y_reg_ex', data=Y_reg_ex)
    f.create_dataset('storm_num', data=len(sample_storm))

    f.close()
