In [1]:
import os
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

import numpy as np

import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from PIL import Image
import base64
from IPython import display as dd

import norse.torch.functional.stdp as stdp
from norse.torch import PoissonEncoder
from norse.torch import LIFParameters, LIFFeedForwardState
from norse.torch.module.lif import LIFCell
from norse.torch.functional.stdp import STDPState, STDPParameters

In [2]:
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

train_data = torchvision.datasets.MNIST(
    root="./data",
    train=True,
    download=True,
    transform=transform,
)

test_data = torchvision.datasets.MNIST(
    root="./data",
    train=False,
    download=True,
    transform=transform,
)

In [3]:
def get_k_winnners_(z: torch.Tensor, k: int):
    """
    スパイク活性を制限するために、入力テンソルの各サンプルについて、
    k個以上のスパイクがある場合にランダムにk個を選択し、それ以外のスパイクを0に変更します。

    Args:
        z (torch.Tensor): ニューロンの出力スパイク活性を表すテンソル。
        k (int): 選択するスパイクの最大数。

    Returns:
        None: 元のテンソルを変更し、返り値はありません。
    """
    for i in range(z.size(0)):
        indices_of_ones = torch.nonzero(z[i]).squeeze(1)
        if len(indices_of_ones) > k:
            random_index = torch.randint(indices_of_ones.size(0), (k,))
            z[i] = torch.zeros_like(z[i])
            for j in random_index:
                z[i][indices_of_ones[j]] = 1.0
    return


def win_take_all_(state: LIFFeedForwardState, z: torch.Tensor, w: torch.Tensor):
    """
    Win-Take-Allアルゴリズムに基づいて、ニューロンの入力電流を更新します。

    Args:
        state (LIFFeedForwardState): LIFニューロンの状態を表すオブジェクト。
        z (torch.Tensor): ニューロンの出力スパイク活性を表すテンソル。
        w (torch.Tensor): 側方抑制の重みを表すテンソル。

    Returns:
        None: 元の状態オブジェクトを変更し、返り値はありません。
    """
    state.i.data = state.i + w * (z - torch.sum(z, dim=1, keepdim=True))
    return

In [4]:
class MNIST(nn.Module):
    def __init__(self, input_neurons: int, excitatory_neurons: int):
        super(MNIST, self).__init__()
        self.input_neurons = input_neurons
        self.excitatory_neurons = excitatory_neurons

        # 入力層から興奮層への全結合(weight: 0~1)
        self.liner_inp_exc = nn.Linear(input_neurons, excitatory_neurons, bias=False)
        nn.init.normal_(self.liner_inp_exc.weight, mean=0.3, std=0.05)
        # 興奮層のニューロン
        self.lifcell_exc = LIFCell(
            p=LIFParameters(tau_mem_inv=torch.as_tensor(1.0 / 50e-3))
        )
        # STDP用のパラメータ
        self.stdp_parameter = STDPParameters(
            eta_plus=torch.as_tensor(1e-3),
            eta_minus=torch.as_tensor(1e-4),
            stdp_algorithm="additive_step",
            hardbound=False,
        )

    def forward(self, x: torch.Tensor, train: bool):
        time, batch_size, _, _, _ = x.shape
        state_exc = None

        with torch.no_grad():
            if train:
                state_stdp = STDPState(
                    t_pre=torch.zeros(self.input_neurons, device=x.device),
                    t_post=torch.zeros(self.excitatory_neurons, device=x.device),
                )
                for t in range(time):
                    z_pre = x[t, :, :, :].view(-1, self.input_neurons)
                    z = self.liner_inp_exc(z_pre)
                    z_post, state_exc = self.lifcell_exc(z, state_exc)
                    # get_k_winnners_(z=z_post, k=1)
                    win_take_all_(state=state_exc, z=z_post, w=torch.tensor(40.0))

                    w, state_stdp = stdp.stdp_step_linear(
                        z_pre=z_pre,
                        z_post=z_post,
                        w=self.liner_inp_exc.weight,
                        state_stdp=state_stdp,
                        p_stdp=self.stdp_parameter,
                    )
                    w = w / torch.max(w, dim=1, keepdim=True)[0]
                    self.liner_inp_exc.weight.data = w
                return w

            else:
                firing_rate = torch.zeros(
                    batch_size, self.excitatory_neurons, device=x.device
                )
                for t in range(time):
                    z = x[t, :, :, :].view(-1, self.input_neurons)
                    z = self.liner_inp_exc(z)
                    z, state_exc = self.lifcell_exc(z, state_exc)
                    firing_rate += z
                return firing_rate

    def save(self, file_path: str = "./mnist_model_parameters.pth"):
        torch.save(self.state_dict(), file_path)

    def load(self, file_path: str = "./mnist_model_parameters.pth"):
        self.load_state_dict(torch.load(file_path))

