<a href="https://colab.research.google.com/github/abusumon/Neural-Network/blob/main/SANN_Neural_Network.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [None]:
class SANNActivation(nn.Module):
  def __init__(self, num_neurons, num_segment, shared=False) -> None:
    super(SANNActivation, self).__init__()
    self.num_neurons = num_neurons
    self.num_segment = num_segment
    self.shared = shared

    if shared:
      self.a = nn.Parameter(torch.zeros(num_segment))
      self.b = nn.Parameter(torch.zeros(num_segment))
    else:
      self.a = nn.Parameter(torch.zeros(num_neurons, num_segment))
      self.b = nn.Parameter(torch.zeros(num_neurons, num_segment))

  def forward(self, x):
    output = F.relu(x)
    if self.shared:
      for s in range(self.num_segment):
        hinge = F.relu(self.b[s]-x)
        output = output + self.a[s] * hinge
    else:
      if x.dim()==2:
        for s in range(self.num_segment):
          b_expand = self.b[:,s].unsqueeze(0)
          a_expand = self.a[:,s].unsqueeze(0)
          hinge = F.relu(b_expand-x)
          output = output + a_expand * hinge
      else:
        for s in range(self.num_segment):
          b_shape = [1, self.num_neurons] + [1]*(len(x.shape)-2)
          a_shape = [1, self.num_neurons] + [1]*(len(x.shape)-2)
          b_expand = self.b[:, s].view(b_shape)
          a_expand = self.a[:, s].view(a_shape)
          hinge = F.relu(b_expand-x)
          output = output + a_expand * hinge
    return output

In [None]:
  # def activation(self, x_vals, neuron_idx = 0):
  #   x_tensor = torch.tensor(x_vals, dtype=torch.float32)
  #   if self.shared:
  #     a_vals = self.a.detach()
  #     b_vals = self.b.detach()
  #   else:
  #     a_vals = self.a[neuron_idx].detach()
  #     b_vals = self.b[neuron_idx].detach()
  #   output = torch.relu(x_tensor)
  #   for s in range(self.num_segment):
  #     hinge = torch.relu(b_vals[s]-x_tensor)
  #     output = output + a_vals[s] * hinge
  #   return output.numpy()

In [None]:
class DenseLayer(nn.Module):
  def __init__(self, input_size, output_size, activation=None):
    super().__init__()
    self.w = nn.Parameter(torch.randn(input_size, output_size)*0.1)
    self.b = nn.Parameter(torch.zeros(output_size))
    self.activation = activation

  def forward(self, x):
    out = x @ self.w + self.b
    if self.activation:
      return self.activation(out)
    return out

In [None]:
class SANNModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_segments):
        super().__init__()
        self.layer1 = DenseLayer(
            input_size,
            hidden_size,
            activation=SANNActivation(hidden_size, num_segments)
        )
        self.layer2 = DenseLayer(
            hidden_size,
            hidden_size * 2,
            activation=SANNActivation(hidden_size * 2, num_segments)
        )
        self.layer3 = DenseLayer(
            hidden_size * 2,
            hidden_size * 4,
            activation=SANNActivation(hidden_size * 4, num_segments)
        )
        self.layer4 = DenseLayer(
            hidden_size * 4,
            hidden_size * 8,
            activation=SANNActivation(hidden_size * 8, num_segments)
        )
        self.layer5 = DenseLayer(
            hidden_size * 8,
            output_size
        )

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        return x

In [None]:
class Trainer:
    def __init__(self, model, device=None):
        self.model = model
        self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)
        self.criterion = None
        self.optimizer = None

    def compile(self, optimizer, loss):
        self.optimizer = optimizer
        self.criterion = loss

    def fit(self, train_loader, epochs=1, val_loader=None):
        for epoch in range(epochs):
            self.model.train()
            running_loss = 0
            for batch_idx, (data, target) in enumerate(train_loader):
                data, target = data.to(self.device), target.to(self.device)
                self.optimizer.zero_grad()
                output = self.model(data)
                loss = self.criterion(output, target)
                loss.backward()
                self.optimizer.step()
                running_loss += loss.item()

            avg_loss = running_loss / len(train_loader)
            print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")

            if val_loader:
                val_acc = self.evaluate(val_loader)
                print(f"Validation Accuracy: {val_acc:.2f}%")

    def evaluate(self, data_loader):
        self.model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data, target in data_loader:
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                _, preds = torch.max(output, 1)
                correct += (preds == target).sum().item()
                total += target.size(0)
        return 100 * correct / total

    def predict(self, data_loader):
        self.model.eval()
        all_preds = []
        with torch.no_grad():
            for data in data_loader:
                if isinstance(data, (list, tuple)):
                    inputs = data[0]
                else:
                    inputs = data
                inputs = inputs.to(self.device)
                output = self.model(inputs)
                _, preds = torch.max(output, 1)
                all_preds.append(preds.cpu())
        return torch.cat(all_preds)

In [None]:
input_size = 784
hidden_size = 128
output_size = 10
num_segments = 4
batch_size = 64
epochs = 5
lr = 0.001

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))  # flatten to 784
])

In [None]:
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
model = SANNModel(input_size, hidden_size, output_size, num_segments).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

In [None]:
model = SANNModel(input_size, hidden_size, output_size, num_segments).to(device)
trainer = Trainer(model, device=device)

In [None]:
trainer.compile(
    optimizer=optim.Adam(model.parameters(), lr=lr),
    loss=nn.CrossEntropyLoss()
)

In [None]:
import numpy
print(numpy.__version__)

2.3.2


In [None]:
import PIL
print(PIL.__version__)

11.3.0


In [None]:
trainer.fit(train_loader, epochs=epochs, val_loader=test_loader)

Epoch 1/5, Loss: 0.2895
Validation Accuracy: 95.78%
Epoch 2/5, Loss: 0.1101
Validation Accuracy: 95.84%
Epoch 3/5, Loss: 0.0827
Validation Accuracy: 96.46%
Epoch 4/5, Loss: 0.0705
Validation Accuracy: 96.94%
Epoch 5/5, Loss: 0.0598
Validation Accuracy: 97.24%


In [None]:
test_acc = trainer.evaluate(test_loader)
print(f"Test Accuracy: {test_acc:.2f}%")

Test Accuracy: 97.24%
