# Make photoz plots

In this notebook we make photo-z plots that include the training set, augmented training set, and test set.

#### Index<a name="index"></a>
1. [Import Packages](#imports)
2. [Load Dataset](#loadData)
    1. [Load train dataset](#loadTrain)
    2. [Load augmented train metadata](#loadAug)
    3. [Load test metadata](#loadTest)
3. [Plot photo-z distribution](#plot)
    1. [Setup](#setup)
    3. [Photo-z distribution](#photoz)
4. [Number of events](#nEvents)

## 1. Import Packages<a name="imports"></a>

In [None]:
import collections
import os
import pickle
import time

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

In [None]:
from snmachine import snaugment, sndata
from utils.plasticc_pipeline import get_directories, load_dataset

In [None]:
%config Completer.use_jedi = False  # enable autocomplete

#### Aestetic settings

In [None]:
%matplotlib inline

size_default = 1.5
size_larger = 1.9
sns.set(font_scale=size_default, style="ticks", context="paper")
sns.set(font_scale=size_default, style="ticks")

## 2. Load Dataset<a name="loadData"></a>

First, **write** the path to the folder that contains the dataset we want to augment, `folder_path`.

In [None]:
# os_name = 'baseline_v2_0_paper'
# os_name = 'noroll_v2_0_paper'
os_name = 'presto_v2_0_paper'

folder_path = f'/folder/path/'
folder_analysis_path = folder_path[:-14] + 'analyses'

**Set** `is_only_roll` to $1$ to consider only the rolling part of the cadence.

In [None]:
is_only_roll = 0
is_updated = 1

### 2.1. Load train dataset<a name="loadTrain"></a>

In [None]:
extra_name_to_save = 'ddf_wfd'

file_id = '000'

data_file_name = f'train_{extra_name_to_save}_{file_id}_gapless50.pckl'
if is_only_roll:
    print('only roll')
    data_file_name = f'train_{extra_name_to_save}_{file_id}_roll_gapless50.pckl'
if is_updated:
    data_file_name = data_file_name[:-5] + '_updated.pckl'
data_file_name

In [None]:
data_path = os.path.join(folder_path, data_file_name)
train_data = load_dataset(data_path)

In [None]:
metadata_train = train_data.metadata

### 2.2. Load augmented train metadata<a name="loadAug"></a>

In [None]:
analysis_name = 'aug_wfd_46k'
if is_only_roll:
    print('only roll')
    analysis_name = 'aug_wfd_roll_46k'
if is_updated:
    analysis_name = analysis_name + '_updated'
    
path_saved_photoz = os.path.join(folder_analysis_path, analysis_name, 'wavelet_features')
path_saved_plots = os.path.join(folder_analysis_path, analysis_name, 'plots')

In [None]:
with open(os.path.join(path_saved_photoz, 'features.pckl'), 'rb') as input:
    metadata_aug = pickle.load(input)  # this is not really the metadata
with open(os.path.join(path_saved_photoz, 'data_labels.pckl'), 'rb') as input:
    data_labels = pickle.load(input)  # load the class labels
metadata_aug['target'] = data_labels  # add the labels to the "metadata"

### 2.2. Load test metadata<a name="loadTest"></a>

In [None]:
time_ini = time.time()
extra_name_to_save = 'wfd'

batch_ids = ['000', '001', '002', '003', '004', '005', '006', 
             '007', '008', '009', '010', '011', '012']

# Collect the aggregated data
metadata_test_ids = []

for batch_id in batch_ids:
    print(f'Batch {batch_id}')

    # Name and path of the test subset
    data_file_name = f'test_{extra_name_to_save}_{batch_id}_gapless50.pckl'
    if is_only_roll:
        print('only roll')
        data_file_name = f'test_{extra_name_to_save}_{batch_id}_roll_gapless50.pckl'
    if is_updated:
        data_file_name = data_file_name[:-5] + '_updated.pckl'
    data_path = os.path.join(folder_path, data_file_name)
    print(data_path)

    # Path to the test subset features
    analysis_name = data_file_name[:-5]
    folder_analysis_path = folder_path[:-14] + 'analyses'
    directories = get_directories(folder_analysis_path, analysis_name) 
    path_saved_reduced_wavelets = directories['features_directory']

    # Load the extended metadata
    with open(os.path.join(path_saved_reduced_wavelets, 'extended_metadata.pckl'), 'rb') as input:
        extended_metadata = pickle.load(input)

    # Aggregate the data
    metadata_test_ids.append(extended_metadata)
metadata_test = pd.concat(metadata_test_ids)
print(time.time()-time_ini)

[Go back to top.](#index)

## 3. Plot photo-z distribution<a name="plot"></a>

### 3.0. Setup <a name="setup"></a>

In [None]:
diverg_color = sns.color_palette("Set2", 6, desat=1)
sn_type_color = {42: diverg_color[1], 62: diverg_color[0], 90: diverg_color[2], 
                 52: diverg_color[3], 67: diverg_color[4], 95: diverg_color[5]}
sn_type_name = {42: 'SN II', 62: 'SN Ibc', 90: 'SN Ia', 
                52: 'SN Iax', 67: 'SN 91bg', 95: 'SLSN'}
unique_types = [90, 42, 62] #, 52, 67, 95]
datasets_ls = ['-', '-', '--']
datasets_linewidth = [1, 3, 3]
datasets_bw_adjust = [.3, .4, .4]

In [None]:
datasets_metadata = [metadata_train, metadata_aug, metadata_test]
datasets_label = ['Train. set', 'Aug. set', 'Test set']

In [None]:
os_name_save = os_name[:-11]
if os_name[:-11] == 'baseline' and is_only_roll:
    os_name_save = os_name[:-11]+'_onlyroll'

### 3.1. Photo-z distribution<a name="photoz"></a>

In [None]:
for sn_type in unique_types: # sns scale 2
    plt.figure()
    for i, metadata in enumerate(datasets_metadata):
        label = datasets_label[i]
        ls = datasets_ls[i]
        linewidth = datasets_linewidth[i]
        bw_adjust= datasets_bw_adjust[i]
        is_sn_type = (metadata['target'] == sn_type)
        sn_type_metadata = metadata[is_sn_type]
        try:
            sns.kdeplot(data=sn_type_metadata['hostgal_photoz'],
                        label=label, color=sn_type_color[sn_type],
                        linestyle=ls, linewidth=linewidth, 
                        bw_adjust=bw_adjust, clip=(0,None))
        except (ValueError, NameError):  # outdated version of matplotlib
            sns.distplot(a=sn_type_metadata['hostgal_photoz'], 
                         label=label, color=sn_type_color[sn_type],
                         kde_kws={'linestyle': ls, 
                                  'linewidth': linewidth, 
                                  'bw_adjust': bw_adjust})
    sn_name = sn_type_name[sn_type]
    plt.title(sn_name)
    plt.xlim(-.1, 1.2)
    plt.ylim(0, 3.)
    plt.ylim(0, 3.6)
#     if sn_name == 'SN Ibc':
#         plt.ylim(0, 3.6)
    plt.xlabel('Photometric redshift')
    plt.ylabel('Density')
    plt.legend(handletextpad=.3) 
    
    sn_name_save = sn_name.replace(' ', '').lower()
    print(os.path.join(path_saved_plots, f'photoz_{os_name_save}_{sn_name_save}_36.pdf'))
#     plt.savefig(os.path.join(path_saved_plots, f'photoz_{os_name_save}_{sn_name_save}_36.pdf'), 
#                 bbox_inches='tight')

[Go back to top.](#index)

## 4. Number of events<a name="nEvents"></a>

Here we check the number of events and proportions per dataset.

In [None]:
sn_type_name = {42: 'SN II', 62: 'SN Ibc', 90: 'SN Ia'}

In [None]:
for i, metadata in enumerate(datasets_metadata):
    label = datasets_label[i]
    print(label)
    counts = collections.Counter(metadata['target'])
    for key in counts.keys():
        print(sn_type_name[key], counts[key]/len(metadata['target']), counts[key])
    print(len(metadata['target']))
    print('')

[Go back to top.](#index)