# Impact of Pixel Normalization on the Performance of Different CNN Activation Functions in FashionMNIST

## 1. Conceptual question
Does pixel normalization affect the performance of ReLU and Tanh activation functions in CNNs?

---

## 2. Hypothesis
We hypothesize that Tanh performs better when input pixel values are normalized to a narrow range near zero (e.g., [-1, 1]), whereas ReLU demonstrates superior performance when pixel values span a wider range (e.g., [-5, 5]).

---

## 3. Experimental design
In this experiment, a CNN architecture was designed where all structural components remained constant, with the exception of the activation functions. The controlled variables were the range of pixel values and the type of activation function. To test the hypothesis, we compared the loss convergence and accuracy of models employing ReLU versus Tanh across different input ranges. This comparison aims to clearly elucidate the performance disparities between CNNs utilizing different activation functions.

---

## 4. Experiment code
Code used to generate the results.

### 1. Setup

In [None]:
#Setup the dependencies
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
import json
import os
import numpy as np
from matplotlib.patches import Patch

#Set the hyperparameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
config = {
    "batch_size": 128,
    "num_workers": 2,
    "num_epochs": 20,
    "learning_rate": 0.01,
    "switch_epoch": 5,
    "device": device
}

### 2. Data Loading (Small Range [-1,1])

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)) # normalize the pixel value range to [-1,1]
])

train_dataset = datasets.FashionMNIST(
    root="./data_project2_DL",
    train=True,
    download=True,
    transform=transform
)

test_dataset = datasets.FashionMNIST(
    root="./data_project2_DL",
    train=False,
    download=True,
    transform=transform
)

train_loader = DataLoader(
    train_dataset,
    batch_size=config["batch_size"],
    shuffle=True,
    num_workers=config["num_workers"]
)

test_loader = DataLoader(
    test_dataset,
    batch_size=config["batch_size"],
    shuffle=False,
    num_workers=config["num_workers"]
)

### 3. Data Loading (Large Range [-5,5])

In [None]:
transform_new = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.1,))# normalize the pixel value range to [-5,5]
])

train_dataset_modified = datasets.FashionMNIST(
    root="./data_project2_DL",
    train=True,
    download=True,
    transform=transform_new
)

test_dataset_modified = datasets.FashionMNIST(
    root="./data_project2_DL",
    train=False,
    download=True,
    transform=transform_new
)

train_loader_modified = DataLoader(
    train_dataset_modified,
    batch_size=config["batch_size"],
    shuffle=True,
    num_workers=config["num_workers"]
)

test_loader_modified = DataLoader(
    test_dataset_modified,
    batch_size=config["batch_size"],
    shuffle=False,
    num_workers=config["num_workers"]
)

### 4. Define the CNN that can use different activation function either Tanh or ReLU

In [None]:
class different_activation_CNN(nn.Module):
    def __init__(self,activation="ReLU"):
        super(different_activation_CNN,self).__init__()
        self.activation=activation
        self.conv1=nn.Conv2d(in_channels=1,out_channels=8,kernel_size=3,padding=1)
        self.conv2=nn.Conv2d(in_channels=8,out_channels=16,kernel_size=3,padding=1)
        self.conv3=nn.Conv2d(in_channels=16,out_channels=32,kernel_size=3,padding=1)
        self.conv4=nn.Conv2d(in_channels=32,out_channels=64,kernel_size=3)
        self.conv5=nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3)
        self.lin1=nn.Linear(in_features=64*3*3,out_features=64)
        self.lin2=nn.Linear(in_features=64,out_features=32)
        self.lin3=nn.Linear(in_features=32,out_features=10)
        self.activation_records = []
        self.use_record = False

    def forward(self, x):
        layers = [
            ("conv1", self.conv1, True),
            ("conv2", self.conv2, True),
            ("conv3", self.conv3, False),
            ("conv4", self.conv4, False),
            ("conv5", self.conv5, False),
            ("lin1", self.lin1, False),
            ("lin2", self.lin2, False),
            ("lin3", self.lin3, False),
        ]

        for name, layer, if_pool in layers:
            x = layer(x)

            if self.use_record: # If we want to record the pre-activation values
                activations = x.detach().cpu().numpy().flatten()
                if len(activations) > 5000: # We only select up to 5000 pre-activation-values
                    activations = np.random.choice(activations, 5000, replace=False)
                self.activation_records.append((name, activations))

            if name.startswith("lin") and x.dim() > 2:
                x = x.reshape(x.shape[0], -1)

            if name != "lin3":
                x = F.relu(x) if self.activation == "ReLU" else torch.tanh(x)

            if if_pool:
                x = F.max_pool2d(x, 2)

        return x

    def switch_activation(self,activation):
        self.activation=activation

