In [11]:
%cd /om2/user/valmiki/bioplnn

/rdma/vast-rdma/user/valmiki/bioplnn


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


## Imports

In [12]:
import os
import math
import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.profiler import profile, record_function, ProfilerActivity
import torchvision
from torchvision.datasets import CIFAR10, MNIST
from torchsparsegradutils import sparse_mm
from tqdm import tqdm

In [13]:
!nvidia-smi

Thu Feb  1 16:41:12 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.86.01    Driver Version: 515.86.01    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100 80G...  On   | 00000000:83:00.0 Off |                    0 |
| N/A   42C    P0    76W / 300W |    489MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Utils

In [14]:
class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.__dict__ = self


def print_mem_stats():
    f, t = torch.cuda.mem_get_info()
    print(f"Free/Total: {f/(1024**3):.2f}GB/{t/(1024**3):.2f}GB")


def get_dataloaders(config):
    # Load the MNIST dataset
    mnist_train = MNIST(
        root="./data",
        train=True,
        download=True,
        transform=torchvision.transforms.ToTensor(),
    )

    # Load the MNIST test dataset
    mnist_test = MNIST(
        root="./data",
        train=False,
        download=True,
        transform=torchvision.transforms.ToTensor(),
    )

    train_loader = DataLoader(
        dataset=mnist_train,
        batch_size=config.batch_size,
        shuffle=True,
        pin_memory=torch.cuda.is_available(),
    )

    test_loader = DataLoader(
        dataset=mnist_test,
        batch_size=config.batch_size,
        shuffle=False,
        pin_memory=torch.cuda.is_available(),
    )

    return train_loader, test_loader

## Parameters

In [22]:
config = AttrDict(
    # Model parameters
    num_neurons=10000,
    synapses_per_neuron=100,
    sheet_size=(100, 100),
    connectivity_std=10,
    num_timesteps=100,
    sheet_bias=True,
    sheet_mm_function=torch.sparse.mm,
    sheet_batch_first=False,
    model_dir="models",
    # Training parameters
    batch_size=16,
    optimizer=optim.SGD,
    lr=1e-3,
    criterion=nn.CrossEntropyLoss,
    log_freq=100,
    num_epochs=30,
    log_wandb=False,
)
try:
    os.mkdir(config.model_dir)  # type: ignore
except FileExistsError:
    pass

## Model

In [23]:
class CorticalSheet(nn.Module):
    def __init__(
        self,
        num_neurons,
        synapses_per_neuron,
        bias=True,
        mm_function=sparse_mm,
        batch_first=False,
        **kwargs
    ):
        super().__init__()
        # Save the sparse matrix multiplication function
        self.mm_function = mm_function
        self.batch_first = batch_first

        # Create a sparse tensor for the weight matrix
        indices = []

        # Create adjacency matrix with normal distribution randomized weights
        for i in range(num_neurons):
            synapses = torch.randint(0, num_neurons, (synapses_per_neuron,))
            synapse_root = torch.full_like(synapses, i)
            indices.append(torch.stack((synapses, synapse_root)))
        indices = torch.cat(indices, dim=1)
        # Xavier initialization of values (synapses_per_neuron is the fan-in/out)
        values = torch.randn(num_neurons * synapses_per_neuron) * math.sqrt(
            1 / synapses_per_neuron
        )

        coo_matrix = torch.sparse_coo_tensor(
            indices, values, (num_neurons, num_neurons), check_invariants=True
        ).coalesce()
        self.weight = nn.Parameter(coo_matrix)
        self.weight.register_hook(lambda grad: print(grad))
        # csr_matrix = coo_matrix.coalesce().to_sparse_csr()
        # self.weight = nn.Parameter(csr_matrix)

        # Initialize the bias vector
        self.bias = nn.Parameter(torch.zeros(num_neurons, 1)) if bias else None

    def coalesce(self):
        self.weight.data = self.weight.data.coalesce()

    def forward(self, x):
        assert self.weight.is_coalesced()
        # x: Dense (strided) tensor of shape (batch_size, num_neurons) if
        # batch_first, otherwise (num_neurons, batch_size)

        # Transpose input if batch_first
        if self.batch_first:
            x = x.t()
            
        # Perform sparse matrix multiplication with or without bias
        x = (
            self.mm_function(self.weight, x)
            if self.bias is None
            else self.mm_function(self.weight, x)
        )

        # Transpose output back to batch first
        if self.batch_first:
            x = x.t()

        return x


