In [1]:
import os
import math
import time
import copy
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import mpn 

class SequentialMNIST(torch.utils.data.Dataset):
    """
    Returns:
      x_seq: (T, 1) where T=784 by default (pixel-by-pixel)
      y: int label
    """
    def __init__(self, root, train=True, download=True, normalize=True):
        tfms = [transforms.ToTensor()]
        self.ds = datasets.MNIST(root=root, train=train, download=download, transform=transforms.Compose(tfms))
        self.normalize = normalize

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        x, y = self.ds[idx]             
        x = x.view(-1)                  
        if self.normalize:
            x = (x - 0.1307) / 0.3081
        x_seq = x.unsqueeze(-1)        
        return x_seq, y


def collate_seq(batch):
    """
    batch: list of (T,1), y
    returns:
      x: (B,T,1)
      y: (B,)
    """
    xs, ys = zip(*batch)
    x = torch.stack(xs, dim=0)
    y = torch.tensor(ys, dtype=torch.long)
    return x, y


In [2]:
@torch.no_grad()
def evaluate(net, loader, device):
    net.eval()
    correct = 0
    total = 0
    for x, y in loader:
        x = x.to(device) 
        y = y.to(device)

        B, T, D = x.shape
        net.reset_state(B=B)

        out = None
        for t in range(T):
            out, _, _ = net.network_step(x[:, t, :], run_mode="minimal", seq_idx=t)

        pred = out.argmax(dim=-1)
        correct += (pred == y).sum().item()
        total += y.numel()

    return correct / max(total, 1)

def count_parameter(net):
    """
    """
    trainable = [(n, p) for n, p in net.named_parameters() if p.requires_grad]
    
    # Print a readable summary
    total = 0
    for n, p in trainable:
        num = p.numel()
        total += num
        print(f"{n:50s}  shape={tuple(p.shape)}  numel={num}")
    
    print(f"\nTotal trainable parameters: {total}")

def train_sequential_mnist(
    device="cuda",
    data_root="./data",
    hidden_dim=256,
    batch_size=64,
    lr=1e-3,
    epochs=5,
    mpn_depth=5
):
    device = torch.device(device if torch.cuda.is_available() else "cpu")

    # Data
    train_ds = SequentialMNIST(root=data_root, train=True, download=True, normalize=True)
    test_ds  = SequentialMNIST(root=data_root, train=False, download=True, normalize=True)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2,
                              pin_memory=True, collate_fn=collate_seq)
    test_loader  = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2,
                              pin_memory=True, collate_fn=collate_seq)

    net_params = {
        "n_neurons": [1] + [hidden_dim] * mpn_depth + [10],
        "linear_embed": 100, 
        "dt": 1.0,
        "activation": "tanh",            
        "output_bias": True,
        "W_output_init": "xavier",
        "input_layer_add": True, 
        'input_layer_add_trainable': True, 
        'input_layer_bias': False, 
        "output_matrix": "", 

        'ml_params': {
            'bias': True, # Bias of layer
            'mp_type': 'mult',
            'm_update_type': 'hebb_assoc', # hebb_assoc, hebb_pre
            'eta_type': 'matrix', # scalar, pre_vector, post_vector, matrix
            'eta_train': True,
            'lam_type': 'matrix', # scalar, pre_vector, post_vector, matrix
            'm_time_scale': 100, 
            'lam_train': True,
            'W_freeze': False, # different combination with [input_layer_add_trainable]
        },
    }

    net = mpn.DeepMultiPlasticNet(net_params, verbose=True, forzihan=False).to(device)
    count_parameter(net)

    # Optimizer / loss
    opt = torch.optim.Adam([p for p in net.parameters() if p.requires_grad], lr=lr)
    criterion = nn.CrossEntropyLoss()

    # Train
    for ep in range(1, epochs + 1):
        net.train()
        t0 = time.time()
        running_loss = 0.0
        n_batches = 0

        for x, y in train_loader:
            x = x.to(device, non_blocking=True)   # (B,T,1)
            y = y.to(device, non_blocking=True)

            B, T, D = x.shape
            net.reset_state(B=B)

            opt.zero_grad(set_to_none=True)

            out = None
            
            for t in range(T):
                out, _, _ = net.network_step(x[:, t, :], run_mode="minimal", seq_idx=t)

            loss = criterion(out, y)
            loss.backward()

            if (n_batches % 100) == 0:
                mpl = net.mp_layers[0]
                with torch.no_grad():
                    M_abs_mean = mpl.M.abs().mean().item()
                    M_abs_max  = mpl.M.abs().max().item()
                    W_abs_mean = mpl.W.abs().mean().item()
                print(f"[ep {ep} | batch {n_batches}] loss={loss.item():.4f} "
                      f"|M| mean={M_abs_mean:.3e} max={M_abs_max:.3e} |W| mean={W_abs_mean:.3e}")


            opt.step()
            net.param_clamp()

            running_loss += float(loss.item())
            n_batches += 1

        train_loss = running_loss / max(n_batches, 1)
        test_acc = evaluate(net, test_loader, device=device)
        dt = time.time() - t0

        print(f"Epoch {ep:02d}/{epochs} | loss={train_loss:.4f} | test_acc={test_acc*100:.2f}% | time={dt:.1f}s")

    return net

In [3]:
trained_net = train_sequential_mnist(
                device="cuda",
                hidden_dim=256,
                batch_size=64,
                lr=1e-3,
                epochs=20,
                mpn_depth=5
            )

MultiPlastic Net:
  output neurons: 10
  Act: tanh

1.0
  MP Layer1 parameters:
    n_neurons - input: 100, output: 128
    M matrix parameters:    update bounds - Max mult: 1.0, Min mult: -1.0
      type: mult // Update - type: hebb_assoc // Act fn: linear
      Eta: matrix (train) // Lambda: matrix (train) // Lambda_max: 0.99 (tau: 1.0e+02)
  MP Layer2 parameters:
    n_neurons - input: 128, output: 128
    M matrix parameters:    update bounds - Max mult: 1.0, Min mult: -1.0
      type: mult // Update - type: hebb_assoc // Act fn: linear
      Eta: matrix (train) // Lambda: matrix (train) // Lambda_max: 0.99 (tau: 1.0e+02)
  MP Layer3 parameters:
    n_neurons - input: 128, output: 128
    M matrix parameters:    update bounds - Max mult: 1.0, Min mult: -1.0
      type: mult // Update - type: hebb_assoc // Act fn: linear
      Eta: matrix (train) // Lambda: matrix (train) // Lambda_max: 0.99 (tau: 1.0e+02)
  No Hidden Recurrency.
W_output                                            s

KeyboardInterrupt: 