In [None]:
from lava.src.lib.dl.netx import hdf5
import os
import glob
import zipfile
import h5py
import numpy as np
import matplotlib.pyplot as plt
import typing as ty
import torch
from torch.utils.data import Dataset, DataLoader
import lava.lib.dl.slayer as slayer

In [100]:
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

# from mnist lava tutorial
def augment(event):
    x_shift = 4
    y_shift = 4
    theta = 10
    xjitter = np.random.randint(2*x_shift) - x_shift
    yjitter = np.random.randint(2*y_shift) - y_shift
    ajitter = (np.random.rand() - 0.5) * theta / 180 * 3.141592654
    sin_theta = np.sin(ajitter)
    cos_theta = np.cos(ajitter)
    event.x = event.x * cos_theta - event.y * sin_theta + xjitter
    event.y = event.x * sin_theta + event.y * cos_theta + yjitter
    return event


class CIFARDataset(Dataset):
    """CIFAR dataset method

    Parameters
    ----------
    path : str, optional
        path of dataset root, by default 'data'
    train : bool, optional
        train/test flag, by default True
    sampling_time : int, optional
        sampling time of event data, by default 1
    sample_length : int, optional
        length of sample data, by default 300
    transform : None or lambda or fx-ptr, optional
        transformation method. None means no transform. By default Noney.
    download : bool, optional
        enable/disable automatic download, by default True
    """
    def __init__(
        self, path='data',
        train=True,
        sampling_time=20, sample_length=8000,
        transform=None, download=True,
    ):
        super(CIFARDataset, self).__init__()
        self.path = path
        self.classes, self.class_to_idx = self._find_classes()
        if train:
            data_path = "/content/events_np/content/CIFAR/events_np"


        self.samples = self._make_dataset()

        self.sampling_time = sampling_time
        self.num_time_bins = int(sample_length/sampling_time)
        self.transform = transform
    def _find_classes(self):
        # Find the class folders in your dataset
        classes = [d for d in os.listdir(self.path) if os.path.isdir(os.path.join(self.path, d))]
        classes.sort()
        class_to_idx = {classes[i]: i for i in range(len(classes))}
        return classes, class_to_idx

    def _make_dataset(self):
        # Create a list of file paths and their corresponding labels
        samples = []
        for target_class in self.classes:
            class_index = self.class_to_idx[target_class]
            target_dir = os.path.join(self.path, target_class)
            for root, _, fnames in sorted(os.walk(target_dir)):
                for fname in fnames:
                    path = os.path.join(root, fname)
                    item = (path, class_index)
                    samples.append(item)
        return samples

    def __getitem__(self, i):
        filename, label = self.samples[i]
        ev = np.load(filename)
        event = slayer.io.Event(t_event = ev['t'], y_event = ev['y'], x_event = ev['x'], c_event = np.zeros_like(ev['x']))
        
        if self.transform is not None:
            event = self.transform(event)

        spike = event.fill_tensor(
                torch.zeros(2, 128, 128, self.num_time_bins),
                sampling_time=self.sampling_time,
            )
        spike_reshaped = np.moveaxis(np.array(spike), 1, -1)
        spike_reshaped = np.moveaxis(spike_reshaped, 1, -1)
        # Perform average pooling only along the last two dimensions 
        spike_pooled = F.avg_pool2d(torch.tensor(spike_reshaped), kernel_size=7, stride=7, padding=0).numpy()

        # Reshape back to the original shape
        spike_pooled = np.moveaxis(spike_pooled, -1, 1)
        spike_pooled = np.moveaxis(spike_pooled, -1, 1)

        spike_pooled = spike_pooled.reshape(-1, np.shape(spike_pooled)[3])
        spike_pooled =  torch.tensor(np.where(spike_pooled > 0, 1.0, 0.0), dtype=torch.float)
        return spike_pooled, label

    def __len__(self):
        return len(self.samples)

In [101]:
root_folder = "events_np/content/events_np/content/CIFAR/events_np"
spike_dataset = CIFARDataset(root_folder,  train=True, transform = augment)

In [18]:
testing_set = torch.load("test_set.pt")