### 5. Training and Evaluation Functions

In [None]:
def single_activation_train_with_recording(model,train_loader,device,optimizer):
    model.train()
    criterion=nn.CrossEntropyLoss()
    losses=[]
    epoch_activations={f'conv{i}':[] for i in range(1,6)}
    epoch_activations.update({f'lin{i}':[] for i in range(1,3)})
    for epoch in range(config["num_epochs"]):
        one_epoch_loss=0
        for idx,(img,label) in enumerate(train_loader):
            model.use_record=(idx==0)
            img=img.to(device)
            label=label.to(device)
            model.activation_records=[]
            output=model(img)
            loss=criterion(output,label)
            one_epoch_loss+=loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if idx==0:
                for layer_name,values in model.activation_records:
                    epoch_activations[layer_name].append(values)
                model.activation_records=[]
                model.use_record=False
        losses.append(one_epoch_loss)
    losses=np.array(losses)/len(train_loader)
    print("Average loss:",np.mean(losses))
    return losses,epoch_activations


def single_activation_test(model,test_loader,device):
    model.eval()
    correct=0
    total=0
    with torch.no_grad():
        for img,label in test_loader:
            img=img.to(device)
            label=label.to(device)
            output=model(img)
            _,predicted=torch.max(output,1)
            total+=label.shape[0]
            correct+=((predicted==label).sum().item())
    print("Accuracy:",100*correct/total)
    return 100*correct/total

### 5. Visualize Function

In [None]:
def plot_activation_violinplots_comparison(epoch_activations1, activation_name1,epoch_activations2, activation_name2,save_dir):
    def plot_violin_subplot(ax, data1, data2, epoch_indices, xlabel_suffix, colors):
        combined, positions = [], []
        for i, epoch_idx in enumerate(epoch_indices):
            combined.append(data1[epoch_idx])
            combined.append(data2[epoch_idx])
            positions.append(i * 3 + 1)
            positions.append(i * 3 + 1.8)

        parts = ax.violinplot(combined, positions=positions, widths=0.7,
                              showmeans=True, showmedians=True)
        for i, pc in enumerate(parts['bodies']):
            pc.set_facecolor(colors[i % 2])
            pc.set_alpha(0.7)
            pc.set_edgecolor('black')
            pc.set_linewidth(0.5)
        for partname in ('cbars','cmins','cmaxes','cmedians','cmeans'):
            if partname in parts:
                parts[partname].set_edgecolor('black')
                parts[partname].set_linewidth(1)

        ax.set_xticks([i * 3 + 1.4 for i in range(len(epoch_indices))])
        ax.set_xticklabels([str(i + 1) for i in epoch_indices])
        ax.set_xlabel(f"Epoch {xlabel_suffix}", fontsize=11)
        ax.set_ylabel("Activation Value", fontsize=11)
        ax.grid(True, alpha=0.3, axis='y')
        ax.tick_params(axis="x", rotation=45)

        if combined:
            flat = np.concatenate([np.ravel(arr) for arr in combined])
            vmin, vmax = flat.min(), flat.max()
            margin = (vmax - vmin) * 0.05 if vmax != vmin else 0.1
            ax.set_ylim(vmin - margin, vmax + margin)

    layer_names = ["conv1","conv2","conv3","conv4","conv5","lin1","lin2"]
    num_epochs = len(epoch_activations1[layer_names[0]])
    early_epochs = list(range(0, min(5, num_epochs)))
    late_epochs = list(range(5, num_epochs))
    colors = ['lightblue', 'lightcoral']

    fig, axes = plt.subplots(len(layer_names), 2, figsize=(18, 28),
                             gridspec_kw={'width_ratios': [8, 12]})
    fig.suptitle(f"Activation Distribution Comparison: {activation_name1} vs {activation_name2}",
                 fontsize=18, y=0.995)

    for idx, layer_name in enumerate(layer_names):
        ax_early = axes[idx, 0]
        plot_violin_subplot(ax_early, epoch_activations1[layer_name],
                           epoch_activations2[layer_name],
                           early_epochs, "(first 5)", colors)
        ax_early.set_title(f"{layer_name} (first 5)", fontsize=13, fontweight="bold", pad=10)

        ax_late = axes[idx, 1]
        if late_epochs:
            plot_violin_subplot(ax_late, epoch_activations1[layer_name],
                               epoch_activations2[layer_name],
                               late_epochs, "(later)", colors)
            ax_late.set_title(f"{layer_name} (later)", fontsize=13, fontweight="bold", pad=10)
        else:
            ax_late.axis('off')

    legend_elements = [Patch(facecolor='lightblue', alpha=0.7, label=activation_name1),
                       Patch(facecolor='lightcoral', alpha=0.7, label=activation_name2)]
    fig.legend(handles=legend_elements, loc='upper right', fontsize=13, bbox_to_anchor=(0.98, 0.99))
    plt.tight_layout(rect=[0, 0, 1, 0.995])
    plt.savefig(f"{save_dir}/figures/activation_violinplots_comparison_{activation_name1}_vs_{activation_name2}.png",
                dpi=150, bbox_inches='tight')
    plt.show()

