In [None]:
# !pip install snntorch

In [None]:
import urllib.request
import torch, torch.nn as nn
import snntorch as snn
import snntorch.functional as SF

from torch import randn_like

import numpy as np

In [None]:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


data_path='C:/desktop' # Directory where MNIST dataset is stored
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") # Use GPU if available

# Define a transform to normalize data
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

# Download and load the training and test FashionMNIST datasets
# fmnist_train = datasets.FashionMNIST(data_path, train=True, download=True, transform=transform)
fmnist_test = datasets.FashionMNIST(data_path, train=False, download=True, transform=transform)

In [None]:
config = {
    "num_epochs": 100,  # Number of epochs to train for (per trial)
    "batch_size": 128,  # Batch size
    "seed": 0,  # Random seed


    # Network parameters
    "batch_norm": True,  # Whether or not to use batch normalization
    "dropout": 0.13,  # Dropout rate
    "beta": 0.39,  # Decay rate parameter (beta)
    "threshold": 1.5,  # Threshold parameter (theta)
    "lr": 2.0e-3,  # Initial learning rate
    "slope": 7.7,  # Slope value (k)

    # Fixed params
    "num_steps": 100,  # Number of timesteps to encode input for
    "correct_rate": 0.8,  # Correct rate
    "incorrect_rate": 0.2,  # Incorrect rate
    "betas": (0.9, 0.999),  # Adam optimizer beta values
    "eta_min": 0,  # Minimum learning rate
}

In [None]:
batch_size = config['batch_size']
# trainloader = DataLoader(fmnist_train, batch_size=batch_size, shuffle=True)
testloader = DataLoader(fmnist_test, batch_size=batch_size, shuffle=False)

In [None]:
from snntorch import surrogate
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.thr = config["threshold"]
        self.slope = config["slope"]
        self.beta = config["beta"]
        self.num_steps = config["num_steps"]
        self.batch_norm = config["batch_norm"]
        self.p1 = config["dropout"]
        self.spike_grad = surrogate.fast_sigmoid(self.slope)

        # Initialize Layers
        self.conv1 = nn.Conv2d(1, 16, 5, bias=False)
        self.conv1_bn = nn.BatchNorm2d(16)
        self.lif1 = snn.Leaky(self.beta, threshold=self.thr, spike_grad=self.spike_grad)
        self.conv2 = nn.Conv2d(16, 64, 5, bias=False)
        self.conv2_bn = nn.BatchNorm2d(64)
        self.lif2 = snn.Leaky(self.beta, threshold=self.thr, spike_grad=self.spike_grad)
        self.fc1 = nn.Linear(64 * 4 * 4, 10, bias=False)
        self.lif3 = snn.Leaky(self.beta, threshold=self.thr, spike_grad=self.spike_grad)
        self.dropout = nn.Dropout(self.p1)

    def forward(self, x):
        # Initialize hidden states and outputs at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        # Record the final layer
        spk3_rec = []
        mem3_rec = []

        # Forward pass
        for step in range(self.num_steps):
            cur1 = F.avg_pool2d(self.conv1(x), 2)
            if self.batch_norm:
                cur1 = self.conv1_bn(cur1)

            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = F.avg_pool2d(self.conv2(spk1), 2)
            if self.batch_norm:
                cur2 = self.conv2_bn(cur2)

            spk2, mem2 = self.lif2(cur2, mem2)
            cur3 = self.dropout(self.fc1(spk2.flatten(1)))
            spk3, mem3 = self.lif3(cur3, mem3)
            spk3_rec.append(spk3)
            mem3_rec.append(mem3)

        return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0)

net = Net(config).to(device)

In [None]:
def test(config, net, testloader, device=device):
    """Calculate accuracy on full test set."""
    correct = 0
    total = 0
    with torch.no_grad():
        net.eval()
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs, _ = net(images)
            accuracy = SF.accuracy_rate(outputs, labels)
            total += labels.size(0)
            correct += accuracy * labels.size(0)

    return 100 * correct / total

In [None]:
optimizer = torch.optim.Adam(net.parameters(),
    lr=config["lr"], betas=config["betas"]
)
criterion = SF.mse_count_loss(correct_rate=config["correct_rate"],
    incorrect_rate=config["incorrect_rate"]
)

In [None]:
# install aihwkit first
# !pip install aihwkit
# !conda install -c conda-forge aihwkit-gpu

In [None]:
# import aihwkit libraries here
from aihwkit.simulator.configs import (
    InferenceRPUConfig,
)
from aihwkit.inference import PCMLikeNoiseModel, ReRamWan2022NoiseModel, GlobalDriftCompensation

rpu_config = InferenceRPUConfig()
rpu_config.noise_model = PCMLikeNoiseModel(g_max=25.0)  # PCM noise model
# rpu_config.noise_model = ReRamWan2022NoiseModel(g_max=40.0) # RRAM noise model

rpu_config.mapping.max_input_size = 256
rpu_config.mapping.max_output_size = 256

rpu_config.forward.out_noise = 0.04
rpu_config.forward.inp_res = 2**8
rpu_config.forward.out_res = 2**8
# rpu_config.drift_compensation = GlobalDriftCompensation() #Enable only for PCM devices

from aihwkit.nn.conversion import convert_to_analog
# from aihwkit.simulator.presets import StandardHWATrainingPreset
from aihwkit.inference.calibration import (
    calibrate_input_ranges,
    InputRangeCalibrationType,
)


In [None]:
t_inferences = [0.0, 3600.0, 86400.0]  # Times to perform inference on PCM.
# t_inferences = [1.0, 86400.0, 172800.0] # Times to perform inference on RRAM.
n_reps = 5  # Number of inference repetitions.

model = Net(config)
model.load_state_dict(torch.load(r'fmnist_fp32.pt', map_location=torch.device('cpu'))) # load the trained model here
model.to(device)
model.eval()
accuracy = test(config, model, testloader, device)
print(f"Original accuracy: {accuracy}")
model = convert_to_analog(model, rpu_config=rpu_config)
model.eval() # Determine the inference accuracy with the specified rpu configuration.
print(f"Evaluating imported model number.")
inference_accuracy_values = torch.zeros((len(t_inferences), n_reps))
for t_id, t in enumerate(t_inferences):
  for i in range(n_reps):
    model.drift_analog_weights(t)
    accuracy = test(config, model, testloader, device)
    inference_accuracy_values[t_id, i] = accuracy
    print(
        f"Test set accuracy (%) at t={t}s: mean: {inference_accuracy_values[t_id].mean()}, \
        std: {inference_accuracy_values[t_id].std()}"
        )