In [None]:
# Only for Colab
import os

REPO_URL = "https://github.com/TommyR06/cross-sim-BTU-course.git"
REPO_NAME = "cross-sim-BTU-course"

if not os.path.exists(REPO_NAME):
    !git clone {REPO_URL}
%cd {REPO_NAME}

!pip install -r requirements.txt

In [None]:
%matplotlib inline
import sys
sys.path.append("../../")
import numpy as np
from simulator import AnalogCore
from simulator import CrossSimParameters
from applications.mvm_params import set_params
import scipy.linalg
import matplotlib
import matplotlib.pyplot as plt
import pickle
from PIL import Image
np.random.seed(498)

# Part 3: Device-aware training

The PyTorch model with CrossSim-compatible layers are not just useful for running inference. They can be trained as well! In this use case, the forward pass through the convolution and fully-connected layers will be executed through CrossSim's AnalogCores, but the backward pass will be idealized, providing a differentiable trace. This means that, to the extent the idealized operation matches the true AnalogCore forward pass, we can perform surrogate gradient descent. This can allow a network to adapt to these non-idealities and recover some of the performance that would be lost from post-training conversion.

For this simple demo, we will train on the very simple MNIST dataset, with and without device errors in the loop injected through CrossSim. First, we will load the dataset.

In [None]:
from simulator.algorithms.dnn.torch.convert import from_torch, reinitialize, synchronize
import torch
from torchvision import datasets, transforms

In [11]:
from tqdm import tqdm
from pathlib import Path

# Load the MNIST training set
mnist_data = datasets.MNIST("./", download=True, train=True,
                              transform=transforms.ToTensor(),
                              target_transform=transforms.Compose([
                                lambda x:torch.tensor([x]), 
                                lambda x:torch.nn.functional.one_hot(x,10).float(),
                                lambda x:x.squeeze(),
                                ]))

# Load the MNIST test set
mnist_test = datasets.MNIST("./", download=True, train=False,
                              transform=transforms.ToTensor(),
                              target_transform=transforms.Compose([
                                lambda x:torch.tensor([x]), 
                                lambda x:torch.nn.functional.one_hot(x,10).float(),
                                lambda x:x.squeeze(),
                                ]))

# Split dataset into training and validation and create data loaders
ds_train, ds_val = torch.utils.data.random_split(mnist_data, [0.8, 0.2])
mnist_loader_train = torch.utils.data.DataLoader(ds_train, batch_size=batch_size, shuffle=True)
mnist_loader_val = torch.utils.data.DataLoader(ds_val, batch_size=batch_size, shuffle=False)

# Create test set loader
mnist_loader_test = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False)
N_test = len(mnist_loader_test.dataset)

We will train a simple three-layer CNN on MNIST, whose topology is defined below. This is a small network with only 7018 trainable weights.

In [12]:
# Define the CNN topology
def mnist_cnn():
    return torch.nn.Sequential(
        torch.nn.Conv2d(1, 8, 3, padding='valid', stride=2),
        torch.nn.ReLU(),
        torch.nn.Conv2d(8, 16, 3, padding='valid', stride=2),
        torch.nn.ReLU(),
        torch.nn.Flatten(),
        torch.nn.Linear(576, 10)
        )

We will use the standard PyTorch wrapper below to train the CNN on MNIST. We will train using the Adam optimizer with a learning rate of $10^{-3}$.

In [13]:
# Wrapper for training the CNN
class SequentialWrapper():
    def __init__(self, net, loss, learning_rate=1e-3):
        self.net = net
        self.loss = loss
        self.learning_rate = learning_rate
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=self.learning_rate)

    def forward(self, x):
        return self.net(x)
    
    def training_step(self, batch):
        self.optimizer.zero_grad()
        pred = self.forward(batch[0])
        loss = self.loss(pred, batch[1])
        loss.backward()
        self.optimizer.step()
        return loss

    def validation_step(self, batch):
        pred = self.forward(batch[0])
        loss = self.loss(pred, batch[1])
        return loss
    
    def train_epoch(self, train_loader, val_loader):
        loss_train, loss_val = 0, 0
        for minibatch in iter(train_loader):
            loss_train += self.training_step(minibatch).detach()
        for minibatch in iter(val_loader):
            loss_val += self.validation_step(minibatch).detach()
        return loss_train/len(train_loader), loss_val/len(val_loader)
    
    def train(self, train_loader, val_loader, epochs):
        loss_train, loss_val = np.zeros(epochs), np.zeros(epochs)
        for e in tqdm(range(0, epochs)):
            lt, lv = self.train_epoch(train_loader, val_loader)
            loss_train[e] = lt
            loss_val[e] = lv
        return loss_train, loss_val

