In [18]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
from torch.optim import SGD

import numpy as np
import matplotlib.pyplot as plt

In [19]:
input_size = 784
hidden_size = 2048
num_classes = 10
num_epochs = 50
batch_size = 50
learning_rate = 0.01

M_train = 5000
M_test = 1000

In [20]:
# MNIST dataset 
train_dataset = torchvision.datasets.MNIST(root='data', train=True, transform=transforms.ToTensor(), download=True)

test_dataset = torchvision.datasets.MNIST(root='data', train=False, transform=transforms.ToTensor())

reduced_train_index=list(range(M_train))
reduced_test_index=list(range(M_test))

subset_train_dataset=Subset(train_dataset, reduced_train_index)
subset_test_dataset=Subset(test_dataset, reduced_test_index)

# Data loader
train_loader = DataLoader(dataset=subset_train_dataset, batch_size=batch_size, shuffle=True)

test_loader = DataLoader(dataset=subset_test_dataset, batch_size=batch_size, shuffle=False) 

hessian_loader = DataLoader(dataset=subset_train_dataset, batch_size=1, shuffle=False)

In [21]:
# Fully connected neural network
class NeuralNet(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(NeuralNet, self).__init__()
        self.input_layer = nn.Linear(input_size, hidden_size) 
        self.output_layer = nn.Linear(hidden_size, num_classes)  
    
    def forward(self, x):
        x = self.input_layer(x)
        x = x.relu()
        x = self.output_layer(x)
        return x
    
def ReLU_glorot_init(model):
    for name, param in model.named_parameters():
        
        if name.endswith(".bias"):
            param.data.fill_(0)
        else:
            nn.init.xavier_normal_(param)

In [22]:
model = NeuralNet(input_size, hidden_size, num_classes)
ReLU_glorot_init(model)

criterion = nn.CrossEntropyLoss()
optimizer = SGD(model.parameters(), lr=learning_rate)

def train_epoch():
    for images, labels in train_loader:
        images = images.reshape(-1, 28*28)
        outputs = model(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    return loss.item()

for epoch in range(num_epochs):
    trainloss = train_epoch()
    print('Epoch: {} TrainLoss: {:.3f}'.format(epoch+1,trainloss))

Epoch: 1 TrainLoss: 1.379
Epoch: 2 TrainLoss: 0.879
Epoch: 3 TrainLoss: 0.855
Epoch: 4 TrainLoss: 0.624
Epoch: 5 TrainLoss: 0.389
Epoch: 6 TrainLoss: 0.520
Epoch: 7 TrainLoss: 0.425
Epoch: 8 TrainLoss: 0.492
Epoch: 9 TrainLoss: 0.448
Epoch: 10 TrainLoss: 0.555
Epoch: 11 TrainLoss: 0.385
Epoch: 12 TrainLoss: 0.343
Epoch: 13 TrainLoss: 0.257
Epoch: 14 TrainLoss: 0.372
Epoch: 15 TrainLoss: 0.394
Epoch: 16 TrainLoss: 0.392
Epoch: 17 TrainLoss: 0.323
Epoch: 18 TrainLoss: 0.314
Epoch: 19 TrainLoss: 0.181
Epoch: 20 TrainLoss: 0.228
Epoch: 21 TrainLoss: 0.263
Epoch: 22 TrainLoss: 0.186
Epoch: 23 TrainLoss: 0.280
Epoch: 24 TrainLoss: 0.367
Epoch: 25 TrainLoss: 0.224
Epoch: 26 TrainLoss: 0.265
Epoch: 27 TrainLoss: 0.196
Epoch: 28 TrainLoss: 0.347
Epoch: 29 TrainLoss: 0.375
Epoch: 30 TrainLoss: 0.139
Epoch: 31 TrainLoss: 0.248
Epoch: 32 TrainLoss: 0.351
Epoch: 33 TrainLoss: 0.169
Epoch: 34 TrainLoss: 0.115
Epoch: 35 TrainLoss: 0.256
Epoch: 36 TrainLoss: 0.156
Epoch: 37 TrainLoss: 0.176
Epoch: 38 

In [41]:
def network_layer_weight_extraction(model):
    network_weights=[]
    for para in model.named_parameters():
        if 'weight' in para[0]:
            network_weights.append(para[1])
    return network_weights

def network_layer_bias_extraction(model):
    network_weights=[]
    for para in model.named_parameters():
        if 'bias' in para[0]:
            network_weights.append(para[1])
    return network_weights

In [57]:
weights=network_layer_weight_extraction(model)
bias=network_layer_bias_extraction(model)