## Activation Function Research & Benchmarks

This notebook implements and compares:
- **Standard**: ReLU, LeakyReLU, PReLU,  
- **Squared-ReLU Variants**: Soft (SoftSqReLU), Exponential (ExpSqReLU)  
- **Advanced**: Adaptive Power Unit (APU), Neuron Resurrection Activation (NRA)  

**Datasets**: MNIST, EMNIST (letters), CIFAR-10  
**Architectures**: SimpleCNN, ResNet-18, Transformer, GAN  
**Metrics**: accuracy, loss, grad-norm, dead-neuron rate, GAN losses, runtime  
Export results for theoretical analysis and paper writing.

## Adaptive Power Unit (APU)

Block formula:

$$
p(z) = p_0 + w_p \,\tanh\bigl(w_z\,z + b_z\bigr),
\quad
f(z) = \mathrm{sign}(z)\,\lvert z\rvert^{p(z)}.
$$

- **Differentiability**  
  The function is smooth for all $z \neq 0$ (we clamp $\lvert z\rvert\ge\epsilon$ if you want to be extra safe near zero).

- **Lipschitz bound**  
  You can show
  $$
    f'(z)
    = p(z)\,\lvert z\rvert^{p(z)-1}
      \;+\;
      \frac{\partial p}{\partial z}\,\lvert z\rvert^{p(z)}\ln\lvert z\rvert,
  $$
  and then bound each term over your training range.


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import time
import os
from contextlib import redirect_stdout
from tqdm import tqdm

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


### 1. ACTIVATION FUNCTIONS

In [3]:
class ParametricReLU(nn.Module):
    def __init__(self, init_alpha=0.1, learnable=True):
        super().__init__()
        self.alpha = nn.Parameter(torch.tensor(init_alpha)) if learnable else init_alpha
        
    def forward(self, x):
        return torch.where(x > 0, x, self.alpha * x)

In [4]:
class SoftReLU(nn.Module):
    def __init__(self, beta=1.0):
        super().__init__()
        self.beta = beta
        
    def forward(self, x):
        return x * torch.sigmoid(self.beta * x)

In [5]:
class ExponentialSquaredReLU(nn.Module):
    def __init__(self, alpha=0.1, gamma=1.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        
    def forward(self, x):
        pos_part = torch.where(x > 0, x, torch.zeros_like(x))
        neg_part = torch.where(x < 0, self.alpha * (torch.exp(-self.gamma * x*x) - 1), torch.zeros_like(x))
        return pos_part + neg_part

In [6]:
class AdaptivePowerUnit(nn.Module):
    def __init__(self, p0=1.0):
        super().__init__()
        self.p0 = p0
        self.w_p = nn.Parameter(torch.randn(1) * 0.1)
        self.w_z = nn.Parameter(torch.randn(1) * 0.1)
        self.b_z = nn.Parameter(torch.zeros(1))
        
    def forward(self, z):
        p_z = self.p0 + self.w_p * torch.tanh(self.w_z * z + self.b_z)
        # small epsilon to avoid zero division
        eps = 1e-6
        return torch.sign(z) * torch.abs(z).clamp(min=eps)**p_z

In [7]:
# Dictionary of all activation functions for testing
activations = {
    'ReLU': nn.ReLU,
    'LeakyReLU': lambda: nn.LeakyReLU(0.01),
    'PReLU': lambda: ParametricReLU(0.1, True),
    'SoftReLU': lambda: SoftReLU(1.0),
    'ESReLU': lambda: ExponentialSquaredReLU(0.1, 1.0),
    'APU': AdaptivePowerUnit
}

### 2. MODELS

In [8]:
class SimpleCNN(nn.Module):
    def __init__(self, act_name, in_channels=1, num_classes=10):
        super().__init__()
        self.in_channels = in_channels
        self.conv1 = nn.Conv2d(in_channels, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        
        # Determine feature size based on input channels (MNIST vs CIFAR)
        feature_size = 9216 if in_channels == 1 else 12544
        
        self.fc1 = nn.Linear(feature_size, 128)
        self.fc2 = nn.Linear(128, num_classes)
        
        # instantiate activations per layer
        self.a1 = activations[act_name]() if act_name != 'NRA' else activations['NRA'](32)
        self.a2 = activations[act_name]() if act_name != 'NRA' else activations['NRA'](64)
        self.a3 = activations[act_name]() if act_name != 'NRA' else activations['NRA'](128)
        
    def forward(self, x):
        x = self.a1(self.conv1(x))
        x = self.a2(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.a3(self.fc1(x))
        return self.fc2(x)

In [9]:
# Basic ResNet block with configurable activation
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, act_name='ReLU'):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.act1 = activations[act_name]() if act_name != 'NRA' else activations['NRA'](out_channels)
        
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.act2 = activations[act_name]() if act_name != 'NRA' else activations['NRA'](out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.act1(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = self.act2(out)
        return out


class ResNet(nn.Module):
    def __init__(self, act_name, num_blocks=[2, 2, 2, 2], num_classes=10, in_channels=3):
        super().__init__()
        self.in_channels = in_channels
        self.in_planes = 64
        
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.act1 = activations[act_name]() if act_name != 'NRA' else activations['NRA'](64)
        
        self.layer1 = self._make_layer(64, num_blocks[0], stride=1, act_name=act_name)
        self.layer2 = self._make_layer(128, num_blocks[1], stride=2, act_name=act_name)
        self.layer3 = self._make_layer(256, num_blocks[2], stride=2, act_name=act_name)
        self.layer4 = self._make_layer(512, num_blocks[3], stride=2, act_name=act_name)
        
        self.linear = nn.Linear(512, num_classes)

    def _make_layer(self, planes, num_blocks, stride, act_name):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(ResidualBlock(self.in_planes, planes, stride, act_name))
            self.in_planes = planes
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.act1(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

In [10]:
# Simple Transformer Block with configurable activation
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, act_name='ReLU'):
        super().__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            activations[act_name]() if act_name != 'NRA' else activations['NRA'](ff_dim),
            nn.Linear(ff_dim, embed_dim)
        )
        
    def forward(self, x):
        # Self-attention with residual connection
        attn_output, _ = self.attention(x, x, x)
        x = self.norm1(x + attn_output)
        
        # Feed-forward with residual connection
        ff_output = self.ff(x)
        x = self.norm2(x + ff_output)
        
        return x


class SimpleTransformer(nn.Module):
    def __init__(self, act_name, input_dim=28*28, embed_dim=128, num_heads=4, 
                 ff_dim=512, num_blocks=2, num_classes=10):
        super().__init__()
        self.embedding = nn.Linear(input_dim, embed_dim)
        self.pos_encoding = nn.Parameter(torch.randn(1, embed_dim))
        
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, ff_dim, act_name) 
            for _ in range(num_blocks)
        ])
        
        self.classifier = nn.Linear(embed_dim, num_classes)
        
    def forward(self, x):
        batch_size = x.shape[0]
        
        # Flatten the input
        x = x.view(batch_size, -1)
        
        # Linear embedding + positional encoding
        x = self.embedding(x)
        x = x + self.pos_encoding
        
        # Add sequence dimension (batch, seq_len=1, embed_dim)
        x = x.unsqueeze(1)
        
        # Pass through transformer blocks
        for block in self.transformer_blocks:
            x = block(x)
        
        # Use the output of the last token for classification
        x = x.squeeze(1)
        
        # Classification layer
        return self.classifier(x)

In [11]:
# Simple GAN with configurable activation
class Generator(nn.Module):
    def __init__(self, latent_dim=100, act_name='ReLU', out_channels=1):
        super().__init__()
        self.latent_dim = latent_dim
        self.out_channels = out_channels
        
        self.model = nn.Sequential(
            # Input: latent_dim
            nn.Linear(latent_dim, 128),
            activations[act_name]() if act_name != 'NRA' else activations['NRA'](128),
            
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            activations[act_name]() if act_name != 'NRA' else activations['NRA'](256),
            
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            activations[act_name]() if act_name != 'NRA' else activations['NRA'](512),
            
            nn.Linear(512, 784 * out_channels),
            nn.Tanh()
        )
        
    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), self.out_channels, 28, 28)
        return img


