In [61]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import pickle


In [62]:
device = torch.device("cpu")

In [63]:
with open('data/reset_splines.pkl', 'rb') as f:
    reset_splines = pickle.load(f)
with open('data/set_splines.pkl', 'rb') as f:
    set_splines = pickle.load(f)

In [64]:
# Hyperparameters
input_size = 14*14  # 28*28 downsampled images
num_classes = 10
num_epochs = 10
batch_size = 64
learning_rate = 0.01

In [65]:
# Load MNIST dataset with downsampling
transform = transforms.Compose([
    transforms.Resize((14, 14)),     # Downsample to 7x7
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
])

train_dataset = MNIST(root='./data/MNIST', train=True, transform=transform, download=True)
test_dataset = MNIST(root='./data/MNIST', train=False, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)


In [66]:
class HardwareAwareOptimizer(torch.optim.Optimizer):
    def __init__(self, params, lr=0.01, set_splines=None, reset_splines = None):
        defaults = dict(lr=lr)
        self.set_splines = set_splines
        self.reset_splines = reset_splines
        super(HardwareAwareOptimizer, self).__init__(params, defaults)

    def step(self):
        def map_weights_to_conductance(weight, min_weight, max_weight, min_conductance, max_conductance):
            # Linear interpolation from weight to conductance
            normalized_weights = (weight - min_weight) / (max_weight - min_weight)
            mapped_conductance = min_conductance + normalized_weights * (max_conductance - min_conductance)
            mapped_conductance = torch.clamp(mapped_conductance, min_conductance, max_conductance)
            return mapped_conductance
        
        def new_conductance(G0, grad):
            
            # G_pos = G0*1.05
            # G_neg = G0*(0.95)
            # Apply the splines based on gradient
            positive_mask = torch.le(grad, 0)  # Check for grad >= 0
            G_pos = G0 * (self.set_splines['1 us'].ev(G0, 2.5))
            G_neg = G0 * (self.reset_splines['1 us'].ev(G0, -2.5))
            
            # # Combine both positive and negative cases
            G_new = torch.where(positive_mask, G_pos, G_neg)
            return G_new

        def map_conductance_to_weights(conductance, min_conductance, max_conductance, min_weight, max_weight):
            # Reverse mapping from conductance to weight

            normalized_conductance = (conductance - min_conductance) / (max_conductance - min_conductance)
            mapped_weight = min_weight + normalized_conductance * (max_weight - min_weight)
            mapped_weight = torch.clamp(mapped_weight, min_weight, max_weight)
            return mapped_weight

        def map_weights_to_conductance_sine(weight, min_weight, max_weight, min_conductance, max_conductance):
            mid_weight = (max_weight+min_weight)/2
            mid_conductance = (max_conductance + min_conductance)/2
            normalized_weight = (weight - mid_weight)/(max_weight-mid_weight)
            normalized_conductance = 2/torch.pi * torch.arcsin(normalized_weight)
            conductance = normalized_conductance * (max_conductance - mid_conductance) + mid_conductance
            return conductance
        
        def map_conductance_to_weights_sine(conductance, min_weight, max_weight, min_conductance, max_conductance):
            mid_weight = (max_weight+min_weight)/2
            mid_conductance = (max_conductance + min_conductance)/2
            normalized_conductance = (conductance - mid_conductance)/(max_conductance - mid_conductance)
            normalized_weight = torch.sin(normalized_conductance * torch.pi/2)
            weight = normalized_weight*(max_weight - mid_weight) + mid_weight
            return weight



        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                min_weight, max_weight = -1.5, 1.5
                min_conductance, max_conductance = 5e-7, 6e-6
                
                # Get the current weights and gradients
                weight, grad = p.data, p.grad.data

                # 1. Map weights to conductance values
                G0 = map_weights_to_conductance_sine(
                    weight=weight,
                    min_weight=min_weight,
                    max_weight=max_weight,
                    min_conductance=min_conductance,
                    max_conductance=max_conductance
                )

                # 2. Modify conductance using spline function based on the gradient
                G_new = new_conductance(G0=G0, grad=grad)

                # 3. Map the new conductance values back to weights
                new_weight = map_conductance_to_weights_sine(
                    conductance=G_new,
                    min_conductance=min_conductance,
                    max_conductance=max_conductance,
                    min_weight=min_weight,
                    max_weight=max_weight
                )
                
                # Update the weights in the optimizer
                p.data.copy_(new_weight)  # In-place update of the weights


In [67]:
# Define MLP model with no hidden layers
class SimpleMLP(nn.Module):
    def __init__(self, input_size, num_classes):
        super(SimpleMLP, self).__init__()
        self.fc = nn.Linear(input_size, num_classes)
        self.softmax = nn.Softmax(dim=1)
    
    def forward(self, x):
        x = x.view(-1, input_size)  # Flatten the input
        x = self.fc(x)
        x = self.softmax(x)
        return x

In [70]:
# Initialize model, loss, and optimizer
model = SimpleMLP(input_size, num_classes).to(device)
criterion = nn.CrossEntropyLoss()
# optimizer = HardwareAwareOptimizer(model.parameters(), lr=0.01, set_splines=set_splines, reset_splines=reset_splines)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Check total number of parameters
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total number of parameters: {total_params}")

# Training loop
for epoch in range(num_epochs):
    for batch_idx, (data, targets) in enumerate(train_loader):
        data, targets = data.to(device), targets.to(device)
        
        # Forward pass
        outputs = model(data)
        loss = criterion(outputs, targets)
        
        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

Total number of parameters: 1970
Epoch [1/10], Loss: 1.6776
Epoch [2/10], Loss: 1.7287
Epoch [3/10], Loss: 1.7394
Epoch [4/10], Loss: 1.8299
Epoch [5/10], Loss: 1.7140
Epoch [6/10], Loss: 1.7114
Epoch [7/10], Loss: 1.6217
Epoch [8/10], Loss: 1.5759
Epoch [9/10], Loss: 1.6088
Epoch [10/10], Loss: 1.6846


In [71]:

# Test the model
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for data, targets in test_loader:
        data, targets = data.to(device), targets.to(device)
        outputs = model(data)
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()

    print(f"Test Accuracy of the model on the 10,000 test images: {100 * correct / total:.2f}%")

Test Accuracy of the model on the 10,000 test images: 82.97%
