# Augment dataset

In this notebook we exemplify how to augment a dataset. This is tipically done to increase the number and/or representativity of trainining sets.

#### Index<a name="index"></a>
1. [Import Packages](#imports)
2. [Load Dataset](#loadData)
    1. [GP Path](#oriGpPath)
3. [Augment Dataset](#augData)
    1. [Choose the Events to Augment](#chooseEvent)
    3. [Choose the Photometric Redshift](#choosePhotoZ)
    4. [Run Augmentation](#aug)
    5. [See Augmented Dataset Properties](#statsAug)
4. [Save Augmented Dataset](#saveAug)
5. [Light curve comparison](#comparison)

## 1. Import Packages<a name="imports"></a>

In [None]:
!pip install ../snmachine/

In [None]:
import collections
import os
import pickle
import sys
import time

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

In [None]:
from snmachine import snaugment, sndata
from utils.plasticc_pipeline import get_directories, load_dataset

In [None]:
%config Completer.use_jedi = False  # enable autocomplete

#### Aestetic settings

In [None]:
%matplotlib inline

sns.set(font_scale=1.3, style="ticks")

## 2. Load Dataset<a name="loadData"></a>

First, **write** the path to the folder that contains the dataset we want to augment, `folder_path`.

In [None]:
# os_name = 'baseline_v2_0_paper'
os_name = 'noroll_v2_0_paper'
# os_name = 'presto_v2_0_paper'

folder_path = f'/path/to/save/'

Then, **write** in `data_file_name` the name of the file where your dataset is saved.

In this notebook we use the dataset saved in [2_preprocess_data]().

In [None]:
is_only_roll = 1
is_updated = 1

In [None]:
extra_name_to_save = 'ddf_wfd'

file_id = '000'
#file_id = '002' # until 009

data_file_name = f'train_{extra_name_to_save}_{file_id}_gapless50.pckl'
if is_only_roll:
    data_file_name = f'train_{extra_name_to_save}_{file_id}_roll_gapless50.pckl'
if is_updated:
    data_file_name = data_file_name[:-5] + '_updated.pckl'
data_file_name

Load the dataset.

In [None]:
data_path = os.path.join(folder_path, data_file_name)
dataset = load_dataset(data_path)

In [None]:
train_metadata = dataset.metadata

### 2.1. GP Path<a name="oriGpPath"></a>

The GP augmentation uses the previously saved GPs, so **write** the path where they were saved. For help in fitting GPs to the dataset, follow [3_model_lightcurves](3_model_lightcurves.ipynb).

**<font color=Orange>A)</font>** Obtain GP path from folder structure.

If you created a folder structure, you can obtain the path from there. **Write** the name of the folder in `analysis_name`. 

In [None]:
analysis_name = data_file_name[:-5]

In [None]:
folder_path

Obtain the required GP path.

In [None]:
folder_analysis_path = folder_path[:-14] + 'analyses'
directories = get_directories(folder_analysis_path, analysis_name) 
path_saved_gps = directories['intermediate_files_directory']

**<font color=Orange>B)</font>** Directly **write** where you saved the GP files.

```python
path_saved_gps = os.path.join(folder_path, data_file_name[:-5])
```

## 3. Augment Dataset<a name="augData"></a>

Here we augment the data and make sure all the properties have the expected values.

In the following sections we decide the following augmentation inputs: 
1. `objs_number_to_aug` : a dictionary specifying which events to augment and by how much.
2. `choose_z` : function used to choose the new true redshift of the augmented events.
3. `z_table` : dataset containing the spectroscopic and photometric redshift, and photometric redshift error of events; it is used to generate realistic augmented photometric redshifts.
4. `max_duration` : maximum duration of the augmented light curves.
5. `random_seed` : random seed used; saving this seed allows reproducible results.

### 3.1. Choose the Events to Augment<a name="chooseEvent"></a>

**Write** in `aug_obj_names` a list containing all the events to augment. Here we will try to augment them all.

In [None]:
aug_obj_names = dataset.object_names  # try to augment all events

**Create** a dictionary that associates to each event, the target number of synthetic events to create from it. Note that some augmentations will fail so this is not the final number of events. Additionally, each class has a different creation efficiency.

In [None]:
np.random.seed(42)
is_to_aug = np.in1d(dataset.object_names, aug_obj_names)

# Choose the target number of events in the augmented dataset. 
# Usually, only half of this number are accepted in the augmented dataset
target_number_aug = np.sum(is_to_aug) * 40

number_objs_per_label = collections.Counter(dataset.labels[is_to_aug])
number_aug_per_label = target_number_aug//len(number_objs_per_label.keys())
objs_number_to_aug = {}
for label in number_objs_per_label.keys():
    is_label = dataset.labels[is_to_aug] == label
    aug_is_label_obj_names = aug_obj_names[is_label]
    number_aug_per_obj = number_aug_per_label // np.sum(is_label)
    if label == 90:
        number_aug_per_obj = int(number_aug_per_obj*.8)
    elif label == 95:
        number_aug_per_obj = int(number_aug_per_obj*.5)
    number_extra_aug_per_obj = number_aug_per_label % np.sum(is_label)
    extra_obj = np.random.choice(aug_is_label_obj_names, size=number_extra_aug_per_obj, 
                                 replace=False)
    objs_number_to_aug.update({obj: number_aug_per_obj for obj in aug_is_label_obj_names})
    objs_number_to_aug.update({obj: number_aug_per_obj+1 for obj in extra_obj})

In [None]:
print(f'We aim to create up to {sum(objs_number_to_aug.values())} events.')  # confirm how many events to create

### 3.2. Choose the Photometric Redshift<a name="choosePhotoZ"></a>

In order to simulate realistic photometric redshifts for the synthetic events, following [Boone (2019)](https://iopscience.iop.org/article/10.3847/1538-3881/ab5182) we chose a random event from the test set events that had a spectroscopic redshift measurement, and calculated the difference between its spectroscopic and photometric redshifts. We then added this difference to the true redshift of the augmented event to generate a photometric redshift. 

**Add** such a dataset containing spectroscopic and photometric redshift, and photometric redshift error of events as `z_table`. If none is provided, a similar table is generated from the events in `dataset`.

In [None]:
ini_time = time.time() 
test_data_file_name = f'test_wfd_000_gapless50.pckl'
if is_only_roll:
    test_data_file_name = f'test_wfd_000_roll_gapless50.pckl'
if is_updated:
    test_data_file_name = test_data_file_name[:-5] + '_updated.pckl'
test_data_path = os.path.join(folder_path, test_data_file_name)
print(test_data_path)

test_data = load_dataset(test_data_path)
test_metadata = test_data.metadata

# Discard the events without spectroscopic redshift; 
# these are encoded with `hostgal_specz` equal to -9
z_table = test_metadata[test_metadata.hostgal_specz > -2]
print(time.time() - ini_time)

In [None]:
z_table = z_table.append(dataset.metadata[dataset.metadata.hostgal_specz > -2])

In [None]:
file_name = 'z_table_' + test_data_file_name
if is_only_roll:
    file_name = 'z_table_roll_' + test_data_file_name
path_to_save_z_table = os.path.join(folder_path, file_name)
with open(path_to_save_z_table, 'wb') as path:
    pickle.dump(z_table, path)

### 3.3. Run Augmentation<a name="aug"></a>

We also need to choose which survey to emulate in the augmentation. At the moment `snmachine` contains the Wide-Fast-Deep (WFD) and the Deep Drilling Field (DDF) survey of the Rubin Observatory Legacy Survey of Space and Time. Use `snaugment.PlasticcWFDAugment` for the former survey and `snaugment.PlasticcDDFAugment` for the latter.
You can also implement your own augmentation using those classes as an example.

In addition to the above inputs, we **chose** the random seed (`random_seed`) used to allow reproducible results and the maximum duration of the augmented light curves (`max_duration`).

The value of `max_duration` must be higher than the maximum duration of any light curve in `dataset`. If none is provided, `max_duration` is set to the length of the longest event in `dataset`.

In [None]:
print(f'The longest event in `dataset` has {dataset.get_max_length():.2f} days.')

In [None]:
random_seed = 42 
max_duration = 295  # this is the length of the longest event in the paper SNe datasets

In [None]:
dataset.get_max_length()

Here we augmented following the DDF survey.

In [None]:
aug = snaugment.PrestoV20WFDAugmentOri(dataset=dataset, path_saved_gps=path_saved_gps, 
                                    objs_number_to_aug=objs_number_to_aug,
                                    random_seed=random_seed, max_duration=max_duration, 
                                    z_table=z_table)

```
aug = snaugment.BaselineV20WFDAugment(dataset=dataset, path_saved_gps=path_saved_gps, 
                                      objs_number_to_aug=objs_number_to_aug,
                                      random_seed=random_seed, max_duration=max_duration, 
                                      z_table=z_table)
                                      
aug = snaugment.NorollV20WFDAugment(dataset=dataset, path_saved_gps=path_saved_gps, 
                                    objs_number_to_aug=objs_number_to_aug,
                                    random_seed=random_seed, max_duration=max_duration, 
                                    z_table=z_table)
```

In [None]:
aug.augment()

Go to:
* [Index](#index)
* [Save Augmented Dataset](#saveAug)
  
### 3.4. See Augmented Dataset Properties<a name="statsAug"></a>

Here we see some properties of the augmented dataset. 

In [None]:
try:
    aug_data
    aug_metadata
    print('Previously loaded')
except NameError:
    aug_data = aug.only_new_dataset
    aug_metadata = aug_data.metadata
aug_data = aug.only_new_dataset
aug_metadata = aug_data.metadata

In [None]:
try:  # a test set was provided
    datasets_label = ['Original', 'Only Aug.', 'Test data']
    datasets_metadata = [dataset.metadata, aug_metadata, test_metadata]
except NameError:   # no test set was provided
    datasets_label = ['Original', 'Only Aug.']
    datasets_metadata = [dataset.metadata, aug_metadata]

In [None]:
print(f'The longest event in the augmented dataset (`aug.only_new_dataset`)'
      f' has {aug_data.get_max_length():.2f} days.')

In [None]:
print(f'In total we generated {len(aug_data.object_names)} events.')

Note that we generated less events than our target number of augmented events. As mentioned in Section [Choose the Events to Augment](#chooseEvent), some of the augmentations fail.

In [None]:
print('{:^12}  {:^12}  {:^12}  {:^12}'.format('Dataset', 'total # objs', '# DDF objs', '% DDF objs'))
print('-'*(12*4 + 3*2))
for i in np.arange(len(datasets_label)):
    is_ddf = datasets_metadata[i]['ddf'] == 1
    number_total_objs = len(is_ddf)
    number_ddf_objs = np.sum(is_ddf)
    print('{:^12} {:^12} {:^12} {:^12.2f}'.format(
        datasets_label[i], number_total_objs, number_ddf_objs, 
        number_ddf_objs/number_total_objs * 100))

In [None]:
collections.Counter(aug.only_new_dataset.labels)

We now see the distribution of the photometric redshift.

In [None]:
diverg_color = sns.color_palette("Set2", 6, desat=1)
sn_type_color = {42: diverg_color[1], 62: diverg_color[0], 90: diverg_color[2]}
sn_type_name = {42: 'SN II', 62: 'SN Ibc', 90: 'SN Ia'}
unique_types = [90, 42, 62] #, 52, 67, 95]
datasets_ls = ['-', '-', '--']
datasets_linewidth = [1, 3, 3]
datasets_bw_adjust = [.3, .4, .4]

In [None]:
bins = np.linspace(0, 1.5, 65)
for sn_type in unique_types: # sns scale 2
    plt.figure(figsize=(12, 4))
    for i, metadata in enumerate(datasets_metadata):
        label = datasets_label[i]
        ls = datasets_ls[i]
        linewidth = datasets_linewidth[i]
        bw_adjust = datasets_bw_adjust[i]
        is_sn_type = (metadata['target'] == sn_type)
        sn_type_metadata = metadata[is_sn_type]
        spec_zs = sn_type_metadata['hostgal_specz']
        if np.min(spec_zs) < 0:
            print(label)
            spec_zs = sn_type_metadata['sim_redshift_cmb']
        sns.distplot(a=spec_zs, bins=bins,
                     label=label, 
                     kde_kws={'linestyle': ls, 
                              'linewidth': linewidth, 
                              'bw_adjust': bw_adjust})
    sn_name = sn_type_name[sn_type]
    plt.title('trap .1*2/d log(z) z aug\n'+sn_name)
    plt.xlabel('Simulated redshift')
    plt.legend()
    plt.xlim(0, 1.5)

[Go back to top.](#index)

#### 3.4.2. Target Number Observations <a name="distrNumberObs"></a>

We compute the number of observations in each light curve.

In [None]:
def compute_number_obs(dataset):
    obj_names = dataset.object_names
    number_obs = np.zeros(len(obj_names))
    for i in np.arange(len(obj_names)):
        obj = obj_names[i]
        obj_data = dataset.data[obj].to_pandas()
        number_obs[i] = np.shape(obj_data)[0]
    return number_obs

In [None]:
train_number_obs = compute_number_obs(dataset)
aug_metadata['number_obs'] = compute_number_obs(aug_data)
test_metadata['number_obs'] = compute_number_obs(test_data)

In [None]:
print(np.min(train_number_obs), np.min(aug_metadata['number_obs']), np.min(test_metadata['number_obs']))
print(np.max(train_number_obs), np.max(aug_metadata['number_obs']), np.max(test_metadata['number_obs']))
print(np.mean(train_number_obs), np.mean(aug_metadata['number_obs']), np.mean(test_metadata['number_obs']))

In [None]:
bins = np.linspace(0, 250, 51)
g = sns.distplot(a=train_number_obs, kde=True, norm_hist=True,
                 label='train', bins=bins)
g = sns.distplot(a=test_metadata['number_obs'], kde=True, norm_hist=True,
                 label='test', bins=bins)
g = sns.distplot(a=aug_metadata['number_obs']+2, kde=True, norm_hist=True,
                 label='aug. train', bins=bins)
plt.legend()
#plt.title(f'All SN')
plt.ylabel('Density')
plt.xlabel('Total number of observations')

In [None]:
bins = np.linspace(0, 250, 51)
g = sns.distplot(a=train_number_obs, kde=True, norm_hist=True,
                 label='train', bins=bins, hist=False)
g = sns.distplot(a=test_metadata['number_obs'], kde=True, norm_hist=True,
                 label='test', bins=bins, hist=False)
g = sns.distplot(a=aug_metadata['number_obs'], kde=True, norm_hist=True,
                 label='aug. train', bins=bins, hist=False)
plt.legend()
#plt.title(f'All SN')
plt.ylabel('Density')
plt.xlabel('Total number of observations')
# aug noroll has 6obs less in general than the test

In [None]:
lc_test = analysis.compute_lc_length(test_data)
lc_aug = analysis.compute_lc_length(aug.only_new_dataset)

In [None]:
g = sns.distplot(a=lc_test, kde=True, norm_hist=True,
                 label='test', bins=bins, hist=False)
g = sns.distplot(a=lc_aug, kde=True, norm_hist=True,
                 label='aug. train', bins=bins, hist=False)
plt.legend()

[Go back to top.](#index)

#### 3.4.3. Observations Uncertainty <a name="distrUnc"></a>

First we compute the uncertainty in each passband for each light curve.

In [None]:
def make_big_pb_unc_table(dataset, pb, subset=None):
    if subset is None:
        obj_names = dataset.object_names
    else:
        obj_names = dataset.object_names[subset]
    metadata = dataset.metadata
    unc_pb = []
    obs_target = []
    for obj in obj_names:
        obj_data = dataset.data[obj].to_pandas()
        is_pb = obj_data['filter'] == pb
        obj_data_pb = obj_data[is_pb]
        unc_pb.append(obj_data_pb['flux_error'])
        obs_target.append(len(obj_data_pb) * [metadata.loc[obj, 'target']])
    unc_pb = pd.concat(unc_pb, ignore_index=True)
    obs_target = pd.DataFrame([inner for outer in obs_target for inner in outer])
    return unc_pb, obs_target

In [None]:
unc_train = []
unc_test = []
unc_aug = []
obs_target_train = []
obs_target_test = []
obs_target_aug = []
for pb in dataset.filter_set:
    unc_pb, obs_target_pb = make_big_pb_unc_table(dataset, pb)
    unc_train.append(unc_pb)
    obs_target_train.append(obs_target_pb)
    
    unc_pb, obs_target_pb = make_big_pb_unc_table(test_data, pb)
    unc_test.append(unc_pb)
    obs_target_test.append(obs_target_pb)
    
    unc_pb, obs_target_pb = make_big_pb_unc_table(aug_data, pb)
    unc_aug.append(unc_pb)
    obs_target_aug.append(obs_target_pb)
u_unc_train, g_unc_train, r_unc_train, i_unc_train, z_unc_train, y_unc_train = unc_train
u_unc_test, g_unc_test, r_unc_test, i_unc_test, z_unc_test, y_unc_test = unc_test
u_unc_aug, g_unc_aug, r_unc_aug, i_unc_aug, z_unc_aug, y_unc_aug = unc_aug
(u_obs_target_train, g_obs_target_train, r_obs_target_train, 
 i_obs_target_train, z_obs_target_train, y_obs_target_train) = obs_target_train
(u_obs_target_test, g_obs_target_test, r_obs_target_test, 
 i_obs_target_test, z_obs_target_test, y_obs_target_test) = obs_target_test
(u_obs_target_aug, g_obs_target_aug, r_obs_target_aug, 
 i_obs_target_aug, z_obs_target_aug, y_obs_target_aug) = obs_target_aug

In [None]:
bins = np.linspace(np.log(.4), np.log(30), 50)
g = sns.distplot(a=np.log(u_unc_train), kde=True, norm_hist=True,
                 label='train', bins=bins)
g = sns.distplot(a=np.log(u_unc_test), kde=True, norm_hist=True,
                 label='test', bins=bins)
g = sns.distplot(a=np.log(u_unc_aug), kde=True, norm_hist=True,
                 label='aug. train', bins=bins)

plt.legend()
plt.title('u passband')
plt.ylabel('Density')
plt.xlim(0, 5)

In [None]:
pb_colors = {'lsstu':'#984ea3', 'lsstg':'#377eb8', 'lsstr':'#4daf4a', 
             'lssti':'#e3c530', 'lsstz':'#ff7f00', 'lssty':'#e41a1c'} # colours for the plot

In [None]:
#datasets_flux_error, pb = [u_unc_train, u_unc_aug, u_unc_test], 'lsstu'
#datasets_flux_error, pb = [g_unc_train, g_unc_aug, g_unc_test], 'lsstg'
#datasets_flux_error, pb = [r_unc_train, r_unc_aug, r_unc_test], 'lsstr'
#datasets_flux_error, pb = [i_unc_train, i_unc_aug, i_unc_test], 'lssti'
# datasets_flux_error, pb = [z_unc_train, z_unc_aug, z_unc_test], 'lsstz'
datasets_flux_error, pb = [y_unc_train, y_unc_aug, y_unc_test], 'lssty'
datasets_label = ['Train set unc.', 'Aug. set unc.', 'Test set unc.']
datasets_ls = ['-', '-', '--']
datasets_linewidth = [1, 3, 3]

bins = np.linspace(np.log(.4), np.log(30), 50)

for i, metadata in enumerate(datasets_flux_error):
    label = datasets_label[i]
    ls = datasets_ls[i]
    linewidth = datasets_linewidth[i]
    sns.distplot(a=np.log(metadata), kde=True, color=pb_colors[pb],
                 hist=False, label=label, bins=bins,
                 kde_kws={'linestyle':ls, 'linewidth':linewidth,
                          'bw_adjust':.7})
#sn_name = sn_type_name[sn_type]
plt.title(f'Passband {pb}')
plt.xlim(-1, 4) # g with DDF train
#plt.ylim(0, 3)
#plt.xscale('log')
plt.xlabel('Log(Flux uncertainty)')
plt.ylabel('Density')
plt.legend(handletextpad=.3)
plt.legend(handletextpad=.3, borderaxespad=.3, handlelength=1,
       labelspacing=.2, borderpad=.3, columnspacing=.4)

#### Other

In [None]:
def plot_feature_space(metadata, title, y_feature_name, ylabel, ylim, yscale):
    fig, axs = plt.subplots(2, 2, sharex='col', sharey='row', figsize=(7, 7),
                            gridspec_kw={'hspace': 0, 'wspace': 0, 
                                         'width_ratios': [4, 1], 
                                         'height_ratios': [1, 4]})
    (ax1, ax2), (ax3, ax4) = axs
    fig.suptitle(title, y=.94)
    n_bins = 10000
    for sn_type in unique_types:
        is_sn_type = (metadata['target'] == sn_type)
        sn_type_metadata = metadata[is_sn_type]
        y_feature = sn_type_metadata[y_feature_name]
        ax3.plot(sn_type_metadata['hostgal_photoz'], 
                 y_feature, alpha=.1, 
                 linestyle='', marker='.', 
                 color=sn_type_color[sn_type])
        ax3.plot(0, 1, 'o', color=sn_type_color[sn_type], 
                 label=sn_type_name[sn_type])

        ax1.hist(sn_type_metadata['hostgal_photoz'], n_bins, density=True, 
                 histtype='step', cumulative=True, label='CDF', linewidth=1.5, 
                 color=sn_type_color[sn_type])

        ax4.hist(y_feature, n_bins, density=True, histtype='step', 
                 cumulative=True, label='CDF', linewidth=1.5, orientation='horizontal', 
                 color=sn_type_color[sn_type])

    ax1.set_ylim(-.1, 1.1)
    ax1.set_ylabel('CDF')
    ax4.set_xlim(-.1, 1.1)
    ax4.set_xlabel('CDF')
    ax4.set_xticks([0., .5, 1.])

    ax3.legend(handletextpad=.3, borderaxespad=.3, labelspacing=.2, 
               borderpad=.2, columnspacing=.4)
    ax3.set_xscale('log')
    ax3.set_yscale(yscale)
    ax3.set_xlim(.01, 4)
    ax3.set_ylim(ylim)
    ax3.set_xlabel('Photometric z')
    ax3.set_ylabel(ylabel)

    fig.delaxes(axs[0][1])
    for ax in axs.flat:
        ax.label_outer()

In [None]:
def compute_max_flux(dataset, return_time=False):
    """Max flux in all pbs"""
    obj_names = dataset.object_names
    pbs = dataset.filter_set
    max_flux = pd.DataFrame(index=obj_names, columns=pbs, dtype=float)
    time_max_flux = pd.DataFrame(index=obj_names, columns=pbs, dtype=float)
    
    for obj in obj_names:
        obj_gps = dataset.models[obj].to_pandas()
        for pb in pbs:
            is_pb = obj_gps['filter'] == pb
            obj_pb = obj_gps[is_pb].reset_index()
            max_flux.loc[obj, pb] = np.max(obj_pb['flux'])
            if return_time:
                index_max = np.argmax(obj_pb['flux'])
                time_max_flux.loc[obj, pb] = obj_pb.loc[index_max, 'mjd']
    if return_time:
        return max_flux, time_max_flux
    return max_flux

In [None]:
max_flux_train = compute_max_flux(dataset)
max_flux_aug = compute_max_flux(aug_data)
max_flux_test = compute_max_flux(test_data)

### 3.5. Select subset of Augmented Dataset<a name="selectAug"></a>

Here we see select a subset of the augmented data to make it balanced.

In [None]:
aug_data = aug.only_new_dataset
aug_metadata = aug_data.metadata.copy()

In [None]:
unique_types = [90, 42, 62]

In [None]:
unique_types

In [None]:
collections.Counter(aug.only_new_dataset.labels)

In [None]:
np.random.seed(42)
objs_to_keep = []

for sn_type in unique_types:
    is_sn = aug_metadata.target == sn_type
    indexes = np.where(is_sn)[0]
    try:
        if i == 44: # placeholder for when I only want 1 class
            indexes_to_stay = indexes
        else:
            indexes_to_stay = np.random.choice(indexes,
                                               size=15440,
                                               replace=False)
    except ValueError:
        print(f'The class {aug_metadata.target[is_sn][0]} only has {len(indexes)} events.')
        indexes_to_stay = indexes
    objs_to_stay = is_sn[indexes_to_stay].index.to_numpy()
    objs_to_keep.append(objs_to_stay)
objs_to_keep = np.concatenate(objs_to_keep)

In [None]:
new_metadata = aug_metadata.loc[objs_to_keep]

In [None]:
aug_data.object_names = list(new_metadata.index)

In [None]:
aug_data.update_dataset(list(new_metadata.index))
aug_data.update_dataset(list(aug_data.metadata.index))

In [None]:
collections.Counter(aug_data.metadata['target'])

In [None]:
len(aug_data.object_names)

In [None]:
len(aug_data.metadata)

In [None]:
print(1262*3)
print(2782*3)
print(6256*3)
print(18928*3)

## 4. Save Augmented Dataset<a name="saveAug"></a>

Now, we save the `PlasticcData` instance containing only the augmented events. **Chose** a path to save (`folder_path_to_save`) and the name of the file (`file_name`).

In [None]:
folder_path_to_save = folder_path[:-9] + 'augmented_data/'
file_name = 'aug_wfd_46k_updated.pckl'

At this point we could also choose to save only part of the augmented dataset. Here we save all the augmented events.

**Add** an extra step to select your chosen subset. See the notebook [1_load_data](1_load_data.ipynb) for a tutorial on how to select a subset from a `PlasticcData` instance. This can be used, for example, to create an augmented training set with the same number of events in each class. For a working example of how to balance the augmented training set, see the notebook [example_plasticc](example_plasticc.ipynb).

Finally, save the `PlasticcData` instance.

In [None]:
#only_aug_dataset = aug.only_new_dataset
only_aug_dataset = aug_data

path_to_save = os.path.join(folder_path_to_save, file_name)
with open(path_to_save, 'wb') as f:
    pickle.dump(only_aug_dataset, f, pickle.HIGHEST_PROTOCOL)

## 5. Light curve visualization<a name="see"></a>

Here we show the light curve of an event along with one of the synthetic events generated from it.

In [None]:
obj_show = '595791'
sndata.PlasticcData.plot_obj_and_model(dataset.data[obj_show])
photo_z = dataset.metadata.loc[obj_show, 'hostgal_photoz']
plt.title(f'Event {obj_show}; z = {photo_z:.3f}')
print(dataset.metadata.loc[obj_show, 'hostgal_photoz'], 
      dataset.metadata.loc[obj_show, 'hostgal_specz'])
obj_data = dataset.data[obj_show] 
print(obj_data[obj_data['detected']==1])

In [None]:
obj_aug_show = obj_show + '_aug32'
sndata.PlasticcData.plot_obj_and_model(only_aug_dataset.data[obj_aug_show])
photo_z = only_aug_dataset.metadata.loc[obj_aug_show, 'hostgal_photoz']
plt.title(f'Event {obj_aug_show}; z = {photo_z:.3f}')
print(only_aug_dataset.metadata.loc[obj_aug_show, 'hostgal_photoz'], 
      only_aug_dataset.metadata.loc[obj_aug_show, 'hostgal_specz'])
obj_data = only_aug_dataset.data[obj_aug_show] 
print(obj_data[obj_data['detected']==1])

In [None]:
oo = only_aug_dataset.data[obj_aug_show]
oo[oo['detected']==1]

[Go back to top.](#index)