class Discriminator(nn.Module):
    def __init__(self, act_name='ReLU', in_channels=1):
        super().__init__()
        
        self.model = nn.Sequential(
            # Input: in_channels x 28 x 28
            nn.Flatten(),
            nn.Linear(784 * in_channels, 512),
            activations[act_name]() if act_name != 'NRA' else activations['NRA'](512),
            
            nn.Linear(512, 256),
            activations[act_name]() if act_name != 'NRA' else activations['NRA'](256),
            
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        
    def forward(self, img):
        validity = self.model(img)
        return validity

### 3. TRAINING AND EVALUATION

In [12]:
def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss, correct, total, grad_norm = 0, 0, 0, 0
    
    for imgs, labels in tqdm(loader, desc="Training", leave=False):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        
        # Calculate gradient norm
        grad_norm_batch = sum(p.grad.norm(2).item() for p in model.parameters() if p.grad is not None)
        grad_norm += grad_norm_batch / len(list(model.parameters()))
        
        optimizer.step()
        
        total_loss += loss.item() * imgs.size(0)
        
        # Calculate accuracy
        preds = outputs.argmax(1)
        correct += (preds == labels).sum().item()
        total += imgs.size(0)
    
    return total_loss / total, correct / total * 100, grad_norm / len(loader)

In [13]:
def evaluate_model(model, loader, device, criterion=None):
    model.eval()
    correct, total = 0, 0
    test_loss = 0
    
    with torch.no_grad():
        for imgs, labels in tqdm(loader, desc="Evaluating", leave=False):
            imgs, labels = imgs.to(device), labels.to(device)
            
            outputs = model(imgs)
            preds = outputs.argmax(1)
            
            if criterion:
                loss = criterion(outputs, labels)
                test_loss += loss.item() * imgs.size(0)
                
            correct += (preds == labels).sum().item()
            total += imgs.size(0)
    
    if criterion:
        return correct / total * 100, test_loss / total
    return correct / total * 100

In [14]:
def train_gan_epoch(gen, disc, g_opt, d_opt, loader, latent_dim, device):
    gen.train()
    disc.train()
    
    d_losses, g_losses = [], []
    real_scores, fake_scores = [], []
    
    for real_imgs, _ in loader:
        batch_size = real_imgs.size(0)
        real_imgs = real_imgs.to(device)
        
        # Train Discriminator
        d_opt.zero_grad()
        
        # Real images
        real_validity = disc(real_imgs)
        real_score = real_validity.mean().item()
        real_scores.append(real_score)
        
        # Fake images
        z = torch.randn(batch_size, latent_dim).to(device)
        fake_imgs = gen(z)
        fake_validity = disc(fake_imgs.detach())
        fake_score = fake_validity.mean().item()
        fake_scores.append(fake_score)
        
        # Discriminator loss
        d_real_loss = F.binary_cross_entropy(real_validity, 
                                             torch.ones(batch_size, 1).to(device))
        d_fake_loss = F.binary_cross_entropy(fake_validity, 
                                             torch.zeros(batch_size, 1).to(device))
        d_loss = (d_real_loss + d_fake_loss) / 2
        d_losses.append(d_loss.item())
        
        d_loss.backward()
        d_opt.step()
        
        # Train Generator
        g_opt.zero_grad()
        
        # Generate new fake images
        z = torch.randn(batch_size, latent_dim).to(device)
        fake_imgs = gen(z)
        fake_validity = disc(fake_imgs)
        
        # Generator loss
        g_loss = F.binary_cross_entropy(fake_validity, torch.ones(batch_size, 1).to(device))
        g_losses.append(g_loss.item())
        
        g_loss.backward()
        g_opt.step()
    
    return {
        'd_loss': sum(d_losses) / len(d_losses),
        'g_loss': sum(g_losses) / len(g_losses),
        'real_score': sum(real_scores) / len(real_scores),
        'fake_score': sum(fake_scores) / len(fake_scores)
    }

### 4. DATASETS

In [15]:
def get_dataset_loaders(dataset_name='MNIST', batch_size=64):
    # Define a context manager to suppress stdout
    with open(os.devnull, 'w') as devnull:
        with redirect_stdout(devnull):
            if dataset_name == 'MNIST':
                transform = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,))
                ])
                train_ds = torchvision.datasets.MNIST('data/', train=True, download=True, transform=transform)
                test_ds = torchvision.datasets.MNIST('data/', train=False, download=True, transform=transform)
                in_channels, num_classes = 1, 10
                img_size = 28
                
            elif dataset_name == 'CIFAR10':
                transform_train = transforms.Compose([
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
                ])
                transform_test = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
                ])
                train_ds = torchvision.datasets.CIFAR10('data/', train=True, download=True, transform=transform_train)
                test_ds = torchvision.datasets.CIFAR10('data/', train=False, download=True, transform=transform_test)
                in_channels, num_classes = 3, 10
                img_size = 32
                
            elif dataset_name == 'EMNIST':
                transform = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,))
                ])
                train_ds = torchvision.datasets.EMNIST('data/', split='balanced', train=True, download=True, transform=transform)
                test_ds = torchvision.datasets.EMNIST('data/', split='balanced', train=False, download=True, transform=transform)
                in_channels, num_classes = 1, 47  # EMNIST balanced has 47 classes
                img_size = 28
                
            else:
                raise ValueError(f"Dataset {dataset_name} not implemented")
    
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_ds, batch_size=batch_size*2, shuffle=False)
    
    return train_loader, test_loader, in_channels, num_classes, img_size

### 5. BENCHMARK EXECUTION

In [16]:
def run_activation_benchmark(model_name, dataset_name, act_names, epochs=10, lr=0.001):
    print(f"\n===== Running {model_name} on {dataset_name} =====")
    
    # Get dataset with the updated function call to get img_size
    train_loader, test_loader, in_channels, num_classes, img_size = get_dataset_loaders(dataset_name)
    
    # Initialize results dictionary
    results = {name: {'train_loss': [], 'train_acc': [], 'test_acc': [], 'grad_norm': [], 'time': []}
              for name in act_names}
    
    for act_name in act_names:
        print(f"\n### Activation: {act_name}")
        
        # Initialize model based on model_name
        if model_name == 'SimpleCNN':
            model = SimpleCNN(act_name, in_channels=in_channels, num_classes=num_classes).to(device)
        elif model_name == 'ResNet':
            model = ResNet(act_name, num_blocks=[2, 2, 2, 2], num_classes=num_classes, in_channels=in_channels).to(device)
        elif model_name == 'Transformer':
            # Use the correct image size for calculating input dimensions
            input_dim = in_channels * img_size * img_size
            model = SimpleTransformer(act_name, input_dim=input_dim, num_classes=num_classes).to(device)
        else:
            raise ValueError(f"Model {model_name} not implemented")
        
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()
        
        for epoch in range(1, epochs+1):
            epoch_start_time = time.time()
            
            # Train for one epoch
            train_loss, train_acc, grad_norm = train_epoch(model, train_loader, optimizer, criterion, device)
            
            # Evaluate
            test_acc = evaluate_model(model, test_loader, device)
            
            # Record time
            epoch_time = time.time() - epoch_start_time
            
            # Save results
            res = results[act_name]
            res['train_loss'].append(train_loss)
            res['train_acc'].append(train_acc)
            res['test_acc'].append(test_acc)
            res['grad_norm'].append(grad_norm)
            res['time'].append(epoch_time)
            
            print(f"Epoch {epoch}/{epochs}: loss={train_loss:.4f}, train_acc={train_acc:.1f}%, "
                  f"test_acc={test_acc:.1f}%, grad_norm={grad_norm:.2f}, time={epoch_time:.2f}s")
    
    # Plot results
    plot_benchmark_results(results, model_name, dataset_name)
    
    return results

In [17]:
def run_gan_benchmark(dataset_name, act_names, epochs=10, lr=0.0002, latent_dim=100):
    print(f"\n===== Running GAN on {dataset_name} =====")
    
    # Get dataset
    train_loader, _, in_channels, _, _ = get_dataset_loaders(dataset_name)
    
    # Initialize results dictionary
    results = {name: {'d_loss': [], 'g_loss': [], 'real_score': [], 'fake_score': [], 'time': []}
              for name in act_names}
    
    for act_name in act_names:
        print(f"\n### Activation: {act_name}")
        
        # Initialize Generator and Discriminator
        generator = Generator(latent_dim=latent_dim, act_name=act_name, out_channels=in_channels).to(device)
        discriminator = Discriminator(act_name=act_name, in_channels=in_channels).to(device)
        
        # Optimizers
        g_optimizer = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
        d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
        
        for epoch in range(1, epochs+1):
            epoch_start_time = time.time()
            
            # Train GAN for one epoch
            gan_metrics = train_gan_epoch(generator, discriminator, g_optimizer, d_optimizer, 
                                         train_loader, latent_dim, device)
            
            # Record time
            epoch_time = time.time() - epoch_start_time
            
            # Save results
            res = results[act_name]
            for k, v in gan_metrics.items():
                res[k].append(v)
            res['time'].append(epoch_time)
            
            print(f"Epoch {epoch}/{epochs}: d_loss={gan_metrics['d_loss']:.4f}, "
                  f"g_loss={gan_metrics['g_loss']:.4f}, real_score={gan_metrics['real_score']:.4f}, "
                  f"fake_score={gan_metrics['fake_score']:.4f}, time={epoch_time:.2f}s")
            
            # Generate and save sample images every few epochs
            if epoch % 5 == 0:
                save_gan_samples(generator, act_name, epoch, latent_dim, in_channels, device)
    
    # Plot results
    plot_gan_results(results, dataset_name)
    
    return results

In [18]:
def save_gan_samples(generator, act_name, epoch, latent_dim, channels, device, n_samples=16):
    """Generate and save sample images from the generator"""
    generator.eval()
    with torch.no_grad():
        z = torch.randn(n_samples, latent_dim).to(device)
        gen_imgs = generator(z).cpu()
    
    # Convert to matplotlib-compatible format and rescale to [0,1]
    imgs = gen_imgs.detach().numpy()
    imgs = (imgs + 1) / 2.0  # Scale from [-1,1] to [0,1]
    
    # Create directory if it doesn't exist
    os.makedirs(f"gan_samples/{act_name}", exist_ok=True)
    
    # Plot and save
    fig, axs = plt.subplots(int(np.sqrt(n_samples)), int(np.sqrt(n_samples)))
    cnt = 0
    for i in range(int(np.sqrt(n_samples))):
        for j in range(int(np.sqrt(n_samples))):
            if channels == 1:
                axs[i, j].imshow(imgs[cnt, 0], cmap='gray')
            else:
                axs[i, j].imshow(np.transpose(imgs[cnt], (1, 2, 0)))
            axs[i, j].axis('off')
            cnt += 1
    
    fig.savefig(f"gan_samples/{act_name}/epoch_{epoch}.png")
    plt.close()

### 6. VISUALIZATION

In [19]:
def plot_benchmark_results(results, model_name, dataset_name):
    """Plot the results of a benchmark"""
    metrics = ['train_loss', 'train_acc', 'test_acc', 'grad_norm', 'time']
    
    plt.figure(figsize=(20, 15))
    
    for i, metric in enumerate(metrics, 1):
        plt.subplot(3, 2, i)
        for name, res in results.items():
            plt.plot(res[metric], label=name)
        plt.title(f"{model_name} - {dataset_name} - {metric.replace('_', ' ').title()}")
        plt.xlabel('Epoch')
        plt.ylabel(metric.replace('_', ' ').title())
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.legend()
    
    plt.tight_layout()
    os.makedirs("results", exist_ok=True)
    plt.savefig(f"results/{model_name}_{dataset_name}_comparison.png")
    plt.close()

