In [1]:
!uname -a

Linux compute-0-26.local 2.6.32-642.el6.x86_64 #1 SMP Tue May 10 17:27:01 UTC 2016 x86_64 x86_64 x86_64 GNU/Linux


In [2]:
!pwd

/home/tallam/plasticc/snmachine/examples


In [3]:
!pip install ../.

Processing /home/tallam/plasticc/snmachine
Building wheels for collected packages: snmachine
  Running setup.py bdist_wheel for snmachine ... [?25ldone
[?25h  Stored in directory: /tmp/pip-ephem-wheel-cache-nztb3k7y/wheels/cd/65/db/fda56ff3f0d6fa8ba1e7b69dab8a17be3a2bbe7940a42d6151
Successfully built snmachine
Installing collected packages: snmachine
  Found existing installation: snmachine 1.1.1
    Uninstalling snmachine-1.1.1:
      Successfully uninstalled snmachine-1.1.1
Successfully installed snmachine-1.1.1


# Notebook for running the snmachine pipeline on PLAsTiCC simulated data

This notebook illustrates the use of the `snmachine` supernova classification package by classifying a subset simulated data from the photometric light-curve astronomical time-series classification challenge (PLAsTiCC). 

See Lochner et al. (2016) http://arxiv.org/abs/1603.00882 for the original SPCC-challenge test.

<img src="pipeline.png" width=600>

This image illustrates the how the pipeline works. As the user, you can choose what feature extraction method you want to use. Here we have three (four, technically, since there are two parametric models) but it's straighforward to write a new feature extraction method. Once features have been extracted, they can be run through one of several machine learning algorithms and again, it's easy to write your own algorithm into the pipeline. There's a convenience function in `snclassifier` to run a feature set through multiple algorithms and plot the result. The rest of this notebook goes through applying each of the feature extraction methods to a set of simulations and running all feature sets through different classification algorithms.

In [4]:
%%capture --no-stdout 
#I use this to supress unnecessary warnings for clarity
%load_ext autoreload
%autoreload #Use this to reload modules if they are changed on disk while the notebook is running
from snmachine import sndata, snfeatures, snclassifier, tsne_plot
import numpy as np
import matplotlib.pyplot as plt
import time, os, pywt,subprocess
from sklearn.decomposition import PCA
from astropy.table import Table,join,vstack,unique
from astropy.io import fits
import sklearn.metrics 
import sncosmo
import pickle
%matplotlib nbagg

In [6]:
# Set the number of processes you want to use throughout the notebook
import multiprocessing
num_cpu = multiprocessing.cpu_count()
nproc=num_cpu
print("Running with {} cores".format(num_cpu))

Running with 40 cores


## Set up output structure

We make lots of output files so it makes sense to put them in one place. This is the recommended output file structure.

## Initialise dataset object

Load a subset of the PLAsTiCC simulated data (https://arxiv.org/abs/1810.00001)

In [7]:
# Please specify Data root, 
# the path to where you have pulled all the data from
rt='/share/hypatia/snmachine_resources/data/cwp/DDFY1/RH_kraken_2026_ddf_DDF_1aONLY_Y1_G10/'
prefixIa='RH_DDF_1aONLY_Y1_G10_Ia-'
prefixNONIa='RH_DDF_1aONLY_Y1_G10_NONIa-'
# Name for the dataset
dataset='kraken_2026_ddf_Y1'

In [8]:
# WARNING...
#Multinest uses a hardcoded character limit for the output file names. I believe it's a limit of 100 characters
#so avoid making this file path to lengthy if using nested sampling or multinest output file names will be truncated

#Change outdir to somewhere on your computer if you like
outdir=os.path.join('output_{}_no_z'.format(dataset),'')
out_features=os.path.join(outdir,'features') #Where we save the extracted features to
out_class=os.path.join(outdir,'classifications') #Where we save the classification probabilities and ROC curves
out_int=os.path.join(outdir,'int') #Any intermediate files (such as multinest chains or GP fits)

final_outdir="/share/hypatia/snmachine_resources/data/LSST_Cadence_WhitePaperClassResults/output_data/output_{}_no_z/".format(dataset)

subprocess.call(['mkdir',outdir])
subprocess.call(['mkdir',out_features])
subprocess.call(['mkdir',out_class])
subprocess.call(['mkdir',out_int])

1

In [9]:
dat=sndata.LSSTCadenceSimulations(folder=rt,prefix_Ia=prefixIa, prefix_NONIa=prefixNONIa, indices=range(1,21))
#dat=sndata.plasticc_data(folder=rt,pickle_file='dataset_full.pickle',from_pickle=True)

Reading data...
chunk 01
chunk 02
chunk 03
chunk 04
chunk 05
chunk 06
chunk 07
chunk 08
chunk 09
chunk 10
chunk 11
chunk 12
chunk 13
chunk 14
chunk 15
chunk 16
chunk 17
chunk 18
chunk 19
chunk 20
0k
10k
20k
30k
33862 objects read into memory.


Now we can plot all the data and cycle through it (left and right arrows on your keyboard)

In [10]:
dat.plot_all(mix=True, sep_detect=False)

<IPython.core.display.Javascript object>

In [11]:
# Get the types, note these are internal snmachine datatypes
types=dat.get_types()

Each light curve is represented in the Dataset object as an astropy table, compatible with `sncosmo`:

Note: The types listed here in the table the internal types to snmachine

In [12]:
dat.data['1002130']

mjd,flux,flux_error,filter,zp,zpsys
float32,float32,float32,str5,float64,str2
0.0,12.7787,1.05079,lsstr,27.5,ab
0.0152,19.5901,1.5723,lssti,27.5,ab
0.0262,14.1931,2.4764,lsstz,27.5,ab
0.0371,9.8337,7.46217,lssty,27.5,ab
2.8961,12.955,1.50977,lsstr,27.5,ab
2.9113,16.7763,2.06007,lssti,27.5,ab
2.9223,10.5178,2.47485,lsstz,27.5,ab
2.9332,10.6499,7.6048,lssty,27.5,ab
16.9412,7.41427,0.839142,lsstr,27.5,ab
16.9565,12.4458,1.38946,lssti,27.5,ab


### Inspect GP fitting capability for individual objects

In [13]:
#test_obj = '3211874' # a nice Ia
test_obj = '1173217'

In [14]:
sn = dat.data[test_obj]
g=snfeatures._GP(test_obj, dat,ngp=100,xmin=0,xmax=dat.get_max_length(),initheta=[500,20], save_output=True, output_root=os.path.join(final_outdir, 'int', ''))
dat.models[test_obj] = g
type(dat)
#dat.plot_lc(test_obj, plot_model=True)
plt.figure()
dat.plot_lc(test_obj, plot_model=True)

1173217


<IPython.core.display.Javascript object>

## Extract features for the data

The next step is to extract useful features from the data. This can often take a long time, depending on the feature extraction method, so it's a good idea to save these to file (`snmachine` by default saves to astropy tables)

In [15]:
read_from_file=False #We can use this flag to quickly rerun from saved features
run_name=os.path.join(out_features,'{}_all'.format(dataset))
read_from_pickle=False
pickle_location = rt
restart_from_GP = False
restart_from_wavefeats=False
restart_from_wavelets=False

### Wavelet features

The wavelet feature extraction process is quite complicated, although it is fairly fast. Remember to save the PCA eigenvalues, vectors and mean for later reconstruction!

In [16]:
#waveFeats=snfeatures.WaveletFeatures()
wavelet_feats=snfeatures.WaveletFeatures(wavelet='sym2', ngp=100)

In [17]:
#%%capture --no-stdout
if read_from_file:
    wave_features=Table.read('%s_wavelets.dat' %run_name, format='ascii')
    #Crucial for this format of id's
    blah=wave_features['Object'].astype(str)
    wave_features.replace_column('Object', blah)
    PCA_vals=np.loadtxt('%s_wavelets_PCA_vals.dat' %run_name)
    PCA_vec=np.loadtxt('%s_wavelets_PCA_vec.dat' %run_name)
    PCA_mean=np.loadtxt('%s_wavelets_PCA_mean.dat' %run_name)
elif read_from_pickle:
    print('THIS IS NOT CURRENTLY IMPLEMENTED')
    f = open(rt)
    wave_features=Table.read('%s_wavelets.dat' %run_name, format='ascii')
    #Crucial for this format of id's
    blah=wave_features['Object'].astype(str)
    wave_features.replace_column('Object', blah)
    PCA_vals=np.loadtxt('%s_wavelets_PCA_vals.dat' %run_name)
    PCA_vec=np.loadtxt('%s_wavelets_PCA_vec.dat' %run_name)
    PCA_mean=np.loadtxt('%s_wavelets_PCA_mean.dat' %run_name)

elif restart_from_GP:
    wave_features=waveFeats.extract_features(dat,nprocesses=nproc,output_root=rt,save_output='all',restart='gp')
    wave_features.write('%s_wavelets.dat' %run_name, format='ascii')
    np.savetxt('%s_wavelets_PCA_vals.dat' %run_name,waveFeats.PCA_eigenvals)
    np.savetxt('%s_wavelets_PCA_vec.dat' %run_name,waveFeats.PCA_eigenvectors)
    np.savetxt('%s_wavelets_PCA_mean.dat' %run_name,waveFeats.PCA_mean)
    
    PCA_vals=waveFeats.PCA_eigenvals
    PCA_vec=waveFeats.PCA_eigenvectors
    PCA_mean=waveFeats.PCA_mean
    
elif restart_from_wavefeats:
    wave_features=Table.read(rt  + 'wavelet_features.fits',format='fits')
    wave_features.write('%s_wavelets.dat' %run_name, format='ascii')
    f = open(rt+'PCA_eigenvals.pickle','rb')
    PCA_vals=pickle.load(f)
    f.close()
    f = open(rt+'PCA_eigenvectors.pickle','rb')
    PCA_vec=pickle.load(f)
    f.close()
    f = open(rt+'PCA_mean.pickle','rb')
    PCA_mean=pickle.load(f)
    f.close()
    np.savetxt('%s_wavelets_PCA_vals.dat' %run_name,PCA_vals)
    np.savetxt('%s_wavelets_PCA_vec.dat' %run_name,PCA_vec)
    np.savetxt('%s_wavelets_PCA_mean.dat' %run_name,PCA_mean)

elif restart_from_wavelets:
    # RESTART FROM WAVELETS
    # Copy int to finaldir and read in raw wavelets
    wavelet_feats=snfeatures.WaveletFeatures(wavelet='sym2', ngp=100)
    wave_raw, wave_err=wavelet_feats.restart_from_wavelets(dat, os.path.join(final_outdir, 'int', ''))
    wavelet_features,vals,vec,means=wavelet_feats.extract_pca(dat.object_names.copy(), wave_raw)

else:
    wavelet_features=wavelet_feats.extract_features(dat,nprocesses=nproc,output_root=out_int,save_output='all')
    wavelet_features.write('%s_wavelets.dat' %run_name, format='ascii')
    np.savetxt('%s_wavelets_PCA_vals.dat' %run_name,wavelet_feats.PCA_eigenvals)
    np.savetxt('%s_wavelets_PCA_vec.dat' %run_name,wavelet_feats.PCA_eigenvectors)
    np.savetxt('%s_wavelets_PCA_mean.dat' %run_name,wavelet_feats.PCA_mean)
    
    vals=wavelet_feats.PCA_eigenvals
    vec=wavelet_feats.PCA_eigenvectors
    means=wavelet_feats.PCA_mean























































































  result = self.as_array() == other
























































































Process ForkPoolWorker-181:
Process ForkPoolWorker-171:
Process ForkPoolWorker-180:
Process ForkPoolWorker-170:
Process ForkPoolWorker-187:
Process ForkPoolWorker-169:
Process ForkPoolWorker-178:
Process ForkPoolWorker-172:
Process ForkPoolWorker-195:
Process ForkPoolWorker-200:
Process ForkPoolWorker-168:
Process ForkPoolWorker-198:
Process ForkPoolWorker-182:
Process ForkPoolWorker-199:
Process ForkPoolWorker-164:
Process ForkPoolWorker-184:
Process ForkPoolWorker-190:
Process ForkPoolWorker-197:
Process ForkPoolWorker-188:
Process ForkPoolWorker-177:
Process ForkPoolWorker-165:
Process ForkPoolWorker-161:
Process ForkPoolWorker-179:
Process ForkPoolWorker-183:
Process ForkPoolWorker-166:
Process ForkPoolWorker-193:
Process ForkPoolWorker-173:
Process ForkPoolWorker-174:
Process ForkPoolWorker-176:
Process ForkPoolWorker-167:
Process ForkPoolWorker-186:
Process ForkPoolWorker-196:
Process ForkPoolWorker-163:
Process ForkPoolWorker-191:
Process ForkPoolWorker-194:
Process ForkPoolWork

  File "/home/tallam/.conda/envs/snmplasticc/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/tallam/.conda/envs/snmplasticc/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/home/tallam/.conda/envs/snmplasticc/lib/python3.6/multiprocessing/pool.py", line 108, in worker
    task = get()
  File "/home/tallam/.conda/envs/snmplasticc/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/tallam/.conda/envs/snmplasticc/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/home/tallam/.conda/envs/snmplasticc/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/tallam/.conda/envs/snmplasticc/lib/python3.6/multiprocessing/proce

  File "/home/tallam/.conda/envs/snmplasticc/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/tallam/.conda/envs/snmplasticc/lib/python3.6/multiprocessing/queues.py", line 334, in get
    with self._rlock:
  File "/home/tallam/.conda/envs/snmplasticc/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/tallam/.conda/envs/snmplasticc/lib/python3.6/multiprocessing/queues.py", line 334, in get
    with self._rlock:
  File "/home/tallam/.conda/envs/snmplasticc/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
Traceback (most recent call last):
  File "/home/tallam/.conda/envs/snmplasticc/lib/python3.6/multiprocessing/synchronize.py", line 96, in __enter__
    return self._semlock.__enter__()
  File "/home/tallam/.conda/envs/snmplasticc/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "

  File "/home/tallam/.conda/envs/snmplasticc/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/tallam/.conda/envs/snmplasticc/lib/python3.6/multiprocessing/pool.py", line 108, in worker
    task = get()
  File "/home/tallam/.conda/envs/snmplasticc/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
KeyboardInterrupt
KeyboardInterrupt
  File "/home/tallam/.conda/envs/snmplasticc/lib/python3.6/multiprocessing/queues.py", line 334, in get
    with self._rlock:
KeyboardInterrupt
  File "/home/tallam/.conda/envs/snmplasticc/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/tallam/.conda/envs/snmplasticc/lib/python3.6/multiprocessing/pool.py", line 108, in worker
    task = get()
  File "/home/tallam/.conda/envs/snmplasticc/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/tallam/.

  File "/home/tallam/.conda/envs/snmplasticc/lib/python3.6/multiprocessing/pool.py", line 108, in worker
    task = get()
  File "/home/tallam/.conda/envs/snmplasticc/lib/python3.6/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
  File "/home/tallam/.conda/envs/snmplasticc/lib/python3.6/multiprocessing/pool.py", line 108, in worker
    task = get()
  File "/home/tallam/.conda/envs/snmplasticc/lib/python3.6/multiprocessing/synchronize.py", line 96, in __enter__
    return self._semlock.__enter__()
  File "/home/tallam/.conda/envs/snmplasticc/lib/python3.6/multiprocessing/synchronize.py", line 96, in __enter__
    return self._semlock.__enter__()
  File "/home/tallam/.conda/envs/snmplasticc/lib/python3.6/multiprocessing/synchronize.py", line 96, in __enter__
    return self._semlock.__enter__()
KeyboardInterrupt
  File "/home/tallam/.conda/envs/snmplasticc/lib/python3.6/multiprocessing/synchronize.py", line 96, in __enter__
    return self._s

In [20]:
wavelet_feats

<snmachine.snfeatures.WaveletFeatures at 0x2acd7e6f6908>

In [21]:
#dat.set_model(waveFeats.fit_sn,wave_features,PCA_vec,PCA_mean,0,dat.get_max_length(),dat.filter_set)
dat.set_model(wavelet_feats.fit_sn,wavelet_features,vec,means,0,dat.get_max_length(),dat.filter_set)

In [31]:
dat.plot_all(mix=True)

<IPython.core.display.Javascript object>

In [24]:
plt.figure()
dat.plot_lc('1173217', plot_model=True)

<IPython.core.display.Javascript object>

## Classify

Finally, we're ready to run the machine learning algorithm. There's a utility function in the `snclassifier` library to make it easy to run all the algorithms available, including converting features to `numpy` arrays and rescaling them and automatically generating ROC curves and metrics. Hyperparameters are automatically selected using a grid search combined with cross-validation. All functionality can also be individually run from `snclassifier`.

Classifiers can be run in parallel, change this parameter to the number of processors on your machine (we're only running 4 algorithms so it won't help to set this any higher than 4).

In [25]:
#Available classifiers 
print(snclassifier.choice_of_classifiers)

### SPCC-like pre-processing

In [26]:
# Like for SPCC example notebook where we restrict ourselves to three supernova types:
# Ia (1), II (2) and Ibc (3) by carrying out the following pre-proccessing steps
types['Type'] = types['Type']-100

types['Type'][np.floor(types['Type']/10)==2]=2
types['Type'][np.floor(types['Type']/10)==3]=3
types['Type'][np.floor(types['Type']/10)==4]=2

In [27]:
fig = plt.figure()
clss, cms=snclassifier.run_pipeline(wavelet_features,types,output_name=os.path.join(out_class,'wavelets'),
                          classifiers=['random_forest'], nprocesses=nproc, return_classifier=True,
                              classifiers_for_cm_plots='all')

<IPython.core.display.Javascript object>

  ax.set_color_cycle(cols)


### Plot confusion matrix

In [28]:
import seaborn as sns
from astropy.table import Table,join,unique

In [30]:
cm = cms[0]
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
annot = np.around(cm, 2)

labels=[]
for tp_row in unique(types, keys='Type'):
    labels.append(tp_row['Type'])

fig, ax = plt.subplots(figsize=(9,7))
sns.heatmap(cm, xticklabels=labels, yticklabels=labels, cmap='Blues', annot=annot, lw=0.5)
ax.set_xlabel('Predicted Label')
ax.set_ylabel('True Label')
ax.set_aspect('equal')

<IPython.core.display.Javascript object>