In [8]:
def visualize_tensor_as_grid(
    tensor: torch.Tensor,
    grid_rows: int,
    grid_cols: int,
    image_size: int = 28,
    save_path: str = None,
):
    """
    テンソルを指定された行数と列数のグリッドに変換して保存する関数。

    Args:
        tensor (torch.Tensor): 表示するテンソル
        grid_rows (int): グリッドの行数
        grid_cols (int): グリッドの列数
        image_size (int, optional): 画像のサイズ。デフォルトは28。
        save_path (str, optional): 画像を保存するパス。デフォルトはNone。

    Returns:
        None
    """
    # 勾配を切り離してテンソルをnumpy配列に変換
    tensor = tensor.to("cpu").detach().numpy()

    # 画像を配置するキャンバスの初期化
    canvas = np.zeros((grid_rows * image_size, grid_cols * image_size))

    # テンソルを画像に変換し、キャンバスに配置
    for i in range(grid_rows):
        for j in range(grid_cols):
            image = tensor[i * grid_cols + j].reshape(image_size, image_size)
            canvas[
                i * image_size : (i + 1) * image_size,
                j * image_size : (j + 1) * image_size,
            ] = image

    # カスタムカラーマップの作成
    cmap = LinearSegmentedColormap.from_list(
        "custom_cmap", [(0, "black"), (1, "green")]
    )

    # 画像を保存
    plt.figure(figsize=(5, 5))
    plt.imshow(canvas, cmap=cmap)
    plt.axis("off")
    if save_path:
        plt.savefig(save_path, bbox_inches="tight")
    plt.close()

In [5]:
seed = 0
gpu = True

# GPU使用の設定
if gpu and torch.cuda.is_available():
    device = torch.device("cuda")
    torch.cuda.manual_seed_all(seed)
else:
    device = torch.device("cpu")
    torch.manual_seed(seed)
    torch.set_num_threads(os.cpu_count() - 1)

In [None]:
# オンライン学習
BATCH_SIZE = 1
TIME = 300
interval = 10
train_loader = torch.utils.data.DataLoader(
    train_data, batch_size=BATCH_SIZE, shuffle=False
)
encoder = PoissonEncoder(TIME, f_max=64)

model = MNIST(28 * 28, 10 * 10)
model.to(device)

for i, (data, target) in enumerate(tqdm(train_loader, desc="Training")):
    data = encoder(data).to(device)
    w = model(data, True)
    if i < 70000 and i % interval == 0:
        visualize_tensor_as_grid(
            w, 10, 10, save_path=f"./weight/image_{int(i/interval)}.png"
        )
    if i % 10000 == 0:
        model.save("./mnist_model_parameters.pth")

model.save("./mnist_model_parameters.pth")

In [11]:
class LabelAssignment:
    def __init__(
        self,
        n_labels: int,
        n_neurons: int,
        device: str,
    ):
        self.n_labels = n_labels
        self.n_neurons = n_neurons
        self.firing_rates = torch.zeros((n_neurons, n_labels), device=device)
        self.assignments = torch.zeros((n_neurons, n_labels), device=device)

SyntaxError: invalid syntax (2770953460.py, line 2)

In [10]:
# ラベル付け
BATCH_SIZE = 3
TIME = 300
train_loader = torch.utils.data.DataLoader(
    train_data, batch_size=BATCH_SIZE, shuffle=False
)
encoder = PoissonEncoder(TIME, f_max=64)


model = MNIST(28 * 28, 10 * 10)
model.load("./mnist_model_parameters.pth")
model.to(device)
for i, (data, target) in enumerate(train_loader):
    data = encoder(data).to(device)
    firing_rate = model(data, False)
    print(firing_rate)
    print(target)
    if i == 0:
        break

tensor([[ 77.,  79.,  76.,  78.,  75.,  81.,  62.,  77.,  72.,  58.,  91.,  89.,
          53.,  74.,  80.,  63.,  70.,  72.,  87.,  74.,  59.,  73.,  94.,  68.,
          70.,  54.,  78.,  76.,  88.,  90.,  78.,  71.,  93., 105.,  67.,  77.,
          60.,  66.,  92.,  72.,  77.,  74.,  85.,  68.,  46.,  80.,  89.,  56.,
          59.,  54.,  74.,  75.,  61.,  98.,  99.,  74.,  98.,  74.,  61.,  73.,
          95.,  66.,  57.,  74.,  56., 100.,  75.,  82.,  90.,  84.,  58.,  78.,
          53.,  66.,  85.,  78.,  59.,  85.,  85.,  68.,  77.,  96.,  76.,  81.,
          89.,  75.,  68.,  61.,  73.,  68.,  99.,  93.,  74.,  68.,  70.,  92.,
          57.,  76.,  83.,  97.],
        [ 74., 118.,  81.,  70.,  70.,  93.,  70.,  71.,  77.,  53.,  69.,  79.,
          53.,  69.,  84.,  71.,  71.,  67.,  79.,  75.,  62.,  69., 142.,  66.,
          76.,  48.,  81.,  79.,  88.,  72.,  63.,  62.,  96., 103.,  84.,  74.,
          59.,  60.,  71.,  75.,  83.,  79.,  78.,  75.,  51., 124.,  85., 

In [None]:
# テスト
BATCH_SIZE = 10
TIME = 300
test_loader = torch.utils.data.DataLoader(
    test_data,
    batch_size=BATCH_SIZE,
)
encoder = PoissonEncoder(TIME, f_max=64)