In [20]:
def plot_gan_results(results, dataset_name):
    """Plot the results of a GAN benchmark"""
    metrics = ['d_loss', 'g_loss', 'real_score', 'fake_score', 'time']
    
    plt.figure(figsize=(20, 15))
    
    for i, metric in enumerate(metrics, 1):
        plt.subplot(3, 2, i)
        for name, res in results.items():
            plt.plot(res[metric], label=name)
        plt.title(f"GAN - {dataset_name} - {metric.replace('_', ' ').title()}")
        plt.xlabel('Epoch')
        plt.ylabel(metric.replace('_', ' ').title())
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.legend()
    
    plt.tight_layout()
    os.makedirs("results", exist_ok=True)
    plt.savefig(f"results/GAN_{dataset_name}_comparison.png")
    plt.close()

### 7. THEORETICAL ANALYSIS HELPER FUNCTIONS

In [21]:
def compute_lipschitz_estimate(activation, input_range=(-10, 10), num_samples=1000):
    """Empirically estimate the Lipschitz constant of an activation function"""
    x = torch.linspace(input_range[0], input_range[1], num_samples).to(device)
    y = activation(x)
    
    # Compute gradients using finite differences
    grad = (y[1:] - y[:-1]) / (x[1:] - x[:-1])
    
    # Lipschitz constant is the maximum absolute gradient
    return torch.max(torch.abs(grad)).item()

In [22]:
def check_differentiability(activation, input_range=(-10, 10), num_samples=1000):
    """Check if the activation function is differentiable at different points"""
    x = torch.linspace(input_range[0], input_range[1], num_samples, requires_grad=True).to(device)
    y = activation(x)
    
    try:
        # Compute gradients
        gradients = torch.autograd.grad(y.sum(), x, create_graph=True)[0]
        
        # Check if there are any NaN or Inf values
        non_differentiable_points = torch.isnan(gradients) | torch.isinf(gradients)
        
        if torch.any(non_differentiable_points):
            non_diff_inputs = x[non_differentiable_points]
            return False, non_diff_inputs.detach().cpu().numpy()
        else:
            return True, []
            
    except Exception as e:
        print(f"Error in differentiability check: {e}")
        return False, []

In [23]:
def analyze_gradient_flow(model, act_name, sample_data, device):
    """Analyze gradient flow through a model with a specific activation function"""
    model.train()
    
    # Get a batch of data
    inputs, targets = next(iter(sample_data))
    inputs, targets = inputs.to(device), targets.to(device)
    
    # Forward pass
    outputs = model(inputs)
    loss = F.cross_entropy(outputs, targets)
    
    # Backward pass
    loss.backward()
    
    # Collect gradients from different layers
    grad_stats = {}
    total_params = 0
    vanishing_grad_params = 0
    
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad_norm = param.grad.norm(2).item()
            grad_stats[name] = grad_norm
            
            total_params += param.numel()
            # Check for vanishing gradients (using a threshold)
            vanishing_grad_params += torch.sum(torch.abs(param.grad) < 1e-8).item()
    
    vanishing_grad_ratio = vanishing_grad_params / total_params
    
    # Calculate gradient statistics
    grad_norms = list(grad_stats.values())
    mean_grad = np.mean(grad_norms)
    std_grad = np.std(grad_norms)
    max_grad = np.max(grad_norms)
    min_grad = np.min(grad_norms)
    
    return {
        'activation': act_name,
        'mean_grad': mean_grad,
        'std_grad': std_grad,
        'max_grad': max_grad,
        'min_grad': min_grad,
        'vanishing_ratio': vanishing_grad_ratio,
        'detailed_grads': grad_stats
    }

In [24]:
def analyze_all_activations():
    """Perform theoretical analysis on all activation functions"""
    print("\n===== Theoretical Analysis of Activation Functions =====")
    
    results = {}
    input_range = (-5, 5)
    
    for name, act_fn in activations.items():
            
        print(f"\nAnalyzing {name}...")
        activation = act_fn().to(device)
        
        # 1. Lipschitz continuity
        lipschitz = compute_lipschitz_estimate(activation, input_range)
        
        # 2. Differentiability
        diff_status, non_diff_points = check_differentiability(activation, input_range)
        
        # 3. Plot activation function and its derivative
        x = torch.linspace(input_range[0], input_range[1], 1000, requires_grad=True).to(device)
        with torch.no_grad():
            y = activation(x)
        
        # Try to compute the derivative
        x_grad = torch.linspace(input_range[0], input_range[1], 1000, requires_grad=True).to(device)
        y_grad = activation(x_grad)
        try:
            dy_dx = torch.autograd.grad(y_grad.sum(), x_grad)[0]
            has_derivative = True
        except:
            has_derivative = False
            dy_dx = torch.zeros_like(x_grad)
        
        # Store results
        results[name] = {
            'lipschitz': lipschitz,
            'differentiable': diff_status,
            'non_diff_points': non_diff_points if not diff_status else [],
            'x': x.detach().cpu().numpy(),
            'y': y.detach().cpu().numpy(),
            'dy_dx': dy_dx.detach().cpu().numpy() if has_derivative else None
        }
        
        print(f"  - Lipschitz constant estimate: {lipschitz:.4f}")
        print(f"  - Differentiable: {diff_status}")
        if not diff_status and len(non_diff_points) > 0:
            print(f"  - Non-differentiable points: {non_diff_points}")
    
    # Create activation function comparison plots
    plot_activation_comparisons(results)
    
    return results

In [25]:
def plot_activation_comparisons(results):
    """Plot and compare all activation functions and their derivatives"""
    plt.figure(figsize=(20, 10))
    
    # Plot activation functions
    plt.subplot(1, 2, 1)
    for name, data in results.items():
        plt.plot(data['x'], data['y'], label=name)
    plt.title('Activation Functions')
    plt.xlabel('Input')
    plt.ylabel('Output')
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend()
    
    # Plot derivatives
    plt.subplot(1, 2, 2)
    for name, data in results.items():
        if data['dy_dx'] is not None:
            plt.plot(data['x'], data['dy_dx'], label=f"{name} derivative")
    plt.title('Derivatives of Activation Functions')
    plt.xlabel('Input')
    plt.ylabel('Derivative')
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend()
    
    plt.tight_layout()
    os.makedirs("results", exist_ok=True)
    plt.savefig("results/activation_function_comparison.png")
    plt.close()
    
    # Create individual plots for each activation
    for name, data in results.items():
        plt.figure(figsize=(12, 6))
        
        plt.subplot(1, 2, 1)
        plt.plot(data['x'], data['y'])
        plt.title(f'{name} Activation Function')
        plt.xlabel('Input')
        plt.ylabel('Output')
        plt.grid(True, linestyle='--', alpha=0.7)
        
        plt.subplot(1, 2, 2)
        if data['dy_dx'] is not None:
            plt.plot(data['x'], data['dy_dx'])
            plt.title(f'{name} Derivative')
            plt.xlabel('Input')
            plt.ylabel('Derivative')
            plt.grid(True, linestyle='--', alpha=0.7)
        else:
            plt.text(0.5, 0.5, "Derivative not available", 
                     horizontalalignment='center', verticalalignment='center')
            plt.axis('off')
        
        plt.tight_layout()
        plt.savefig(f"results/{name}_function_analysis.png")
        plt.close()

In [26]:
def run_gradient_flow_analysis(datasets, model_types, act_names):
    """Analyze gradient flow for different models, datasets, and activations"""
    print("\n===== Gradient Flow Analysis =====")
    
    results = {}
    
    for dataset_name in datasets:
        train_loader, _, in_channels, num_classes, img_size = get_dataset_loaders(dataset_name)
        
        for model_name in model_types:
            model_results = {}
            
            for act_name in act_names:
                print(f"\nAnalyzing gradient flow for {model_name} with {act_name} on {dataset_name}...")
                
                # Initialize model
                if model_name == 'SimpleCNN':
                    model = SimpleCNN(act_name, in_channels=in_channels, num_classes=num_classes).to(device)
                elif model_name == 'ResNet':
                    model = ResNet(act_name, num_blocks=[2, 2, 2, 2], num_classes=num_classes, in_channels=in_channels).to(device)
                elif model_name == 'Transformer':
                    input_dim = in_channels * img_size * img_size
                    model = SimpleTransformer(act_name, input_dim=input_dim, num_classes=num_classes).to(device)
                else:
                    continue
                
                # Analyze gradient flow
                grad_flow = analyze_gradient_flow(model, act_name, train_loader, device)
                model_results[act_name] = grad_flow
                
                print(f"  - Mean gradient: {grad_flow['mean_grad']:.6f}")
                print(f"  - Gradient std: {grad_flow['std_grad']:.6f}")
                print(f"  - Vanishing gradient ratio: {grad_flow['vanishing_ratio']:.2%}")
            
            results.setdefault(dataset_name, {})[model_name] = model_results
    
    # Plot gradient flow results
    plot_gradient_flow_results(results)
    
    return results

