In [None]:
import mcalf.models
from mcalf.utils import normalise_spectrum
import numpy as np
from astropy.io import fits
import matplotlib.pyplot as plt

# Load data files

wavelengths = np.loadtxt('wavelengths.csv', delimiter=',')  # Original wavelengths
prefilter_response_wvscl = np.loadtxt('prefilter_response_wvscl.csv', delimiter=',')
prefilter_response_main = np.loadtxt('prefilter_response_main.csv', delimiter=',')

with fits.open('spectral_data.fits') as hdul:  # Raw spectral data
    datacube = np.asarray(hdul[0].data, dtype=np.float64)

In [None]:
# Initialise the model that will use the labelled data
model = mcalf.models.IBIS8542Model(original_wavelengths=wavelengths, prefilter_ref_main=prefilter_response_main, 
                                     prefilter_ref_wvscl=prefilter_response_wvscl)

In [None]:
# Select the points to label
i_points, j_points = np.load('labelled_points.npy')

In [None]:
# Select the spectra to label from the data file
raw_spectra = datacube[:, i_points, j_points].T

In [None]:
# Normalise each spectrum to be in range [0, 1]
labelled_spectra = np.empty((len(raw_spectra), len(model.constant_wavelengths)))
for i in range(len(labelled_spectra)):
    labelled_spectra[i] = normalise_spectrum(raw_spectra[i], model=model)

In [None]:
# # Script to semi-automate the classification process

# Type a number 0 - 4 for assign a classification to the plotted spectrum
# Type 5 to skip and move on to the next spectrum
# Type 'back' to move to the previous spectrum
# Type 'exit' to give up (keeping ones already done)

# The labels are present in the `labels` variable (-1 represents an unclassified spectrum)

labels = np.full(len(labelled_spectra), -1, dtype=int)
i = 0
while i < len(labelled_spectra):
    
    # Show the spectrum to be classified along with description
    plt.figure(figsize=(15, 10))
    plt.plot(labelled_spectra[i])
    plt.show()
    print("i = {}".format(i))
    print("absorption --- both --- emission / skip")
    print("       0    1    2    3    4         5 ")
    
    # Ask for user's classification
    classification = input('Type [0-4]:')
    
    try:  # Must be an integer
        classification_int = int(classification)
    except ValueError:
        classification_int = -1  # Try current spectrum again
    
    if classification == 'back':
        i -= 1  # Go back to the previous spectrum
    elif classification == 'exit':
        break  # Exit the loop, saving labels that were given
    elif 0 <= classification_int <= 4:  # Valid classification
        labels[i] = int(classification)  # Assign the classification to the spectrum
        i += 1  # Move on to the next spectrum
    elif classification_int == 5:
        i += 1  # Skip and move on to the next spectrum
    else:  # Invalid integer classification
        i += 0  # Try current spectrum again

In [None]:
# Plot bar chart of classification populations
unique, counts = np.unique(labels, return_counts=True)
plt.figure()
plt.bar(unique, counts)
plt.title('Number of spectra in each classification')
plt.xlabel('Classification')
plt.ylabel('N_spectra')
plt.show()

In [None]:
# Overplot the spectra of each classification
for classification in unique:
    plt.figure()
    for spectrum in labelled_spectra[labels == classification]:
        plt.plot(model.constant_wavelengths, spectrum)
    plt.title('Classification {}'.format(classification))
    plt.yticks([0, 1])
    plt.show()

In [None]:
# Save the labelled spectra for use later
np.save('labelled_data.npy', labelled_spectra)
np.save('labels.npy', labels)