In [None]:
import torch
import torch.nn as nn
from torchvision import datasets
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn.functional as F
from collections import defaultdict
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import time
import numpy as np


device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

# Our proposed activation functions

In [None]:
class BAH(nn.Module):
    def __init__(self):
        super().__init__() 

    def forward(self, input):
        return torch.sign(input) * (1 -torch.exp(-torch.sign(input)*input))

class DReLU(nn.Module):
  def __init__(self):
      super().__init__() 
      self.zero = torch.zeros(1,device=device)

  def forward(self, input):
      return torch.where(input>0,torch.sin(input)+input,self.zero)

class DReLU_Parameter(nn.Module):
    def __init__(self):
        super(DReLU_Parameter,self).__init__()
        self.alpha = nn.Parameter(torch.rand(1).to(device))
        self.zero = torch.zeros(1,device=device)

    def forward(self, input):
        return torch.where(input>0, self.alpha*torch.sin(input)+input,self.zero)

## Proposed and benchmark activation functions list

In [None]:
activation_list = [DReLU_Parameter, nn.ReLU, DReLU, nn.Tanh, BAH, nn.Sigmoid, nn.SiLU]
EPOCHS = 2

## Activation Function runtime

To see the impact of different activation functions on runtime, we define a simple network and run a few epochs with different activation functions.

In [None]:
class SimpleNet(nn.Module):
    def __init__(self, activation):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(100, 1000)
        self.fc2 = nn.Linear(1000, 1)
        self.activation = activation()
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.activation(x)
        return self.fc2(x)

## Dummy dataset
X = torch.normal(0, 1, (1000, 100)).to(device)
t = torch.rand(1000).unsqueeze(1).to(device)

## Train function
def train(network, criterion, optimizer, epochs=1000):
    for epoch in range(epochs):
        y = model(X)
        loss = criterion(y, t)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# dict to keep track of run time
timings = defaultdict(list)

for activation in activation_list:
    # perform 10 trials with each activation function
    for repeat in range(10):
        model = SimpleNet(activation).to(device)
        criterion = nn.MSELoss()
        optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
        start = time.time()
        train(model, criterion, optimizer, 1000)
        end = time.time()
        timings[activation.__name__].append(end - start)

## print out the mean of 10 trials for each activation function
for key in timings:
    print(f"Network with {key} took {np.array(timings[key]).mean()}s")

Network with DReLU_Parameter took 1.0163837909698485s
Network with ReLU took 0.7716005802154541s
Network with DReLU took 0.8866305112838745s
Network with Tanh took 0.7778322219848632s
Network with BAH took 1.0490894794464112s
Network with Sigmoid took 0.7858211755752563s
Network with SiLU took 0.8022176265716553s


## Network Definition

In [None]:
# LeNet for MNIST
class LeNet(nn.Module):
    def __init__(self, activation):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.avg_pool2d = nn.AvgPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.activation1 = activation()
        self.activation2 = activation()
        self.activation3 = activation()
        self.activation4 = activation()
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.activation1(x)
        x = self.avg_pool2d(x)
        x = self.conv2(x)
        x = self.activation2(x)
        x = self.avg_pool2d(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.activation3(x)
        x = self.fc2(x)
        x = self.activation4(x)
        x = self.fc3(x)
        return x


cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

# For CIFAR10
# https://github.com/icpm/pytorch-cifar10
# We will only use VGG11 for our experiments
class VGG(nn.Module):
    def __init__(self, vgg_name, activation):
        super(VGG, self).__init__()
        self.activation = activation
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(512, 10)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.BatchNorm2d(x),
                           self.activation()]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)


def VGG11():
    return VGG('VGG11')


def VGG13():
    return VGG('VGG13')


def VGG16():
    return VGG('VGG16')


def VGG19():
    return VGG('VGG19')

## Create train, validation and test data loaders for MNIST and CIFAR10

In [None]:
## Higher batch size for faster training? Current train bs is 64 and test is 1000
TRAIN_BS = 64

MEAN_TRANSFORM = 0.1307
STD_DEV_TRANSFORM = 0.3081

mnist_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((MEAN_TRANSFORM,), (STD_DEV_TRANSFORM,))
])

mnist_train_dataset, mnist_validation_dataset = torch.utils.data.random_split(datasets.MNIST(root='./data', train=True, download=True, transform=mnist_transforms), [54000, 6000])
mnist_test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=mnist_transforms)

mnist_train_loader = DataLoader(dataset=mnist_train_dataset, batch_size=TRAIN_BS, shuffle=True)
mnist_validation_loader = DataLoader(dataset=mnist_validation_dataset, batch_size=1000, shuffle=True)
mnist_test_loader = DataLoader(dataset=mnist_test_dataset, batch_size=1000, shuffle=False)

MEAN_TRANSFORM = [0.4914, 0.4822, 0.4465]
STD_DEV_TRANSFORM = [0.2470, 0.2435, 0.2616]
normalize = transforms.Normalize((MEAN_TRANSFORM), (STD_DEV_TRANSFORM))
cifar_transform = transforms.Compose([transforms.ToTensor(), normalize])

