In [9]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Define Mish activation
class Mish(nn.Module):
    def forward(self, x):
        return x * torch.tanh(torch.nn.functional.softplus(x))

# LSUV (Layer-sequential unit variance) Initialization
@torch.no_grad()
def lsuv_init(model, data_loader, device):
    """Applies LSUV initialization to a model."""
    model.to(device)
    model.eval()
    for module in model.modules():
        if isinstance(module, nn.Linear):
            module.weight.normal_(0, 1)
            module.bias.zero_()
    
    for images, _ in data_loader:
        images = images.to(device)
        break
    
    with torch.no_grad():
        output = model(images)
        std = output.std().item()
        if std > 1e-6:
            for module in model.modules():
                if isinstance(module, nn.Linear):
                    module.weight.data /= std
    return model

# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

# Define the neural network model
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(32 * 32 * 3, 256)
        self.bn1 = nn.BatchNorm1d(256)
        self.mish1 = Mish()
        self.fc2 = nn.Linear(256, 128)
        self.bn2 = nn.BatchNorm1d(128)
        self.mish2 = Mish()
        self.fc3 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = self.flatten(x)
        x = self.mish1(self.bn1(self.fc1(x)))
        x = self.mish2(self.bn2(self.fc2(x)))
        x = self.fc3(x)
        return x

# Training setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleNet().to(device)
model = lsuv_init(model, trainloader, device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(10):
    running_loss = 0.0
    all_preds = []
    all_labels = []
    progress_bar = tqdm(trainloader, desc=f"Epoch {epoch+1}")
    for images, labels in progress_bar:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        
        progress_bar.set_postfix(loss=running_loss / len(trainloader))
    
    acc = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='macro')
    recall = recall_score(all_labels, all_preds, average='macro')
    f1 = f1_score(all_labels, all_preds, average='macro')
    print(f"Epoch {epoch+1}: Loss: {running_loss / len(trainloader):.4f}, Acc: {acc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")

# Evaluation
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

test_acc = accuracy_score(all_labels, all_preds)
test_precision = precision_score(all_labels, all_preds, average='macro')
test_recall = recall_score(all_labels, all_preds, average='macro')
test_f1 = f1_score(all_labels, all_preds, average='macro')

print(f"Test Accuracy: {test_acc:.4f}, Precision: {test_precision:.4f}, Recall: {test_recall:.4f}, F1: {test_f1:.4f}")


Epoch 1: 100%|████████████████████████████████████████████████████████████| 782/782 [00:14<00:00, 54.39it/s, loss=1.71]


Epoch 1: Loss: 1.7099, Acc: 0.3808, Precision: 0.3744, Recall: 0.3808, F1: 0.3732


Epoch 2: 100%|████████████████████████████████████████████████████████████| 782/782 [00:14<00:00, 55.71it/s, loss=1.44]


Epoch 2: Loss: 1.4441, Acc: 0.4866, Precision: 0.4818, Recall: 0.4866, F1: 0.4826


Epoch 3: 100%|████████████████████████████████████████████████████████████| 782/782 [00:13<00:00, 56.20it/s, loss=1.31]


Epoch 3: Loss: 1.3139, Acc: 0.5331, Precision: 0.5286, Recall: 0.5331, F1: 0.5298


Epoch 4: 100%|████████████████████████████████████████████████████████████| 782/782 [00:14<00:00, 55.16it/s, loss=1.22]


Epoch 4: Loss: 1.2179, Acc: 0.5654, Precision: 0.5614, Recall: 0.5654, F1: 0.5627


Epoch 5: 100%|████████████████████████████████████████████████████████████| 782/782 [00:14<00:00, 55.55it/s, loss=1.13]


Epoch 5: Loss: 1.1318, Acc: 0.5971, Precision: 0.5937, Recall: 0.5971, F1: 0.5948


Epoch 6: 100%|████████████████████████████████████████████████████████████| 782/782 [00:14<00:00, 55.60it/s, loss=1.06]


Epoch 6: Loss: 1.0576, Acc: 0.6226, Precision: 0.6201, Recall: 0.6226, F1: 0.6209


Epoch 7: 100%|███████████████████████████████████████████████████████████| 782/782 [00:14<00:00, 55.39it/s, loss=0.994]


Epoch 7: Loss: 0.9941, Acc: 0.6463, Precision: 0.6437, Recall: 0.6463, F1: 0.6446


Epoch 8: 100%|███████████████████████████████████████████████████████████| 782/782 [00:13<00:00, 56.24it/s, loss=0.925]


Epoch 8: Loss: 0.9246, Acc: 0.6696, Precision: 0.6678, Recall: 0.6696, F1: 0.6684


Epoch 9: 100%|███████████████████████████████████████████████████████████| 782/782 [00:14<00:00, 55.14it/s, loss=0.867]


Epoch 9: Loss: 0.8667, Acc: 0.6905, Precision: 0.6888, Recall: 0.6905, F1: 0.6894


Epoch 10: 100%|██████████████████████████████████████████████████████████| 782/782 [00:14<00:00, 54.18it/s, loss=0.811]


Epoch 10: Loss: 0.8114, Acc: 0.7094, Precision: 0.7078, Recall: 0.7094, F1: 0.7084
Test Accuracy: 0.5366, Precision: 0.5392, Recall: 0.5366, F1: 0.5356


In [11]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Define the Mish activation function
class Mish(nn.Module):
    def forward(self, x):
        return x * torch.tanh(torch.nn.functional.softplus(x))

