# Classification Task

In this notebook we aim to train, validate and test a neural net to classify between:
* one galaxy behind one foreground galaxy
* two galaxies at the same redshift and with small angular separation (1 to 4 arcseconds)

These datasets were created in notebook `1_create_simulations`.

This notebook also incorporates the critical diagnostics described in [Training a Machine Learning Model](https://docs.google.com/document/d/1U17RNPqDA5uP9-M5V3ENweKig5eRZWHTpja-hh1z1MI/edit).

#### Index<a name="index"></a>
1. [Import Packages](#imports)
2. [Load data](#load)
3. [Process data](#process)
    1. [Ingest the data](#ingest)
    2. [Get dataloader](#dataloader)
4. [Train network](#train)
    1. [Save performance](#savePerformance)
    2. [Predict class](#predict)
5. [Results](#results)
    1. [Confusion matrix](#cm)
    2. [Learning curve](#learningCurve)


## 1. Import Packages<a name="imports"></a>

In [None]:
import sys
import time
import glob

In [None]:
import numpy as np
import pandas as pd
import torch
import process
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
sys.path.append('../Network')
import data_utils
import networks
import training
import save
import predict

In [None]:
from sklearn.metrics import confusion_matrix

### Aestetic settings

In [None]:
%matplotlib inline
%config Completer.use_jedi = False  # enable autocomplete

size_default = 1.5
size_larger = 1.9
sns.set(font_scale=size_default, style="ticks")

## Notebook Options

In [None]:
is_processed = 1  # Set to 0 if the data has not yet been processed
is_trained = 0  # Set to 0 if the network has not yet been trained

# Select the network type
#net_type = 'CNN'
# net_type = 'RNN'
net_type = 'ZIPPER'

## 2. Load data<a name="load"></a>

In [None]:
dataset_name = 'high_cad_1_2_data'
directory = dataset_name

[Go back to top.](#index)

### 2.1. Process data<a name="process"></a>

Following [DeepZipper paper](https://arxiv.org/pdf/2112.01541.pdf), we condense the image information to single-image input by averaging all images in the time series on a pixel-by-pixel basis within each band.

This also augments the training set by rotating and morroring the images.

In [None]:
if not is_processed:
    configurations = [x.split("/")[1].split("_images.")[0] for x in glob.glob(f"{directory}/*_images.npy")]
    
    ini_time = time.time()
    for configuration in sorted(configurations):
        process.run(directory, configuration)
    print(time.time() - ini_time)

### 2.2. Original data<a name="loadOri"></a>

Configuration 1:

In [None]:
images_config1_ori = np.load(dataset_name+'/CONFIGURATION_1_images.npy')
metadata_config1_ori = pd.read_csv(dataset_name+'/CONFIGURATION_1_metadata.csv', sep=',')

In [None]:
print(np.shape(images_config1_ori))
print(np.shape(metadata_config1_ori))
print('# events: ', len(np.unique(metadata_config1_ori['OBJID-g'])))

Configuration 2:

In [None]:
images_config2_ori = np.load(dataset_name+'/CONFIGURATION_2_images.npy')
metadata_config2_ori = pd.read_csv(dataset_name+'/CONFIGURATION_2_metadata.csv', sep=',')

In [None]:
print(np.shape(images_config2_ori))
print(np.shape(metadata_config2_ori))
print('# events: ', len(np.unique(metadata_config2_ori['OBJID-g'])))

### 2.3. Processed data<a name="loadAug"></a>

Configuration 1:

In [None]:
images_config1 = np.load(dataset_name+'/CONFIGURATION_1_proc_ims_15.npy')
metadata_config1 = np.load(dataset_name+'/CONFIGURATION_1_proc_mds_15.npy', 
                           allow_pickle=True).item()

In [None]:
print(np.shape(images_config1))
print(np.shape(metadata_config1[0]))
print('# events: ', len(metadata_config1.keys()))

Configuration 2:

In [None]:
images_config2 = np.load(dataset_name+'/CONFIGURATION_2_proc_ims_15.npy')
metadata_config2 = np.load(dataset_name+'/CONFIGURATION_2_proc_mds_15.npy', 
                           allow_pickle=True).item()

In [None]:
print(np.shape(images_config2))
print(np.shape(metadata_config2[0]))
print('# events: ', len(metadata_config2.keys()))

In [None]:
# Show rotation of the images
# from deeplenstronomy.visualize import view_image, view_image_rgb
# view_image_rgb(images_config2[1], Q=10, stretch=1)
# view_image_rgb(images_config2[138], Q=10, stretch=1)
# view_image_rgb(images_config2[275], Q=10, stretch=1)

## 3. Prepare data for network<a name="prepare"></a>

### 3.1. Train/Test set split<a name="split"></a>

In [None]:
train_dataset, test_dataset = data_utils.make_train_test_datasets(
    directory=dataset_name, class_names=['CONFIGURATION_1_proc', 'CONFIGURATION_2_proc'], 
    suffix='15', label_map={})

In [None]:
np.shape(train_dataset.images) # I thought we added all the images? Why still many filters??

### 3.2. Get dataloader<a name="dataloader"></a>

In [None]:
train_dataloader = data_utils.make_dataloader(train_dataset)

### 3.3. Data Histogram<a name="dataHist"></a>

In [None]:
dict_label_to_real = {0: 'gal-gal', 1: '2 gals'}

In [None]:
bins = np.arange(-.5, 1.5, 0.25)
sns.histplot(data=train_dataset.labels-.25, kde=False, stat='density', 
             bins=bins,
             color='C0', label='Train set', linewidth=3, fill=False)
sns.histplot(data=test_dataset.labels, kde=False, stat='density', 
             bins=bins, 
             color='C1', label='Test set', linewidth=3, fill=False)
plt.xticks(ticks=[0, 1], labels=[dict_label_to_real[0], dict_label_to_real[1]])
plt.xlim(-.5, 1.5)
plt.legend()

[Go back to top.](#index)

## 4. Train network<a name="train"></a>

In [None]:
network_types = {'ZIPPER': networks.ZipperNN(4, 4, 4),
                 'CNN': networks.CNN_single(4, 2),
                 'RNN': networks.RNN_single(4, 3)}

network = network_types[net_type]

In [None]:
if not is_trained:
    ini_time = time.time()
    if net_type == 'ZIPPER':
        network = training.train_zipper(network, train_dataloader,
                                        train_dataset, test_dataset,
                                        monitor=True,
                                        outfile_prefix=f"{dataset_name}/{dataset_name}_{net_type}")
    else:
        if net_type == 'CNN':
            datatype = 'image'
        elif net_type == 'RNN':
            datatype = 'lightcurve'
        else:
            raise ValueError('`net_type` {net_type} not recognised')
        network = training.train_single(network, train_dataloader, 
                                    train_dataset, test_dataset,
                                    datatype, monitor=True,
                                    outfile_prefix=f"{dataset_name}/{dataset_name}_{net_type}")
    print(time.time() - ini_time)

[Go back to top.](#index)

### 4.1. Save performance<a name="savePerformance"></a>

In [None]:
if not is_trained:
    print("Saving results")
    # Save the performance
    save.save_performance(dataset_name, dataset_name, net_type, network, test_dataset)
    save.save_performance(dataset_name, dataset_name, net_type, network, train_dataset, train=True)

[Go back to top.](#index)

### 4.2. Load network<a name="loadNetwork"></a>

If the network was previously ran, simply load it.

In [None]:
monitor_table = pd.read_csv(directory+f'/{directory}_{net_type}_monitoring.csv', sep=',')

In [None]:
if is_trained:
    network = network_types[net_type]
    network.load_state_dict(torch.load(directory+f'/{directory}_{net_type}_network.pt'))

In [None]:
network

In [None]:
network.state_dict()

[Go back to top.](#index)

### 4.3. Predict class<a name="predict"></a>

In [None]:
predictions, labels = predict.predict(network, test_dataset) 
accuracy = np.sum(predictions == labels) / len(labels)

[Go back to top.](#index)

## 5. Results<a name="results"></a>

### 5.1. Confusion matrix<a name="cm"></a>

In [None]:
def plot_confusion_matrix(y_true, y_pred, title=None, normalise=None,
                          dict_label_to_real=None, figsize=None, **kwargs):
    """Plot a confusion matrix.

    Uses the true and predicted class labels to compute a confusion matrix.
    This can be non-normalised, normalised by true class/row (the diagonals
    show the accuracy of each class), and by predicted class/column (the
    diagonals show the precision).
    
    This code is from snmachine: https://github.com/LSSTDESC/snmachine

    Parameters
    ----------
    y_true : 1D array-like
        Ground truth (correct) labels of shape (n_samples,).
    y_true : 1D array-like
        Predicted class labels of shape (n_samples,).
    title : {None, str}, optional
        Title of the plot.
    normalise : {None, str}, optional
       If `None`, use the absolute numbers in each matrix entry. If 'accuracy',
       normalise per true class. If 'precision', normalise per predicted class.
    dict_label_to_real : dict, optional
        Dictionary containing the class labels as key and its real name as
        values. E.g. for PLAsTiCC
        `dict_label_to_real = {42: 'SNII', 62: 'SNIbc', 90: 'SNIa'}`.
        If `None`, the default class labels are used.
    figsize : {None, tuple}
        If `None`, use the default `figsize` of the plot. Otherwise, create a
        figure with the given size.

    Returns
    -------
    cm : np.array
       The confusion matrix, as computed by `sklearn.metrics.confusion_matrix`.
    """
    # Make and normalise the confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    if normalise == 'accuracy':
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        kwargs = {'vmin': 0, 'vmax': 1}
        print("Confusion matrix normalised by true class.")
    elif normalise == 'precision':
        cm = cm.astype('float') / cm.sum(axis=0)[np.newaxis, :]
        kwargs = {'vmin': 0, 'vmax': 1}
        print("Confusion matrix normalised by predicted class.")
    else:
        print('Confusion matrix without normalisation')

    # Classes in the dataset
    target_names = np.unique(y_true)
    target_names_ori = np.copy(target_names)  # the labels might be strings
    if dict_label_to_real is not None:
        target_names = np.vectorize(dict_label_to_real.get)(target_names)
        if target_names[0] is None:  # fix the names being strings
            target_names = list(map(int, target_names_ori))

    # Plot the confusion matrix
    if figsize is not None:
        _, ax = plt.subplots(figsize=figsize)  # good values: (9, 7)
    else:
        _, ax = plt.subplots()
    sns.heatmap(cm, xticklabels=target_names,
                yticklabels=target_names, cmap='Blues',
                annot=True, fmt='.2f', lw=0.5,
                cbar_kws={'label': 'Fraction of events',
                          'shrink': .82}, **kwargs)
    ax.set_xlabel('Predicted class')
    ax.set_ylabel('True class')
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
    ax.set_aspect('equal')
    if title is not None:
        plt.title(title)

    return cm

In [None]:
plot_confusion_matrix(y_true=labels, y_pred=predictions, 
                      normalise='accuracy', 
                      title=net_type+f'\nAccuracy: {accuracy:.3f}', 
                      dict_label_to_real={0:'gal-gal', 1:'2 gals'})

[Go back to top.](#index)

### 5.2. Learning curve<a name="learningCurve"></a>

In [None]:
plt.plot(monitor_table['Loss'], linewidth=3, label='Loss')
plt.plot(monitor_table['Train Acc'], linewidth=3, label='Train Accuracy')
plt.plot(monitor_table['Test Acc'], linewidth=3, label='Test Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Score')
plt.legend()
plt.title('Learning curve')

[Go back to top.](#index)