cifar_train_dataset, cifar_validation_dataset = torch.utils.data.random_split(torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=cifar_transform), [45000, 5000])
cifar_test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=cifar_transform)

cifar_train_loader = torch.utils.data.DataLoader(dataset=cifar_train_dataset, batch_size=TRAIN_BS, shuffle=True)
cifar_validation_loader = torch.utils.data.DataLoader(dataset=cifar_validation_dataset, batch_size=1000, shuffle=True)
cifar_test_loader = torch.utils.data.DataLoader(dataset=cifar_test_dataset, batch_size=1000, shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=9912422.0), HTML(value='')))


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=28881.0), HTML(value='')))


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=1648877.0), HTML(value='')))


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=4542.0), HTML(value='')))


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=0.0, max=170498071.0), HTML(value='')))


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


## Model Training and metric collection

In [None]:
# return model loss and accuracy
def model_loss_accuracy(model, data_loader, criterion):
    # set eval mode
    model.eval()
    correct_predictions = 0
    total_loss = 0
    with torch.no_grad():
        for x,t in data_loader:
            x = x.to(device)
            t = t.to(device)
            y = model(x)
            loss = criterion(y, t)
            total_loss += loss.item()*len(t)
            correct_predictions += (y.argmax(dim=1) == t).sum().item()

    return total_loss/len(data_loader.dataset), correct_predictions/len(data_loader.dataset)

# tracks training and validation loss and accuracy. Saves the model with best
# validation accuracy
def train(net, epochs, optimizer, criterion, train_dl, validation_dl, net_save_path):
    metrics = defaultdict(list)

    best_validation_accuracy = 0

    net = net.to(device)
    for epoch in tqdm(range(epochs)):
        net.train()
        train_correct = 0
        train_loss_epoch = 0
        for x, t in train_dl:
            x = x.to(device)
            t = t.to(device)
            y = net(x)
            loss = criterion(y, t)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss_epoch += loss.item()*len(t)
            train_correct += (y.argmax(dim=1) == t).sum().item()
        
        training_loss, training_accuracy = train_loss_epoch/len(train_dl.dataset), train_correct/len(train_dl.dataset)
        validation_loss, validation_accuracy = model_loss_accuracy(net, validation_dl, criterion)

        if validation_accuracy >= best_validation_accuracy:
            best_validation_accuracy = validation_accuracy
            torch.save(net.state_dict(), net_save_path)
        
        print(f"Epoch: {epoch}, training loss: {training_loss:.5f}, training accuracy: {training_accuracy:.5f}, validation loss: {validation_loss:.5f}, validation accuracy: {validation_accuracy:.5f}")

        metrics["train_loss"].append(train_loss_epoch/len(train_dl.dataset))
        metrics["train_accuracy"].append(train_correct/len(train_dl.dataset))
        metrics["validation_loss"].append(validation_loss)
        metrics["validation_accuracy"].append(validation_accuracy)
    
    return metrics

## Experiments

In [None]:
# dictionary to store all metrics
results = defaultdict(dict)

In [None]:
learning_rate = 0.01
momentum = 0.5

root_lenet = "./saved_models/lenet/"

from pathlib import Path
Path(root_lenet).mkdir(parents=True, exist_ok=True)

def init_weights(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d:
        nn.init.xavier_uniform_(m.weight)

for activation in activation_list:
    act_func_name = activation.__name__
    model_save_path = root_lenet + act_func_name + ".pt"
    print("****************************************************************************", flush=True)
    print(f"For activation function {act_func_name}", flush=True)
    network = LeNet(activation)
    network.apply(init_weights)
    optimizer = optim.SGD(network.parameters(), lr=learning_rate,
                      momentum=momentum)
    
    results['lenet'][act_func_name] = train(network, EPOCHS, optimizer, F.cross_entropy, mnist_train_loader, mnist_validation_loader, model_save_path)

    # Load model for test loss and accuracy
    loaded_model = LeNet(activation)
    loaded_model.load_state_dict(torch.load(model_save_path))
    loaded_model = loaded_model.to(device)
    test_loss, test_accuracy = model_loss_accuracy(loaded_model, mnist_test_loader, F.cross_entropy)
    results['lenet'][act_func_name]['test_loss'] = test_loss
    results['lenet'][act_func_name]['test accuracy'] = test_accuracy

****************************************************************************
For activation function DReLU_Parameter


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))

Epoch: 0, training loss: 0.32594, training accuracy: 0.90750, validation loss: 0.14350, validation accuracy: 0.95933
Epoch: 1, training loss: 0.11544, training accuracy: 0.96548, validation loss: 0.11447, validation accuracy: 0.96450

****************************************************************************
For activation function ReLU


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))

Epoch: 0, training loss: 0.45516, training accuracy: 0.86167, validation loss: 0.21228, validation accuracy: 0.93967
Epoch: 1, training loss: 0.14703, training accuracy: 0.95596, validation loss: 0.13422, validation accuracy: 0.95783

