# Playing with snake(s)

Taking our new activation function for a test-drive on some time-series data.  We'll be using a temperature time-series from an Irish weather station.  `maxtp` means "maximum temperature".

In [43]:
from snake.activations import Snake
import torch
import plotly.express as px
import pandas as pd
from datetime import datetime as dt

df = pd.read_csv("data/dly575.csv")

# Turn date column into an actual date, not a string
dtm = lambda x: dt.strptime(str(x), "%d-%b-%Y")
df.date = df.date.apply(lambda x: dtm(x))
df.set_index(['date'])
df.maxtp = pd.to_numeric(df['maxtp'],errors='coerce')

df["smooth_maxtp"] = df.maxtp.rolling(30, min_periods=1).mean()

df.head(n=5)

Unnamed: 0,date,ind,maxtp,ind.1,mintp,igmin,gmin,ind.2,rain,cbl,...,ind.5,hg,soil,pe,evap,smd_wd,smd_md,smd_pd,glorad,smooth_maxtp
0,2003-08-11,0,21.9,0,9.6,,,0,0.0,1017.5,...,0,12,,,,,,,,21.9
1,2003-08-12,0,23.5,0,8.6,,,0,0.0,1017.7,...,0,14,,,,,,,,22.7
2,2003-08-13,0,21.8,4,,,,8,,1017.9,...,0,14,,,,,,,,22.4
3,2003-08-14,0,21.0,0,7.7,,,0,0.0,1016.8,...,0,15,,,,,,,,22.05
4,2003-08-15,0,19.8,0,5.8,,,0,0.0,1015.2,...,0,10,,,,,,,,21.6


In [44]:
fig = px.line(df, x='date', y='maxtp', range_x=['2008-01-01','2010-12-31'])
fig.show()

## Time to model
We have our data and our activation function, time to build an MLP and try them out.

### Data and training
First load the data, and normalize it.  Then, we'll need a training loop.

In [125]:
from torch import nn
from torch import optim

def temp_loader(df_x, batch_size=32, target_col='maxtp', shuffle=True, infinite=True):
    
    if shuffle==True:
        df_x = df_x.sample(frac=1)
    
    # Turns our time series data frame into a series of batches using yield
    df_x = df_x[['date', target_col]].dropna(axis = 0)
    
    while(1):
        X = []
        Y = []
        for index, row in df_x.iterrows():
            
            # Embarassingly crude normalization
            date_as_float = (row['date'].timestamp() / 1000000.0 - 1000.0) / 500.0 - 0.5
            X.append(date_as_float)
            Y.append(row[target_col])
            if len(X) == batch_size:
                X = torch.tensor(X)
                Y = torch.tensor(Y).unsqueeze(dim=-1)
                
                yield X,Y
                X = []
                Y = []

        # Under-sized batch, to finish the epoch
        if len(X) != 0:
            X = torch.tensor(X)
            Y = torch.tensor(Y).unsqueeze(dim=-1)
            yield X,Y
            
        if infinite == False:
            return (None)
            
    return (None)
            
                

def test_model(mdl, test_loader):
    X,Y = next(test_loader)
    
    with torch.no_grad():
        mdl.eval()
        y_hat = mdl(X)
        model.train()
        
    mean_abs_error = (Y - y_hat).abs().mean()
    return (mean_abs_error)                


