# Train the StarNet Model

This notebook takes you through the steps of how to train a StarNet Model
- Required Python packages: `numpy h5py pytorch torchsummary`
- Required data files: training_data.h5, mean_and_std.npy

In [1]:
import os
import numpy as np
import h5py
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
from torch.utils.data import TensorDataset, DataLoader

from torchsummary import summary

datadir = ""
training_set = datadir + 'training_data.h5'
normalization_data = datadir + 'mean_and_std.npy'

**Normalization**

Write a function to normalize the output labels. Each label will be normalized to have approximately have a mean of zero and unit variance.

NOTE: This is necessary to put output labels on a similar scale in order for the model to train properly, this process is reversed in the test stage to give the output labels their proper units

In [2]:
mean_and_std = np.load(normalization_data)
mean_labels = mean_and_std[0]
std_labels = mean_and_std[1]

def normalize(labels):
    # Normalize labels
    return (labels-mean_labels) / std_labels

**Obtain training data**

Here we will collect the output labels for the training and validation sets, then normalize each.

In [3]:
# Define the number of output labels
num_labels = np.load(datadir+'mean_and_std.npy').shape[1]

# Define the number of training spectra
num_train = 41000

# Load spectra and labels
with  h5py.File(training_set, 'r') as data_F:
    x_train = data_F['spectrum'][0:num_train]
    x_val = data_F['spectrum'][num_train:]
    y_train = np.hstack((data_F['TEFF'][0:num_train], 
                         data_F['LOGG'][0:num_train], 
                         data_F['FE_H'][0:num_train]))
    y_val = np.hstack((data_F['TEFF'][num_train:], 
                      data_F['LOGG'][num_train:], 
                      data_F['FE_H'][num_train:]))

# Normalize labels
y_train = normalize(y_train)
y_val = normalize(y_val)

# Define the number of output labels
num_labels = y_train.shape[1]
num_fluxes = x_train.shape[1]

print('Each spectrum contains ' + str(num_fluxes) + ' wavelength bins')
print('Training set includes ' + str(x_train.shape[0]) + 
      ' spectra and the validation set includes ' + str(x_val.shape[0])+' spectra')

Each spectrum contains 7214 wavelength bins
Training set includes 41000 spectra and the validation set includes 3784 spectra


**Build the StarNet model architecture**

The StarNet architecture is built with:
- input layer
- 2 convolutional layers
- 1 maxpooling layer followed by flattening for the fully connected layer
- 2 fully connected layers
- output layer

In [4]:
# Number of filters used in the convolutional layers
num_filters = [4,16]

# Length of the filters in the convolutional layers
filter_length = 8

# Length of the maxpooling window 
pool_length = 4

# Number of nodes in each of the hidden fully connected layers
num_hidden = [256,128]

def compute_out_size(in_size, mod):
    """
    Compute output size of Module `mod` given an input with size `in_size`.
    """
    
    f = mod.forward(autograd.Variable(torch.Tensor(1, *in_size)))
    return f.size()[1:]

class StarNet(nn.Module):
    def __init__(self, num_fluxes, num_filters, filter_length, 
                 pool_length, num_hidden, num_labels):
        super().__init__()
        
        # Convolutional and pooling layers
        self.conv1 = nn.Conv1d(1, num_filters[0], filter_length)
        self.conv2 = nn.Conv1d(num_filters[0], num_filters[1], filter_length)
        self.pool = nn.MaxPool1d(pool_length, pool_length)
        
        # Determine shape after pooling
        pool_output_shape = compute_out_size((1,num_fluxes), 
                                             nn.Sequential(self.conv1, 
                                                           self.conv2, 
                                                           self.pool))
        
        # Fully connected layers
        self.fc1 = nn.Linear(pool_output_shape[0]*pool_output_shape[1], num_hidden[0])
        self.fc2 = nn.Linear(num_hidden[0], num_hidden[1])
        self.output = nn.Linear(num_hidden[1], num_labels)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.output(x)
        return x


model = StarNet(num_fluxes, num_filters, filter_length, 
          pool_length, num_hidden, num_labels)

