# Dataset Preporation

## Modules and Utility Functions

In [None]:
import pandas as pd
import numpy as np
import math
import seaborn as sns

%matplotlib inline

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.cm as cm
matplotlib.rcParams['figure.dpi'] = 75

from snad import OSCCurve
from snad import SNFiles
from utils import binarization, interpolation, rolling_diff, rolling_ratio, log10_

from tqdm import tqdm_notebook as tqdm
from tqdm import trange

from scipy import stats

pd.options.display.float_format = "{:.3f}".format
pd.set_option('display.max_columns', 500)

In [None]:
def _plot_curve(row):
    y = row.curve_r.y / row.curve_r.y.max()
    x = row.curve_r.x - row.curve_r.x[y.argmax()]
    plt.plot(x, y, 'b', alpha=0.2)
    plt.xlim(-200, 500)
    return len(row.curve_r.x)

def _plot_spyder(row, ax1, ax2):
    x_curve = [i for i in range(16)]
    y_curve = row[['x_' + str(i) for i in range(16)]]
    x_spyder = [i for i in range(16, 32)]
    y_spyder = row[['x_' + str(i) for i in range(16, 32)]]
    ax1.plot(x_curve, y_curve, c='r', alpha=0.2)
    ax2.plot(x_spyder, y_spyder, c='r', alpha=0.2)
    return row.size_r
    
def plot_curves_by(df, metric, min_metric, min_size_r, plot_spyders=False):
    if not plot_spyders:
        plt.figure(figsize=(8, 3))
        plt.ylabel("Поток, эрг / с / Гц / см^2", fontsize=10)
        # plt.xlabel("Время, MJD", fontsize=15)
        length = df[(df[metric] >= min_metric) & (df.size_r >= min_size_r) & (df.size_r <= 400)].apply(_plot_curve, axis=1)
    else:
        fig, ax = plt.subplots(1, 2, figsize=(12, 3))
        length = df[(df[metric] >= min_metric) & (df.size_r >= min_size_r) & (df.size_r <= 400)].apply(_plot_spyder, args=[ax[0], ax[1]], axis=1)
    plt.grid(b=1)
    plt.tight_layout()
    plt.title(metric + " >= {0:.4f};    Кол-во объектов: {1};    Точек: от {2} до {3}".format(min_metric,
                                                                                        df[df[metric] >= min_metric].shape[0],
                                                                                        np.array(length).min(),
                                                                                        np.array(length).max()), 
                  loc='center', fontsize=12)
    plt.show()
    
def plot_curves_by_lower(df, metric, min_metric, min_size_r, plot_spyders=False):
    if not plot_spyders:
        plt.figure(figsize=(8, 3))
        plt.ylabel("Поток, эрг / с / Гц / см^2", fontsize=10)
        # plt.xlabel("Время, MJD", fontsize=15)
        length = df[(df[metric] < min_metric) & (df.rms > 0) & (df.size_r >= min_size_r) & (df.size_r <= 400)].apply(_plot_curve, axis=1)
    else:
        fig, ax = plt.subplots(1, 2, figsize=(12, 3))
        length = df[(df[metric] < min_metric) & (df.rms > 0) & (df.size_r >= min_size_r) & (df.size_r <= 400)].apply(_plot_spyder, args=[ax[0], ax[1]], axis=1)
    plt.grid(b=1)
    plt.tight_layout()
    plt.title(metric + " < {0:.4f};    Кол-во объектов: {1};    Точек: от {2} до {3}".format(min_metric,
                                                                                        df[df[metric] < min_metric].shape[0],
                                                                                        np.array(length).min(),
                                                                                        np.array(length).max()), 
                  loc='center', fontsize=12)
    plt.show()

In [None]:
def get_bin_curves(row):
    aname = row.sn_name
    atype = row.type
    asize = row.size_r
    acurve = row.curve_r
    aline = []

    y_cor = acurve.y / acurve.y.max()
    # x_cor = (acurve.x - acurve.x.min())
    x_cor = (acurve.x - acurve.x[y_cor.argmax()])

    new_y = binarization(x_cor, y_cor, n=16, x_ranges=[-50, 100])
    new_y = interpolation(np.arange(len(new_y)), new_y)
    # new_y = rolling_diff(new_y)
    new_y2 = rolling_ratio(new_y)
    # new_y2 = log10_(new_y2)
        
    aline += [row.is_ia, row.sn_name, row.rms, row.size_r, row.type] + list(new_y) + list(new_y2)
    return aline