In [27]:
def plot_gradient_flow_results(results):
    """Plot gradient flow analysis results"""
    metrics = ['mean_grad', 'std_grad', 'max_grad', 'min_grad', 'vanishing_ratio']
    metric_titles = {
        'mean_grad': 'Mean Gradient Magnitude',
        'std_grad': 'Gradient Standard Deviation',
        'max_grad': 'Maximum Gradient Magnitude',
        'min_grad': 'Minimum Gradient Magnitude',
        'vanishing_ratio': 'Vanishing Gradient Ratio'
    }
    
    for dataset_name, dataset_results in results.items():
        for model_name, model_results in dataset_results.items():
            plt.figure(figsize=(20, 15))
            
            for i, metric in enumerate(metrics, 1):
                plt.subplot(3, 2, i)
                
                act_names = list(model_results.keys())
                values = [model_results[act]['vanishing_ratio']*100 if metric == 'vanishing_ratio' 
                          else model_results[act][metric] for act in act_names]
                
                plt.bar(act_names, values)
                plt.title(f"{metric_titles[metric]} - {model_name} on {dataset_name}")
                plt.ylabel(metric_titles[metric])
                plt.xticks(rotation=45)
                plt.grid(True, linestyle='--', alpha=0.7, axis='y')
                
                # Add value labels
                for j, v in enumerate(values):
                    plt.text(j, v + max(values)*0.01, f"{v:.4f}" if metric != 'vanishing_ratio' else f"{v:.1f}%", 
                             ha='center', va='bottom')
            
            plt.tight_layout()
            os.makedirs("results/gradient_flow", exist_ok=True)
            plt.savefig(f"results/gradient_flow/{dataset_name}_{model_name}_gradient_analysis.png")
            plt.close()

### 8. MAIN EXECUTION PIPELINE

In [28]:
def main():
    # Create directories for results
    os.makedirs("results", exist_ok=True)
    os.makedirs("gan_samples", exist_ok=True)
    
    # Activation functions
    act_names = ['ReLU', 'LeakyReLU', 'PReLU', 'SoftReLU', 'ESReLU', 'APU']
    
    # Perform theoretical analysis
    activation_analysis = analyze_all_activations()
    
    # Datasets to test
    datasets = ['MNIST', 'CIFAR10', 'EMNIST']
    
    # Models to test
    models = ['SimpleCNN', 'ResNet', 'Transformer']
    
    # Initialize counters for total models
    total_initialized = 0
    total_used = 0
    
    # Run benchmarks for each model and dataset combination
    results = {}
    
    for dataset in datasets:
        results[dataset] = {}
        
        for model in models:
            print(f"\n\n{'='*50}")
            print(f"Running benchmark for {model} on {dataset}")
            print(f"{'='*50}")
            
            # Run benchmark with fewer epochs for quick results
            benchmark_results = run_activation_benchmark(model, dataset, act_names, epochs=5)
            results[dataset][model] = benchmark_results
            
            # Assume run_activation_benchmark initializes and uses len(act_names) models per call
            initialized_count = len(act_names)  # One model per activation
            used_count = len(act_names)  # All initialized models are used, assuming no errors
            total_initialized += initialized_count
            total_used += used_count
            print(f"Models initialized for {model} on {dataset}: {initialized_count}")
            print(f"Models used for {model} on {dataset}: {used_count}")
    
    # Run GAN benchmarks
    gan_results = {}
    for dataset in ['MNIST', 'EMNIST']:  # GANs are typically simpler to train on MNIST-like datasets
        print(f"\n\n{'='*50}")
        print(f"Running GAN benchmark on {dataset}")
        print(f"{'='*50}")
        
        # Run GAN benchmark with fewer epochs for quick results
        gan_benchmark = run_gan_benchmark(dataset, act_names, epochs=5)
        gan_results[dataset] = gan_benchmark
        
        # Assume run_gan_benchmark initializes and uses 2 models (Generator, Discriminator) per activation
        initialized_count = 2 * len(act_names)  # Two models per activation
        used_count = 2 * len(act_names)  # All initialized models are used
        total_initialized += initialized_count
        total_used += used_count
        print(f"Models initialized for GAN on {dataset}: {initialized_count}")
        print(f"Models used for GAN on {dataset}: {used_count}")
    
    # Run gradient flow analysis
    gradient_flow_results = run_gradient_flow_analysis(datasets, models, act_names)
    
    # Print total counts
    print("\n\n===== All benchmarks completed =====")
    print(f"Total models initialized: {total_initialized}")
    print(f"Total models used: {total_used}")
    print("Results saved to 'results/' directory")

if __name__ == "__main__":
    main()


===== Theoretical Analysis of Activation Functions =====

Analyzing ReLU...
  - Lipschitz constant estimate: 1.0000
  - Differentiable: True

Analyzing LeakyReLU...
  - Lipschitz constant estimate: 1.0000
  - Differentiable: True

Analyzing PReLU...
  - Lipschitz constant estimate: 1.0000
  - Differentiable: True

Analyzing SoftReLU...
  - Lipschitz constant estimate: 1.0998
  - Differentiable: True

Analyzing ESReLU...
  - Lipschitz constant estimate: 1.0000
  - Differentiable: True

Analyzing APU...
  - Lipschitz constant estimate: 1.0153
  - Differentiable: True


Running benchmark for SimpleCNN on MNIST

===== Running SimpleCNN on MNIST =====


100%|██████████| 9.91M/9.91M [00:00<00:00, 38.2MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.14MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 10.4MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 7.99MB/s]



### Activation: ReLU


                                                           

Epoch 1/5: loss=0.1133, train_acc=96.5%, test_acc=98.5%, grad_norm=0.29, time=17.90s


                                                           

Epoch 2/5: loss=0.0375, train_acc=98.8%, test_acc=98.8%, grad_norm=0.16, time=17.13s


                                                           

Epoch 3/5: loss=0.0230, train_acc=99.3%, test_acc=98.9%, grad_norm=0.13, time=17.22s


                                                           

Epoch 4/5: loss=0.0154, train_acc=99.5%, test_acc=98.9%, grad_norm=0.10, time=17.06s


                                                           

Epoch 5/5: loss=0.0114, train_acc=99.6%, test_acc=98.6%, grad_norm=0.09, time=17.10s

### Activation: LeakyReLU


                                                           

Epoch 1/5: loss=0.1179, train_acc=96.5%, test_acc=98.9%, grad_norm=0.27, time=16.88s


                                                           

Epoch 2/5: loss=0.0361, train_acc=98.9%, test_acc=99.0%, grad_norm=0.15, time=17.01s


                                                           

Epoch 3/5: loss=0.0224, train_acc=99.3%, test_acc=98.5%, grad_norm=0.12, time=17.26s


                                                           

Epoch 4/5: loss=0.0152, train_acc=99.5%, test_acc=98.8%, grad_norm=0.10, time=17.05s


                                                           

Epoch 5/5: loss=0.0098, train_acc=99.7%, test_acc=98.9%, grad_norm=0.07, time=16.80s

### Activation: PReLU


                                                           

Epoch 1/5: loss=0.1100, train_acc=96.7%, test_acc=98.7%, grad_norm=0.20, time=18.03s


                                                           

Epoch 2/5: loss=0.0331, train_acc=99.0%, test_acc=98.7%, grad_norm=0.13, time=18.16s


                                                           

Epoch 3/5: loss=0.0201, train_acc=99.3%, test_acc=98.5%, grad_norm=0.10, time=18.06s


                                                           

Epoch 4/5: loss=0.0138, train_acc=99.6%, test_acc=98.8%, grad_norm=0.08, time=18.07s


                                                           

Epoch 5/5: loss=0.0103, train_acc=99.7%, test_acc=98.9%, grad_norm=0.07, time=18.04s

### Activation: SoftReLU


                                                           

Epoch 1/5: loss=0.1055, train_acc=96.7%, test_acc=98.6%, grad_norm=0.23, time=17.80s


                                                           

Epoch 2/5: loss=0.0328, train_acc=99.0%, test_acc=98.7%, grad_norm=0.14, time=17.86s


                                                           

Epoch 3/5: loss=0.0190, train_acc=99.4%, test_acc=98.7%, grad_norm=0.11, time=17.88s


                                                           

Epoch 4/5: loss=0.0139, train_acc=99.6%, test_acc=98.8%, grad_norm=0.09, time=17.94s


                                                           

Epoch 5/5: loss=0.0125, train_acc=99.6%, test_acc=98.9%, grad_norm=0.09, time=17.62s

### Activation: ESReLU


                                                           

Epoch 1/5: loss=0.1112, train_acc=96.6%, test_acc=98.6%, grad_norm=0.24, time=19.86s


                                                           

Epoch 2/5: loss=0.0343, train_acc=98.9%, test_acc=98.8%, grad_norm=0.14, time=19.72s


                                                           

Epoch 3/5: loss=0.0209, train_acc=99.3%, test_acc=98.5%, grad_norm=0.11, time=19.85s


                                                           

Epoch 4/5: loss=0.0131, train_acc=99.6%, test_acc=98.3%, grad_norm=0.09, time=19.93s


                                                           

Epoch 5/5: loss=0.0113, train_acc=99.6%, test_acc=98.7%, grad_norm=0.08, time=20.03s

### Activation: APU


                                                           

Epoch 1/5: loss=0.1675, train_acc=94.9%, test_acc=97.5%, grad_norm=0.20, time=21.14s


                                                           

Epoch 2/5: loss=0.0655, train_acc=98.0%, test_acc=97.9%, grad_norm=0.15, time=21.04s


                                                           

Epoch 3/5: loss=0.0409, train_acc=98.6%, test_acc=98.5%, grad_norm=0.16, time=20.79s


                                                           

Epoch 4/5: loss=0.0298, train_acc=99.0%, test_acc=98.4%, grad_norm=0.18, time=21.00s


                                                           

Epoch 5/5: loss=0.0234, train_acc=99.2%, test_acc=98.4%, grad_norm=0.21, time=20.92s
Models initialized for SimpleCNN on MNIST: 6
Models used for SimpleCNN on MNIST: 6


Running benchmark for ResNet on MNIST

===== Running ResNet on MNIST =====

