# The Cannon Workshop
#### Import important python packages

In [29]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from astropy.io import fits
from astropy.table import Table

from IPython.display import clear_output

import thecannon as tc

#### Useful functions we'll use throughout the workshop

In [449]:
def convert_labels_into_training_set_labels(list_of_all_labels, labels):
    """
    Create a training set labels table with the specified labels.

    Args:
        list_of_all_labels (list or pandas.DataFrame): A list or pandas DataFrame containing all possible labels.
        labels (list): A list of labels (str) to include in the training set.

    Returns:
        training_set_labels (astropy.Table): An astropy Table containing only the specified labels.
    """
    trainingset = Table.from_pandas(list_of_all_labels)
    
    training_set_labels = Table()
    for label in labels:
        training_set_labels[label] = trainingset[label]
        
    return training_set_labels

def plot_1to1(test_labels,original_labels,model_labels):
    """
    Create a 1-to-1 scatter plot comparing predicted and expected labels.

    Args:
        test_labels (numpy.ndarray): Array of predicted labels.
        original_labels (pandas.DataFrame): DataFrame of original labels.
        model_labels (list): List of label names to be plotted.

    Returns:
        None - displays a plot
    """
    ncols = 2
    nrows = int(np.ceil(len(model_labels)/2))
    
    fig, ax = plt.subplots(nrows,ncols,figsize=(5*ncols,5*nrows))
    
    for idx,axes in enumerate(ax.flatten()):
        if idx==len(model_labels):
            break
        else:
            axes.scatter(original_labels[model_labels[idx]],test_labels[:,idx],s=10)
            buffer = (max(test_labels[:,idx])-min(test_labels[:,idx]))/3
            axes.set_xlim(min(test_labels[:,idx])-buffer,max(test_labels[:,idx])+buffer)
            axes.set_ylim(min(test_labels[:,idx])-buffer,max(test_labels[:,idx])+buffer)
            axes.plot([0, 1], [0, 1], transform=axes.transAxes,color='k',ls='dashed',lw=1)
            axes.set_xlabel(f"Expected {model_labels[idx]}",fontsize=12)
            axes.set_ylabel(f"Cannon Predicted {model_labels[idx]}",fontsize=12)
            
def plot_coefficient(term,model,model_labels,linelist=line_list,xlim=None):
    """
    Plot the coefficient of a specific term in the model.

    Args:
        term (str): The term for which the coefficient will be plotted.
        model (object): The Cannon model containing the coefficient values.
        model_labels (list): List of label names in the model.
        linelist (pandas.DataFrame, optional): DataFrame containing a list of wavelengths and species names.
                                               Defaults to line_list.
        xlim (tuple, optional): The limits of the x-axis. Defaults to None.

    Returns:
        None - displays a plot
    """
    theta_idx = model_labels.index(term)+1
    
    fig, ax = plt.subplots(2,1,figsize=(15,7),sharex=True)
    plt.subplots_adjust(hspace=0.05)
    
    ax[0].plot(model.dispersion,model.theta[:,0])
    ax[0].set_ylabel('Stellar Flux',fontsize=12)
    ax[1].plot(model.dispersion,model.theta[:,theta_idx],label=term)
    ax[1].set_ylabel(term+' coefficient',fontsize=12)
    ax[1].set_xlabel(r"Wavelength [$\AA$]",fontsize=12)
    ax[0].set_ylim(0,1.3)
    
    if xlim==None:
        subset = linelist[(linelist['Wavelength']>=model.dispersion.min()) & (linelist['Wavelength']<=model.dispersion.max())]
        xlim = ax[0].get_xlim()
    else:
        subset = linelist[(linelist['Wavelength']>=min(xlim)) & (linelist['Wavelength']<=max(xlim))]
        ax[0].set_xlim(xlim)
        
    yshift = (max(xlim)-min(xlim))/200
    for i in subset.index:
        ax[0].arrow(x=subset['Wavelength'][i],y=1.1,dx=0,dy=-0.08)
        ax[0].text(x=subset['Wavelength'][i]-yshift,y=1.14,s=subset['Species'][i].split()[0],rotation=90)
        