def preprocess_sn_data(sn_data):
    spec_class = sn_data.spec_class.values
    sn_name = sn_data.sn_name.values
    sn_data_processed = sn_data.apply(get_bin_curves, axis=1)
    row_len = len(list(sn_data_processed)[0])
    cols = ['is_ia', 'sn_name', 'rms', 'size_r', 'type'] + ['x_' + str(i) for i in range(row_len-5)]
    sn_data_processed_df = pd.DataFrame(data=list(sn_data_processed), columns=cols)
    L = sn_data_processed_df['is_ia'].values
    W = np.ones(len(L)) * (L == 0) + 1. * (L == 1) * ((L == 0).sum()) / ((L == 1).sum())
    sn_data_processed_df['weight'] = W
    sn_data_processed_df['spec_class'] = spec_class
    return sn_data_processed_df

In [None]:
def plot_log_hist(ax, df_col, xlim, bins, title):
    x = df_col
    ax.hist(x, bins=bins)
    ax.set_xlabel('Value')
    ax.set_ylabel('Count')
    ax.set_xscale('log')
    ax.set_xlim(xlim)
    ax.grid(b=1)
    ax.title.set_text(title)
    
def plot_hist(ax, df_col, bins, xlim, title):
    x = df_col
    ax.hist(x, bins=bins)
    ax.set_xlabel('Value')
    ax.set_ylabel('Count')
    ax.set_xlim(xlim)
    ax.grid(b=1)
    ax.title.set_text(title)
    
def plot_2d_hist(fig, ax, col1, col2):
    x = col1.values
    y = col2.values
    x_bins = np.logspace(-2, 3, 20)
    y_bins = np.linspace(0, 30, 30)
    H, xedges, yedges = np.histogram2d(x, y, bins=[x_bins, y_bins])
    pc = ax.pcolormesh(xedges, yedges, H.T, cmap=cm.Blues, alpha=0.8)
    ax.set_xscale('log')
    ax.set_xlabel('RMS')
    ax.set_ylabel('Number of Points')
    ax.title.set_text('RMS & Number of Points Distribution')
    fig.colorbar(pc)
    
def plot_cum_hist(ax, df_col):
    x = np.log(df_col)
    mu = np.mean(x)
    sigma = np.std(x)
    n_bins = 50
    n, bins, patches = ax.hist(x, n_bins, density=True, histtype='step',
                               cumulative=True, label='Empirical')
    y = ((1 / (np.sqrt(2 * np.pi) * sigma)) *
         np.exp(-0.5 * (1 / sigma * (bins - mu))**2))
    y = y.cumsum()
    y /= y[-1]
    ax.plot(bins, y, 'k--', linewidth=1.5, label='Theoretical')
    ax.hist(x, bins=bins, density=True, histtype='step', cumulative=-1,
            label='Reversed emp.')
    ax.grid(True)
    ax.legend(loc='best')
    ax.set_title('Cumulative step histograms')
    ax.set_xlabel('RMS (log)')
    ax.set_ylabel('Likelihood of occurrence')

## Data Import

In [None]:
new_types = pd.read_csv('../data/types.csv')
new_types = dict(new_types.values)

def type_mapping(claimedtype):
    return new_types.get(str(claimedtype).split(';')[0], 0)

def get_sn_data(file_name, print_errors=False):
    data = pd.read_csv(file_name)
    #data = data[data.spec_class == 1]
    
    load_n = 1000
    data_names = data['Name'].values
    spec_classes = data['spec_class'].values
    data_types = data['claimedtype'].values
    sn_names = []
    sn_types = []
    sn_sizes_r = []
    sn_curves_r = []
    sn_spec_class = []
    sn_rms = []
    sn_rms_for_pval = []
    with tqdm(total=data.shape[0]) as pbar:
        for num, (aname, atype, aspec_class) in enumerate(zip(data_names, data_types, spec_classes)):
            rms = -1
            rms_for_pval = -1
            try:
                sn = OSCCurve.from_name(aname, 
                                        bands=["r"], 
                                        down_args={'baseurl': 'http://sai.snad.space/sne20200801/'})
            except Exception as e:
                #print(e)
                continue
            sn = sn.filtered(with_upper_limits=False, with_inf_e_flux=True)
            try:
                rms = sn.rms()['r']
                rms_for_pval = sn.rms(pval=True)['r']
            except Exception as e:
                # print(e)
                pass
            curve_r = sn['r']
            size_r = len(curve_r)
            sn_names.append(aname)
            sn_types.append(atype)
            sn_sizes_r.append(size_r)
            sn_curves_r.append(curve_r)
            sn_spec_class.append(aspec_class)
            sn_rms.append(rms)
            sn_rms_for_pval.append(rms_for_pval)
            pbar.update(1)
    sn_data = pd.DataFrame({'sn_name' : sn_names, 'type' : sn_types,
                            'size_r' : sn_sizes_r, 'curve_r' : sn_curves_r, 
                            'spec_class' : sn_spec_class,
                            'rms' : sn_rms, 'rms_for_pval' : sn_rms_for_pval})
    return sn_data

In [None]:
%%time
data_r = get_sn_data('../data/min0obs_r.csv')

In [None]:
data_r.shape

In [None]:
data_r.to_pickle('../data/new_article_data_r.pkl')