# helper function to train a model
def train_model(model, 
                train_df, 
                test_df, 
                samples_per_epoch, 
                epochs=2, 
                learning_rate=0.0001, 
                batch_size=128, 
                weight_decay=0.001,
                momentum=0.99):
    '''
    Function trains the model and prints out the training log.
    INPUT:
        model - initialized PyTorch model ready for training.
        trainloader - PyTorch dataloader for training data.
    '''
    #setup training
    test_loader = temp_loader(test_df, 1000, shuffle=False)

    #define loss function
    #criterion = nn.MSELoss()
    criterion = nn.SmoothL1Loss()
    
    #initialize optimizer
    #optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.99, weight_decay=0.001)

    #run training and print out the loss to make sure that we are actually fitting to the training set
    print(f'Training for {epochs} epochs...\n')
    best_mae = 1000.0
    for e in range(epochs):
        running_loss = 0
        smooth_loss = 0
        batch = 0
        
        
        # Note that we re-shuffle the data for each epoch
        train_loader = temp_loader(train_df, batch_size=batch_size, shuffle=True)
        
        for x, y in train_loader:
            y_hat = model(x)
            loss = criterion(y_hat, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            batch = batch + 1
            running_loss += loss.item()
            smooth_loss = smooth_loss * 0.9 + loss.item() * 0.1
            
            if batch * x.shape[0] > samples_per_epoch:
                
                with torch.no_grad():
                    model.eval()
                    
                    x_test, y_test = next(test_loader)
                    test_y_hat = model(x_test)
                    test_loss = criterion(test_y_hat, y_test)
                    
                    model.train()
                
                
                
                train_mean_abs_error = test_model(model, train_loader)
                test_mean_abs_error = test_model(model, test_loader)
                print(f"\rEpoch {e}, batch {batch}, training loss: {loss.item():0.3f}, test loss = {test_loss.item():0.3f}, test mae = {test_mean_abs_error:0.2f}, train mae = {train_mean_abs_error:0.2f}", end="")
                if test_mean_abs_error < best_mae:
                    torch.save(model.state_dict, "best_temperature.mdl")
                    best_mae = test_mean_abs_error

                break
                
    return (test_loss)

### Plots
Let's make it easy to compare model predictions against our data.

In [115]:
import plotly.graph_objects as go

def evaluate_model(mdl, test_data, batch_size=1000):    
    # Makes predictions for an entire dataframe
    test_loader = temp_loader(test_data, batch_size, shuffle=False, infinite=False)
    
    with torch.no_grad():
        mdl.eval()

    X = []
    Y = []
    y_hats = []
    more_data = True
    while more_data:
        try:
            x,y = next(test_loader)
            y_hat = mdl(x)

            X.append(x)
            Y.append(y)
            y_hats.append(y_hat)            
        except StopIteration:
            more_data = False            
            break


    X = torch.cat(X).detach()
    Y = torch.cat(Y).detach()
    y_hats = torch.cat(y_hats).detach()
    
    return (X, Y, y_hats)
    


    
def plot_model(mdl, df_x, test_offset = None, ground_truth = "maxtp"):    
    x,y,y_hats = evaluate_model(model, df_x, batch_size=512)

    fig = go.Figure()
    fig.add_trace(go.Scatter(x=df.date, y=df[ground_truth], name='Ground truth',line=dict(color='green', width=2)))
    
    if test_offset is not None:    
        fig.add_trace(go.Scatter(x=df[test_offset:].date, y=y_hats[test_offset:].squeeze(), name='Predicted (test)',line=dict(color='firebrick', width=2)))
    
    if test_offset is None:
        test_offset = -1
        
    fig.add_trace(go.Scatter(x=df[0:test_offset].date, y=y_hats[0:test_offset].squeeze(), name='Predicted (train)',line=dict(color='blue', width=2)))    

    fig.update_layout(title='Temperatures at Moore Park, Ireland',
                       xaxis_title='Time',
                       yaxis_title='Temperature (degrees C)')

    fig.show()

### Model definition and training
We'l create a simple multi-layer perceptron - a stack of fully-connected layers - using the snake activation function.  Note that the final layer has no activation function, because this is a regression problem.

In [108]:
# create class for basic fully-connected deep neural network
from torch import nn
class RegressSnake(nn.Module):
    '''
    Simple fully-connected network to try out periodic activations with snake
    '''
    def __init__(self, width=256, alpha=1):
        super().__init__()

        # initialize layers
        self.fc1 = nn.Linear(1, width)
        self.fc2 = nn.Linear(width, width)
        self.fc3 = nn.Linear(width, width)
        self.fc4 = nn.Linear(width, 1)

        # initialize Soft Exponential activation
        self.a1 = Snake(width, alpha)
        self.a2 = Snake(width, alpha)
        self.a3 = Snake(width, alpha)

    def forward(self, x):
        # make sure the input tensor is flattened
        x = x.view(x.shape[0], -1)

        # apply snake
        x = self.a1(self.fc1(x))
        x = self.a2(self.fc2(x))
        x = self.a3(self.fc3(x))
        x = self.fc4(x) # No activation here because we are doing regression

        return x
    
model = RegressSnake(width=512, alpha=20)
df_train = df[0:5000]
df_test = df[5000:]

train_model(model, df_train, df_test, epochs=1000, samples_per_epoch=5000, learning_rate=0.001, batch_size=512)
plot_model(model, df, test_offset = df_train.shape[0])

Training for 1000 epochs...

Epoch 999, batch 11, training loss: 1.898, test loss = 4.734, test mae = 6.44, train mae = 2.66Training complete


## Better hyperparameters

The paper we are following makes it clear that some hyperparameter tuning was required.  Let's try that!  We'll use Optuna, not their choice but possibly a better option.

In [None]:
import optuna
from functools import partial 

def objective(trial):
    params = {'learning_rate': trial.suggest_loguniform('learning_rate', 0.00001, 0.01),
              'batch_size': trial.suggest_int('batch_size', 4, 10, step=1),
              'width': trial.suggest_int('width', 4, 9, step=1),
#              'hidden_dims':  trial.suggest_int('hidden_dims', 2,10),
              'momentum':  trial.suggest_loguniform('momentum',0.9,0.999)
              }

    params['batch_size'] = 2 ** params['batch_size']
    params['width'] = 2 ** params['width']

    
    model = RegressSnake(width=params['width'], alpha=20)
    del params['width']
    
    params["epochs"] = 500
    params["samples_per_epoch"] = 5000
    return train_model(model, df_train, df_test, **params)

def optimise_params(trials):


    #train_model(model, df_train, df_test, epochs=1000, samples_per_epoch=5000, learning_rate=0.001, batch_size=512)
    p_objective = partial(objective)

    study = optuna.create_study(direction='minimize')
    study.optimize(p_objective, n_trials=10)

    # log metrics
    print('Best Validation Loss: {}'.format(study.best_value))
    print('Best Params: {}'.format(study.best_params))

    #neptune.log_metric('validation auc', study.best_value)
    #neptune.set_property('best_parameters', study.best_params)
    #pickle_and_send_artifact(study, 'study.pkl')

    #neptune.stop()
    return 

                                              
optimise_params(trials=10)

Training for 500 epochs...

Epoch 240, batch 11, training loss: 3.113, test loss = 3.707, test mae = 2.89, train mae = 3.9640

## Snake Example
What does snake look like at different values of alpha?

In [104]:
def create_snake_fig(alpha):
    x = torch.arange(-1, 1, 0.01)
    s = Snake(1,alpha)

    fig = go.Figure()
    fig.add_trace(go.Scatter(x=x, y=s(x).detach(), name='Snake(x)',line=dict(color='green', width=2)))
    fig.update_layout(title=f'Snake activation function, alpha = {alpha:.2f}',
                       xaxis_title='x',
                       yaxis_title=f'Snake(x) alpha={alpha:0.2f}')
    fig.update_yaxes(range=[-1, 1])
    
    return (fig)


f = create_snake_fig(alpha=10)
f.show()

## Wriggling snakes!
To get a better sense of what the range of alpha values does, let's make an animation.

In [113]:
import plotly
from subprocess import call
import json

alphas = torch.arange(0,25,0.1)
frame = 0
for a in alphas:
    f = create_snake_fig(a)
    file_name = f"snake_{frame:02d}.png"
    
    call(['orca', 'graph', json.dumps(f, cls=plotly.utils.PlotlyJSONEncoder), '-o', file_name])
    
    #f.write_image(file_name)
    frame += 1


Now, we can pull the frames together to make a video.

In [114]:
!ffmpeg -y -r 10 -i snake_%02d.png -vcodec  h264 -pix_fmt yuv420p  snake_alphas.mp4

ffmpeg version 3.4.6-0ubuntu0.18.04.1 Copyright (c) 2000-2019 the FFmpeg developers
  built with gcc 7 (Ubuntu 7.3.0-16ubuntu3)
  configuration: --prefix=/usr --extra-version=0ubuntu0.18.04.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --enable-gpl --disable-stripping --enable-avresample --enable-avisynth --enable-gnutls --enable-ladspa --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librubberband --enable-librsvg --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvorbis --enable-libvpx --enable-libwavpack --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --ena

In [101]:
model.a3.alpha

Parameter containing:
tensor([11.5899, 11.5801, 11.5819, 11.5873, 11.6058, 11.5841, 11.5832, 11.5897,
        11.5920, 11.5897, 11.5834, 11.5995, 11.5911, 11.5834, 11.5920, 11.6005,
        11.5871, 11.6194, 11.5947, 11.6064, 11.5915, 11.5919, 11.6004, 11.5615,
        11.5826, 11.5733, 11.5917, 11.5942, 11.5914, 11.5971, 11.5910, 11.5679,
        11.5828, 11.5869, 11.5682, 11.5842, 11.5916, 11.5933, 11.5821, 11.6103,
        11.5912, 11.5794, 11.5907, 11.5951, 11.5914, 11.5912, 11.5795, 11.5970,
        11.5783, 11.5886, 11.5861, 11.5810, 11.5892, 11.6563, 11.5873, 11.5892,
        11.5814, 11.5892, 11.5929, 11.5949, 11.5832, 11.5914, 11.5896, 11.5884,
        11.5888, 11.7966, 11.5912, 11.5880, 11.5907, 11.5917, 11.6023, 11.5919,
        11.5985, 11.5682, 11.5907, 11.5923, 11.5874, 11.5899, 11.5791, 11.5896,
        11.5845, 11.5883, 11.5835, 11.6096, 11.5848, 11.5850, 11.5889, 11.5856,
        11.5910, 11.5904, 11.5917, 11.5890, 11.5913, 11.5840, 11.5834, 11.5901,
        11.5764, 1

In [103]:
model.a2.alpha

Parameter containing:
tensor([11.5915, 11.6005, 11.5918, 11.5923, 11.5945, 11.5922, 11.5920, 11.0550,
        11.5914, 11.5915, 11.5920, 11.5915, 11.5921, 11.5918, 11.5920, 11.5920,
        11.5979, 11.5913, 11.5913, 11.5935, 11.5915, 11.5922, 11.5920, 11.5920,
        11.6013, 11.5938, 11.5912, 11.5918, 11.5920, 11.5917, 11.5924, 11.5956,
        11.5920, 11.5888, 11.6690, 11.5923, 11.5919, 11.5918, 11.5828, 11.4144,
        11.5910, 11.5901, 11.5933, 11.5882, 11.5924, 11.5916, 11.5913, 11.5266,
        11.5925, 11.5921, 11.5910, 11.5909, 11.5136, 11.5940, 11.5941, 11.5920,
        11.5921, 11.3803, 11.5917, 11.5793, 11.5914, 11.5911, 11.5918, 11.5922,
        11.5915, 11.5873, 11.5913, 11.5915, 11.5921, 11.5922, 11.5903, 11.5926,
        11.5915, 11.5915, 11.5913, 11.5922, 11.5917, 11.5915, 11.5918, 11.5929,
        11.3575, 11.4948, 11.5911, 11.4774, 11.6014, 11.5359, 11.5915, 11.5915,
        11.3488, 11.5919, 11.5942, 11.7541, 11.4122, 11.5924, 11.5918, 11.5920,
        11.5925, 1

In [102]:
model.a1.alpha

Parameter containing:
tensor([11.6353, 11.6013, 11.6180, 11.5909, 11.5980, 11.5931, 11.5924, 11.5910,
        11.5977, 11.5915, 11.5918, 11.5918, 11.6057, 11.6155, 11.6043, 11.5964,
        11.6326, 11.7047, 11.6121, 11.6140, 11.6112, 11.6081, 11.6604, 11.5974,
        11.6101, 11.6146, 11.5936, 11.6150, 11.5916, 11.6289, 11.5930, 11.6063,
        11.5905, 11.6180, 11.5925, 11.5913, 11.5973, 11.5946, 11.5918, 11.5944,
        11.6067, 11.5915, 11.5896, 11.5926, 11.6200, 11.5998, 11.6161, 11.6003,
        11.5988, 11.5961, 11.6173, 11.6143, 11.6024, 11.5928, 11.5914, 11.5909,
        11.6133, 11.6139, 11.6119, 11.5924, 11.6231, 11.6106, 11.5926, 11.6101,
        11.6097, 11.5904, 11.5935, 11.6073, 11.5930, 11.5990, 11.6767, 11.5988,
        11.5986, 11.5917, 11.6311, 11.5947, 11.5989, 11.5930, 11.5938, 11.6138,
        11.5934, 11.6042, 11.5999, 11.5989, 11.5918, 11.5909, 11.5936, 11.6248,
        11.5933, 11.6071, 11.6059, 11.6722, 11.5921, 11.6807, 11.6074, 11.5887,
        11.6104, 1