In [195]:
inputs = []
labels = np.zeros(1000)
for i in range(1000): 
    input, label = testing_set[i]
    
    inputs.append(input.numpy())
    labels[i] = label

labels = labels.tolist()

In [105]:
# Process primitives, used for creating the structure of the entire 
# process and connecting all the modules together. ProcessModels
# specify the actual technicalities of how everything works,
# and these are backend-specific (we deal with this later)

from lava.magma.core.process.process import AbstractProcess
from lava.magma.core.process.variable import Var
from lava.magma.core.process.ports.ports import InPort, OutPort

In [196]:
# Takes the output of the model and outputs a classification prediction
# based on which neuron has the highest spiking rate

class OutputProcess(AbstractProcess):
    def __init__(self, num_images):
        super().__init__()
        shape = (10,)
        
        # Creating Vars, InPorts and OutPorts (each process has these!)
        
        self.spikes_in = InPort(shape=shape)
        self.label_in = InPort(shape=(1,))
        
        self.num_images = Var(shape=(1,), init=num_images)
        # Place for acculumating spikes over the time period
        self.spikes_accum = Var(shape=shape)
        # Each image has 300 timepoints in the training data
        self.num_steps_per_image = Var(shape=(1,), init=400)
        self.pred_labels = Var(shape=(num_images,))
        # Labels for the ground truth
        self.gt_labels = Var(shape=(num_images,))

# Feeds the input to the model.

class InputProcess(AbstractProcess):
    def __init__(self, num_images, num_steps_per_image):
        super().__init__()
        shape = (648,)

        # OutPorts
        self.spikes_out = OutPort(shape=shape)
        self.label_out = OutPort(shape=(1,))

        # Vars
        self.num_images = Var(shape=(1,), init=num_images)
        self.num_steps_per_image = Var(shape=(1,), init=num_steps_per_image)
        self.input_img = Var(shape=shape+(400,))# CHANGEE
        print(shape+(1,))
        self.ground_truth_label = Var(shape=(1,))
        

In [107]:
# Import parent classes for ProcessModels
from lava.magma.core.model.sub.model import AbstractSubProcessModel
from lava.magma.core.model.py.model import PyLoihiProcessModel

# Import ProcessModel ports, data-types
from lava.magma.core.model.py.ports import PyInPort, PyOutPort
from lava.magma.core.model.py.type import LavaPyType

# Import execution protocol and hardware resources
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol
from lava.magma.core.resources import CPU, Loihi2NeuroCore

# Import decorators
from lava.magma.core.decorator import implements, requires

In [201]:
# Output ProcessModel

@implements(proc=OutputProcess, protocol=LoihiProtocol)
@requires(Loihi2NeuroCore)
class PyOutputProcessModel(PyLoihiProcessModel):
    label_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, int, precision=32)
    spikes_in: PyInPort = LavaPyType(PyInPort.VEC_DENSE, bool, precision=1)
    num_images: int = LavaPyType(int, int, precision=32)
    spikes_accum: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=32)
    num_steps_per_image: int = LavaPyType(int, int, precision=32)
    pred_labels: np.ndarray = LavaPyType(np.ndarray, int, precision=32)
    gt_labels: np.ndarray = LavaPyType(np.ndarray, int, precision=32)
    
    def __init__(self, proc_params):
        super().__init__(proc_params=proc_params)
        self.current_img_id = 0 # Used to iterate through examples
        
    def post_guard(self):
        ''' Used during PostManagement, determines if an image
            has just finished being passed through'''
        #print(self.time_step)
        return self.time_step % self.num_steps_per_image == 0 and \
                self.time_step > 1

    def run_post_mgmt(self):
        ''' Function executed after post_guard returns True'''
        # Storing prediction and ground truths
        gt_label = self.label_in.recv() 
        pred_label = np.argmax(self.spikes_accum)
        self.gt_labels[self.current_img_id] = gt_label # Indexing into Process Vars
        self.pred_labels[self.current_img_id] = pred_label

        # Setting up for next image
        self.current_img_id += 1
        self.spikes_accum = np.zeros_like(self.spikes_accum)

    def run_spk(self):
        ''' Runs at every timepoint; getting spikes from the forward pass
            and accumulating'''
        print("there")
        spk_in = self.spikes_in.recv()
        print(spk_in.shape)
        print("gottit")
        self.spikes_accum = self.spikes_accum + spk_in
        print(self.spikes_accum.shape)
