In [None]:
# NOTE: This code was used for experiments as-is.
# Naming and structure may not follow programming best practices.
# Focus is on reproducibility.
#This code was developed for internal experimentation and contains hardcoded values for various test cases.
#It was not refactored for modularity, but the logic matches the experiments reported in the paper.

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import numpy as np
import torchvision.transforms as transforms
from torchvision.models import mobilenet_v3_small
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch.nn.functional as F
import random

def set_all_seeds(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    print(f"Seed set to: {seed}")

#######################################################
#######################################################
#######################################################
#set_all_seeds(40)
#set_all_seeds(41)
set_all_seeds(42)
#set_all_seeds(43)
#set_all_seeds(44)
#######################################################
#######################################################
#######################################################

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define transforms for FashionMNIST
transform_train = transforms.Compose([
    transforms.RandomCrop(28, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.Resize(32),  # Resize to make it work better with MobileNetV3
    transforms.ToTensor(),
    transforms.Normalize((0.2860,), (0.3530,)),  # FashionMNIST mean and std
])

transform_test = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize((0.2860,), (0.3530,)),
])

# Load FashionMNIST dataset
train_data = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform_train)
test_data = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform_test)

# Create DataLoaders
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

# Print dataset information
print(f"Number of training samples: {len(train_data)}")
print(f"Number of test samples: {len(test_data)}")
print(f"Number of classes: {len(train_data.classes)}")
print(f"Classes: {train_data.classes}")

# Example: Check a batch
for images, labels in train_loader:
    print(f'Batch of images shape: {images.shape}')
    print(f'Batch of labels: {labels}')
    break

# Define custom activation function (idle/swish)
class Idle(nn.Module):
    def forward(self, x):
         #return x * torch.sigmoid(x)  #Swish
         #return 1.25*x * torch.sigmoid(x) #ESwish(UP)
         #return x*(torch.sigmoid(x)+0.125*torch.exp(-0.5*x**2))  #SwishPlus(UP)
         #return 0.95*x * torch.sigmoid(x) #ESwish(DOWN)
         #return x*(torch.sigmoid(x)-0.025*torch.exp(-0.5*x**2))  #SwishPlus(DOWN)
         #return x * torch.tanh(F.softplus(x))   #Mish
         #return x * torch.tanh(F.softplus(0.9454113159514*x)/0.9454113159514)  #PMish(UP)
         return x * torch.tanh(F.softplus(x)) +0.025*x*torch.exp(-0.5*x**2)  #MishPlus(UP)
         #return x * torch.tanh(F.softplus(1.34198859922*x)/1.34198859922)  #PMish(DOWN)
         #return x * torch.tanh(F.softplus(x)) -0.125*x*torch.exp(-0.5*x**2)  #MishPlus(DOWN)
         #return torch.max(x, torch.zeros_like(x))

# Custom weight initialization function
def initialize_weights(module):
    if isinstance(module, nn.Conv2d):
        # Using Xavier/Glorot Uniform initialization for convolutional layers
        nn.init.xavier_uniform_(module.weight, gain=1.0)
        if module.bias is not None:
            nn.init.constant_(module.bias, 0.0)
    elif isinstance(module, nn.Linear):
        # Using Orthogonal initialization for linear layers
        nn.init.orthogonal_(module.weight, gain=1.0)
        if module.bias is not None:
            nn.init.constant_(module.bias, 0.0)
    elif isinstance(module, nn.BatchNorm2d):
        nn.init.constant_(module.weight, 1.0)
        nn.init.constant_(module.bias, 0.0)


# Define MobileNetV3 model with custom activation
class CustomMobileNetV3(nn.Module):
    def __init__(self, num_classes=10, pretrained=False):
        super(CustomMobileNetV3, self).__init__()
        self.mobilenet = mobilenet_v3_small(pretrained=pretrained)

        # Modify first conv layer for 32x32 images (though MobileNetV3 should handle this size already)
        self.mobilenet.features[0][0] = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1, bias=False)

        if pretrained:
            for param in self.mobilenet.parameters():
                param.requires_grad = False

        # Replace the classifier to match FashionMNIST's 10 classes
        in_features = self.mobilenet.classifier[3].in_features
        self.mobilenet.classifier[3] = nn.Linear(in_features, num_classes)

        # Replace activations
        self.replace_activations(self.mobilenet)
        
        # Apply custom weight initialization
        self.apply(initialize_weights)
        print("Applied custom Xavier/Uniform initialization to convolutional layers and Orthogonal to linear layers")

    def replace_activations(self, module):
        """
        Recursively replace all activation functions in the model with Idle
        """
        for name, child in module.named_children():
            if isinstance(child, nn.ReLU) or isinstance(child, nn.Hardswish):
                # Direct replacement of activation modules
                new_activation = Idle()
                if isinstance(module, nn.Sequential):
                    # For Sequential containers, we need to maintain the order
                    module[int(name)] = new_activation
                else:
                    setattr(module, name, new_activation)
            elif len(list(child.children())) > 0:
                # If module has children, recurse into them
                self.replace_activations(child)

    def forward(self, x):
        return self.mobilenet(x)

