In [1]:
import torch
import torch.nn as nn
import snntorch as snn
import torch.nn.functional as nnfunc
from torch.utils.data import DataLoader, TensorDataset
import ExpNeuron as en

В данной модели используется следующая формула для вычисления мембранного потенциала:
\begin{equation}
    U(t)=U_{0}\exp(-\beta t +\sum_{i=0}^{t}I_{in}(i)),
\end{equation}
Если $U(t) \geq threshold$, то $U_{0} = U_{min}, t = 0$. Используется сумма спайков за время $t$, так как это упрощает вычисления и решает проблему "мертового нейрона" для входных спайков при малом U".   


In [2]:
class Neuron(nn.Module):
	def __init__(self, beta, threshold, membrane_min, membrane_zero):
		super().__init__()
		self.fll = nn.Linear(1, 1, bias = False)
		self.nrn = en.ExpNeuron(
			beta= beta,
			threshold= threshold,
			membrane_zero= membrane_zero,
			membrane_min= membrane_min,
			learn_beta= True,
			learn_membrane_min= True,
			learn_threshold= True)
		
	
	def forward(self, spk_input):
		mem = self.nrn.init_neuron()
		spk_outpt = []

		for step in range(spk_input.shape[1]):
			cur = self.fll(spk_input[:, step])
			spk, mem = self.nrn(cur, mem)
			spk_outpt.append(spk)

		return torch.stack(spk_outpt, dim = 1) 
	
	
def gen_spike_train(lambda_, num_steps):
	spike_train = torch.tensor([1 if lambda_ > element else 0 for element in torch.rand(num_steps)], dtype=torch.float32)
	return spike_train

In [3]:
Lambda = 0.95
steps = 100
samples = 200
test_samples = 100
eps = 5e-2
btch_sz = 5


gnrtr = torch.Generator().manual_seed(0)

# Generating data in (a, b))
a = 0.8
b = 1
if a > 0.8:
	low_lambda_samples = int(samples * 0.3)
	low_lambda_arr = torch.ones(low_lambda_samples).uniform_(0, 0.5)
	lambd_arr = torch.cat((torch.ones(samples - low_lambda_samples).uniform_(a, b), low_lambda_arr))
else:
	lambd_arr = torch.ones(samples).uniform_(a, b)

test_lambd_arr = torch.ones(test_samples).uniform_(0, 1)

test_labels = ((abs(test_lambd_arr - Lambda) <= eps ) | (test_lambd_arr > Lambda)).float().unsqueeze(1)
labels = ((abs(lambd_arr - Lambda) <= eps ) | (lambd_arr > Lambda)).float().unsqueeze(1)

test_trains = torch.stack([gen_spike_train(lambd.item(), steps) for lambd in test_lambd_arr])
trains = torch.stack([gen_spike_train(lambd.item(), steps) for lambd in lambd_arr])

# train_data, test_data = random_split(TensorDataset(trains, labels), [samples * 80 // 100, samples * 20 // 100])
train_data = TensorDataset(trains, labels)
train_dataldr = DataLoader(dataset= train_data, batch_size= btch_sz)

test_data = TensorDataset(test_trains, test_labels) 
test_dataldr = DataLoader(dataset = test_data, batch_size= 1)

In [12]:
beta = torch.tensor(0.7)
threshold = torch.tensor(0.8)
membrane_min = torch.tensor(0.8)
membrane_zero = torch.tensor(0.6)

lrng_rt = 5e-3

epochs = 20
neuron = Neuron(beta = beta, threshold = threshold, membrane_min = membrane_min, membrane_zero = membrane_zero)
optim = torch.optim.Adam(neuron.parameters(), lr = lrng_rt, weight_decay= 0.01)

In [17]:
# BCE - функция потерь
for epoch in range(epochs):
		for trns, lbls in train_dataldr:
				optim.zero_grad()

				outputs = neuron(trns.unsqueeze(-1))
				spike_cnt = outputs.mean(dim=1)
				
				loss = nnfunc.binary_cross_entropy(input= spike_cnt, target= lbls, reduction= "mean")
				loss.backward()
				optim.step()

				with torch.no_grad():
					neuron.nrn.beta.clamp_(min = 0.1)
					neuron.nrn.threshold.clamp_(min = 0.1)
					neuron.nrn.membrane_min.clamp_(min = 0.1)


		if epoch % 5 == 0:
			print((f"Epoch {epoch}, Loss: {loss.item():.4f}, "
				f"Beta: {neuron.nrn.beta.item():.4f}, Threshold: {neuron.nrn.threshold.item():.4f}, "
				f"Membrane_min: {neuron.nrn.membrane_min.item():.4f}"))
		
# Testing accuracy
data = []
for trn, lbl in test_dataldr:
		prediction = neuron(trn).mean(dim=1).item()
		data.append([float(prediction > 0.5), lbl.item()])
		# print((f"pred: {prediction:.2f}, predicted: {prediction > 0.5}, "
		#        f"true: {lbl}"))
data = torch.tensor(data)
true_positive = data[(data[:, 0] == 1) & (data[:, 1] == 1), 0].numel()
false_positive = data[(data[:, 0] == 1) & (data[:, 1] == 0), 0].numel()
true_negative = data[(data[:, 0] == 0) & (data[:, 1] == 0), 0].numel()
false_negative = data[(data[:, 0] == 0) & (data[:, 1] == 1), 0].numel()

# Counting metrics
if true_positive and true_negative: 
		precision = true_positive / (true_positive + false_positive)
		recall = true_positive / (true_positive + false_negative)
		accuracy = (true_positive + true_negative) / (true_positive + true_negative 
																								+ false_positive + false_negative)

		print((f"precision: {precision:.2f}, recall: {recall:.2f}, accuracy : {accuracy:.2f}"))
		print((f"\ntrue_positive: {true_positive},\n"
				f"true_negative: {true_negative},\nfalse_negative: {false_negative},\n"
				f"false_positive: {false_positive}"))
else:
		print((f"Low accuracy of the model.\ntrue_positive: {true_positive},\n"
				f"true_negative: {true_negative},\nfalse_negative: {false_negative},\n"
				f"false_positive: {false_positive}"))
	
if precision and recall:
		f1 = 2 * precision * recall / (precision + recall)
		print(f"f1 metric: {f1:.3f}")

Epoch 0, Loss: 0.3089, Beta: 0.9307, Threshold: 0.7920, Membrane_min: 0.4281
Epoch 5, Loss: 0.3089, Beta: 0.9307, Threshold: 0.7920, Membrane_min: 0.4280
Epoch 10, Loss: 0.3089, Beta: 0.9307, Threshold: 0.7920, Membrane_min: 0.4280
Epoch 15, Loss: 0.3089, Beta: 0.9307, Threshold: 0.7920, Membrane_min: 0.4280
precision: 0.85, recall: 1.00, accuracy : 0.98

true_positive: 11,
true_negative: 87,
false_negative: 0,
false_positive: 2
f1 metric: 0.917


In [14]:
for par, val in neuron.named_parameters():
	print(par, val)

# torch.save(neuron.state_dict(), "binary_classification_weights.pth")


fll.weight Parameter containing:
tensor([[0.6245]], requires_grad=True)
nrn.threshold Parameter containing:
tensor(0.7920, requires_grad=True)
nrn.beta Parameter containing:
tensor(0.9307, requires_grad=True)
nrn.membrane_min Parameter containing:
tensor(0.4323, requires_grad=True)