### Activation: ReLU


                                                           

Epoch 1/5: loss=0.1047, train_acc=96.8%, test_acc=98.8%, grad_norm=0.06, time=58.62s


                                                           

Epoch 2/5: loss=0.0445, train_acc=98.7%, test_acc=98.5%, grad_norm=0.03, time=58.36s


                                                           

Epoch 3/5: loss=0.0341, train_acc=99.0%, test_acc=98.6%, grad_norm=0.02, time=58.27s


                                                           

Epoch 4/5: loss=0.0281, train_acc=99.2%, test_acc=98.9%, grad_norm=0.01, time=58.38s


                                                           

Epoch 5/5: loss=0.0261, train_acc=99.2%, test_acc=98.9%, grad_norm=0.01, time=58.28s

### Activation: LeakyReLU


                                                           

Epoch 1/5: loss=0.1101, train_acc=96.6%, test_acc=98.4%, grad_norm=0.06, time=58.46s


                                                           

Epoch 2/5: loss=0.0437, train_acc=98.7%, test_acc=98.6%, grad_norm=0.02, time=58.37s


                                                           

Epoch 3/5: loss=0.0371, train_acc=98.9%, test_acc=99.0%, grad_norm=0.02, time=58.43s


                                                           

Epoch 4/5: loss=0.0288, train_acc=99.1%, test_acc=99.0%, grad_norm=0.01, time=58.38s


                                                           

Epoch 5/5: loss=0.0241, train_acc=99.2%, test_acc=99.0%, grad_norm=0.01, time=58.30s

### Activation: PReLU


                                                           

Epoch 1/5: loss=0.1071, train_acc=96.7%, test_acc=95.1%, grad_norm=0.05, time=65.86s


                                                           

Epoch 2/5: loss=0.0443, train_acc=98.7%, test_acc=98.8%, grad_norm=0.02, time=65.80s


                                                           

Epoch 3/5: loss=0.0340, train_acc=99.0%, test_acc=99.3%, grad_norm=0.02, time=65.89s


                                                           

Epoch 4/5: loss=0.0248, train_acc=99.2%, test_acc=99.2%, grad_norm=0.01, time=65.89s


                                                           

Epoch 5/5: loss=0.0241, train_acc=99.3%, test_acc=99.3%, grad_norm=0.01, time=65.85s

### Activation: SoftReLU


                                                           

Epoch 1/5: loss=0.1079, train_acc=96.7%, test_acc=97.2%, grad_norm=0.05, time=65.62s


                                                           

Epoch 2/5: loss=0.0427, train_acc=98.7%, test_acc=99.2%, grad_norm=0.02, time=65.68s


                                                           

Epoch 3/5: loss=0.0333, train_acc=99.0%, test_acc=99.2%, grad_norm=0.02, time=65.61s


                                                           

Epoch 4/5: loss=0.0273, train_acc=99.1%, test_acc=99.3%, grad_norm=0.01, time=65.66s


                                                           

Epoch 5/5: loss=0.0223, train_acc=99.3%, test_acc=98.9%, grad_norm=0.01, time=65.70s

### Activation: ESReLU


                                                           

Epoch 1/5: loss=0.1097, train_acc=96.5%, test_acc=98.4%, grad_norm=0.06, time=78.07s


                                                           

Epoch 2/5: loss=0.0446, train_acc=98.7%, test_acc=98.6%, grad_norm=0.02, time=77.93s


                                                           

Epoch 3/5: loss=0.0356, train_acc=98.9%, test_acc=98.8%, grad_norm=0.02, time=78.02s


                                                           

Epoch 4/5: loss=0.0289, train_acc=99.1%, test_acc=99.0%, grad_norm=0.01, time=78.09s


                                                           

Epoch 5/5: loss=0.0250, train_acc=99.2%, test_acc=99.0%, grad_norm=0.01, time=77.85s

### Activation: APU


                                                           

Epoch 1/5: loss=0.3276, train_acc=90.2%, test_acc=97.5%, grad_norm=0.05, time=96.60s


                                                           

Epoch 2/5: loss=0.0725, train_acc=97.7%, test_acc=98.4%, grad_norm=0.03, time=96.58s


                                                           

Epoch 3/5: loss=0.0488, train_acc=98.5%, test_acc=98.5%, grad_norm=0.02, time=96.54s


                                                           

Epoch 4/5: loss=0.0402, train_acc=98.7%, test_acc=98.7%, grad_norm=0.02, time=96.57s


                                                           

Epoch 5/5: loss=0.0347, train_acc=99.0%, test_acc=98.9%, grad_norm=0.02, time=96.57s
Models initialized for ResNet on MNIST: 6
Models used for ResNet on MNIST: 6


Running benchmark for Transformer on MNIST

===== Running Transformer on MNIST =====

### Activation: ReLU


                                                           

Epoch 1/5: loss=0.2441, train_acc=92.5%, test_acc=95.9%, grad_norm=0.14, time=17.41s


                                                           

Epoch 2/5: loss=0.1153, train_acc=96.3%, test_acc=96.9%, grad_norm=0.11, time=17.33s


                                                           

Epoch 3/5: loss=0.0911, train_acc=97.2%, test_acc=97.0%, grad_norm=0.10, time=17.44s


                                                           

Epoch 4/5: loss=0.0765, train_acc=97.6%, test_acc=97.5%, grad_norm=0.09, time=17.36s


                                                           

Epoch 5/5: loss=0.0727, train_acc=97.7%, test_acc=97.0%, grad_norm=0.10, time=17.35s

### Activation: LeakyReLU


                                                           

Epoch 1/5: loss=0.2407, train_acc=92.7%, test_acc=96.6%, grad_norm=0.13, time=17.47s


                                                           

Epoch 2/5: loss=0.1172, train_acc=96.2%, test_acc=95.7%, grad_norm=0.11, time=17.25s


                                                           

Epoch 3/5: loss=0.0987, train_acc=96.9%, test_acc=96.7%, grad_norm=0.11, time=17.46s


                                                           

Epoch 4/5: loss=0.0798, train_acc=97.5%, test_acc=97.1%, grad_norm=0.10, time=17.24s


                                                           

Epoch 5/5: loss=0.0673, train_acc=97.8%, test_acc=97.4%, grad_norm=0.09, time=17.30s

### Activation: PReLU


                                                           

Epoch 1/5: loss=0.2421, train_acc=92.6%, test_acc=94.0%, grad_norm=0.14, time=17.78s


                                                           

Epoch 2/5: loss=0.1189, train_acc=96.4%, test_acc=96.6%, grad_norm=0.11, time=17.83s


                                                           

Epoch 3/5: loss=0.0963, train_acc=97.0%, test_acc=97.0%, grad_norm=0.11, time=17.83s


                                                           

Epoch 4/5: loss=0.0843, train_acc=97.3%, test_acc=97.0%, grad_norm=0.11, time=17.93s


                                                           

Epoch 5/5: loss=0.0698, train_acc=97.8%, test_acc=97.8%, grad_norm=0.09, time=17.91s

### Activation: SoftReLU


                                                           

Epoch 1/5: loss=0.2607, train_acc=92.0%, test_acc=95.2%, grad_norm=0.13, time=17.52s


                                                           

Epoch 2/5: loss=0.1309, train_acc=96.0%, test_acc=96.4%, grad_norm=0.11, time=17.75s


                                                           

Epoch 3/5: loss=0.0959, train_acc=97.0%, test_acc=96.8%, grad_norm=0.10, time=17.58s


                                                           

Epoch 4/5: loss=0.0833, train_acc=97.4%, test_acc=96.6%, grad_norm=0.10, time=17.64s


                                                           

Epoch 5/5: loss=0.0674, train_acc=97.8%, test_acc=97.5%, grad_norm=0.08, time=17.55s

### Activation: ESReLU


                                                           

Epoch 1/5: loss=0.2488, train_acc=92.4%, test_acc=95.2%, grad_norm=0.13, time=18.46s


                                                           

Epoch 2/5: loss=0.1237, train_acc=96.1%, test_acc=96.3%, grad_norm=0.11, time=18.48s


                                                           

Epoch 3/5: loss=0.0972, train_acc=97.0%, test_acc=96.3%, grad_norm=0.10, time=18.44s


                                                           

Epoch 4/5: loss=0.0838, train_acc=97.4%, test_acc=97.5%, grad_norm=0.10, time=18.44s


                                                           

Epoch 5/5: loss=0.0853, train_acc=97.2%, test_acc=97.7%, grad_norm=0.11, time=17.96s

### Activation: APU


                                                           

Epoch 1/5: loss=0.3787, train_acc=88.6%, test_acc=92.3%, grad_norm=0.13, time=18.94s


                                                           

Epoch 2/5: loss=0.2115, train_acc=93.6%, test_acc=95.5%, grad_norm=0.11, time=18.82s


                                                           

Epoch 3/5: loss=0.1477, train_acc=95.6%, test_acc=95.2%, grad_norm=0.10, time=18.83s


                                                           

Epoch 4/5: loss=0.1189, train_acc=96.4%, test_acc=96.5%, grad_norm=0.10, time=18.92s


                                                           

Epoch 5/5: loss=0.1050, train_acc=96.7%, test_acc=96.3%, grad_norm=0.10, time=18.89s
Models initialized for Transformer on MNIST: 6
Models used for Transformer on MNIST: 6


Running benchmark for SimpleCNN on CIFAR10

===== Running SimpleCNN on CIFAR10 =====


100%|██████████| 170M/170M [00:02<00:00, 74.5MB/s] 



### Activation: ReLU


                                                           

Epoch 1/5: loss=1.4909, train_acc=45.8%, test_acc=60.5%, grad_norm=0.90, time=23.58s


                                                           

