In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset, Subset

import matplotlib.pyplot as plt

import numpy as np

plt.style.use("ggplot")

# Data Preprocessing

In [None]:
from torchvision.datasets import MNIST
import torchvision.transforms as T

In [None]:
data = MNIST("~/Developer/data", train=True, transform=T.ToTensor(), download=True)

In [None]:
data_flat = data.data.view(-1, 28 * 28).float()
data_flat -= data_flat.mean()
data_flat /= data_flat.std()
eigvals, eigvecs = torch.linalg.eigh(data_flat.T @ data_flat)

# plot the largest 5 eigenvectors
fig, axs = plt.subplots(1, 5, figsize=(20, 4))
for i in range(5):
    axs[i].imshow(eigvecs[:, -i-1].view(28, 28))
    axs[i].axis("off")

In [None]:
n_dim = 10
data_compressed = data_flat @ eigvecs[:, -n_dim:]
full_mnist = TensorDataset(data_compressed, data.targets)
dataset_train = Subset(full_mnist, np.arange(512))
dataset_test = Subset(full_mnist, np.arange(512, 1024))
dataloader_train = DataLoader(dataset_train, batch_size=64, shuffle=True)
dataloader_test = DataLoader(dataset_train, batch_size=64, shuffle=True)

# Stable-Baselines 3 Training

In [None]:
from typing import Callable, Optional

import torch.nn as nn

import gym
from gym.spaces import Box

import stable_baselines3 as sb3
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.env_checker import check_env

In [None]:
class NeuralNetworkTrainingEnv(gym.Env):
    """
    A meta-learning environment for training neural networks.
    The observation is given by the state_dict of the model.
    Each parameter is an agent, and the action space is the same as the observation space.
    The environment is episodic, and the episode ends when the model has been trained for a
    specified number of epochs.

    Args:
        init_model: A function that returns a new model.
        dataloader: A dataloader that returns a batch of data for evaluation.
        loss_fn: The loss function to use for training. Should return the _sum_ of the losses across the batch.
    """

    metadata = {
        "name": "meta_learning",
        "render_modes": [],
    }

    def __init__(
        self,
        init_model: Callable[[], nn.Module],
        dataloader: DataLoader,
        loss_fn: nn.Module,
        max_cycles: int = 200,
        action_bounds=1e6
    ):
        self.init_model = init_model
        self.loss_fn = loss_fn
        self.dataloader = dataloader
        self.max_cycles = max_cycles

        # get dummy model to get the possible agents.
        state_dict = init_model().state_dict()
        self.weight_keys = list(state_dict.keys())

        # the action spaces are identical to the state spaces
        self.action_shape = state_dict[self.weight_keys[0]].shape
        assert all([v.shape == self.action_shape for v in state_dict.values()])
        space = Box(-action_bounds, action_bounds, shape=(len(self.weight_keys),) + self.action_shape)
        self.observation_space = space
        self.action_space = space

        del state_dict

    def seed(self, seed):
        torch.manual_seed(seed)
        np.random.seed(seed)
        return [seed]

    def to_state_dict(self, actions: np.ndarray):
        actions = actions.reshape(len(self.weight_keys), *self.action_shape)
        return {k: torch.from_numpy(action) for k, action in zip(self.weight_keys, actions)}

    def get_obs(self):
        state_dict = self.model.state_dict()
        return np.stack([state_dict[k] for k in self.weight_keys])

    def reset(
        self,
        seed: Optional[int] = None,
        return_info: bool = False,
        options: Optional[dict] = None,
    ):
        self.model = self.init_model()
        self.model.eval()
        self.epoch = 0

        return self.get_obs()

    def step(self, actions: np.ndarray):
        actions = self.to_state_dict(actions)
        self.model.load_state_dict(actions)
        test_loss = self.eval()
        done = self.epoch >= self.max_cycles
        self.epoch += 1
        return (
            self.get_obs(),
            -test_loss,
            done,
            {},
        )

    @torch.no_grad()
    def eval(self):
        """
        Calculate total error across the given dataloader.
        """
        test_loss = 0
        for x, y in self.dataloader:
            out = self.model(x)
            loss = self.loss_fn(out, y)
            test_loss += loss.item()
        return test_loss / len(self.dataloader.dataset)

In [None]:
layer = lambda: nn.Linear(10, 10, bias=False)
deep_model = lambda: nn.Sequential(
    layer(),
    nn.ReLU(True),
    layer(),
)
loss_fn = nn.CrossEntropyLoss()
env = NeuralNetworkTrainingEnv(layer, dataloader_train, loss_fn)
env.action_space.shape, env.observation_space.shape

In [None]:
check_env(env)