class CorticalRNN(nn.Module):
    def __init__(
        self,
        num_neurons,
        synapses_per_neuron,
        num_timesteps,
        activation=nn.GELU,
        sheet_bias=True,
        sheet_mm_function=torch.sparse.mm,
        sheet_batch_first=False,
        **kwargs
    ):
        super().__init__()
        self.num_neurons = num_neurons
        self.num_timesteps = num_timesteps
        self.activation = activation()
        self.sheet_batch_first = sheet_batch_first

        # Create the CorticalSheet layer
        self.cortical_sheet = CorticalSheet(
            num_neurons,
            synapses_per_neuron,
            sheet_bias,
            sheet_mm_function,
            sheet_batch_first,
        )

        # Create output block
        self.out_block = nn.Sequential(
            nn.Linear(28 * 28, 64), activation(), nn.Linear(64, 10)
        )

    def forward(self, x):
        # x: Dense (strided) tensor of shape (batch_size, 1, 32, 32)

        # Coallesce weight matrix
        self.cortical_sheet.coalesce()

        # Flatten spatial and channel dimensions
        x = x.flatten(1)
        # Pad with zeros for rest of neurons
        x = F.pad(x, (0, self.num_neurons - x.shape[1]))

        # To avoid tranposing x before and after every iteration, we tranpose
        # before and after ALL iterations and do not tranpose within forward()
        # of self.cortical_sheet
        if not self.sheet_batch_first:
            x = x.t()

        # Pass the input through the CorticalSheet layer num_timesteps times
        for _ in range(self.num_timesteps):
            x = self.activation(self.cortical_sheet(x))

        # Transpose back
        if not self.sheet_batch_first:
            x = x.t()

        # Extract output from last 28*28 neurons (can be arbitrarily large number of neurons)
        x = x[:, -28 * 28 :]

        # Return classification from out_block
        return self.out_block(x)

