In [1]:
import ExpNeuron as en

import torch
import torch.nn as nn
import snntorch as snn
import snntorch.functional as snnfunc
import snntorch.spikegen as snngen
from torch.utils.data import DataLoader, TensorDataset, random_split
import pandas as pd

In [None]:
class Net(nn.Module):
	def __init__(self, beta = 0.5, threshold = 0.5, membrane_zero = 0.3):
		super().__init__()
		self.fll = nn.Linear(4, 4, bias = False)
		self.fsl = en.ExpNeuron(beta = beta, threshold = threshold,
								   	membrane_zero= membrane_zero,
									learn_beta= True, 
									learn_threshold= True, 
									learn_membrane_min= True)
		
		self.sll = nn.Linear(4, 3, bias = False)
		self.ssl = en.ExpNeuron(beta = beta, threshold = threshold,
								   	membrane_zero= membrane_zero,
									learn_beta= True, 
									learn_threshold= True, 
									learn_membrane_min= True)
		
		# self.fll.weight.data.clamp_(min=0.05)
		# self.sll.weight.data.clamp_(min=0.05)
		
	def forward(self, spk_input):
		mem1 = self.fsl.init_neuron()
		mem2 = self.ssl.init_neuron()

		spk_output = []
		for step in range(spk_input.shape[1]):
			cur1 = self.fll(spk_input.to(torch.float32)[:, step])
			spk1, mem1 = self.fsl(cur1, mem1)
			cur2 = self.sll(spk1)
			spk2, mem2 = self.ssl(cur2, mem2)
			
			spk_output.append(spk2)
		
		return torch.stack(spk_output)

In [None]:
file_path = "../iris_folder/Iris.csv"
df = pd.read_csv(filepath_or_buffer= file_path, sep= ",", header= 0)
df.drop("Id", axis= 1, inplace= True)

# преобразование датафрейма
species = df["Species"].unique()

# При присваивании происходит изменение типа данных на object
for i in range(len(species)):
	df.loc[df["Species"] == species[i], "Species"] = i
df["Species"] = df["Species"].astype("int")

columns_headers = df.columns

# Не учитываю species
for header in columns_headers[:-1]:
	# Нормализую знеачения
	df[header] = df[header] / df[header].max()
# создание датасета для обучения
data = []
num_steps = 100

for header in columns_headers:
	data.append(torch.tensor(df[header].values))

# Транспонируем тензор, чтобы иметь features и target каждого образца
data = torch.stack(data, dim = 0).T

trains = snngen.rate(data= data[:, :-1], num_steps= num_steps)

# labels = snngen.targets_rate(data[:, -1], num_classes=3)
labels = data[:, -1]
print(trains.shape)
# Возможно полная хрень. Требует проверки
# Я уверен, что с осями какой-то косяк, так что придется переделывать при плохих результатах обучения
trains = trains.permute(1, 0, 2)
dataset = TensorDataset(trains, labels)
train_data, test_data = random_split(dataset, [0.8, 0.2])
train_data_loader = DataLoader(train_data, shuffle= True)
test_data_loader = DataLoader(test_data, shuffle= True)
lrng_rt = 5e-3

epochs = 5
net = Net()
optim = torch.optim.Adam(net.parameters(), lr = lrng_rt)

In [24]:
for epoch in range(epochs):
		for trns, lbls in train_data_loader:
				optim.zero_grad()
				outputs = net(trns)
				
				loss_fn = snnfunc.loss.ce_count_loss()
				loss = loss_fn(spk_out=outputs, targets= lbls.to(torch.long))
				loss.backward()
				optim.step()

				# net.fll.weight.data.clamp_(min=0.05)
				# net.sll.weight.data.clamp_(min=0.05)

		print((f"Epoch {epoch}, Loss: {loss.item():.4f}"))

Epoch 0, Loss: 0.1269
Epoch 1, Loss: 0.1269
Epoch 2, Loss: 1.0986
Epoch 3, Loss: 3.0486
Epoch 4, Loss: 0.0550


In [None]:
# Testing accuracy
data = []
for trn, lbl in test_data_loader:
		prediction = net(trn)
		prediction = torch.mean(input= prediction.unsqueeze(0), dim = 1)

		data.append([prediction.argmax().item(), lbl.item()])
data = torch.tensor(data)
accurate = data[(data[:, 0] == data[:, 1]), 0].numel()
accuracy = accurate/data.shape[0]

print((f"accuracy : {accuracy:.2f}"))

In [None]:
for par, val in net.named_parameters():
    print(par, val)

# torch.save(net.state_dict(), "fisher_iris_weights.pth")


fll.weight Parameter containing:
tensor([[0.0500, 0.0500, 0.0618, 0.0765],
        [0.4804, 0.1619, 0.9437, 1.2211],
        [0.0938, 0.0762, 0.6437, 0.8028],
        [0.0500, 0.2858, 0.0505, 0.1506]], requires_grad=True)
fsl.threshold Parameter containing:
tensor(1.3287, requires_grad=True)
fsl.beta Parameter containing:
tensor(0.7791, requires_grad=True)
fsl.membrane_min Parameter containing:
tensor(0.4176, requires_grad=True)
sll.weight Parameter containing:
tensor([[0.3084, 0.6793, 0.5332, 0.1823],
        [0.4148, 0.0708, 0.5344, 0.0545],
        [0.0500, 0.1883, 0.0500, 0.0823]], requires_grad=True)
ssl.threshold Parameter containing:
tensor(-0.0929, requires_grad=True)
ssl.beta Parameter containing:
tensor(0.2737, requires_grad=True)
ssl.membrane_min Parameter containing:
tensor(-0.4422, requires_grad=True)
