# SNN on ISING-MODEL

In this notebook I train a Spiking Neural network on the Ising-model


[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/{user_name}/{repo_name}/blob/{branch_name}/mnist.ipynb)

In [None]:
%pip install snntorch
%pip install torchmetrics


### Imports

In [None]:

import matplotlib.pyplot as plt
import snntorch.functional as sf
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np

from torchmetrics.classification import MulticlassAccuracy

from snn_net import Net, train

### Helper functions to import data and plot the accuracy

In [None]:
from Ising.IsingData import generate_Ising_configurations

class ISING(Dataset):
    
    def __init__(self, all_data, train, Ts):
        super().__init__()
        self.train = train

        self.Ts = Ts
        
        self.T, self.data, self.targets = self._get_data(all_data)

    def _get_data(self, all_data):
        raw, train_all, test_all = self._get_training_data(all_data, self.Ts)
        Temp, data, targets = train_all if self.train else test_all
        return Temp, data, targets
    
    def _get_training_data(self, all_data, Ts, Tc=2.7, train_fraction=0.8):
        # Lists to store the raw data
        raw_T = []
        raw_x = []
        raw_y = []
        
        for T in Ts:      
            raw_x.append(all_data['%.3f'%(T)])
            n = len(all_data['%.3f'%(T)])
            label = 1 if T < Tc else 0
            raw_y.append(np.array([label] * n))
            raw_T.append(np.array([T]*n))
            
        raw_T = np.concatenate(raw_T, dtype=np.float32)
        raw_x = np.concatenate(raw_x, axis=0, dtype=np.float32)
        raw_y = np.concatenate(raw_y, axis=0, dtype=np.longlong)
        
        # Shuffle
        indices = np.random.permutation(len(raw_x))
        all_T = raw_T[indices]
        all_x = raw_x[indices]
        all_y = raw_y[indices]

        # Split into train and test sets
        train_split = int(train_fraction * len(all_x))
        train_T = torch.from_numpy(all_T[:train_split])
        train_x = torch.from_numpy(np.clip(all_x[:train_split], 0, 1))
        train_y = torch.from_numpy(all_y[:train_split])
        test_T = torch.from_numpy(all_T[train_split:])
        test_x = torch.from_numpy(np.clip(all_x[train_split:], 0, 1))
        test_y = torch.from_numpy(all_y[train_split:])
        
        return [raw_T, raw_x, raw_y], [train_T, train_x, train_y], [test_T, test_x, test_y]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.T[index], self.data[index], self.targets[index]

In [None]:

def import_data():
    # The temperatures that we are going to generate samples at
    Ts = np.arange(1.95, 0.04, -0.1) * 2.27

    # For a few different system sizes, store the data in a dictionary with L as key
    # TODO: cache generated data
    all_data = generate_Ising_configurations(10, 1000, Ts)

    train = ISING(all_data, train=True, Ts=Ts)
    test = ISING(all_data, train=False, Ts=Ts)

    # print(f'raw: \ntrain: [{train[0].size()}, {train[1].size()}, {train[2].size()}] \ntest: [{test[0].size()}, {test[1].size()}, {test[2].size()}]')

    train_loader = DataLoader(train, batch_size=128, shuffle=True)
    test_loader = DataLoader(test, batch_size=64)

    return train_loader, test_loader


In [None]:


def plot_accuracy(acc_hist, title, test=False):
    plt.plot(acc_hist)
    plt.title(title)
    plt.xlabel("Epoch" if test else "Batch")
    plt.ylabel("Accuracy")
    # plt.savefig(title+".png")
    plt.show()

def plot_loss(loss_hist, title):
    plt.plot(loss_hist)
    plt.title(title)
    plt.xlabel("Batch")
    plt.ylabel("Loss")
    # plt.savefig(title+".png")
    plt.show()


### Hyper parameters

In [None]:

# number of epochs
epochs = 10

# number of time steps
n_steps = 25 #ms

# neuron counts
inputs = 10 * 10
hiddens = 200
outputs = 2

# membrane potential decay
decay = 0.9


In [None]:

# import training and test data
train_loader, test_loader = import_data()


In [None]:

train_T, train_x, train_y = next(iter(train_loader))

fig, ax = plt.subplots(4,5, figsize=(14,10))
for i in range(4):
    for j in range(5):
        ca = ax[i,j].matshow(np.reshape(train_x[(4*i + j)*4], (10,10)), cmap='Greys')
        ax[i,j].set_title("$T = %.3f$"%train_T[(4*i + j)*4])
        ax[i,j].set_xticks([])
        ax[i,j].set_yticks([])
        # fig.colorbar(ca)
fig.tight_layout()


### Train command for the snn

#### Encoding schemes

##### Rate encoded

In [None]:

print("SNN rate:")

# initialize net
rate_snn = Net(inputs, hiddens, outputs, decay, n_steps, enc_type='rate')

# optimization algoritm
optimizer = torch.optim.Adam(rate_snn.parameters()) # (NOTE: Adam stond in de tutorial misschien beter algoritme)

# loss function
loss_fn = sf.ce_count_loss() # type: ignore

# accuracy function
accuracy = sf.accuracy_rate

test_acc_snn_rate = train(rate_snn, optimizer, loss_fn, accuracy, train_loader, test_loader, epochs)


##### Time encoded

In [None]:

print("SNN temporal:")

# initialize net
temp_snn = Net(inputs, hiddens, outputs, decay, n_steps, 'latency')

# optimization algoritm
optimizer = torch.optim.Adam(temp_snn.parameters()) # (NOTE: Adam stond in de tutorial misschien beter algoritme)

# loss function
loss_fn = sf.ce_temporal_loss() # type: ignore

# accuracy function
accuracy = sf.accuracy_temporal

test_acc_snn_temp = train(temp_snn, optimizer, loss_fn, accuracy, train_loader, test_loader, epochs)


#### Test net on single img

In [None]:

# #ISING
train_T, train_x, train_y = next(iter(train_loader))

output = rate_snn(train_x[0])

fig, ax = plt.subplots(2)

for i in range(2):
    ax[i].plot(output[:,i].detach().numpy())
    ax[i].set_title(f'{i}')
    ax[i].set_ybound(-0.2,1.2)
    ax[i].set_xticks([])
    ax[i].set_yticks([])

fig.tight_layout()

fig, ax = plt.subplots(1, figsize=(7,5))
ax.matshow(np.reshape(train_x[0], (10,10)), cmap='Greys')
ax.set_title("$T = %.3f$, Label = %d"%(train_T[0],train_y[0]))
ax.set_xticks([])
ax.set_yticks([])
fig.tight_layout()


### Train command for Feed Forward net

In [None]:

feed_fwd_net = nn.Sequential(nn.Linear(inputs, hiddens),
                            nn.ReLU(),
                            nn.Linear(hiddens, outputs))

# optimization algoritm
optimizer = torch.optim.Adam(feed_fwd_net.parameters()) # (NOTE: Adam stond in de tutorial, misschien beter algoritme)

# loss function
loss_fn = nn.CrossEntropyLoss()

# accuracy function
accuracy = MulticlassAccuracy(num_classes=outputs)

print("FFN:")
test_acc_feed = train(feed_fwd_net, optimizer, loss_fn, accuracy, train_loader, test_loader, epochs)


### Comparison

In [None]:

fig = plt.figure(1)
plt.plot(test_acc_snn_rate, label="SNN rate")
plt.plot(test_acc_snn_temp, label="SNN temporal")
plt.plot(test_acc_feed, label="FFN")
plt.title("Test accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.show()