In [24]:
class TopographicalCorticalSheet(nn.Module):
    def __init__(
        self,
        sheet_size,
        connectivity_std,
        synapses_per_neuron,
        bias=True,
        mm_function=sparse_mm,
        batch_first=False,
        **kwargs
    ):
        super().__init__()
        # Save the sparse matrix multiplication function
        self.sheet_size = sheet_size
        num_neurons = sheet_size[0] * sheet_size[1]
        self.mm_function = mm_function
        self.batch_first = batch_first

        # Create a sparse tensor for the weight matrix
        indices = []

        # Create adjacency matrix with normal distribution randomized weights
        for i in range(sheet_size[0]):
            for j in range(sheet_size[1]):
                synapses = (
                    torch.randn(2, synapses_per_neuron)
                    * torch.tensor((connectivity_std, connectivity_std))[
                        :, None
                    ]
                    + torch.tensor((i, j))[:, None]
                ).long()
                synapses = synapses.clamp(
                    torch.tensor((0, 0))[:, None],
                    torch.tensor((sheet_size[0] - 1, sheet_size[1] - 1))[
                        :, None
                    ],
                )
                synapses = self.idx_2D_to_1D(synapses)
                synapse_root = torch.full_like(
                    synapses, self.idx_2D_to_1D(torch.tensor((i, j)))
                )
                indices.append(torch.stack((synapses, synapse_root)))
        indices = torch.cat(indices, dim=1)
        # Sort indices by synapses
        # indices = indices[:, torch.argsort(indices[0])]
        # Xavier initialization of values (synapses_per_neuron is the fan-in/out)
        values = torch.randn(indices.shape[1]) * math.sqrt(
            1 / synapses_per_neuron
        )

        coo_matrix = torch.sparse_coo_tensor(
            indices, values, (num_neurons, num_neurons), check_invariants=True
        ).coalesce()
        csr_matrix = coo_matrix.to_sparse_csr()
        self.weight = nn.Parameter(csr_matrix)
        # self.weight.register_hook(lambda grad: grad.coalesce())
        # self.weight.register_hook(lambda grad: print(grad))
        # self.weight = nn.Parameter(csr_matrix)

        # Initialize the bias vector
        self.bias = nn.Parameter(torch.zeros(num_neurons, 1)) if bias else None

    def coalesce(self):
        self.weight.data = self.weight.data.coalesce()

    def idx_1D_to_2D(self, x):
        return torch.stack((x // self.sheet_size[1], x % self.sheet_size[1]))

    def idx_2D_to_1D(self, x):
        return x[0] * self.sheet_size[1] + x[1]

    def forward(self, x):
        # x: Dense (strided) tensor of shape (batch_size, num_neurons) if
        # batch_first, otherwise (num_neurons, batch_size)
        # assert self.weight.is_coalesced()

        # Transpose input if batch_first
        if self.batch_first:
            x = x.t()

        # Perform sparse matrix multiplication with or without bias
        x = (
            self.mm_function(self.weight, x)
            if self.bias is None
            else self.mm_function(self.weight, x)
        )

        # Transpose output back to batch first
        if self.batch_first:
            x = x.t()

        return x


class TopographicalCorticalRNN(nn.Module):
    def __init__(
        self,
        sheet_size,
        connectivity_std,
        synapses_per_neuron,
        num_timesteps,
        activation=nn.GELU,
        sheet_bias=True,
        sheet_mm_function=sparse_mm,
        sheet_batch_first=False,
        **kwargs
    ):
        super().__init__()
        self.num_neurons = sheet_size[0] * sheet_size[1]
        self.num_timesteps = num_timesteps
        self.activation = activation()
        self.sheet_batch_first = sheet_batch_first

        # Create the CorticalSheet layer
        self.cortical_sheet = TopographicalCorticalSheet(
            sheet_size,
            connectivity_std,
            synapses_per_neuron,
            sheet_bias,
            sheet_mm_function,
            sheet_batch_first,
        )

        # Create output block
        self.out_block = nn.Sequential(
            nn.Linear(28 * 28, 64), activation(), nn.Linear(64, 10)
        )

    def forward(self, x):
        # x: Dense (strided) tensor of shape (batch_size, 1, 32, 32)

        # Coallesce weight matrix
        # self.cortical_sheet.coalesce()

        # Flatten spatial and channel dimensions
        x = x.flatten(1)
        # Pad with zeros for rest of neurons
        x = F.pad(x, (0, self.num_neurons - x.shape[1]))

        # To avoid tranposing x before and after every iteration, we tranpose
        # before and after ALL iterations and do not tranpose within forward()
        # of self.cortical_sheet
        if not self.sheet_batch_first:
            x = x.t()

        # Pass the input through the CorticalSheet layer num_timesteps times
        for _ in range(self.num_timesteps):
            x = self.activation(self.cortical_sheet(x))

        # Transpose back
        if not self.sheet_batch_first:
            x = x.t()

        # Extract output from last 28*28 neurons (can be arbitrarily large number of neurons)
        x = x[:, -28 * 28 :]

        # Return classification from out_block
        return self.out_block(x)

In [25]:
def train(config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = TopographicalCorticalRNN(**config).to(device)  # type: ignore
    optimizer = config.optimizer(model.parameters(), lr=config.lr)
    criterion = config.criterion()
    train_loader, test_loader = get_dataloaders(config)

    if config.log_wandb:
        wandb.init(project="Cortical RNN", config=config)

    for epoch in range(config.num_epochs):
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        running_loss = 0.0
        running_correct = 0
        running_total = 0

        bar = tqdm(
            train_loader,
            desc=(
                f"Training | Epoch: {epoch} | "
                f"Loss: {0:.4f} | "
                f"Acc: {0:.2%}"
            ),
        )
        for i, (images, labels) in enumerate(bar):
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update statistics
            train_loss += loss.item()
            running_loss += loss.item()

            predicted = outputs.argmax(-1)
            correct = (predicted == labels).sum().item()
            train_correct += correct
            running_correct += correct
            train_total += len(labels)
            running_total += len(labels)

            # Log statistics
            if (i + 1) % config.log_freq == 0:
                running_loss /= config.log_freq
                running_acc = running_correct / running_total
                if config.log_wandb:
                    wandb.log(
                        dict(
                            running_loss=running_loss, running_acc=running_acc
                        )
                    )
                bar.set_description(
                    f"Training | Epoch: {epoch} | "
                    f"Loss: {running_loss:.4f} | "
                    f"Acc: {running_acc:.2%}"
                )
                running_loss = 0
                running_correct = 0
                running_total = 0

        # Calculate average training loss and accuracy
        train_loss /= len(train_loader)
        train_acc = train_correct / train_total

        if config.log_wandb:
            wandb.log(dict(train_loss=train_loss, train_acc=train_acc))

        # Evaluate the model on the test set
        model.eval()
        test_loss = 0.0
        test_correct = 0
        test_total = 0

        with torch.no_grad():
            for images, labels in test_loader:
                images = images.to(device)
                labels = labels.to(device)

                # Forward pass
                outputs = model(images)
                loss = criterion(outputs, labels)

                # Update statistics
                test_loss += loss.item()
                predicted = outputs.argmax(-1)
                correct = (predicted == labels).sum().item()
                test_correct += correct
                test_total += len(labels)

        # Calculate average test loss and accuracy
        test_loss /= len(train_loader)
        test_acc = test_correct / test_total

        if config.log_wandb:
            wandb.log(
                dict(test_loss=test_loss, test_acc=test_acc, epoch=epoch)
            )

        # Print the epoch statistics
        print(
            f"Epoch [{epoch}/{config.num_epochs}] | "
            f"Train Loss: {train_loss:.4f} | "
            f"Train Accuracy: {train_acc:.2%} | "
            f"Test Loss: {test_loss:.4f}, "
            f"Test Accuracy: {test_acc:.2%}"
        )

        # Save Model
        # Save Model
        file_path = os.path.abspath(
            os.path.join(config.model_dir, f"model_{epoch}.pt")
        )
        link_path = os.path.abspath(os.path.join(config.model_dir, "model.pt"))
        torch.save(model, file_path)
        try:
            os.remove(link_path)
        except FileNotFoundError:
            pass
        os.symlink(file_path, link_path)

In [26]:
print_mem_stats()

Free/Total: 78.30GB/79.21GB


In [27]:
print_mem_stats()
torch.cuda.empty_cache()
print_mem_stats()

Free/Total: 78.30GB/79.21GB
Free/Total: 78.66GB/79.21GB


In [28]:
train(config)

Training | Epoch: 0 | Loss: 0.0000 | Acc: 0.00%:   0%|          | 0/3750 [00:00<?, ?it/s]


NotImplementedError: Could not run 'aten::_foreach_add_.List' with arguments from the 'SparseCsrCUDA' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::_foreach_add_.List' is only available for these backends: [CPU, CUDA, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMeta, AutogradMTIA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradNestedTensor, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PythonDispatcher].

CPU: registered at aten/src/ATen/RegisterCPU.cpp:31034 [kernel]
CUDA: registered at aten/src/ATen/RegisterCUDA.cpp:43986 [kernel]
BackendSelect: fallthrough registered at ../aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:144 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at ../aten/src/ATen/functorch/DynamicLayer.cpp:491 [backend fallback]
Functionalize: registered at aten/src/ATen/RegisterFunctionalization_2.cpp:21384 [kernel]
Named: registered at ../aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at ../aten/src/ATen/ConjugateFallback.cpp:17 [backend fallback]
Negative: registered at ../aten/src/ATen/native/NegateFallback.cpp:19 [backend fallback]
ZeroTensor: registered at ../aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
ADInplaceOrView: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:63 [backend fallback]
AutogradOther: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:15276 [autograd kernel]
AutogradCPU: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:15276 [autograd kernel]
AutogradCUDA: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:15276 [autograd kernel]
AutogradHIP: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:15276 [autograd kernel]
AutogradXLA: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:15276 [autograd kernel]
AutogradMPS: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:15276 [autograd kernel]
AutogradIPU: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:15276 [autograd kernel]
AutogradXPU: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:15276 [autograd kernel]
AutogradHPU: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:15276 [autograd kernel]
AutogradVE: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:15276 [autograd kernel]
AutogradLazy: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:15276 [autograd kernel]
AutogradMeta: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:15276 [autograd kernel]
AutogradMTIA: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:15276 [autograd kernel]
AutogradPrivateUse1: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:15276 [autograd kernel]
AutogradPrivateUse2: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:15276 [autograd kernel]
AutogradPrivateUse3: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:15276 [autograd kernel]
AutogradNestedTensor: registered at ../torch/csrc/autograd/generated/VariableType_4.cpp:15276 [autograd kernel]
Tracer: registered at ../torch/csrc/autograd/generated/TraceType_0.cpp:16728 [kernel]
AutocastCPU: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:487 [backend fallback]
AutocastCUDA: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:354 [backend fallback]
FuncTorchBatched: registered at ../aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:815 [backend fallback]
FuncTorchVmapMode: fallthrough registered at ../aten/src/ATen/functorch/VmapModeRegistrations.cpp:28 [backend fallback]
Batched: registered at ../aten/src/ATen/LegacyBatchingRegistrations.cpp:1073 [backend fallback]
VmapMode: fallthrough registered at ../aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at ../aten/src/ATen/functorch/TensorWrapper.cpp:210 [backend fallback]
PythonTLSSnapshot: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:152 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at ../aten/src/ATen/functorch/DynamicLayer.cpp:487 [backend fallback]
PythonDispatcher: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:148 [backend fallback]


In [None]:
# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
#     train(config)

In [None]:
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=50))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
autograd::engine::evaluate_function: SparseAddmmBack...         0.41%      30.035ms        62.90%        4.593s       4.176ms       0.000us         0.00%        4.674s       4.249ms          1100  
                                              aten::add         0.43%      31.321ms        12.65%     923.936ms     422.081us       0.000us         0.00%        4.071s       1.860ms          2189  
         

In [None]:
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=50))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        aten::_coalesce       -10.96%  -9162002.000us        96.56%       80.749s      26.090ms       70.803s        87.27%       75.788s      24.487ms          3095  
                                         aten::coalesce         2.88%        2.406s        96.57%       80.758s      24.509ms       0.000us         0.00%       73.554s      22.323ms          3295  