# Input ProcessModel

@implements(proc=InputProcess, protocol=LoihiProtocol)
@requires(Loihi2NeuroCore)
class PySpikeInputModel(PyLoihiProcessModel):
    num_images: int = LavaPyType(int, int, precision=32)
    spikes_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, bool, precision=1)
    label_out: PyOutPort = LavaPyType(PyOutPort.VEC_DENSE, np.int32,
                                      precision=32)
    num_steps_per_image: int = LavaPyType(int, int, precision=32)
    input_img: np.ndarray = LavaPyType(np.ndarray, int, precision=32)
    ground_truth_label: int = LavaPyType(int, int, precision=32)
    
    def __init__(self, proc_params):
        super().__init__(proc_params=proc_params)
        self.input_data = inputs
        self.gt_labels = labels
        self.curr_img_id = 0
        self.curr_img_time = 0

    def post_guard(self):
        ''' PostManagement phase guard'''
        return self.time_step % self.num_steps_per_image == 1

    def run_post_mgmt(self):
        ''' Executed when post_guard returns True, i.e. after image 
            has finished being fed through. '''
        
        self.input_img = self.input_data[self.curr_img_id]
        
        self.ground_truth_label = self.gt_labels[self.curr_img_id]
        self.label_out.send(np.array([self.ground_truth_label]))
        self.curr_img_id += 1
        self.curr_img_time = 0

    def run_spk(self):
        ''' Spiking phase, executed at every timepoint '''
        
        self.spikes_out.send(self.input_img[:,self.curr_img_time])
        
        self.curr_img_time += 1
        
        

In [203]:
num_images = 10
num_steps_per_image = 4

input_process = InputProcess(num_images, num_steps_per_image)
# Loading the network as a Process
slayer_network = hdf5.Network(net_config='network.net')
print(slayer_network)
output_process = OutputProcess(num_images)

# Connecting Processes
input_process.spikes_out.connect(slayer_network.inp)
slayer_network.out_layer.out.connect(output_process.spikes_in)
# Connecting input and output processes to allow ground truth to flow
input_process.label_out.connect(output_process.label_in)

(648, 1)
|   Type   |  W  |  H  |  C  | ker | str | pad | dil | grp |delay|
|Dense     |    1|    1|  512|     |     |     |     |     |True |
|Dense     |    1|    1|  512|     |     |     |     |     |True |
|Dense     |    1|    1|   10|     |     |     |     |     |False|


In [None]:
# Running on one test image
from lava.magma.core.run_conditions import RunSteps
from lava.magma.core.run_configs import Loihi2SimCfg

for img_id in range(num_images):
    print(f"\rCurrent image: {img_id+1}", end="")
    input_process.run(
        condition=RunSteps(num_steps=num_steps_per_image),
        run_cfg=Loihi2SimCfg(select_sub_proc_model=True))
    print("done")
ground_truth = output_process.gt_labels.get().astype(np.int32)
predictions = output_process.pred_labels.get().astype(np.int32)

input_process.stop()
    

# Metrics

In [None]:
num_steps = 400
run_config = Loihi2HwCfg()
profiler = Profiler.init(run_config)

In [None]:
profiler.energy_probe(num_steps=num_steps)
profiler.activity_probe()
profiler.memory_probe()

In [None]:
# Execute Process lif_src and all Processes connected to it (dense, lif_dest).
lif_src.run(condition=RunSteps(num_steps=num_steps), run_cfg=run_config)
lif_src.stop()

In [None]:
print(f"Total execution time: {np.round(np.sum(profiler.execution_time), 6)} s")
print(f"Total power: {np.round(profiler.power, 6)} W") 
print(f"Total energy: {np.round(profiler.energy, 6)} J")
print(f"Static energy: {np.round(profiler.static_energy, 6)} J") 

In [None]:
profiler.power_breakdown()

In [None]:
profiler.energy_breakdown()

In [None]:
print(f"Total execution time: {np.round(np.sum(profiler.execution_time), 6)} s")