In [1]:
%cd /om2/user/valmiki/bioplnn

/rdma/vast-rdma/user/valmiki/bioplnn


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


## Imports

In [2]:
import os
import math
import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
from torchvision.datasets import CIFAR10, MNIST
from tqdm import tqdm

In [3]:
!nvidia-smi

Thu Jan  4 18:01:31 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.86.01    Driver Version: 515.86.01    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100 80G...  On   | 00000000:84:00.0 Off |                    0 |
| N/A   39C    P0    52W / 300W |      0MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Utils

In [4]:
class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.__dict__ = self


def print_mem_stats():
    f, t = torch.cuda.mem_get_info()
    print(f"Free/Total: {f/(1024**3):.2f}GB/{t/(1024**3):.2f}GB")


def get_dataloaders(config):
    # Load the MNIST dataset
    mnist_train = MNIST(
        root="./data",
        train=True,
        download=True,
        transform=torchvision.transforms.ToTensor(),
    )

    # Load the MNIST test dataset
    mnist_test = MNIST(
        root="./data",
        train=False,
        download=True,
        transform=torchvision.transforms.ToTensor(),
    )

    train_loader = DataLoader(
        dataset=mnist_train,
        batch_size=config.batch_size,
        shuffle=True,
        pin_memory=torch.cuda.is_available(),
    )

    test_loader = DataLoader(
        dataset=mnist_test,
        batch_size=config.batch_size,
        shuffle=False,
        pin_memory=torch.cuda.is_available(),
    )

    return train_loader, test_loader

## Parameters

In [10]:
config = AttrDict(
    # Model parameters
    num_neurons=10000,
    synapses_per_neuron=100,
    num_timesteps=100,
    sheet_bias=True,
    sheet_mm_function=torch.sparse.mm,
    sheet_addmm_function=torch.sparse.addmm,
    sheet_batch_first=False,
    model_dir="models",
    # Training parameters
    batch_size=16,
    optimizer=optim.SGD,
    lr=1e-3,
    criterion=nn.CrossEntropyLoss,
    log_freq=10,
    num_epochs=30,
    log_wandb=True
)
try:
    os.mkdir(config.model_dir) # type: ignore
except FileExistsError:
    pass

## Model

In [11]:
class CorticalSheet(nn.Module):
    def __init__(
        self,
        num_neurons,
        synapses_per_neuron,
        bias=True,
        mm_function=torch.sparse.mm,
        addmm_function=torch.sparse.addmm,
        batch_first=False,
        **kwargs
    ):
        super().__init__()
        # Save the sparse matrix multiplication function
        self.mm_function = mm_function
        self.addmm_function = addmm_function
        self.batch_first = batch_first

        # Create a sparse tensor for the weight matrix
        indices = []

        # Create adjacency matrix with normal distribution randomized weights
        for i in range(num_neurons):
            synapses = torch.randint(0, num_neurons, (synapses_per_neuron,))
            synapse_root = torch.full_like(synapses, i)
            indices.append(torch.stack((synapses, synapse_root)))
        indices = torch.cat(indices, dim=1)
        # Xavier initialization of values (synapses_per_neuron is the fan-in/out)
        values = torch.randn(num_neurons * synapses_per_neuron) * math.sqrt(
            1 / synapses_per_neuron
        )

        coo_matrix = torch.sparse_coo_tensor(
            indices, values, (num_neurons, num_neurons)
        ).coalesce()
        self.weight = nn.Parameter(coo_matrix)
        # csr_matrix = coo_matrix.coalesce().to_sparse_csr()
        # self.weight = nn.Parameter(csr_matrix)

        # Initialize the bias vector
        self.bias = nn.Parameter(torch.zeros(num_neurons, 1)) if bias else None

    def forward(self, x):
        # x: Dense (strided) tensor of shape (batch_size, num_neurons) if
        # batch_first, otherwise (num_neurons, batch_size)

        # Transpose input if batch_first
        if self.batch_first:
            x = x.t()
        # Perform sparse matrix multiplication with or without bias
        if self.bias is not None:
            x = self.addmm_function(self.bias, self.weight, x)
        else:
            x = self.mm_function(self.weight, x)

        # Transpose output back to batch first
        if self.batch_first:
            x = x.t()

        return x