def compare_models_theta(model1,model2,term,model_labels):
    """
    Compare the coefficient of a specific term between two models.

    Args:
        model1 (object): The first model object.
        model2 (object): The second model object.
        term (str): The term for which the coefficient will be compared.
        model_labels (list): List of label names in the models.

    Returns:
        None - displays a plot
    """
    fig, ax = plt.subplots(1,1,figsize=(15,5))

    theta_idx = model_labels.index(term)+1

    ax.plot(model1.dispersion,model1.theta[:,theta_idx],label='model 1',alpha=1)
    ax.plot(model2.dispersion,model2.theta[:,theta_idx],label='model 2',alpha=0.7)
    ax.legend()

    ax.set_xlabel("Wavelength [$\AA$]",fontsize=12)
    ax.set_ylabel(f"{term} coefficient",fontsize=12)
    
def compare_models_1to1(test_labels1,test_set1,test_labels2,test_set2,label,model_labels):
    fig, ax = plt.subplots(1,2,figsize=(12,5))
    
    idx = model_labels.index(label)
    
    ax[0].scatter(test_set1[model_labels[idx]],test_labels1[:,idx],s=10)
    buffer = (max(test_labels1[:,idx])-min(test_labels1[:,idx]))/3
    ax[0].set_xlim(min(test_labels1[:,idx])-buffer,max(test_labels1[:,idx])+buffer)
    ax[0].set_ylim(min(test_labels1[:,idx])-buffer,max(test_labels1[:,idx])+buffer)
    ax[0].plot([0, 1], [0, 1], transform=ax[0].transAxes,color='k',ls='dashed',lw=1)
    ax[0].set_xlabel(f"Expected {label}",fontsize=12)
    ax[0].set_ylabel(f"Cannon Predicted {label}",fontsize=12)
    ax[0].set_title('Model 1')
    
    ax[1].scatter(test_set2[model_labels[idx]],test_labels2[:,idx],s=10)
    buffer = (max(test_labels2[:,idx])-min(test_labels2[:,idx]))/3
    ax[1].set_xlim(min(test_labels2[:,idx])-buffer,max(test_labels2[:,idx])+buffer)
    ax[1].set_ylim(min(test_labels2[:,idx])-buffer,max(test_labels2[:,idx])+buffer)
    ax[1].plot([0, 1], [0, 1], transform=ax[1].transAxes,color='k',ls='dashed',lw=1)
    ax[1].set_xlabel(f"Expected {label}",fontsize=12)
    ax[1].set_ylabel(f"Cannon Predicted {label}",fontsize=12)
    ax[1].set_title("Model 2")

#### Load data necessary for The Cannon

In [416]:
# This loads in the labels we'll use to train The Cannon
all_labels = pd.read_csv('./labels.csv')

# This loads in the data we're going to use
flux_all = []
errs_all = []

for i in all_labels['sobject_id'].values:
    with fits.open(f"./data/{str(i)}.fits") as data:
        dispersion = data[1].data
        flux_all.append(data[0].data)
        errs_all.append(data[2].data)
    
flux_all = np.array(flux_all)
errs_all = np.array(errs_all)
ivar_all = 1/errs_all**2

# This is a list of atomic lines identified in the spectra we're analysing
line_list = pd.read_csv('./lines.csv')

## Separating the training objects from the test objects
Here we're going to separate the total data set into a training set of size $N$ and test set

In [417]:
# Generate a list of 100 random sobject_id entries with no repeats



# Separate the all_labels DataFrame into training and test sets based on the above list



# Separate training/test flux and ivars




## Setting up The Cannon model
What we need:
* training set labels
* training set flux
* training set invariances
* a vectoriser to declare the polynomial function used to predict each flux pixel (and the subsequent order of that polynomial)
* the dispersion (optional)

In [457]:
# Let's set up The Cannon model
# Declare the order of your polynomial


# Which labels are we going to use in our model?


# Set up a model file name based on those labels


# Set up the atraining set labels Table


# Declare The Cannon model



## Training The Cannon model

In [455]:
# Write some code that will train The Cannon model we set up



## Testing The Cannon model

In [456]:
# Write some code that will perform the test step of The Cannon



## Let's look at the results!
### Label transfer
Let's see how well the test step went!

In [454]:
# Use the plot_1to1 function here



### Coefficients of the model
Let's have a closer look at the model coefficients...

In [453]:
# Use the plot_coefficients function here



## Let's investigate some more
(If there's time)

### What happens when we have models with different training set sizes?

In [420]:
# Copy and paste the cell where we separated training and sets sets
# (but be sure to change the variable names!)



In [421]:
# Copy and paste the cell where we set up The Cannon model
# (be sure to change the variable names!)



In [422]:
# Copy and paste training The Cannon cell
# (be sure to change the variable names!)



In [None]:
# Copy and paste the testing cell
# (be sure to change the variable names!)



### Let's compare the coefficients

### Let's compare the label transfer