Epoch 2/5: loss=1.1584, train_acc=58.9%, test_acc=65.8%, grad_norm=1.00, time=23.49s


                                                           

Epoch 3/5: loss=1.0586, train_acc=62.6%, test_acc=68.2%, grad_norm=1.00, time=23.51s


                                                           

Epoch 4/5: loss=0.9945, train_acc=64.9%, test_acc=69.2%, grad_norm=1.01, time=23.69s


                                                           

Epoch 5/5: loss=0.9483, train_acc=66.5%, test_acc=69.1%, grad_norm=1.01, time=23.28s

### Activation: LeakyReLU


                                                           

Epoch 1/5: loss=1.4321, train_acc=47.7%, test_acc=62.0%, grad_norm=0.92, time=23.39s


                                                           

Epoch 2/5: loss=1.1043, train_acc=60.5%, test_acc=67.3%, grad_norm=1.01, time=23.45s


                                                           

Epoch 3/5: loss=0.9916, train_acc=65.1%, test_acc=69.7%, grad_norm=1.01, time=23.36s


                                                           

Epoch 4/5: loss=0.9198, train_acc=67.6%, test_acc=71.4%, grad_norm=1.00, time=23.19s


                                                           

Epoch 5/5: loss=0.8626, train_acc=69.8%, test_acc=71.5%, grad_norm=0.99, time=23.33s

### Activation: PReLU


                                                           

Epoch 1/5: loss=1.3962, train_acc=49.4%, test_acc=63.6%, grad_norm=0.79, time=24.60s


                                                           

Epoch 2/5: loss=1.0784, train_acc=62.0%, test_acc=67.5%, grad_norm=0.89, time=24.48s


                                                           

Epoch 3/5: loss=0.9551, train_acc=66.1%, test_acc=68.3%, grad_norm=0.93, time=24.51s


                                                           

Epoch 4/5: loss=0.8848, train_acc=68.8%, test_acc=73.1%, grad_norm=0.96, time=24.61s


                                                           

Epoch 5/5: loss=0.8298, train_acc=70.7%, test_acc=73.2%, grad_norm=0.98, time=24.78s

### Activation: SoftReLU


                                                           

Epoch 1/5: loss=1.4353, train_acc=48.3%, test_acc=62.4%, grad_norm=0.79, time=24.47s


                                                           

Epoch 2/5: loss=1.1179, train_acc=60.5%, test_acc=66.6%, grad_norm=0.93, time=24.58s


                                                           

Epoch 3/5: loss=1.0168, train_acc=64.4%, test_acc=68.9%, grad_norm=0.95, time=24.43s


                                                           

Epoch 4/5: loss=0.9490, train_acc=66.5%, test_acc=70.3%, grad_norm=0.95, time=24.60s


                                                           

Epoch 5/5: loss=0.9068, train_acc=68.3%, test_acc=71.8%, grad_norm=0.98, time=24.79s

### Activation: ESReLU


                                                           

Epoch 1/5: loss=1.4110, train_acc=48.9%, test_acc=62.6%, grad_norm=0.90, time=26.97s


                                                           

Epoch 2/5: loss=1.0860, train_acc=61.5%, test_acc=66.5%, grad_norm=0.99, time=27.10s


                                                           

Epoch 3/5: loss=0.9690, train_acc=65.6%, test_acc=69.6%, grad_norm=1.01, time=27.00s


                                                           

Epoch 4/5: loss=0.8990, train_acc=68.5%, test_acc=71.7%, grad_norm=1.00, time=26.91s


                                                           

Epoch 5/5: loss=0.8498, train_acc=70.2%, test_acc=73.7%, grad_norm=1.01, time=26.84s

### Activation: APU


                                                           

Epoch 1/5: loss=1.5606, train_acc=44.8%, test_acc=57.1%, grad_norm=0.55, time=30.10s


                                                           

Epoch 2/5: loss=1.2629, train_acc=55.2%, test_acc=62.4%, grad_norm=0.70, time=29.95s


                                                           

Epoch 3/5: loss=1.1636, train_acc=58.8%, test_acc=64.6%, grad_norm=0.89, time=30.05s


                                                           

Epoch 4/5: loss=1.1226, train_acc=60.7%, test_acc=65.8%, grad_norm=1.05, time=29.81s


                                                           

Epoch 5/5: loss=1.0693, train_acc=62.4%, test_acc=66.3%, grad_norm=1.06, time=29.76s
Models initialized for SimpleCNN on CIFAR10: 6
Models used for SimpleCNN on CIFAR10: 6


Running benchmark for ResNet on CIFAR10

===== Running ResNet on CIFAR10 =====

### Activation: ReLU


                                                           

Epoch 1/5: loss=1.4600, train_acc=46.5%, test_acc=59.2%, grad_norm=0.15, time=60.43s


                                                           

Epoch 2/5: loss=0.9839, train_acc=65.0%, test_acc=69.7%, grad_norm=0.13, time=60.39s


                                                           

Epoch 3/5: loss=0.7746, train_acc=72.8%, test_acc=74.8%, grad_norm=0.12, time=60.45s


                                                           

Epoch 4/5: loss=0.6377, train_acc=77.9%, test_acc=78.8%, grad_norm=0.11, time=60.54s


                                                           

Epoch 5/5: loss=0.5477, train_acc=80.9%, test_acc=80.2%, grad_norm=0.10, time=60.39s

### Activation: LeakyReLU


                                                           

Epoch 1/5: loss=1.4129, train_acc=48.2%, test_acc=62.0%, grad_norm=0.16, time=60.38s


                                                           

Epoch 2/5: loss=0.9437, train_acc=66.9%, test_acc=64.6%, grad_norm=0.13, time=60.79s


                                                           

Epoch 3/5: loss=0.7398, train_acc=74.1%, test_acc=74.5%, grad_norm=0.11, time=60.70s


                                                           

Epoch 4/5: loss=0.6171, train_acc=78.7%, test_acc=78.3%, grad_norm=0.11, time=60.62s


                                                           

Epoch 5/5: loss=0.5326, train_acc=81.6%, test_acc=81.2%, grad_norm=0.10, time=60.43s

### Activation: PReLU


                                                           

Epoch 1/5: loss=1.4495, train_acc=46.7%, test_acc=58.6%, grad_norm=0.12, time=68.39s


                                                           

Epoch 2/5: loss=0.9351, train_acc=66.7%, test_acc=70.5%, grad_norm=0.11, time=68.31s


                                                           

Epoch 3/5: loss=0.6994, train_acc=75.6%, test_acc=76.4%, grad_norm=0.10, time=68.24s


                                                           

Epoch 4/5: loss=0.5560, train_acc=80.7%, test_acc=79.6%, grad_norm=0.10, time=68.27s


                                                           

Epoch 5/5: loss=0.4775, train_acc=83.6%, test_acc=81.6%, grad_norm=0.09, time=68.18s

### Activation: SoftReLU


                                                           

Epoch 1/5: loss=1.5446, train_acc=42.5%, test_acc=57.3%, grad_norm=0.11, time=68.37s


                                                           

Epoch 2/5: loss=0.9868, train_acc=65.1%, test_acc=67.6%, grad_norm=0.10, time=68.39s


                                                           

Epoch 3/5: loss=0.7487, train_acc=73.6%, test_acc=75.6%, grad_norm=0.10, time=68.36s


                                                           

Epoch 4/5: loss=0.6067, train_acc=78.9%, test_acc=79.7%, grad_norm=0.09, time=68.30s


                                                           

Epoch 5/5: loss=0.5245, train_acc=81.9%, test_acc=81.8%, grad_norm=0.08, time=68.23s

### Activation: ESReLU


                                                           

Epoch 1/5: loss=1.5199, train_acc=43.8%, test_acc=54.7%, grad_norm=0.15, time=82.73s


                                                           

Epoch 2/5: loss=1.0236, train_acc=63.3%, test_acc=66.9%, grad_norm=0.12, time=82.67s


                                                           

Epoch 3/5: loss=0.8204, train_acc=71.0%, test_acc=75.2%, grad_norm=0.11, time=82.59s


                                                           

Epoch 4/5: loss=0.6698, train_acc=76.7%, test_acc=76.5%, grad_norm=0.10, time=83.34s


                                                           

Epoch 5/5: loss=0.5804, train_acc=79.7%, test_acc=80.6%, grad_norm=0.10, time=83.07s

### Activation: APU


                                                           

Epoch 1/5: loss=1.8697, train_acc=31.2%, test_acc=41.5%, grad_norm=0.07, time=102.62s


                                                           

Epoch 2/5: loss=1.3922, train_acc=49.3%, test_acc=54.8%, grad_norm=0.08, time=102.81s


                                                           

Epoch 3/5: loss=1.1221, train_acc=60.2%, test_acc=63.9%, grad_norm=0.12, time=102.65s


                                                           

Epoch 4/5: loss=0.9941, train_acc=64.9%, test_acc=68.2%, grad_norm=0.16, time=102.68s


                                                           

Epoch 5/5: loss=0.9112, train_acc=68.1%, test_acc=71.2%, grad_norm=0.20, time=102.95s
Models initialized for ResNet on CIFAR10: 6
Models used for ResNet on CIFAR10: 6


Running benchmark for Transformer on CIFAR10

===== Running Transformer on CIFAR10 =====

### Activation: ReLU


                                                           

Epoch 1/5: loss=1.9101, train_acc=30.3%, test_acc=36.0%, grad_norm=0.20, time=23.86s


                                                           

Epoch 2/5: loss=1.7805, train_acc=35.2%, test_acc=38.8%, grad_norm=0.21, time=23.43s


                                                           

Epoch 3/5: loss=1.7338, train_acc=37.1%, test_acc=39.3%, grad_norm=0.22, time=23.53s


                                                           

Epoch 4/5: loss=1.7077, train_acc=37.9%, test_acc=39.5%, grad_norm=0.24, time=23.48s


                                                           