# Verification function to check if replacement worked
def verify_activation_replacement(model):
    """
    Verify that all ReLU and Hardswish activations have been replaced with Idle
    """
    def check_module(module):
        relu_count = 0
        hardswish_count = 0
        idle_count = 0
        for child in module.modules():
            if isinstance(child, nn.ReLU):
                relu_count += 1
            if isinstance(child, nn.Hardswish):
                hardswish_count += 1
            if isinstance(child, Idle):
                idle_count += 1
        return relu_count, hardswish_count, idle_count

    relu_count, hardswish_count, idle_count = check_module(model)
    print(f"Found {relu_count} ReLU activations, {hardswish_count} Hardswish activations, and {idle_count} Idle activations")
    assert relu_count == 0 and hardswish_count == 0, "Some activations were not replaced!"
    return idle_count > 0

def calculate_batch_accuracy(outputs, labels):
    """Calculate accuracy for a single batch"""
    _, predicted = torch.max(outputs, 1)
    return (predicted == labels).sum().item() / labels.size(0)

def calculate_accuracy(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    return accuracy

# Initialize model, optimizer, and criterion
model = CustomMobileNetV3().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
criterion = nn.CrossEntropyLoss()

# Verify activation replacement
verify_activation_replacement(model)

# Training loop
train_accuracies = []
test_accuracies = []
train_losses = []
test_losses = []
num_epochs = 40  # Reduced epochs for faster execution

for epoch in range(num_epochs):
    # Training phase
    model.train()
    running_train_loss = 0.0
    running_train_acc = 0.0
    num_train_batches = 0

    for inputs, labels in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs} (Training)'):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # Calculate batch accuracy
        batch_acc = calculate_batch_accuracy(outputs, labels)

        running_train_loss += loss.item()
        running_train_acc += batch_acc
        num_train_batches += 1

    scheduler.step()
    avg_train_loss = running_train_loss / num_train_batches
    avg_train_acc = running_train_acc / num_train_batches

    # Testing phase
    model.eval()
    running_test_loss = 0.0
    running_test_acc = 0.0
    num_test_batches = 0

    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc=f'Epoch {epoch + 1}/{num_epochs} (Testing)'):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            batch_acc = calculate_batch_accuracy(outputs, labels)

            running_test_loss += loss.item()
            running_test_acc += batch_acc
            num_test_batches += 1

    avg_test_loss = running_test_loss / num_test_batches
    avg_test_acc = running_test_acc / num_test_batches

    # Store metrics
    train_accuracies.append(avg_train_acc)
    test_accuracies.append(avg_test_acc)
    train_losses.append(avg_train_loss)
    test_losses.append(avg_test_loss)

    print(f'Epoch {epoch + 1}/{num_epochs}')
    print(f'Training Loss: {avg_train_loss:.4f}, Training Accuracy: {avg_train_acc:.4f}')
    print(f'Test Loss: {avg_test_loss:.4f}, Test Accuracy: {avg_test_acc:.4f}')

# Print best metrics summary
best_train_acc = max(train_accuracies)
best_train_acc_epoch = train_accuracies.index(best_train_acc) + 1
best_test_acc = max(test_accuracies)
best_test_acc_epoch = test_accuracies.index(best_test_acc) + 1
lowest_train_loss = min(train_losses)
lowest_train_loss_epoch = train_losses.index(lowest_train_loss) + 1
lowest_test_loss = min(test_losses)
lowest_test_loss_epoch = test_losses.index(lowest_test_loss) + 1

print("\n" + "="*50)
print("BEST METRICS SUMMARY")
print("="*50)
print(f"Best Training Accuracy: {best_train_acc:.4f} (Epoch {best_train_acc_epoch})")
print(f"Best Test Accuracy: {best_test_acc:.4f} (Epoch {best_test_acc_epoch})")
print(f"Lowest Training Loss: {lowest_train_loss:.4f} (Epoch {lowest_train_loss_epoch})")
print(f"Lowest Test Loss: {lowest_test_loss:.4f} (Epoch {lowest_test_loss_epoch})")
print("="*50)