autogra

In [30]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
indices = torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]]).long()
values = torch.tensor([1, 2, 3, 4]).float()
weight = torch.sparse_coo_tensor(
    indices, values, (100000, 100000), check_invariants=True
).coalesce()
weight = weight.to_sparse_csr()
weight = weight.to(device)
weight.requires_grad = True
B = weight.clone()

weight + B
(weight + B).sum().backward()
print(weight.grad)


indices = torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]]).long()
values = torch.tensor([1, 2, 3, 4]).float()
weight = torch.sparse_coo_tensor(
    indices, values, (100000, 100000), check_invariants=True
).coalesce()
weight = weight.to(device)
# weight = weight.to_sparse_csr()
weight.requires_grad = True

x = torch.ones(16, 1, 28, 28).to(device)
x = x.flatten(1)
x = F.pad(x, (0, 100000 - x.shape[1]))
out = x.t()
for _ in range(100):
    out = F.relu(sparse_mm(weight, out))
loss = out.sum()
loss.backward()
weight.grad
model = CorticalRNN(**config)
optimizer = config.optimizer(model.parameters(), lr=config.lr)
train_loader, _ = get_dataloaders(config)
train_iter = iter(train_loader)

for _ in range(10):
    x = next(train_iter)[0]
    optimizer.zero_grad()
    out = model(x)
    loss = out.sum()
    loss.backward()
    print(model.cortical_sheet.weight.grad._nnz())
    optimizer.step()
    print(model.cortical_sheet.weight.grad._nnz())
model.cortical_sheet.weight

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


print("CorticalRNN")
model = CorticalRNN(**config)
total_params = 0
for param in model.named_parameters():
    num_params = (
        param[1]._nnz()
        if param[0] == "cortical_sheet.weight"
        else param[1].numel()
    )
    total_params += num_params
    print(param[0], num_params)
print(f"Total Parameters: {total_params}\n")

print("TopographicalCorticalRNN")
model = TopographicalCorticalRNN(**config)
total_params = 0
for param in model.named_parameters():
    num_params = (
        param[1]._nnz()
        if param[0] == "cortical_sheet.weight"
        else param[1].numel()
    )
    total_params += num_params
    print(param[0], num_params)
print(f"Total Parameters: {total_params}")

CorticalRNN
cortical_sheet.weight 995140
cortical_sheet.bias 10000
out_block.0.weight 50176
out_block.0.bias 64
out_block.2.weight 640
out_block.2.bias 10
Total Parameters: 1056030

TopographicalCorticalRNN
cortical_sheet.weight 907791
cortical_sheet.bias 10000
out_block.0.weight 50176
out_block.0.bias 64
out_block.2.weight 640
out_block.2.bias 10
Total Parameters: 968681