Epoch 5/5: loss=1.6639, train_acc=39.8%, test_acc=40.6%, grad_norm=0.23, time=23.84s

### Activation: LeakyReLU


                                                           

Epoch 1/5: loss=1.9189, train_acc=30.0%, test_acc=35.4%, grad_norm=0.19, time=23.74s


                                                           

Epoch 2/5: loss=1.7806, train_acc=35.1%, test_acc=37.3%, grad_norm=0.20, time=23.50s


                                                           

Epoch 3/5: loss=1.7233, train_acc=37.4%, test_acc=40.3%, grad_norm=0.21, time=23.39s


                                                           

Epoch 4/5: loss=1.6780, train_acc=39.3%, test_acc=38.5%, grad_norm=0.20, time=23.40s


                                                           

Epoch 5/5: loss=1.6672, train_acc=39.5%, test_acc=41.9%, grad_norm=0.22, time=23.54s

### Activation: PReLU


                                                           

Epoch 1/5: loss=1.9129, train_acc=29.9%, test_acc=36.2%, grad_norm=0.19, time=23.77s


                                                           

Epoch 2/5: loss=1.7658, train_acc=36.1%, test_acc=39.3%, grad_norm=0.20, time=23.81s


                                                           

Epoch 3/5: loss=1.7130, train_acc=37.8%, test_acc=36.9%, grad_norm=0.21, time=23.74s


                                                           

Epoch 4/5: loss=1.6827, train_acc=39.0%, test_acc=42.1%, grad_norm=0.22, time=23.67s


                                                           

Epoch 5/5: loss=1.6440, train_acc=40.5%, test_acc=42.0%, grad_norm=0.21, time=23.66s

### Activation: SoftReLU


                                                           

Epoch 1/5: loss=1.9241, train_acc=29.8%, test_acc=35.7%, grad_norm=0.18, time=23.51s


                                                           

Epoch 2/5: loss=1.7861, train_acc=35.1%, test_acc=38.5%, grad_norm=0.20, time=23.43s


                                                           

Epoch 3/5: loss=1.7252, train_acc=37.5%, test_acc=40.9%, grad_norm=0.20, time=23.41s


                                                           

Epoch 4/5: loss=1.6865, train_acc=39.0%, test_acc=41.7%, grad_norm=0.20, time=23.83s


                                                           

Epoch 5/5: loss=1.6587, train_acc=39.9%, test_acc=40.9%, grad_norm=0.19, time=23.69s

### Activation: ESReLU


                                                           

Epoch 1/5: loss=1.9289, train_acc=29.3%, test_acc=33.4%, grad_norm=0.20, time=24.04s


                                                           

Epoch 2/5: loss=1.8031, train_acc=34.6%, test_acc=35.7%, grad_norm=0.21, time=23.84s


                                                           

Epoch 3/5: loss=1.7526, train_acc=36.4%, test_acc=39.7%, grad_norm=0.22, time=23.95s


                                                           

Epoch 4/5: loss=1.6994, train_acc=38.0%, test_acc=41.5%, grad_norm=0.21, time=24.10s


                                                           

Epoch 5/5: loss=1.6695, train_acc=39.3%, test_acc=40.8%, grad_norm=0.20, time=24.01s

### Activation: APU


                                                           

Epoch 1/5: loss=2.0148, train_acc=25.9%, test_acc=29.2%, grad_norm=0.14, time=24.40s


                                                           

Epoch 2/5: loss=1.9047, train_acc=30.8%, test_acc=34.9%, grad_norm=0.14, time=24.67s


                                                           

Epoch 3/5: loss=1.8178, train_acc=34.2%, test_acc=38.9%, grad_norm=0.15, time=24.82s


                                                           

Epoch 4/5: loss=1.7415, train_acc=36.8%, test_acc=38.7%, grad_norm=0.15, time=24.65s


                                                           

Epoch 5/5: loss=1.7314, train_acc=37.2%, test_acc=38.3%, grad_norm=0.17, time=24.68s
Models initialized for Transformer on CIFAR10: 6
Models used for Transformer on CIFAR10: 6


Running benchmark for SimpleCNN on EMNIST

===== Running SimpleCNN on EMNIST =====


100%|██████████| 562M/562M [00:03<00:00, 148MB/s]  



### Activation: ReLU


                                                             

Epoch 1/5: loss=0.6310, train_acc=79.9%, test_acc=84.5%, grad_norm=0.75, time=31.97s


                                                             

Epoch 2/5: loss=0.3770, train_acc=86.7%, test_acc=86.1%, grad_norm=0.56, time=31.15s


                                                             

Epoch 3/5: loss=0.3123, train_acc=88.5%, test_acc=86.2%, grad_norm=0.49, time=31.21s


                                                             

Epoch 4/5: loss=0.2658, train_acc=89.8%, test_acc=86.8%, grad_norm=0.46, time=30.89s


                                                             

Epoch 5/5: loss=0.2269, train_acc=91.1%, test_acc=86.8%, grad_norm=0.43, time=30.96s

### Activation: LeakyReLU


                                                             

Epoch 1/5: loss=0.5945, train_acc=80.8%, test_acc=85.5%, grad_norm=0.67, time=30.73s


                                                             

Epoch 2/5: loss=0.3513, train_acc=87.4%, test_acc=86.2%, grad_norm=0.50, time=30.83s


                                                             

Epoch 3/5: loss=0.2867, train_acc=89.2%, test_acc=87.1%, grad_norm=0.44, time=30.89s


                                                             

Epoch 4/5: loss=0.2374, train_acc=90.7%, test_acc=86.6%, grad_norm=0.40, time=31.15s


                                                             

Epoch 5/5: loss=0.1969, train_acc=92.0%, test_acc=86.8%, grad_norm=0.38, time=31.38s

### Activation: PReLU


                                                             

Epoch 1/5: loss=0.5852, train_acc=81.1%, test_acc=85.6%, grad_norm=0.53, time=33.36s


                                                             

Epoch 2/5: loss=0.3517, train_acc=87.4%, test_acc=86.1%, grad_norm=0.47, time=33.05s


                                                             

Epoch 3/5: loss=0.2884, train_acc=89.2%, test_acc=87.0%, grad_norm=0.44, time=33.24s


                                                             

Epoch 4/5: loss=0.2371, train_acc=90.7%, test_acc=87.1%, grad_norm=0.42, time=33.05s


                                                             

Epoch 5/5: loss=0.1949, train_acc=92.1%, test_acc=86.8%, grad_norm=0.40, time=33.22s

### Activation: SoftReLU


                                                             

Epoch 1/5: loss=0.5772, train_acc=81.3%, test_acc=85.0%, grad_norm=0.60, time=32.79s


                                                             

Epoch 2/5: loss=0.3531, train_acc=87.4%, test_acc=85.8%, grad_norm=0.49, time=32.84s


                                                             

Epoch 3/5: loss=0.2872, train_acc=89.2%, test_acc=86.5%, grad_norm=0.45, time=32.90s


                                                             

Epoch 4/5: loss=0.2381, train_acc=90.7%, test_acc=86.7%, grad_norm=0.43, time=32.84s


                                                             

Epoch 5/5: loss=0.1964, train_acc=92.0%, test_acc=86.3%, grad_norm=0.41, time=32.98s

### Activation: ESReLU


                                                             

Epoch 1/5: loss=0.5893, train_acc=81.0%, test_acc=85.7%, grad_norm=0.64, time=36.39s


                                                             

Epoch 2/5: loss=0.3532, train_acc=87.3%, test_acc=86.2%, grad_norm=0.48, time=36.79s


                                                             

Epoch 3/5: loss=0.2884, train_acc=89.1%, test_acc=86.8%, grad_norm=0.44, time=36.60s


                                                             

Epoch 4/5: loss=0.2391, train_acc=90.7%, test_acc=86.4%, grad_norm=0.41, time=36.64s


                                                             

Epoch 5/5: loss=0.1992, train_acc=92.0%, test_acc=86.5%, grad_norm=0.39, time=36.78s

### Activation: APU


                                                             

Epoch 1/5: loss=0.7308, train_acc=77.3%, test_acc=82.4%, grad_norm=0.42, time=38.60s


                                                             

Epoch 2/5: loss=0.4845, train_acc=83.6%, test_acc=83.1%, grad_norm=0.53, time=38.70s


                                                             

Epoch 3/5: loss=0.4405, train_acc=84.8%, test_acc=83.6%, grad_norm=0.93, time=38.51s


                                                             

Epoch 4/5: loss=0.4349, train_acc=84.9%, test_acc=80.3%, grad_norm=1.72, time=38.60s


                                                             

Epoch 5/5: loss=0.4302, train_acc=85.0%, test_acc=83.3%, grad_norm=2.16, time=38.45s
Models initialized for SimpleCNN on EMNIST: 6
Models used for SimpleCNN on EMNIST: 6


Running benchmark for ResNet on EMNIST

===== Running ResNet on EMNIST =====

### Activation: ReLU


                                                             

Epoch 1/5: loss=0.5110, train_acc=82.5%, test_acc=86.2%, grad_norm=0.09, time=109.23s


                                                             

Epoch 2/5: loss=0.3411, train_acc=87.4%, test_acc=87.2%, grad_norm=0.05, time=109.27s


                                                             

Epoch 3/5: loss=0.3032, train_acc=88.7%, test_acc=88.3%, grad_norm=0.04, time=109.33s


                                                             

Epoch 4/5: loss=0.2792, train_acc=89.4%, test_acc=88.3%, grad_norm=0.03, time=109.16s


                                                             

Epoch 5/5: loss=0.2571, train_acc=90.1%, test_acc=89.3%, grad_norm=0.03, time=109.29s

