In [1]:
import torch, torch.nn as nn
import snntorch as snn

# === Modello SNN ===
class SAE(nn.Module):
    def __init__(self,num_inputs,num_hidden, num_outputs, num_steps=25,beta=0.95):
        super().__init__()
        
        self.num_inputs = num_inputs
        self.num_hidden = num_hidden
        self.num_outputs = num_outputs
        self.num_steps = num_steps
        self.beta = beta
        
        self.fc1 = nn.Linear(self.num_inputs,self.num_hidden)
        self.lif1 = snn.Leaky(beta=self.beta)
        self.fc2 = nn.Linear(self.num_hidden, self.num_outputs)
        self.lif2 = snn.Leaky(beta=self.beta)

    def forward(self, x):

        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        
        spk2_rec = []
        mem2_rec = []

        for step in range(self.num_steps):
            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)

In [2]:
import pandas as pd

df = pd.read_csv('C:\\Users\\anton\\Documents\\PhD\\Spiking\\PotentialField_Sim\\simulation_data\\simulation_log.csv')
sensor_df = df[[col for col in df.columns if 'sensor' in col]]  
sensor_df = pd.concat([df["run"], sensor_df], axis=1)
# sensor_df

In [3]:
class utils:
    @staticmethod
    def preprocess(image):
        return cv2.resize(image, (512, 512))

    @staticmethod
    def low_res(image):
        return utils.preprocess(image)[0::12, 0::12]

    @staticmethod
    def preprocess_predict(image):
        return np.expand_dims(utils.preprocess(tf.keras.applications.vgg19.preprocess_input(image)), axis=0)
    
    @staticmethod
    def acquire_image(path):
        return cv2.imread(path,0)
    
    @staticmethod
    def preprocess_2828(image):
        return cv2.resize(image, (28, 28))

In [None]:
import cv2 
import numpy as np
import os
from tqdm import tqdm   

img_path = "C:\\Users\\anton\\Documents\\PhD\\Spiking\\PotentialField_Sim\\simulation_data\\"
dirs = os.listdir(img_path)
dirs.sort()

tot_images = []

for item in dirs:
    if item.endswith(".csv"):
        continue
    else:
        run_path = os.path.join(img_path, item)
        images = []

        for img_file in tqdm(os.listdir(run_path)):
            if img_file.endswith(".png"):
                image = utils.acquire_image(os.path.join(run_path, img_file))
                low_res_image = utils.preprocess_2828(image)
                images.append(low_res_image)
        tot_images.append(np.array(images))


