<a href="https://colab.research.google.com/github/Aliiior/snntorch/blob/master/SNN_ocr_farsi.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install snntorch -q
pip install ranger-adabelief -q

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Test Accuracy: 98.43%

1. Excitatory neurons (positive influence): Use a higher threshold of 1.0 and learning enabled

2. Inhibitory neurons (negative influence): Use a lower threshold of 0.5 and no learning on the threshold

3. Integrated combined spikes from excitatory and inhibitory neurons in the forward pass

4. Increased Temporal Resolution

- Reason: A higher number of time steps allows for better temporal dynamics and more accurate spike-based computation

- Change: Increased num_steps from 50 to 60 for the spike sequence processing


5. Enhanced Dropout Rates

- Reason: Improved regularization to prevent overfitting during training

- Change: Slightly increased dropout rates for better generalization

... Dropout(0.5) for the first layer

... Dropout(0.4) for the second layer



6. Optimized Training with Mixed Precision

- Added PyTorch AMP for mixed-precision training
- Faster computation on GPUs
- Reduced memory usage without losing accuracy
- Integrated with GradScaler and torch.cuda.amp.autocast()

7. Weight Initialization Improvements

- Enhanced weight initialization for fully connected layers using Xavier Normal Initialization for better convergence:

8. Optimizer and Scheduler

- Used RangerAdaBelief for better generalization and optimization.
- Scheduler: CosineAnnealingWarmRestarts for adaptive learning rate changes, encouraging better convergence during later epochs:

In [None]:
# Imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import snntorch as snn
from snntorch import surrogate
import numpy as np
import random
from ranger_adabelief import RangerAdaBelief

# Set Random Seed
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Class Distribution and Weights
class_counts = {0.0: 19393, 1.0: 18148, 2.0: 15879, 3.0: 12010, 4.0: 11207,
                5.0: 11993, 6.0: 11323, 7.0: 13356, 8.0: 13666, 9.0: 11851}
total_samples = sum(class_counts.values())
class_weights = {k: total_samples / v for k, v in class_counts.items()}
weights = torch.tensor([class_weights[i] for i in range(10)], dtype=torch.float32)

# Data Transformations
train_transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.Grayscale(num_output_channels=1),
    transforms.RandomRotation(15),
    transforms.RandomAffine(0, scale=(0.8, 1.2)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

test_transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Custom Dataset Class
class FarsiDigitsDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = self.data[idx].squeeze()
        label = self.labels[idx]
        image = torch.tensor(image, dtype=torch.float32)
        image = transforms.ToPILImage()(image)
        if self.transform:
            image = self.transform(image)
        return image, label

# Load and Preprocess Data
data1 = np.load('..../data_1241_before_preprocessing.npy')
labels1 = np.load('.../labels_1241_before_preprocessing.npy')
data2 = np.load('.../data_873_before_preprocessing.npy')
labels2 = np.load('.../labels_873_before_preprocessing.npy')
data3 = np.load('.../data_12610_before_preprocessing.npy')
labels3 = np.load('.../labels_12610_before_preprocessing.npy')
data4 = np.load('.../data_14579_before_preprocessing.npy')
labels4 = np.load('.../labels_14579_before_preprocessing.npy')
data5 = np.load('.../data_109523_before_preprocessing.npy')
labels5 = np.load('.../labels_109523_before_preprocessing.npy')


data = np.concatenate((data1, data2, data3,data4,data5), axis=0) / 255.0
labels = np.concatenate((labels1, labels2, labels3,labels4,labels5), axis=0)

# Dataset Splitting
dataset = FarsiDigitsDataset(data, labels, transform=None)
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

train_dataset = FarsiDigitsDataset(train_dataset.dataset.data[train_dataset.indices],
                                   train_dataset.dataset.labels[train_dataset.indices],
                                   transform=train_transform)
val_dataset = FarsiDigitsDataset(val_dataset.dataset.data[val_dataset.indices],
                                 val_dataset.dataset.labels[val_dataset.indices],
                                 transform=test_transform)
test_dataset = FarsiDigitsDataset(test_dataset.dataset.data[test_dataset.indices],
                                  test_dataset.dataset.labels[test_dataset.indices],
                                  transform=test_transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

# Enhanced SNN Model (including exc / inh)

class SNN_Enhanced(nn.Module):
    def __init__(self):
        super(SNN_Enhanced, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.layer_norm = nn.LayerNorm([64, 7, 7])

        self.fc1 = nn.Linear(64 * 7 * 7, 512)
        self.dropout1 = nn.Dropout(0.5)

        self.fc2 = nn.Linear(512 * 10, 128)
        self.dropout2 = nn.Dropout(0.4)

        self.fc3 = nn.Linear(128, 10)

        self.lif_exc = snn.Leaky(beta=0.9, spike_grad=surrogate.fast_sigmoid(), threshold=1.0, learn_threshold=True)
        self.lif_inh = snn.Leaky(beta=0.85, spike_grad=surrogate.fast_sigmoid(), threshold=0.5, learn_threshold=False)

        self.num_steps = 60
        self.population_size = 10

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.layer_norm(self.pool(torch.relu(self.conv2(x))))

        x = x.view(-1, 64 * 7 * 7)
        x = torch.relu(self.fc1(x))
        x = x.repeat(1, self.population_size).view(-1, 512 * self.population_size)

        x = torch.relu(self.fc2(x))
        x = self.dropout2(x)

        x = x.unsqueeze(0).repeat(self.num_steps, 1, 1)
        spk_rec = []

        mem_exc, spk_exc = self.lif_exc.init_leaky(), self.lif_exc.init_leaky()
        mem_inh, spk_inh = self.lif_inh.init_leaky(), self.lif_inh.init_leaky()

        for step in range(self.num_steps):
            spk_exc, mem_exc = self.lif_exc(x[step], mem_exc)
            spk_inh, mem_inh = self.lif_inh(-x[step], mem_inh)
            combined_spikes = spk_exc + spk_inh
            cur = self.fc3(combined_spikes)
            spk_rec.append(cur)

        return torch.stack(spk_rec).sum(dim=0)

# Training Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SNN_Enhanced().to(device)
loss_fn = nn.CrossEntropyLoss(weight=weights.to(device))
optimizer = RangerAdaBelief(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)

# Training Loop
best_model = None
best_val_loss = float('inf')
scaler = torch.cuda.amp.GradScaler()

for epoch in range(50):
    model.train()
    total_loss = 0
    for data, labels in train_loader:
        data, labels = data.to(device), labels.to(device, dtype=torch.int64)
        with torch.cuda.amp.autocast():
            spk_out = model(data)
            loss = loss_fn(spk_out, labels)
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()

    scheduler.step()

    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for data, labels in val_loader:
            data, labels = data.to(device), labels.to(device, dtype=torch.int64)
            spk_out = model(data)
            val_loss += loss_fn(spk_out, labels).item()
    val_loss /= len(val_loader)
    print(f"Epoch [{epoch+1}/50], Train Loss: {total_loss/len(train_loader):.4f}, Validation Loss: {val_loss:.4f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = model.state_dict()

# Test Accuracy
model.load_state_dict(best_model)
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for data, labels in test_loader:
        data, labels = data.to(device), labels.to(device, dtype=torch.int64)
        spk_out = model(data)
        preds = spk_out.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

print(f"Test Accuracy: {100 * correct / total:.2f}%")