In [88]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import time

In [89]:
transform = transforms.Compose([
	transforms.ToTensor(),
])

In [90]:
def get_data(y_min, y_max):
	train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
	test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

	X_train, y_train = next(iter(DataLoader(train_dataset, batch_size=len(train_dataset))))
	X_test, y_test = next(iter(DataLoader(test_dataset, batch_size=len(test_dataset))))

	X_train = X_train.view(-1, 28*28)
	X_test = X_test.view(-1, 28*28)
	# add a 1 to the end of each sample for the bias term
	X_train = torch.cat((X_train, torch.ones(X_train.shape[0], 1)), dim=1)
	X_test = torch.cat((X_test, torch.ones(X_test.shape[0], 1)), dim=1)

	y_train[y_train < 5] = y_min
	y_train[y_train >= 5] = y_max
	y_test[y_test < 5] = y_min
	y_test[y_test >= 5] = y_max

	return X_train, y_train, X_test, y_test


In [91]:
def logistic_loss(y, x, theta):
	y_pred = x @ theta
	y_pred = y_pred.squeeze()
	loss = torch.log(1 + torch.exp(-y * y_pred)).mean()
	return loss

In [92]:
def non_linear_least_squares_with_sigmoid_loss(y, x, theta):
	y_pred = x @ theta
	y_pred = y_pred.squeeze()
	loss = (y - 1 / (1 + torch.exp(-y * y_pred))).mean()
	return loss

In [93]:
@torch.no_grad()
def line_search(f, x, w, y, loss, s, beta=0.5, c=1e-4, max_iter=16):
	alpha = 1
	n_iter = 0
	loss = loss.item()
	while f(y, x, w + alpha * s) >= loss + c * alpha * -(s.T @ s):
		alpha *= beta
		if n_iter >= max_iter:
			break
		n_iter += 1
	return alpha

In [94]:
@torch.no_grad()
def accuracy_func(y, x, theta):
	y_pred = x @ theta
	y_pred = torch.sign(y_pred).squeeze()
	return (y_pred == y).float().mean()

In [111]:
def two_level(X, y, loss_func, p, sample_size, n_epochs, verbose=False):
	theta = torch.randn(X.shape[1], 1, dtype=torch.float32, requires_grad=True)
	X_h = X
	y_h = y
	Losses_h = []
	Losses_H = []
	accuracies = []
	delta_fine_iter = []
	delta_coarse_iter = []
	time_start = time.time()
	for epoch in range(n_epochs):
		theta.grad = None
		loss_h = loss_func(y_h, X_h, theta)

		Losses_h.append(loss_h.item())

		loss_h.backward()
		s = -theta.grad
		alpha = line_search(loss_func, X_h, theta, y_h, loss_h, s)
		with torch.no_grad():
			theta += alpha * s

		loss_after_fine = loss_func(y_h, X_h, theta)
		delta_fine_iter.append(loss_after_fine - loss_h)

		idx = np.random.choice(X_h.shape[0], int(X_h.shape[0] * sample_size), replace=False)
		X_H = X_h[idx]
		y_H = y_h[idx]
		theta_H = theta.clone().detach().requires_grad_(True)
		loss_H = loss_func(y_H, X_H, theta_H)
		theta_H.grad = None
		loss_H.backward()
		s_H = -theta_H.grad
		V_H = -s - s_H

		Losses_H.append([])

		for _ in range(p):
			loss_H = loss_func(y_H, X_H, theta_H) + V_H.T @ theta_H

			Losses_H[-1].append(loss_H.item())

			theta_H.grad = None
			loss_H.backward()
			s_H = -theta_H.grad
			alpha = line_search(lambda y, x, w: loss_func(y, x, w) + V_H.T @ w, X_H, theta_H, y_H, loss_H, s_H)
			with torch.no_grad():
				theta_H += alpha * s_H
		s_h = theta_H - theta
		alpha = line_search(loss_func, X_h, theta, y_h, loss_after_fine, s_h)
		with torch.no_grad():
			theta += alpha * s_h
			accuracy = accuracy_func(y_h, X_h, theta)
			accuracies.append(accuracy.item())
		loss_after_coarse = loss_func(y_h, X_h, theta)
		delta_coarse_iter.append(loss_after_coarse - loss_after_fine)

		if verbose and epoch % 10 == 0:
			print(f'Epoch {epoch}/{n_epochs}, loss {loss_after_coarse.item()}, accuracy {accuracy.item()}')

	result = {
		'Losses_h': Losses_h,
		'Losses_H': Losses_H,
		'accuracies': accuracies,
		'delta_fine_iter': delta_fine_iter,
		'delta_coarse_iter': delta_coarse_iter,
		'time': time.time() - time_start
	}
	return theta, result

In [None]:
ps = [1, 5, 10]
sample_sizes = [0.1, 0.3, 0.5]
losses = [logistic_loss]

# ps = [5, 10, 20]
# sample_sizes = [0.01, 0.05, 0.1]

results = {}

for i, loss in enumerate(losses):
	X_train, y_train, X_test, y_test = get_data(-1, 1)
	for p in ps:
		for sample_size in sample_sizes:
			theta, result = two_level(X_train, y_train, loss, p, sample_size, 100)
			train_accuracy = accuracy_func(y_train, X_train, theta).item()
			test_accuracy = accuracy_func(y_test, X_test, theta).item()
			results[f"{p}_{sample_size}"] = result
			print(f"p={p}, sample_size={sample_size}, train_accuracy={train_accuracy}, test_accuracy={test_accuracy}")

In [None]:
fig = plt.figure(figsize=(10, 10))
for i, p in enumerate(ps):
	for j, sample_size in enumerate(sample_sizes):
		ax = plt.subplot(len(ps), len(sample_sizes), i * len(sample_sizes) + j + 1)
		losses_h = results[f"{p}_{sample_size}"]['Losses_h']
		accuracies = results[f"{p}_{sample_size}"]['accuracies']
		ax.plot(losses_h, label="Loss_h", color='blue')
		ax.set_title(f"p={p}, sample_size={sample_size}")
		ax.set_xlabel("Epoch")
		ax.set_ylabel("Loss")
		ax.legend()
		ax2 = ax.twinx()
		ax2.plot(accuracies, color='red', label="Accuracy")
		ax2.set_ylabel("Accuracy")
		ax2.set_ylim(0, 1)
		ax2.legend()

fig.tight_layout()
plt.show()