In [None]:
def train(model, loss_fn, loader, get_acc=False, n_epochs=100):
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    losses = []
    if get_acc:
        accuracies = []
    for _ in range(n_epochs):
        total_loss = 0
        if get_acc:
            total_acc = 0
        for x, y in loader:
            optimizer.zero_grad()
            y_hat = model(x)
            loss_sum = loss_fn(y_hat, y)
            total_loss += loss_sum.item()
            loss = loss_sum / len(y)
            loss.backward()
            optimizer.step()

            if get_acc:
                total_acc += (y_hat.argmax(dim=-1) == y).sum().item()
            
        losses.append(total_loss / len(loader.dataset))
        if get_acc:
            accuracies.append(total_acc / len(loader.dataset))
    
    if get_acc:
        return losses, accuracies
    return losses

In [None]:
def train_meta(model, policy, loss_fn, loader, n_epochs=100):
    env = NeuralNetworkTrainingEnv(lambda: model, loader, loss_fn, max_cycles=n_epochs)
    obs = env.reset()
    losses = []
    for _ in range(n_epochs):
        action, _ = policy.predict(obs, deterministic=True)
        obs, reward, done, _ = env.step(action)
        losses.append(-reward)
    return losses, env.model

In [None]:
policy = sb3.PPO("MlpPolicy", env, verbose=1, gamma=0.7)

In [None]:
policy.learn(total_timesteps=1024, progress_bar=True, tb_log_name="ppo_mnist")

In [None]:
policy.save("saved_models/ppo_mnist")

In [None]:
# losses_train = train(layer(), loss_fn, loader_train)
losses_meta_train, train_model = train_meta(layer(), policy, loss_fn, dataloader_train)
losses_meta_test, test_model = train_meta(layer(), policy, loss_fn, dataloader_test)

In [None]:
X, y = dataset_train[:]
nn.CrossEntropyLoss(reduction="sum")(train_model(X), y) / 512

In [None]:
plt.imshow((eigvecs[:, -n_dim:] @ X).reshape(28, 28), cmap="gray")
plt.axis("off")
plt.show()

In [None]:
# plt.scatter(np.arange(100), losses, label="Gradient Descent Learning")
plt.scatter(np.arange(100), losses_meta_train, label="Meta-RL Learning (Train)")
plt.scatter(np.arange(100), losses_meta_test, label="Meta-RL Learning (Test)")
plt.title("Training losses using Meta-RL-Optimization on MNIST \n after 2048 steps of meta training")
plt.xlabel("Epoch")
plt.ylabel("Cross-entropy loss with target labels")
plt.legend()
plt.savefig("losses-mnist.png")
plt.show()

In [None]:
evaluate_policy(policy, env, n_eval_episodes=1, render=False, deterministic=True)

# Student-Teacher Loss

In [None]:
@torch.no_grad()
def gen_dataset():
    N = 512
    M = 100
    D = 10
    X = torch.randn(N, D)
    grounded = nn.Sequential(
        nn.Linear(D, M),
        nn.ReLU(True),
        nn.Linear(M, D),
        nn.ReLU(True),
    ).eval()
    for p in grounded.parameters():
        p.normal_(0, 1)
    y = grounded(X)

    return TensorDataset(X, y)

dataloader_teacher = DataLoader(gen_dataset(), batch_size=64, shuffle=True)
testloader_teacher = DataLoader(gen_dataset(), batch_size=64, shuffle=True)

In [None]:
mse_loss = nn.MSELoss(reduction="sum")
env = NeuralNetworkTrainingEnv(deep_model, dataloader_teacher, mse_loss)

In [None]:
check_env(env)

In [None]:
ppo_teacher = sb3.PPO("MlpPolicy", env, verbose=1, gamma=0.7)

In [None]:
ppo_teacher.learn(total_timesteps=1024, progress_bar=True, tb_log_name="ppo_teacher")

In [None]:
# losses = train(layer(), mse_loss, dataloader_teacher)
losses_meta_train, train_model = train_meta(layer(), policy, mse_loss, dataloader_teacher)
losses_meta_test, test_model = train_meta(layer(), policy, mse_loss, testloader_teacher)

In [None]:
# plt.scatter(np.arange(100), losses, label="Gradient Descent Learning")
plt.figure(figsize=(6, 6))
plt.scatter(np.arange(100), np.log(losses_meta_train), label="Loss on training dataset")
plt.scatter(np.arange(100), np.log(losses_meta_test), label="Loss on test dataset")
plt.title("Mean squared error using Meta-RL on teacher-student model\nafter 2048 timesteps of meta learning")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.savefig("losses-teacher-student.png")
plt.show()