In [None]:
import ExpNeuron as en

import torch
import torch.nn as nn
import snntorch as snn
import snntorch.functional as snnfunc
import snntorch.spikegen as snngen
from torch.utils.data import DataLoader, TensorDataset, random_split
import pandas as pd

In [None]:
class Net(nn.Module):
	def __init__(self, 
			beta_first = 0.8, threshold_first = 0.7, membrane_zero_first = 0.2, membrane_min_first = 0.3,
			learn_beta_first = False, learn_threshold_first = False, learn_membrane_min_first = False,
			beta_second = 0.8, threshold_second = 0.5, membrane_zero_second = 0.2, membrane_min_second = 0.3,
			learn_beta_second = False, learn_threshold_second = False, learn_membrane_min_second = False
			):
		super().__init__()
		self.fll = nn.Linear(4, 4, bias = False)
		self.fsl = en.ExpNeuron(beta = beta_first, threshold = threshold_first,
								   	membrane_zero= membrane_zero_first,
									membrane_min = membrane_min_first,
									learn_beta= learn_beta_first, 
									learn_threshold= learn_threshold_first,
									learn_membrane_min= learn_membrane_min_first)
		
		self.sll = nn.Linear(4, 3, bias = False)
		self.ssl = en.ExpNeuron(beta = beta_second, threshold = threshold_second,
								   	membrane_zero= membrane_zero_second,
									membrane_min = membrane_min_second,
									learn_beta= learn_beta_second, 
									learn_threshold= learn_threshold_second,
									learn_membrane_min= learn_membrane_min_second)
		
	def forward(self, spk_input):
		mem1 = self.fsl.init_neuron()
		mem2 = self.ssl.init_neuron()

		spk_output = []
		for step in range(spk_input.shape[1]):
			cur1 = self.fll(spk_input.to(torch.float32)[:, step])
			spk1, mem1 = self.fsl(cur1, mem1)
			cur2 = self.sll(spk1)
			spk2, mem2 = self.ssl(cur2, mem2)
			
			spk_output.append(spk2)
		
		return torch.stack(spk_output)

In [None]:
file_path = "../iris_folder/Iris.csv"
df = pd.read_csv(filepath_or_buffer= file_path, sep= ",", header= 0)
df.drop("Id", axis= 1, inplace= True)

# преобразование датафрейма
species = df["Species"].unique()

# При присваивании происходит изменение типа данных на object
for i in range(len(species)):
	df.loc[df["Species"] == species[i], "Species"] = i
df["Species"] = df["Species"].astype("int")

columns_headers = df.columns

# Не учитываю species
for header in columns_headers[:-1]:
	# Нормализую знеачения
	df[header] = df[header] / df[header].max()
# создание датасета для обучения
data = []
num_steps = 100

batch_size = 3

for header in columns_headers:
	data.append(torch.tensor(df[header].values))

# Транспонируем тензор, чтобы иметь features и target каждого образца
data = torch.stack(data, dim = 0).T

trains = snngen.rate(data= data[:, :-1], num_steps= num_steps)

# labels = snngen.targets_rate(data[:, -1], num_classes=3)
labels = data[:, -1]
# Возможно полная хрень. Требует проверки
# Я уверен, что с осями какой-то косяк, так что придется переделывать при плохих результатах обучения
trains = trains.permute(1, 0, 2)
dataset = TensorDataset(trains, labels)
train_data, test_data = random_split(dataset, [0.8, 0.2])
train_data_loader = DataLoader(train_data, shuffle= True, batch_size=batch_size)
test_data_loader = DataLoader(test_data, shuffle= True)

In [None]:
epochs = 20

states = []
size = 5

params = [[torch.concatenate([torch.rand(3).clamp(min = 0.3), torch.Tensor([0.2])]) for i in range(2)] for _ in range(size)]

In [None]:
for i in range(size):
	print(f"Net numb: {i}. Parameters: {params[i]}")
	lrng_rt = 1e-2

	net = Net(
			beta_first = params[i][0][0], threshold_first = params[i][0][1], membrane_zero_first = params[i][0][2], membrane_min_first = params[i][0][3],
			learn_beta_first = False, learn_threshold_first = False, learn_membrane_min_first = True,
			beta_second = params[i][1][0], threshold_second = params[i][1][1], membrane_zero_second = params[i][1][2], membrane_min_second = params[i][1][3],
			learn_beta_second = False, learn_threshold_second = False, learn_membrane_min_second = True
		)

	optim = torch.optim.AdamW(net.parameters(), lr = lrng_rt)
	loss_fn = snnfunc.loss.mse_count_loss(correct_rate=0.3, incorrect_rate= 0.1, num_classes=3)

	for epoch in range(epochs):

		if epoch == 10: lrng_rt = 1e-3

		for trns, lbls in train_data_loader:
				optim.zero_grad()
				outputs = net(trns)

				loss = loss_fn(spk_out=outputs, targets= lbls.to(torch.long))
				loss.backward()
				optim.step()


		print((f"Epoch {epoch}, Loss: {loss.item():.4f}"))

		# Testing accuracy
		data = []
		for trn, lbl in test_data_loader:
				prediction = net(trn)
				prediction = torch.mean(input= prediction.unsqueeze(0), dim = 1)

				data.append([prediction.argmax().item(), lbl.item()])
				
		data = torch.tensor(data)
		accurate = data[(data[:, 0] == data[:, 1]), 0].numel()
		accuracy = accurate/data.shape[0]

		if accuracy >= 0.9:
			states.append(net.state_dict())

		print((f"accuracy : {accuracy:.2f}"))


In [None]:
acc = []
for state in states:
    load_net = Net()
    load_net.load_state_dict(state)

    # Testing accuracy
    data = []
    for trn, lbl in test_data_loader:
            prediction = load_net(trn)
            prediction = torch.mean(input= prediction.unsqueeze(0), dim = 1)

            data.append([prediction.argmax().item(), lbl.item()])
            
    data = torch.tensor(data)
    accurate = data[(data[:, 0] == data[:, 1]), 0].numel()
    accuracy = accurate/data.shape[0]

    print((f"accuracy : {accuracy:.2f}"))
    acc.append(accuracy)

In [None]:
acc = torch.tensor(acc)
max_acc_idx = torch.argmax(acc, dim = 0)

torch.save(states[max_acc_idx], "Iris_ExpNeuron_weights.pth")

In [None]:
# for par, val in net.named_parameters():
#     print(par, val)