### 6. Train, Evaluate, and Visualize for models with small pixel value range

In [None]:
model_relu=different_activation_CNN(activation="ReLU").to(device)
model_tanh=different_activation_CNN(activation="Tanh").to(device)
optimizer_relu=optim.SGD(model_relu.parameters(),lr=config["learning_rate"],momentum=0.5)
optimizer_tanh=optim.SGD(model_tanh.parameters(),lr=config["learning_rate"],momentum=0.5)

losses_relu,epoch_activations_relu=single_activation_train_with_recording(model_relu,train_loader,device,optimizer_relu)
losses_tanh,epoch_activations_tanh=single_activation_train_with_recording(model_tanh,train_loader,device,optimizer_tanh)

accuracy_relu=single_activation_test(model_relu,test_loader,device)
accuracy_tanh=single_activation_test(model_tanh,test_loader,device)


# Here is visualization
plt.plot(losses_relu, "-o")
plt.plot(losses_tanh, "-o")
plt.title("Losses between ReLU and Tanh")
plt.legend(["ReLU","Tanh"])
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid(linestyle='-', alpha=0.3)
plt.savefig(f"{save_dir}/figures/comparisons.png")

plot_activation_violinplots_comparison(epoch_activations_relu, 'ReLU', epoch_activations_tanh, 'Tanh', save_dir)

### 7. Train, Evaluate, and Visualize for models with large pixel value range

In [None]:
model_relu_new=different_activation_CNN(activation="ReLU").to(device)
model_tanh_new=different_activation_CNN(activation="Tanh").to(device)
optimizer_relu_new=optim.SGD(model_relu_new.parameters(),lr=config["learning_rate"],momentum=0.5)
optimizer_tanh_new=optim.SGD(model_tanh_new.parameters(),lr=config["learning_rate"],momentum=0.5)

losses_relu_new,epoch_activations_relu_new=single_activation_train_with_recording(model_relu_new,train_loader_modified,device,optimizer_relu_new)
losses_tanh_new,epoch_activations_tanh_new=single_activation_train_with_recording(model_tanh_new,train_loader_modified,device,optimizer_tanh_new)

accuracy_relu_new=single_activation_test(model_relu_new,test_loader_modified,device)
accuracy_tanh_new=single_activation_test(model_tanh_new,test_loader_modified,device)


# Here is visualization
plt.plot(losses_relu_new,'-o')
plt.plot(losses_tanh_new,'-o')
plt.title("Losses between ReLU and Tanh")
plt.legend(["ReLU","Tanh"])
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid(linestyle='-', alpha=0.3)
plt.savefig(f"{save_dir}/figures/comparisons_new.png")
plt.show()

plot_activation_violinplots_comparison(epoch_activations_relu_new, 'ReLU_new', epoch_activations_tanh_new, 'Tanh_new', save_dir)