In [1]:
import torch
import torchvision
import torchvision.transforms as transforms



In [2]:
transform = transforms.Compose(
    [torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,))])

batch_size = 1000

trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.MNIST(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('0', '1', '2', '3',
           '4', '5', '6', '7', '8', '9')

In [3]:
from mnist import *
from viz import *

import torch
from torch.nn import Sequential, Module, CrossEntropyLoss
from torch.nn.functional import normalize
import numpy as np
from neurophoxTorch.torch import RMTorch
from scipy.stats import unitary_group
from tqdm import tqdm_notebook as pbar
from tqdm import tqdm

import matplotlib.pyplot as plt
%matplotlib inline
import warnings
def rc_mul(real: torch.Tensor, comp: torch.Tensor):
    return real.unsqueeze(dim=0) * comp


def cc_mul(comp1: torch.Tensor, comp2: torch.Tensor) -> torch.Tensor:
    real = comp1[0] * comp2[0] - comp1[1] * comp2[1]
    comp = comp1[0] * comp2[1] + comp1[1] * comp2[0]
    return torch.stack((real, comp), dim=0)

def phasor(real: torch.Tensor):
    return torch.stack((real.cos(), real.sin()), dim=0)


def cnorm(comp: torch.Tensor):
    return (comp[0] ** 2 + comp[1] ** 2).sqrt()


def cnormsq(comp: torch.Tensor):
    return comp[0] ** 2 + comp[1] ** 2


def to_complex_t(nparray: np.ndarray):
    return torch.stack((torch.as_tensor(nparray.real),
                        torch.as_tensor(nparray.imag)), dim=0)


class ElectroopticNonlinearity(Module):
    def __init__(self, alpha: float=0.1, g: float=0.05 * np.pi, phi_b: float=np.pi):
        super(ElectroopticNonlinearity, self).__init__()
        self.alpha = alpha
        self.g = g
        self.phi_b = phi_b

    def forward(self, inputs):
        phase = 0.5 * self.g * cnormsq(inputs) + 0.5 * self.phi_b
        return np.sqrt(1 - self.alpha) * cc_mul(rc_mul(phase.cos(), phasor(-phase)), inputs)


class CNormSq(Module):
    def __init__(self, normed=True):
        super(CNormSq, self).__init__()
        self.normed = normed

    def forward(self, inputs):
        return normalize(cnormsq(inputs), dim=1) if self.normed else cnormsq(inputs)

  import scipy as sp


In [4]:
import torch.nn as nn
import torch.nn.functional as F

class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5, padding=2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 64)
        self.norm1 = nn.BatchNorm2d(6)
        self.norm2 = nn.BatchNorm2d(16)
        self.norm3 = nn.BatchNorm1d(120)
        self.layer1 = RMTorch(64,phase_error = 0.0, phase_error_files = None,bs_error=0.0,bs_error_files = None)
        self.output = CNormSq()

    def forward(self,x):

        x = F.max_pool2d(F.relu(self.norm1(self.conv1(x))), (2, 2))
        x = F.max_pool2d(F.relu(self.norm2(self.conv2(x))), (2, 2))

        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.norm3(x)
        x = self.fc2(x)
        self.feature3 = x.detach()
        x = torch.stack((x, torch.zeros(x.shape, dtype=torch.float32, device=x.device)), dim=0)
        x,_,_,_ = self.layer1(x)
        self.feature4 = x.detach()
        x = self.output(x)
        self.feature5 = x.detach()

        return x

    def apply_constraints(self):
        self.conv1.cuda().weight.data = torch.clamp(self.conv1.cuda().weight.data, -1, 1)
        self.conv2.cuda().weight.data = torch.clamp(self.conv2.cuda().weight.data, -1, 1)
        self.fc1.cuda().weight.data = torch.clamp(self.fc1.cuda().weight.data, -1, 1)
        self.fc2.cuda().weight.data = torch.clamp(self.fc2.cuda().weight.data, -1, 1)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Assuming that we are on a CUDA machine, this should print a CUDA device:

print(device)
net = LeNet5()
net.to(device)

cuda:0


LeNet5(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=64, bias=True)
  (norm1): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (norm2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (norm3): BatchNorm1d(120, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): RMTorch()
  (output): CNormSq()
)

In [15]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from pyhessian import hessian # Hessian computation
# This is a simple function, that will allow us to perturb the model parameters and get the result
def get_params(model_orig, direction, alpha):
    model_perb = LeNet5()  # Create a new model
    state_dict = torch.load(PATH)
    state_dict = {k: v for k, v in state_dict.items() if 'cn1' not in k}
    state_dict = {k: v for k, v in state_dict.items() if 'cn2' not in k}
    state_dict = {k: v for k, v in state_dict.items() if 'cn3' not in k}
    state_dict = {k: v for k, v in state_dict.items() if 'cn4' not in k}    
    state_dict = {k: v for k, v in state_dict.items() if 'cn5' not in k}
    state_dict = {k: v for k, v in state_dict.items() if 'cn6' not in k}   
    model_perb.load_state_dict(state_dict)  # Copy original model parameters
    model_perb = model_perb.cuda()  # Move new model to CUDA
    for m_orig, m_perb, d in zip(model_orig.parameters(), model_perb.parameters(), direction):
        m_perb.data = m_orig.data + alpha * d
    return model_perb

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load models
# PATH = './Training Results/01-Standard BP.pth'
# PATH = './Training Results/02-PAT.pth'
# PATH = './Training Results/03-SAT-In silico.pth'
# PATH = './Training Results/04-SAT-In situ.pth'
# PATH = './Training Results/05-DAT.pth'
PATH = './Training Results/06-NAT.pth'

net = LeNet5()
state_dict = torch.load(PATH)
state_dict = {k: v for k, v in state_dict.items() if 'cn1' not in k}
state_dict = {k: v for k, v in state_dict.items() if 'cn2' not in k}
state_dict = {k: v for k, v in state_dict.items() if 'cn3' not in k}
state_dict = {k: v for k, v in state_dict.items() if 'cn4' not in k}    
state_dict = {k: v for k, v in state_dict.items() if 'cn5' not in k}
state_dict = {k: v for k, v in state_dict.items() if 'cn6' not in k}
net.load_state_dict(state_dict)

<All keys matched successfully>

In [16]:
# Define perturbation range
lams_wo = np.linspace(-0.3, 0.3, 201).astype(np.float32)
loss_list_wo = []

# Set model to evaluation mode
net.eval()

# Define loss function
criterion = torch.nn.CrossEntropyLoss()

# Select a fixed batch of data
n = 0
dataiter = iter(testloader)
for _ in range(n - 1):
    next(dataiter)  # Skip the first n-1 batches
n_batch = next(dataiter)  # Get the n-th batch
inputs, targets = n_batch[0].to(device), n_batch[1].to(device)

# Ensure data is on the correct device
inputs, targets = inputs.cuda(), targets.cuda()

# Create the Hessian computation module
net.to(device)
hessian_comp = hessian(net, criterion, data=(inputs, targets), cuda=True)

# Compute the top eigenvalue and eigenvector
top_eigenvalues, top_eigenvector = hessian_comp.eigenvalues()
print("The top Hessian eigenvalue of this model is %.4f" % top_eigenvalues[-1])
trace = hessian_comp.trace()
print("The trace of this model is: %.4f" % (np.mean(trace)))


The top Hessian eigenvalue of this model is 96.7924
The trace of this model is: 1045.5142