class CorticalRNN(nn.Module):
    def __init__(
        self,
        num_neurons,
        synapses_per_neuron,
        num_timesteps,
        activation=nn.GELU,
        sheet_bias=True,
        sheet_mm_function=torch.sparse.mm,
        sheet_addmm_function=torch.sparse.addmm,
        sheet_batch_first=False,
        **kwargs
    ):
        super().__init__()
        self.num_neurons = num_neurons
        self.num_timesteps = num_timesteps
        self.activation = activation()
        self.sheet_batch_first = sheet_batch_first

        # Create the CorticalSheet layer
        self.cortical_sheet = CorticalSheet(
            num_neurons,
            synapses_per_neuron,
            sheet_bias,
            sheet_mm_function,
            sheet_addmm_function,
            sheet_batch_first,
        )

        # Create output block
        self.out_block = nn.Sequential(
            nn.Linear(28 * 28, 64), activation(), nn.Linear(64, 10)
        )

    def forward(self, x):
        # x: Dense (strided) tensor of shape (batch_size, 1, 32, 32)

        # Flatten spatial and channel dimensions
        x = x.flatten(1)
        # Pad with zeros for rest of neurons
        x = F.pad(x, (0, self.num_neurons - x.shape[1]))

        # To avoid tranposing x before and after every iteration, we tranpose
        # before and after ALL iterations and do not tranpose within forward()
        # of self.cortical_sheet
        if not self.sheet_batch_first:
            x = x.t()

        # Pass the input through the CorticalSheet layer num_timesteps times
        for _ in range(self.num_timesteps):
            x = self.activation(self.cortical_sheet(x))

        # Transpose back
        if not self.sheet_batch_first:
            x = x.t()

        # Extract output from last 28*28 neurons (can be arbitrarily large number of neurons)
        x = x[:, -28 * 28 :]

        # Return classification from out_block
        return self.out_block(x)

In [15]:
def train(config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = CorticalRNN(**config).to(device)  # type: ignore
    optimizer = config.optimizer(model.parameters(), lr=config.lr)
    criterion = config.criterion()
    train_loader, test_loader = get_dataloaders(config)
    
    if config.log_wandb:
        wandb.init(
            project="Cortical RNN", 
            config=config
        )

    for epoch in range(config.num_epochs):
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        running_loss = 0.0
        running_correct = 0
        running_total = 0

        bar = tqdm(
            train_loader,
            desc=(
                f"Training | Epoch: {epoch} | "
                f"Loss: {0:.4f} | "
                f"Acc: {0:.2%}"
            ),
        )
        for i, (images, labels) in enumerate(bar):
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update statistics
            train_loss += loss.item()
            running_loss += loss.item()

            predicted = outputs.argmax(-1)
            correct = (predicted == labels).sum().item()
            train_correct += correct
            running_correct += correct
            train_total += len(labels)
            running_total += len(labels)

            # Log statistics
            if (i + 1) % config.log_freq == 0:
                running_loss /= config.log_freq
                running_acc = running_correct / running_total
                if config.log_wandb:
                    wandb.log(dict(
                        running_loss = running_loss,
                        running_acc = running_acc)
                    )
                bar.set_description(
                    f"Training | Epoch: {epoch} | "
                    f"Loss: {running_loss:.4f} | "
                    f"Acc: {running_acc:.2%}"
                )
                running_loss = 0
                running_correct = 0
                running_total = 0

        # Calculate average training loss and accuracy
        train_loss /= len(train_loader)
        train_acc = train_correct / train_total

        if config.log_wandb:
            wandb.log(dict(
                train_loss = train_loss,
                train_acc = train_acc)
            )

        # Evaluate the model on the test set
        model.eval()
        test_loss = 0.0
        test_correct = 0
        test_total = 0

        with torch.no_grad():
            for images, labels in test_loader:
                images = images.to(device)
                labels = labels.to(device)

                # Forward pass
                outputs = model(images)
                loss = criterion(outputs, labels)

                # Update statistics
                test_loss += loss.item()
                predicted = outputs.argmax(-1)
                correct = (predicted == labels).sum().item()
                test_correct += correct
                test_total += len(labels)

        # Calculate average test loss and accuracy
        test_loss /= len(train_loader)
        test_acc = test_correct / test_total

        if config.log_wandb:
            wandb.log(dict(
                test_loss = test_loss,
                test_acc = test_acc)
            )

        # Print the epoch statistics
        print(
            f"Epoch [{epoch}/{config.num_epochs}] | "
            f"Train Loss: {train_loss:.4f} | "
            f"Train Accuracy: {train_acc:.2%} | "
            f"Test Loss: {test_loss:.4f}, "
            f"Test Accuracy: {test_acc:.2%}"
        )
        
        # Save Model
        # Save Model
        file_path = os.path.abspath(os.path.join(config.model_dir, f'model_{epoch}.pt'))
        link_path = os.path.abspath(os.path.join(config.model_dir, 'model.pt'))
        torch.save(model, file_path)
        try:
            os.remove(link_path)
        except FileNotFoundError:
            pass
        os.symlink(file_path, link_path)

In [16]:
print_mem_stats()

Free/Total: 46.57GB/79.21GB


In [17]:
train(config)



VBox(children=(Label(value='0.021 MB of 0.021 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
running_acc,▁
running_loss,▁

0,1
running_acc,0.09375
running_loss,2.3082


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111943867419743, max=1.0…

Training | Epoch: 0 | Loss: 2.2931 | Acc: 11.25%:  44%|████▍     | 1661/3750 [15:26<19:26,  1.79it/s]

: 