# Create the wrapped PyTorch model
mnist_cnn_pt = SequentialWrapper(mnist_cnn(), torch.nn.CrossEntropyLoss())

We will first train this CNN as we would normally do in PyTorch, without any analog error injection during training. After training, we'll evaluate the test accuracy, again without any analog errors.

In [14]:
# Number of training epochs
N_epochs = 20

# Train the standard PyTorch CNN
loss_train_pt, loss_val_pt = mnist_cnn_pt.train(mnist_loader_train, mnist_loader_val, N_epochs)

# Perform inference on the test set, with no analog errors
y_pred, y, k = np.zeros(N_test), np.zeros(N_test), 0
for inputs, labels in mnist_loader_test:
    output = mnist_cnn_pt.net(inputs)
    y_pred_k = output.data.detach().numpy()
    y_pred = np.append(y_pred,y_pred_k.argmax(axis=-1))
    y = np.append(y,labels.detach().numpy().argmax(axis=1))

# Evaluate accuracy
accuracy_digitalTrain_digitalTest = np.sum(y == y_pred)/len(y)
print('===========')
print('No analog errors during training, no analog errors during test')
print('Test accuracy: {:.2f}%\n'.format(accuracy_digitalTrain_digitalTest*100))

  0%|          | 0/20 [00:00<?, ?it/s]

100%|██████████| 20/20 [12:46<00:00, 38.34s/it]


No analog errors during training, no analog errors during test
Test accuracy: 99.16%



How well does this CNN do when analog errors are injected at inference time? Since this is MNIST, we will simulate inference assuming a memory device that has very large errors. This device will have state-independent conductance errors with $\alpha = 0.3$. We will disable all other error models to keep this demo simple.

We will run inference by first passing our trained CNN through our PyTorch layer converter as we did in Part 2. Since the device error is large, we will simulate inference ten times with re-sampled random device errors each time. This will give us a good statistical picture of the network's accuracy.

In [15]:
# Create a parameters object that models a memory device with very large errors
params_analog = set_params(weight_bits = 8, wtmodel = "BALANCED", 
                         error_model = "generic",
                         proportional_error = "False",
                         alpha_error = 0.3)

# Convert the layers in the trained CNN
analog_mnist_cnn_pt = from_torch(mnist_cnn_pt.net, params_analog)

# Number of inference simulations with re-sampled random analog errors
N_runs = 10

# Perform analog inference on the test set
accuracies = np.zeros(N_runs)
for i in range(N_runs):
    print("Inference simulation {:d} of {:d}".format(i+1,N_runs), end="\r")
    y_pred, y, k = np.zeros(N_test), np.zeros(N_test), 0
    for inputs, labels in mnist_loader_test:
        output = analog_mnist_cnn_pt.forward(inputs)
        y_pred_k = output.data.detach().numpy()
        y_pred = np.append(y_pred,y_pred_k.argmax(axis=-1))
        y = np.append(y,labels.detach().numpy().argmax(axis=1))
    accuracies[i] = np.sum(y == y_pred)/len(y)
    reinitialize(analog_mnist_cnn_pt)

# Evaluate average test accuracy
print('\n===========')
print('No analog errors during training, CrossSim analog errors during test')
accuracy_digitalTrain_analogTest = np.mean(accuracies)
std_digitalTrain_analogTest = np.std(accuracies)
print('Test accuracy: {:.2f}% +/- {:.3f}%'.format(100*accuracy_digitalTrain_analogTest,100*std_digitalTrain_analogTest))

Inference simulation 10 of 10
No analog errors during training, CrossSim analog errors during test
Test accuracy: 93.33% +/- 1.606%


With the inclusion of these large conductance errors, our model loses quite a bit of accuracy on MNIST.