100%|██████████| 994/994 [00:00<00:00, 6040.95it/s]
100%|██████████| 1602/1602 [00:00<00:00, 6011.08it/s]
100%|██████████| 1968/1968 [00:00<00:00, 5878.05it/s]
100%|██████████| 1576/1576 [00:00<00:00, 5686.56it/s]
100%|██████████| 1561/1561 [00:00<00:00, 6012.72it/s]
100%|██████████| 1425/1425 [00:00<00:00, 6323.74it/s]
100%|██████████| 2025/2025 [00:00<00:00, 6249.22it/s]
100%|██████████| 3354/3354 [00:00<00:00, 6374.09it/s]
100%|██████████| 1068/1068 [00:00<00:00, 6477.68it/s]
100%|██████████| 1333/1333 [00:00<00:00, 6226.25it/s]
100%|██████████| 2303/2303 [00:00<00:00, 6321.98it/s]
100%|██████████| 1610/1610 [00:00<00:00, 6437.35it/s]
100%|██████████| 1973/1973 [00:00<00:00, 6433.67it/s]
100%|██████████| 2641/2641 [00:00<00:00, 5811.99it/s]
100%|██████████| 1777/1777 [00:00<00:00, 6088.67it/s]
100%|██████████| 2485/2485 [00:00<00:00, 6055.97it/s]
100%|██████████| 2673/2673 [00:00<00:00, 5887.60it/s]
100%|██████████| 2251/2251 [00:00<00:00, 5973.13it/s]
100%|██████████| 1233/1233 [00

In [96]:
run = []
for i in range(len(tot_images)):
    for n in range(len(tot_images[i])):
        run.append(i)
run = np.array(run)

In [None]:
X_img = np.concatenate(tot_images, axis=0)

np.save('X_img.npy', X_img)
np.save('run.npy', run)

In [25]:
import numpy as np
X_img = np.load('X_img.npy')
run = np.load('run.npy')

In [7]:
def train(model,train_dataloader,val_dataloader,epochs,loss_fn,optimizer):
    
    train_losses = []
    val_losses = []

    for e in range(epochs):
        train_batch = iter(train_dataloader)
    
        train_loss_epoch = 0
        val_loss_epoch = 0
    
        for data, targets in tqdm(train_batch):
            data = data.to(device)
            targets = targets.to(device)
        
            model.train()
            spk_rec, mem_rec = model(data.view(len(data), -1))

            loss_val = torch.zeros((1), dtype=dtype, device=device)
            for step in range(model.num_steps):
                loss_val += loss_fn(mem_rec[step], targets)


            loss_val = loss_val / model.num_steps
                #spk_rec
         
            train_loss_epoch += loss_val.item()

            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()

            with torch.no_grad():
                model.eval()
                val_data, val_targets = next(iter(val_dataloader))
                val_data = val_data.to(device)
                val_targets = val_targets.to(device)

                val_spk, val_mem = model(val_data.view(len(val_data), -1))

                val_loss = torch.zeros((1), dtype=dtype, device=device)
                for step in range(model.num_steps):
                    val_loss += loss_fn(val_mem[step], val_targets)
                    #val_spk

                val_loss = val_loss / model.num_steps
          
                #recon = torch.mean(val_mem,axis=0)
                #val_loss = loss_fn(recon,val_targets)
        
            val_loss_epoch += val_loss.item()
        
        print(f"Train loss at epoch: {e+1}: {train_loss_epoch}")
        print(f"Val loss at epoch: {e+1}: {val_loss_epoch}")
    
        train_losses.append(train_loss_epoch)
        val_losses.append(val_loss_epoch)
    
    return train_losses,val_losses

In [8]:
X_img_tensor = torch.tensor(X_img).unsqueeze(1)  # [N, 1, 28, 28]
X_img_tensor = X_img_tensor.view(len(X_img_tensor), -1)  # [N, 784]
print("Immagini caricate:", X_img_tensor.shape)
X_sensors = sensor_df.iloc[:,1:].to_numpy()
X_sensors = torch.tensor(X_sensors, dtype=torch.float32)
print("Sensori caricati:", X_sensors.shape)


Immagini caricate: torch.Size([37360, 784])
Sensori caricati: torch.Size([37360, 16])


In [14]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler    

In [77]:
scaler = MinMaxScaler(feature_range=(-1, 1))
X_scaled = scaler.fit_transform(X_img_tensor)


X_sensor_tensor = torch.tensor(sensor_df.iloc[:,1:].to_numpy(), dtype=torch.float32)
# X_sensor_tensor.shape


X_img_tensor_scaled = torch.tensor(X_scaled, dtype=torch.float32)
X_total = torch.concat((X_img_tensor_scaled, X_sensor_tensor), axis=-1)

run_elements = np.bincount(run)
run_elements
test_run = np.sum(run_elements[-2:])
val_run = np.sum(run_elements[-4:-2])
train_run = np.sum(run_elements[:-4])

#  Split 80/20
N_train = train_run
N_val = val_run
N_test = test_run

X_train = X_total[:N_train]
X_val = X_total[N_train:N_train+N_val]
X_test = X_total[N_train+N_val:N_train+N_val+N_test]

In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32

In [82]:
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
batch_size = 256

train_dataset = TensorDataset(X_train, X_train)
test_dataset = TensorDataset(X_test, X_test)

# CREA DATALOADER
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


# SpikingAE = SAE(X_train.shape[-1],500,num_steps=30)
SpikingAE = SAE(num_inputs= (28*28) + 16 , num_hidden=500, num_outputs= (28*28)+16, num_steps=25, beta=0.95)
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(SpikingAE.parameters(), lr=1e-3, betas=(0.9, 0.999))
train_losses, val_losses = train(SpikingAE,train_loader,test_loader,50,loss_fn,optimizer)

100%|██████████| 116/116 [00:49<00:00,  2.35it/s]


Train loss at epoch: 1: 33.23077738285065
Val loss at epoch: 1: 28.890399426221848


100%|██████████| 116/116 [00:49<00:00,  2.32it/s]


Train loss at epoch: 2: 21.771459475159645
Val loss at epoch: 2: 21.289051726460457


100%|██████████| 116/116 [00:50<00:00,  2.28it/s]


Train loss at epoch: 3: 22.19031612575054
Val loss at epoch: 3: 21.69804534316063


100%|██████████| 116/116 [00:48<00:00,  2.38it/s]


Train loss at epoch: 4: 21.256640881299973
Val loss at epoch: 4: 21.33943897485733


100%|██████████| 116/116 [00:50<00:00,  2.31it/s]


Train loss at epoch: 5: 20.58378331363201
Val loss at epoch: 5: 20.54596820473671


100%|██████████| 116/116 [00:49<00:00,  2.34it/s]


Train loss at epoch: 6: 19.843179687857628
Val loss at epoch: 6: 19.803601056337357


100%|██████████| 116/116 [00:46<00:00,  2.48it/s]


Train loss at epoch: 7: 19.371784642338753
Val loss at epoch: 7: 19.198521986603737


 13%|█▎        | 15/116 [00:05<00:34,  2.92it/s]


KeyboardInterrupt: 