summary(model, (1, num_fluxes))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv1d-1              [-1, 4, 7207]              36
            Conv1d-2             [-1, 16, 7200]             528
         MaxPool1d-3             [-1, 16, 1800]               0
            Linear-4                  [-1, 256]       7,373,056
            Linear-5                  [-1, 128]          32,896
            Linear-6                    [-1, 3]             387
Total params: 7,406,903
Trainable params: 7,406,903
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.03
Forward/backward pass size (MB): 1.32
Params size (MB): 28.26
Estimated Total Size (MB): 29.60
----------------------------------------------------------------


**More model techniques**
* The `Adam` optimizer is the gradient descent algorithm used for minimizing the loss function

In [5]:
# number of spectra fed into model at once during training
batch_size = 32

# number of epochs
num_epochs = 15

# initial learning rate for optimization algorithm
learning_rate = 0.0007

# Construct optimizer
optimizer = torch.optim.Adam(model.parameters(), learning_rate,
                             weight_decay=0)

In [6]:
train_dataset = TensorDataset(torch.Tensor(x_train),torch.Tensor(y_train))
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

val_dataset = TensorDataset(torch.Tensor(x_val),torch.Tensor(y_val))
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)

**Train model**

In [7]:
print_iters = 200
train_losses = []
val_losses = []
# loop over the dataset multiple times
for epoch in range(num_epochs):

    running_loss = 0.0
    for i, data in enumerate(train_dataloader, 0):
        # Collect batch data
        x_batch, y_batch = data

        # Zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        y_pred = model(x_batch.unsqueeze(1))
        loss = torch.nn.MSELoss()(y_pred, y_batch)
        loss.backward()
        optimizer.step()

        # Print statistics
        running_loss += loss.item()
        if (i+1) % print_iters == 0:
            print('[Epoch %i, %0.0f%%] Train Loss: %0.4f' % (epoch+1, 
                                                       (i+1)/len(train_dataloader)*100, 
                                                       running_loss/(i+1)), end="\r")
    train_loss = running_loss/len(train_dataloader)
    running_loss = 0.0
    with torch.no_grad():
        for i, data in enumerate(val_dataloader, 0):
            # Collect batch data
            x_batch, y_batch = data
            y_pred = model(x_batch.unsqueeze(1))
            loss = torch.nn.MSELoss()(y_pred, y_batch)
            running_loss += loss.item()
    val_loss = running_loss/len(val_dataloader)
    print('[Epoch %i] Train Loss: %0.4f, Val Loss: %0.4f' % (epoch+1,  
                                                             train_loss, 
                                                             val_loss))
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
print('Finished Training')

[Epoch 1] Train Loss: 0.0635, Val Loss: 0.0178
[Epoch 2] Train Loss: 0.0078, Val Loss: 0.0103
[Epoch 3] Train Loss: 0.0053, Val Loss: 0.0060
[Epoch 4] Train Loss: 0.0047, Val Loss: 0.0067
[Epoch 5] Train Loss: 0.0042, Val Loss: 0.0036
[Epoch 6] Train Loss: 0.0044, Val Loss: 0.0036
[Epoch 7] Train Loss: 0.0035, Val Loss: 0.0038
[Epoch 8] Train Loss: 0.0035, Val Loss: 0.0044
[Epoch 9] Train Loss: 0.0032, Val Loss: 0.0030
[Epoch 10] Train Loss: 0.0031, Val Loss: 0.0056
[Epoch 11] Train Loss: 0.0030, Val Loss: 0.0027
[Epoch 12] Train Loss: 0.0028, Val Loss: 0.0027
[Epoch 13] Train Loss: 0.0026, Val Loss: 0.0026
[Epoch 14] Train Loss: 0.0029, Val Loss: 0.0026
[Epoch 15] Train Loss: 0.0038, Val Loss: 0.0057
Finished Training


**Save model**

In [9]:
model_filename =  os.path.join(datadir,'starnet_cnn.pth.tar')
torch.save({'optimizer' : optimizer.state_dict(),
            'model' : model.state_dict()}, 
           model_filename)