****************************************************************************
For activation function DReLU


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))

Epoch: 0, training loss: 0.28310, training accuracy: 0.91880, validation loss: 0.13327, validation accuracy: 0.96283
Epoch: 1, training loss: 0.10856, training accuracy: 0.96748, validation loss: 0.10066, validation accuracy: 0.97167

****************************************************************************
For activation function Tanh


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))

Epoch: 0, training loss: 0.53165, training accuracy: 0.85885, validation loss: 0.28119, validation accuracy: 0.92017
Epoch: 1, training loss: 0.22456, training accuracy: 0.93480, validation loss: 0.18739, validation accuracy: 0.94817

****************************************************************************
For activation function BAH


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))

Epoch: 0, training loss: 0.69273, training accuracy: 0.82189, validation loss: 0.34055, validation accuracy: 0.90767
Epoch: 1, training loss: 0.28391, training accuracy: 0.91965, validation loss: 0.24664, validation accuracy: 0.92967

****************************************************************************
For activation function Sigmoid


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))

Epoch: 0, training loss: 2.30503, training accuracy: 0.10846, validation loss: 2.30192, validation accuracy: 0.10250
Epoch: 1, training loss: 2.30182, training accuracy: 0.11226, validation loss: 2.30312, validation accuracy: 0.10250

****************************************************************************
For activation function SiLU


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))

Epoch: 0, training loss: 0.46382, training accuracy: 0.85544, validation loss: 0.17413, validation accuracy: 0.94800
Epoch: 1, training loss: 0.15114, training accuracy: 0.95509, validation loss: 0.12836, validation accuracy: 0.96317



In [None]:
root_vgg = "./saved_models/vgg/"

from pathlib import Path
Path(root_vgg).mkdir(parents=True, exist_ok=True)

for activation in activation_list:
    act_func_name = activation.__name__
    model_save_path = root_vgg + act_func_name + ".pt"

    print("****************************************************************************", flush=True)
    print(f"For activation function {act_func_name}", flush=True)
    network = VGG('VGG11', activation)
    optimizer = optim.Adam(network.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss().to(device)
    results['vgg'][act_func_name] = train(network, EPOCHS, optimizer, F.cross_entropy, cifar_train_loader, cifar_validation_loader, model_save_path)

    # Load model for test loss and accuracy
    loaded_model = VGG('VGG11', activation)
    loaded_model.load_state_dict(torch.load(model_save_path))
    loaded_model = loaded_model.to(device)
    test_loss, test_accuracy = model_loss_accuracy(loaded_model, cifar_test_loader, F.cross_entropy)
    results['vgg'][act_func_name]['test_loss'] = test_loss
    results['vgg'][act_func_name]['test accuracy'] = test_accuracy

****************************************************************************
For activation function DReLU_Parameter


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))

Epoch: 0, training loss: 1.38205, training accuracy: 0.49482, validation loss: 1.19696, validation accuracy: 0.57540
Epoch: 1, training loss: 0.86949, training accuracy: 0.69378, validation loss: 1.05125, validation accuracy: 0.65640

****************************************************************************
For activation function ReLU


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))

Epoch: 0, training loss: 1.31582, training accuracy: 0.51213, validation loss: 1.08177, validation accuracy: 0.61960
Epoch: 1, training loss: 0.83479, training accuracy: 0.70669, validation loss: 0.85263, validation accuracy: 0.70280

****************************************************************************
For activation function DReLU


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))

Epoch: 0, training loss: 1.43549, training accuracy: 0.47836, validation loss: 1.27204, validation accuracy: 0.57500
Epoch: 1, training loss: 0.95652, training accuracy: 0.66273, validation loss: 1.00454, validation accuracy: 0.66820

****************************************************************************
For activation function Tanh


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))

Epoch: 0, training loss: 1.50948, training accuracy: 0.44444, validation loss: 1.32389, validation accuracy: 0.52480
Epoch: 1, training loss: 1.11296, training accuracy: 0.59878, validation loss: 1.36601, validation accuracy: 0.54980

****************************************************************************
For activation function BAH


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))

Epoch: 0, training loss: 1.48911, training accuracy: 0.44896, validation loss: 1.31457, validation accuracy: 0.53960
Epoch: 1, training loss: 1.09850, training accuracy: 0.60456, validation loss: 1.07862, validation accuracy: 0.62360

****************************************************************************
For activation function Sigmoid


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))

Epoch: 0, training loss: 2.00763, training accuracy: 0.20462, validation loss: 2.39243, validation accuracy: 0.18080
Epoch: 1, training loss: 1.49734, training accuracy: 0.42022, validation loss: 1.55739, validation accuracy: 0.39480

****************************************************************************
For activation function SiLU


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))

Epoch: 0, training loss: 1.30542, training accuracy: 0.52133, validation loss: 1.19702, validation accuracy: 0.58520
Epoch: 1, training loss: 0.83505, training accuracy: 0.70700, validation loss: 0.75379, validation accuracy: 0.74000

