In [None]:
import numpy as np
import json
import matplotlib.pyplot as plt
import os
from astropy.io import fits

In [None]:
#Put the inputs:
gal_fname = 'PTF12gzk_gal.json'
sn_type = 'Ib'
sn_fname = 'iPTF13bvn_discovery346.csv'
epoch = '346'
iau_name = 'SNiPTF13bvn'
redshift = 0.0137

In [None]:
def line_remover_manual(wl, flux):
    
    question1 = input("Are there galaxy lines? [y/n/exit]")
    
    while question1 not in ['y', 'n', 'exit']:
        print('That is an invalid input, try again')
        question1 = input("Are there galaxy lines? [y/n/exit]")
    
    if question1 == 'y':
        
        question2 = print("""Give rough edges to the line complexes 
                             (format: [xxxx,xxxx]). When all complexes are given,
                             write 'done' """)
        complexes = []
        while question2 != 'done':
            question2 = input("""Line complex edges: """)
            if question2 == 'exit':
                sys.exit()
            elif question2 == 'done':
                break
            
            complex_edges = [int(question2[1:5]), int(question2[6:10])]
            complexes.append(complex_edges)
        
        for i in range(len(complexes)):
            this_complex = complexes[i]
            left = np.where(wl > this_complex[0])[0][0] - 1
            right = np.where(wl > this_complex[1])[0][0]
            
            wl_diff = wl[right]-wl[left]
            flux_diff = flux[right]-flux[left]
            
            flux[left:right+1] = (flux_diff/wl_diff) * (wl[left:right+1]-wl[left]) + flux[left] 
            
        print("Now follows the galaxy corrected spectrum")
        
    return wl, flux

def save_standardised_output_spectrum(wl, flux, iau_name, epoch, sn_type):
    
    sn_standard_name = iau_name.replace(" ", "")
    obs_epoch = epoch
    
    save_folder = '/home/stba7609/SECRETO/WISEREP_' + sn_type + '/Standardised_spectra/'
    save_name = sn_fname
    final_array = np.zeros((len(wl), 2))
    final_array[:, 0] = wl
    final_array[:, 1] = flux
                    
    np.savetxt(save_folder + save_name, final_array, delimiter = ',', header = 'Wavelength [Å], Flux [10**-15 erg s-1 cm-2]')  

In [None]:
#Read in the galaxy

if '.json' in gal_fname:
    f = open(gal_fname)
    data = json.load(f)
    wl_gal = np.array(data['traces']['1163225900049786880']['wavelength'])
    flux_gal = np.array(data['traces']['1163225900049786880']['flambda'])
    wl_gal /= (1+redshift)
    
elif '.fits' in gal_fname:
    hdul = fits.open(gal_fname)
    spectrum = hdul[1].data
    print(spectrum)
    wl_gal, flux_gal = [], []
    for i in range(len(spectrum)):
        wl_gal.append(spectrum[i][0])
        flux_gal.append(spectrum[i][1])

    wl_gal = np.array(wl_gal)
    flux_gal = np.array(flux_gal)

#Plot galaxy before
plt.plot(wl_gal, flux_gal)
plt.xlim(5000, 8000)
plt.show()

#Remove lines from galaxy
wl_gal, flux_gal = line_remover_manual(wl_gal, flux_gal)

#Plot galaxy past
plt.plot(wl_gal, flux_gal)
plt.xlim(5000, 8000)
plt.show()

In [None]:
#Fit the galaxy by taking the 7800-8000 region as being all galaxy
#Read in the spectrum
spec_dir = '/home/stba7609/SECRETO/WISEREP_' + sn_type + '/Standardised_spectra/'
spectrum = np.loadtxt(spec_dir + sn_fname, delimiter = ',')
sn_wl, sn_flux = spectrum[:, 0], spectrum[:, 1]

plt.plot(sn_wl, sn_flux)
plt.plot(wl_gal, flux_gal*0.7)
plt.xlim(5000, 8000)
#plt.ylim(0, 1)
plt.show()

ratios = []
closest_flux = []
for i in range(len(sn_wl)):
    this_wl = sn_wl[i]
    index = np.where(abs(this_wl - wl_gal) == np.min(abs(this_wl - wl_gal)) )[0][0]
    ratios.append(sn_flux[i]/flux_gal[index])
    closest_flux.append(flux_gal[index])
    
ratios = np.array(ratios)
closest_flux = np.array(closest_flux)

#Now take the 7800-8000 region as being all due to gal:
mask = (sn_wl > 7800)*(sn_wl < 8000)
tot_ratio = np.mean(ratios[mask])

#Remove the galaxy
sn_flux = sn_flux - closest_flux * tot_ratio * 0.7

plt.plot(sn_wl, sn_flux/np.max(sn_flux))
plt.xlim(5000, 8000)
plt.ylim(0, 1)
plt.show()

In [None]:
plt.plot(sn_wl, sn_flux/np.max(sn_flux))
plt.xlim(5000, 8000)
plt.show()

In [None]:
#Remove lines from SN and save

plt.plot(sn_wl, sn_flux)
plt.xlim(5000, 8000)
plt.show()

sn_wl, sn_flux = line_remover_manual(sn_wl, sn_flux)

plt.plot(sn_wl, sn_flux)
plt.xlim(5000, 8000)
plt.show()

In [None]:
save_standardised_output_spectrum(sn_wl, sn_flux, iau_name, epoch, sn_type)