### Activation: LeakyReLU


                                                             

Epoch 1/5: loss=0.5122, train_acc=82.5%, test_acc=86.6%, grad_norm=0.08, time=109.65s


                                                             

Epoch 2/5: loss=0.3421, train_acc=87.5%, test_acc=86.8%, grad_norm=0.05, time=109.52s


                                                             

Epoch 3/5: loss=0.3035, train_acc=88.6%, test_acc=87.9%, grad_norm=0.04, time=109.63s


                                                             

Epoch 4/5: loss=0.2809, train_acc=89.3%, test_acc=89.0%, grad_norm=0.03, time=109.77s


                                                             

Epoch 5/5: loss=0.2585, train_acc=90.0%, test_acc=89.5%, grad_norm=0.03, time=109.58s

### Activation: PReLU


                                                             

Epoch 1/5: loss=0.5061, train_acc=82.8%, test_acc=84.7%, grad_norm=0.07, time=123.43s


                                                             

Epoch 2/5: loss=0.3323, train_acc=87.7%, test_acc=87.4%, grad_norm=0.04, time=123.45s


                                                             

Epoch 3/5: loss=0.2946, train_acc=88.9%, test_acc=88.6%, grad_norm=0.03, time=123.46s


                                                             

Epoch 4/5: loss=0.2674, train_acc=89.8%, test_acc=88.5%, grad_norm=0.03, time=123.44s


                                                             

Epoch 5/5: loss=0.2448, train_acc=90.4%, test_acc=89.4%, grad_norm=0.03, time=123.55s

### Activation: SoftReLU


                                                             

Epoch 1/5: loss=0.4955, train_acc=83.1%, test_acc=86.2%, grad_norm=0.07, time=122.87s


                                                             

Epoch 2/5: loss=0.3295, train_acc=87.9%, test_acc=87.6%, grad_norm=0.04, time=122.96s


                                                             

Epoch 3/5: loss=0.2975, train_acc=88.8%, test_acc=88.7%, grad_norm=0.03, time=122.97s


                                                             

Epoch 4/5: loss=0.2704, train_acc=89.7%, test_acc=88.2%, grad_norm=0.03, time=122.95s


                                                             

Epoch 5/5: loss=0.2477, train_acc=90.4%, test_acc=88.7%, grad_norm=0.03, time=122.82s

### Activation: ESReLU


                                                             

Epoch 1/5: loss=0.5123, train_acc=82.4%, test_acc=84.2%, grad_norm=0.09, time=146.40s


                                                             

Epoch 2/5: loss=0.3431, train_acc=87.4%, test_acc=86.8%, grad_norm=0.05, time=146.20s


                                                             

Epoch 3/5: loss=0.3070, train_acc=88.6%, test_acc=88.0%, grad_norm=0.04, time=146.33s


                                                             

Epoch 4/5: loss=0.2804, train_acc=89.4%, test_acc=88.2%, grad_norm=0.03, time=146.37s


                                                             

Epoch 5/5: loss=0.2601, train_acc=89.9%, test_acc=89.0%, grad_norm=0.03, time=146.42s

### Activation: APU


                                                             

Epoch 1/5: loss=0.7572, train_acc=76.2%, test_acc=85.0%, grad_norm=0.06, time=181.40s


                                                             

Epoch 2/5: loss=0.3717, train_acc=86.7%, test_acc=87.5%, grad_norm=0.04, time=181.58s


                                                             

Epoch 3/5: loss=0.3270, train_acc=87.9%, test_acc=87.5%, grad_norm=0.04, time=181.59s


                                                             

Epoch 4/5: loss=0.2983, train_acc=88.8%, test_acc=88.4%, grad_norm=0.04, time=181.61s


                                                             

Epoch 5/5: loss=0.2770, train_acc=89.4%, test_acc=88.1%, grad_norm=0.04, time=181.70s
Models initialized for ResNet on EMNIST: 6
Models used for ResNet on EMNIST: 6


Running benchmark for Transformer on EMNIST

===== Running Transformer on EMNIST =====

### Activation: ReLU


                                                             

Epoch 1/5: loss=0.8273, train_acc=74.5%, test_acc=80.5%, grad_norm=0.23, time=32.83s


                                                             

Epoch 2/5: loss=0.5311, train_acc=81.8%, test_acc=81.8%, grad_norm=0.22, time=32.75s


                                                             

Epoch 3/5: loss=0.4739, train_acc=83.5%, test_acc=82.8%, grad_norm=0.21, time=33.07s


                                                             

Epoch 4/5: loss=0.4394, train_acc=84.3%, test_acc=83.5%, grad_norm=0.21, time=32.79s


                                                             

Epoch 5/5: loss=0.4164, train_acc=85.1%, test_acc=82.9%, grad_norm=0.21, time=32.71s

### Activation: LeakyReLU


                                                             

Epoch 1/5: loss=0.8271, train_acc=74.3%, test_acc=80.6%, grad_norm=0.23, time=33.30s


                                                             

Epoch 2/5: loss=0.5300, train_acc=81.8%, test_acc=80.9%, grad_norm=0.21, time=32.98s


                                                             

Epoch 3/5: loss=0.4634, train_acc=83.8%, test_acc=82.4%, grad_norm=0.20, time=32.67s


                                                             

Epoch 4/5: loss=0.4440, train_acc=84.2%, test_acc=83.6%, grad_norm=0.21, time=33.05s


                                                             

Epoch 5/5: loss=0.4332, train_acc=84.5%, test_acc=82.5%, grad_norm=0.23, time=32.90s

### Activation: PReLU


                                                             

Epoch 1/5: loss=0.8208, train_acc=74.5%, test_acc=80.1%, grad_norm=0.23, time=33.51s


                                                             

Epoch 2/5: loss=0.5252, train_acc=82.0%, test_acc=81.7%, grad_norm=0.22, time=33.73s


                                                             

Epoch 3/5: loss=0.4722, train_acc=83.4%, test_acc=82.1%, grad_norm=0.22, time=33.65s


                                                             

Epoch 4/5: loss=0.4293, train_acc=84.7%, test_acc=84.1%, grad_norm=0.21, time=33.28s


                                                             

Epoch 5/5: loss=0.4054, train_acc=85.4%, test_acc=83.4%, grad_norm=0.21, time=33.89s

### Activation: SoftReLU


                                                             

Epoch 1/5: loss=0.8469, train_acc=73.8%, test_acc=81.0%, grad_norm=0.21, time=33.36s


                                                             

Epoch 2/5: loss=0.5135, train_acc=82.3%, test_acc=82.7%, grad_norm=0.19, time=33.41s


                                                             

Epoch 3/5: loss=0.4470, train_acc=84.2%, test_acc=83.2%, grad_norm=0.18, time=33.66s


                                                             

Epoch 4/5: loss=0.4103, train_acc=85.2%, test_acc=84.3%, grad_norm=0.18, time=33.60s


                                                             

Epoch 5/5: loss=0.3830, train_acc=86.0%, test_acc=84.0%, grad_norm=0.18, time=33.51s

### Activation: ESReLU


                                                             

Epoch 1/5: loss=0.8435, train_acc=73.9%, test_acc=80.0%, grad_norm=0.23, time=34.78s


                                                             

Epoch 2/5: loss=0.5341, train_acc=81.8%, test_acc=82.4%, grad_norm=0.21, time=34.69s


                                                             

Epoch 3/5: loss=0.4816, train_acc=83.2%, test_acc=83.0%, grad_norm=0.21, time=34.52s


                                                             

Epoch 4/5: loss=0.4441, train_acc=84.2%, test_acc=82.8%, grad_norm=0.21, time=34.09s


                                                             

Epoch 5/5: loss=0.4165, train_acc=85.0%, test_acc=83.7%, grad_norm=0.21, time=34.39s

### Activation: APU


                                                             

Epoch 1/5: loss=1.0938, train_acc=67.6%, test_acc=77.1%, grad_norm=0.22, time=35.69s


                                                             

Epoch 2/5: loss=0.6424, train_acc=78.8%, test_acc=78.5%, grad_norm=0.21, time=35.78s


                                                             

Epoch 3/5: loss=0.5559, train_acc=81.3%, test_acc=81.0%, grad_norm=0.20, time=36.21s


                                                             

Epoch 4/5: loss=0.4968, train_acc=82.8%, test_acc=81.9%, grad_norm=0.18, time=36.00s


                                                             

Epoch 5/5: loss=0.4790, train_acc=83.3%, test_acc=82.9%, grad_norm=0.19, time=36.22s
Models initialized for Transformer on EMNIST: 6
Models used for Transformer on EMNIST: 6


Running GAN benchmark on MNIST

===== Running GAN on MNIST =====

### Activation: ReLU
Epoch 1/5: d_loss=0.1732, g_loss=3.8635, real_score=0.9083, fake_score=0.1765, time=15.00s
Epoch 2/5: d_loss=0.0505, g_loss=6.0432, real_score=0.9728, fake_score=0.0394, time=15.28s
Epoch 3/5: d_loss=0.0373, g_loss=6.6036, real_score=0.9808, fake_score=0.0259, time=14.97s
Epoch 4/5: d_loss=0.0263, g_loss=7.0473, real_score=0.9866, fake_score=0.0172, time=15.35s
Epoch 5/5: d_loss=0.0211, g_loss=7.6734, real_score=0.9898, fake_score=0.0130, time=15.07s

### Activation: LeakyReLU
Epoch 1/5: d_loss=0.1767, g_loss=3.7994, real_score=0.9040, fake_score=0.1791, time=15.18s
Epoch 2/5: d_loss=0.0525, g_loss=5.8269, real_score=0.9709, fake_score=0.0413, time=15.23s
Epoch 3/5: d_loss=0.0353, g_loss=6.4890, real_score=0.9818, fake_score=0.