# Poisson-based weight initialization function
def poisson_weight_initializer(shape, lam=3.0, scale=0.01):
    """
    Poisson-based weight initialization.
    Args:
        shape (tuple): Shape of the weight tensor.
        lam (float): Mean of the Poisson distribution.
        scale (float): Scaling factor to control weight magnitude.
    Returns:
        A PyTorch tensor with Poisson-initialized values.
    """
    weights = np.random.poisson(lam, shape).astype(np.float32)
    weights = weights - np.mean(weights)  # Zero centering
    weights = weights / np.std(weights)   # Normalize variance
    return torch.tensor(weights * scale, dtype=torch.float32)

# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

# Define the neural network model with Mish activation and Poisson-initialized weights
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(32 * 32 * 3, 256)
        self.mish1 = Mish()
        self.fc2 = nn.Linear(256, 128)
        self.mish2 = Mish()
        self.fc3 = nn.Linear(128, 10)
        
        # Apply Poisson weight initialization
        with torch.no_grad():
            self.fc1.weight = nn.Parameter(poisson_weight_initializer(self.fc1.weight.shape, lam=5, scale=0.05))
            self.fc2.weight = nn.Parameter(poisson_weight_initializer(self.fc2.weight.shape, lam=5, scale=0.05))
            self.fc3.weight = nn.Parameter(poisson_weight_initializer(self.fc3.weight.shape, lam=5, scale=0.05))
    
    def forward(self, x):
        x = self.flatten(x)
        x = self.mish1(self.fc1(x))
        x = self.mish2(self.fc2(x))
        x = self.fc3(x)
        return x

# Training setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(10):
    running_loss = 0.0
    all_preds = []
    all_labels = []
    progress_bar = tqdm(trainloader, desc=f"Epoch {epoch+1}")
    for images, labels in progress_bar:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        
        progress_bar.set_postfix(loss=running_loss / len(trainloader))
    
    acc = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='macro')
    recall = recall_score(all_labels, all_preds, average='macro')
    f1 = f1_score(all_labels, all_preds, average='macro')
    print(f"Epoch {epoch+1}: Loss: {running_loss / len(trainloader):.4f}, Acc: {acc:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")

# Evaluation
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

test_acc = accuracy_score(all_labels, all_preds)
test_precision = precision_score(all_labels, all_preds, average='macro')
test_recall = recall_score(all_labels, all_preds, average='macro')
test_f1 = f1_score(all_labels, all_preds, average='macro')

print(f"Test Accuracy: {test_acc:.4f}, Precision: {test_precision:.4f}, Recall: {test_recall:.4f}, F1: {test_f1:.4f}")

Epoch 1: 100%|████████████████████████████████████████████████████████████| 782/782 [00:13<00:00, 56.25it/s, loss=1.61]


Epoch 1: Loss: 1.6052, Acc: 0.4314, Precision: 0.4255, Recall: 0.4314, F1: 0.4268


Epoch 2: 100%|████████████████████████████████████████████████████████████| 782/782 [00:13<00:00, 57.24it/s, loss=1.39]


Epoch 2: Loss: 1.3900, Acc: 0.5094, Precision: 0.5040, Recall: 0.5094, F1: 0.5054


Epoch 3: 100%|████████████████████████████████████████████████████████████| 782/782 [00:13<00:00, 56.00it/s, loss=1.28]


Epoch 3: Loss: 1.2834, Acc: 0.5464, Precision: 0.5423, Recall: 0.5464, F1: 0.5436


Epoch 4: 100%|█████████████████████████████████████████████████████████████| 782/782 [00:14<00:00, 55.06it/s, loss=1.2]


Epoch 4: Loss: 1.1975, Acc: 0.5781, Precision: 0.5739, Recall: 0.5781, F1: 0.5753


Epoch 5: 100%|████████████████████████████████████████████████████████████| 782/782 [00:14<00:00, 54.43it/s, loss=1.12]


Epoch 5: Loss: 1.1243, Acc: 0.6054, Precision: 0.6022, Recall: 0.6054, F1: 0.6033


Epoch 6: 100%|████████████████████████████████████████████████████████████| 782/782 [00:13<00:00, 56.05it/s, loss=1.05]


Epoch 6: Loss: 1.0538, Acc: 0.6270, Precision: 0.6243, Recall: 0.6270, F1: 0.6252


Epoch 7: 100%|███████████████████████████████████████████████████████████| 782/782 [00:14<00:00, 54.90it/s, loss=0.991]


Epoch 7: Loss: 0.9912, Acc: 0.6479, Precision: 0.6452, Recall: 0.6479, F1: 0.6462


Epoch 8: 100%|███████████████████████████████████████████████████████████| 782/782 [00:13<00:00, 56.44it/s, loss=0.931]


Epoch 8: Loss: 0.9305, Acc: 0.6707, Precision: 0.6686, Recall: 0.6707, F1: 0.6693


Epoch 9: 100%|███████████████████████████████████████████████████████████| 782/782 [00:14<00:00, 54.44it/s, loss=0.871]


Epoch 9: Loss: 0.8710, Acc: 0.6915, Precision: 0.6894, Recall: 0.6915, F1: 0.6901


Epoch 10: 100%|███████████████████████████████████████████████████████████| 782/782 [00:13<00:00, 56.87it/s, loss=0.82]


Epoch 10: Loss: 0.8204, Acc: 0.7090, Precision: 0.7075, Recall: 0.7090, F1: 0.7080
Test Accuracy: 0.5195, Precision: 0.5245, Recall: 0.5195, F1: 0.5202
