In [27]:
%cd /om2/user/valmiki/bioplnn

/rdma/vast-rdma/user/valmiki/bioplnn


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


## Imports

In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
from torchvision.datasets import CIFAR10
from tqdm import tqdm

In [29]:
!nvidia-smi

Tue Dec 19 17:04:21 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.86.01    Driver Version: 515.86.01    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |


|   0  NVIDIA A100 80G...  On   | 00000000:61:00.0 Off |                    0 |
| N/A   39C    P0    74W / 300W |   1999MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|    0   N/A  N/A   1356877      C   ...s/pytorch-3.10/bin/python     1997MiB |
+-----------------------------------------------------------------------------+


## Utils

In [30]:
class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.__dict__ = self

def get_dataloaders(config):
    # Load the MNIST dataset
    cifar_train = torchvision.datasets.CIFAR10(
        root="./data",
        train=True,
        download=True,
        transform=torchvision.transforms.ToTensor(),
    )

    # Load the MNIST test dataset
    cifar_test = torchvision.datasets.CIFAR10(
        root="./data",
        train=False,
        download=True,
        transform=torchvision.transforms.ToTensor(),
    )

    train_loader = DataLoader(
        dataset=cifar_train,
        batch_size=config.batch_size,
        shuffle=True,
        pin_memory=torch.cuda.is_available(),
    )

    test_loader = DataLoader(
        dataset=cifar_test,
        batch_size=config.batch_size,
        shuffle=False,
        pin_memory=torch.cuda.is_available(),
    )

    return train_loader, test_loader

## Parameters

In [31]:
config = AttrDict(
    # Model parameters
    num_neurons=100000,
    synapses_per_neuron=1000,
    num_timesteps=1000,
    bias=True,
    mm_function=torch.sparse.mm,
    # Training parameters
    batch_size=32,
    optimizer=optim.AdamW,
    lr=1e-3,
    criterion=nn.CrossEntropyLoss,
    log_freq=100,
    num_epochs=30,
)

## Model

In [32]:
class CorticalSheet(nn.Module):
    def __init__(
        self,
        num_neurons,
        synapses_per_neuron,
        bias=True,
        mm_function=torch.sparse.mm,
        **kwargs
    ):
        super().__init__()
        # Save the sparse matrix multiplication function
        self.mm_function = mm_function

        # Create a sparse tensor for the weight matrix
        indices = []

        # Create adjacency matrix with normal distribution randomized weights
        for i in range(num_neurons):
            synapses = torch.randint(0, num_neurons, (synapses_per_neuron,))
            synapse_root = torch.full_like(synapses, i)
            indices.append(torch.stack((synapses, synapse_root)))
        indices = torch.cat(indices, dim=1)
        values = torch.randn(num_neurons * synapses_per_neuron)

        coo_matrix = torch.sparse_coo_tensor(
            indices, values, (num_neurons, num_neurons)
        )
        csr_matrix = coo_matrix.coalesce().to_sparse_csr()
        self.weight = nn.Parameter(csr_matrix)

        # Initialize the bias vector
        self.bias = nn.Parameter(torch.zeros(num_neurons)) if bias else None

    def forward(self, x):
        # Perform sparse matrix multiplication
        output = self.mm_function(self.weight, x)

        # Add the bias vector
        if self.bias is not None:
            output += self.bias

        return output


class CorticalRNN(nn.Module):
    def __init__(
        self,
        num_neurons,
        synapses_per_neuron,
        num_timesteps,
        activation=nn.GELU,
        **kwargs
    ):
        super().__init__()
        self.num_timesteps = num_timesteps
        self.activation = activation()

        # Define the CorticalSheet layer
        self.cortical_sheet = CorticalSheet(num_neurons, synapses_per_neuron)

    def forward(self, x):
        # Pass the input through the CorticalSheet layer num_timesteps times
        for _ in range(self.num_timesteps):
            x = self.activation(self.cortical_sheet(x))

        return x

In [33]:
def train(config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = CorticalRNN(**config).to(device)  # type: ignore
    optimizer = config.optimizer(model.parameters(), lr=config.lr)
    criterion = config.criterion()
    train_loader, test_loader = get_dataloaders(config)

    for epoch in range(config.num_epochs):
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        running_loss = 0.0
        running_correct = 0
        running_total = 0

        bar = tqdm(
            train_loader,
            desc=(
                f"Training | Epoch: {epoch} | "
                f"Loss: {0:.4f} | "
                f"Acc: {0:.2%}"
            ),
        )
        for i, (images, labels) in enumerate(bar):
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update statistics
            train_loss += loss.item()
            running_loss += loss.item()

            predicted = outputs.argmax(-1)
            correct = (predicted == labels).sum().item()
            train_correct += correct
            running_correct += correct
            train_total += len(labels)
            running_total += len(labels)

            # Log statistics
            if (i + 1) % config.log_freq == 0:
                running_loss /= config.log_freq
                acc = running_correct / running_total
                # if config.log_wandb:
                #     wandb.log(dict(
                #         running_train_loss = running_loss,
                #         running_train_acc = running_acc)
                #     )
                bar.set_description(
                    f"Training | Epoch: {epoch} | "
                    f"Loss: {running_loss:.4f} | "
                    f"Acc: {acc:.2%}"
                )
                running_loss = 0
                running_correct = 0
                running_total = 0

        # Calculate average training loss and accuracy
        train_loss /= len(train_loader)
        train_acc = train_correct / train_total

        # if config.log_wandb:
        #     wandb.log(dict(
        #         train_loss = train_loss,
        #         train_acc = train_acc)
        #     )

        # Evaluate the model on the test set
        model.eval()
        test_loss = 0.0
        test_correct = 0
        test_total = 0

        with torch.no_grad():
            for images, labels in test_loader:
                images = images.to(device)
                labels = labels.to(device)

                # Forward pass
                outputs = model(images)
                loss = criterion(outputs, labels)

                # Update statistics
                test_loss += loss.item()
                predicted = outputs.argmax(-1)
                correct = (predicted == labels).sum().item()
                test_correct += correct
                test_total += len(labels)

        # Calculate average test loss and accuracy
        test_loss /= len(train_loader)
        test_acc = test_correct / test_total

        # if config.log_wandb:
        #     wandb.log(dict(
        #         test_loss = test_loss,
        #         test_acc = test_acc)
        #     )

        # Print the epoch statistics
        print(
            f"Epoch [{epoch}/{config.num_epochs}] | "
            f"Train Loss: {train_loss:.4f} | "
            f"Train Accuracy: {train_acc:.2%} | "
            f"Test Loss: {test_loss:.4f}, "
            f"Test Accuracy: {test_acc:.2%}"
        )

In [34]:
train(config)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:15<00:00, 11075698.04it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


Training | Epoch: 0 | Loss: 0.0000 | Acc: 0.00%:   0%|          | 0/1563 [00:00<?, ?it/s]


RuntimeError: mat2 must be a matrix, got 4-D tensor