Now let's try to see if we can make up this accuracy loss by simulating the conductance errors at inference time during the training process. As before, we will disable all other error models to keep things simple. For a practical hardware-aware training scenario, we can specify our parameters to represent the exact analog hardware configuration that would be used during inference and enable as many different error models in CrossSim as we would like.

To do this, we will use a modified training wrapper below that includes only a single new line. The "synchronize" method is called after the backward pass to update the conductance values in the AnalogCores with the new updated weight values found using the optimizer. These updated AnalogCores will then be used for the forward pass of the next training epoch.

We create another PyTorch CNN, convert its layers to be CrossSim-compatible, then wrap it with the modified training wrapper. Then we will train this model with the same large conductance errors injected during training.

In [16]:
# Modified training warpper for CrossSim-in-the-loop training
class SequentialWrapper_CrossSim(SequentialWrapper):
    def __init__(self, net, loss, learning_rate=1e-3):
        super().__init__(net, loss, learning_rate)
        
    def training_step(self, batch):
        self.optimizer.zero_grad()
        pred = self.forward(batch[0])
        loss = self.loss(pred, batch[1])
        loss.backward()
        self.optimizer.step()
        synchronize(self.net)  # <--- The only changed line in all of training!
        return loss

# Create a PyTorch model with CrossSim-compatible layers
analog_mnist_cnn = from_torch(mnist_cnn(), params_analog)

# Create the wrapped analog PyTorch model
analog_mnist_cnn_CS = SequentialWrapper_CrossSim(analog_mnist_cnn, torch.nn.CrossEntropyLoss())

# Train the analog PyTorch model
loss_train_CS, loss_val_CS = analog_mnist_cnn_CS.train(mnist_loader_train, mnist_loader_val, N_epochs)

100%|██████████| 20/20 [16:20<00:00, 49.05s/it]


Finally, let's perform inference simulation with conductance errors to see if our model that had device-aware training (with the same conductance errors as inference) achieves higher accuracy than the model with standard training.

In [17]:
# Perform analog inference on the test set
accuracies = np.zeros(N_runs)
for i in range(N_runs):
    print("Inference simulation {:d} of {:d}".format(i+1,N_runs), end="\r")
    y_pred, y, k = np.zeros(N_test), np.zeros(N_test), 0
    for inputs, labels in mnist_loader_test:
        output = analog_mnist_cnn_CS.net(inputs)
        y_pred_k = output.data.detach().numpy()
        y_pred = np.append(y_pred,y_pred_k.argmax(axis=-1))
        y = np.append(y,labels.detach().numpy().argmax(axis=1))
    accuracies[i] = np.sum(y == y_pred)/len(y)
    reinitialize(analog_mnist_cnn_CS.net)

# Evaluate average test accuracy
print('\n===========')
print('CrossSim analog errors during training, CrossSim analog errors during test')
accuracy_analogTrain_analogTest = np.mean(accuracies)
std_analogTrain_analogTest = np.std(accuracies)
print('Test accuracy: {:.2f}% +/- {:.3f}%'.format(accuracy_analogTrain_analogTest*100,std_analogTrain_analogTest*100))

Inference simulation 10 of 10
CrossSim analog errors during training, CrossSim analog errors during test
Test accuracy: 98.27% +/- 0.172%


Let's summarize our results.

In [18]:
print("Accuracy on MNIST test set")
print("================")
print("Standard training, standard inference: {:.2f}%".format(100*accuracy_digitalTrain_digitalTest))
print("Standard training, CrossSim inference: {:.2f}% +/- {:.3f}%".format(100*accuracy_digitalTrain_analogTest, 100*std_digitalTrain_analogTest))
print("CrossSim training, CrossSim inference: {:.2f}% +/- {:.3f}%".format(100*accuracy_analogTrain_analogTest, 100*std_analogTrain_analogTest))

Accuracy on MNIST test set
Standard training, standard inference: 99.16%
Standard training, CrossSim inference: 93.33% +/- 1.606%
CrossSim training, CrossSim inference: 98.27% +/- 0.172%


Device-aware training using CrossSim yielded a substantial recovery of the test accuracy in the presence of very large conductance errors!