In [4]:
import torch
import torchvision
from PyHessian.pyhessian.hessian import hessian
from tqdm import tqdm

import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# Load CIFAR-10 dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True, num_workers=2)

# Define a simple neural network model
class SimpleCNN(nn.Module):
	def __init__(self):
		super(SimpleCNN, self).__init__()
		self.conv1 = nn.Conv2d(3, 6, 5)
		self.pool = nn.MaxPool2d(2, 2)
		self.conv2 = nn.Conv2d(6, 16, 5)
		self.fc1 = nn.Linear(16 * 5 * 5, 120)
		self.fc2 = nn.Linear(120, 84)
		self.fc3 = nn.Linear(84, 10)
		self.act = nn.ReLU()

	def forward(self, x):
		x = self.pool(self.act(self.conv1(x)))
		x = self.pool(self.act(self.conv2(x)))
		x = x.view(-1, 16 * 5 * 5)
		x = self.act(self.fc1(x))
		x = self.act(self.fc2(x))
		x = self.fc3(x)
		return x

# Initialize the model, loss function, and optimizers
# Function to train the model
criterion = nn.CrossEntropyLoss()

def train_model(model, optimizer, regularize=False):
	model.train()
	losses = []
	for epoch in range(2):  # loop over the dataset multiple times
		running_loss = 0.0
		for i, data in enumerate(tqdm(trainloader, desc=f"Epoch {epoch+1}"), 0):
			inputs, labels = data
			optimizer.zero_grad()
			outputs = model(inputs)
			loss = criterion(outputs, labels)
			
			if regularize:
				hessian_comp = hessian(model, criterion, data=(inputs, labels), cuda=False)
				top_eigenvalue, _ = hessian_comp.eigenvalues()
				loss += torch.norm(top_eigenvalue)
			
			loss.backward()
			optimizer.step()
			running_loss += loss.item()
			if i % 100 == 99:    # log every 100 mini-batches
				tqdm.write(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 100:.3f}')
				losses.append(running_loss / 100)
				running_loss = 0.0
	return losses

# Train and plot the results
# Initialize a dummy model to copy weights from
dummy_model = SimpleCNN()

# Function to copy weights from dummy model to the target model
def copy_weights(target_model, source_model):
	target_model.load_state_dict(source_model.state_dict())

# Initialize models and optimizers, and train with and without regularization
# net = SimpleCNN()
# copy_weights(net, dummy_model)
# optimizer_adam = optim.Adam(net.parameters(), lr=0.001)
# losses_adam_no_reg = train_model(net, optimizer_adam, regularize=False)

# net = SimpleCNN()
# copy_weights(net, dummy_model)
# optimizer_sgd = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# losses_sgd_no_reg = train_model(net, optimizer_sgd, regularize=False)

net = SimpleCNN()
copy_weights(net, dummy_model)
optimizer_adam = optim.Adam(net.parameters(), lr=0.001)
losses_adam_reg = train_model(net, optimizer_adam, regularize=True)

net = SimpleCNN()
copy_weights(net, dummy_model)
optimizer_sgd = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
losses_sgd_reg = train_model(net, optimizer_sgd, regularize=True)

# Plotting code
plt.figure()
# plt.plot(losses_adam_no_reg, label='Adam without Regularization')
# plt.plot(losses_sgd_no_reg, label='SGD without Regularization')
plt.plot(losses_adam_reg, label='Adam with Regularization')
plt.plot(losses_sgd_reg, label='SGD with Regularization')
plt.title('Training Loss with and without Regularization')
plt.xlabel('Iteration (per 100 mini-batches)')
plt.ylabel('Loss')
plt.legend()
plt.show()

Files already downloaded and verified


Epoch 1:   0%|          | 0/500 [00:07<?, ?it/s]


KeyboardInterrupt: 