## [WIP] Epsilon-Beta Visualizations
This notebook aims to visualize $\hat{\lambda}_n^\beta$ for various values of $\beta$ (inverse temperature) and $\epsilon$ (step size). 

Adrian's Thesis roughly states that:
- $\beta$ can be tuned via graphing $\hat{\lambda}_n^\beta$ for a sweep of $\beta$, and using $\beta$ in a range around the critical points on the graph.
- $\epsilon$ should be the greatest possible value that doesn't cause excessive numerical instability or cause the SGLD chains to fail to converge. An MALA proposal acceptance rate (see `sgld_calibration.ipynb`) between 0.9 - 0.95 is roughly optimal.

## Set-up

In [None]:
%pip install devinterp transformers torchvision

The epsilon-beta sweep analyzer is fairly flexible. To sweep, all you need is:
- A callable function (typically a built-in DevInterp function) that returns local learning coefficient traces.
- Epsilon and beta ranges.

To start, we'll visualize an epsilon-beta LLC sweep for a pretrained MNIST classifier.

## Sweep LLC given a model

In [1]:
import torch
import torchvision
from transformers import AutoModelForImageClassification

from devinterp.slt import estimate_learning_coeff_with_summary
from devinterp.optim import SGLD

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Define a loss function
def transformers_cross_entropy(inputs, outputs):
    return torch.nn.functional.cross_entropy(
        inputs.logits, outputs
    )  # transformers doesn't output a vector

# Load a pretrained MNIST classifier
model = AutoModelForImageClassification.from_pretrained("fxmarty/resnet-tiny-mnist")
data = torchvision.datasets.MNIST(
    root="../data",
    download=True,
    transform=torchvision.transforms.Compose(
        [
            torchvision.transforms.ToTensor(),
        ]
    ),
)
loader = torch.utils.data.DataLoader(data, batch_size=256, shuffle=True)

In [2]:
from devinterp.slt.mala import MalaAcceptanceRate
from tqdm import tqdm, trange
import typing
from typing import Type, Union, List, Any, Optional, Callable
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from collections.abc import Sequence
from devinterp.utils import optimal_temperature
import warnings

from pydantic import BaseModel, Field

# Sampling config validates input parameters while allowing us to use **kwargs later on
class SweepConfig(BaseModel):
    epsilon_range: List[float]
    beta_range: List[float]
    llc_estimator: Callable
    llc_estimator_kwargs: dict
    
    # Pydantic-recognized field for custom settings
    class Config:
        arbitrary_types_allowed = True # Allows Pydantic to accept pytorch models

    # Build epsilon_range and beta_range given different user input formats for beta and epsilon ranges
    @classmethod 
    def setup(cls, llc_estimator, llc_estimator_kwargs, 
              min_beta, max_beta, beta_samples, beta_range, min_epsilon, max_epsilon, epsilon_samples, epsilon_range,
              dataloader = None):
        if epsilon_range is not None:
            assert isinstance(epsilon_range, Sequence), "epsilon_range must be a list-like object (e.g list or numpy array)"
            if min_epsilon is not None or max_epsilon is not None:
                warnings.warn("min_epsilon and max_epsilon will be ignored as epsilon_range is provided")
        else:
            epsilon_range = np.power(10, np.linspace(np.log10(min_epsilon), np.log10(max_epsilon), epsilon_samples))

        if beta_range is not None:
            assert isinstance(beta_range, Sequence), "beta_range must be a list-like object (e.g list or numpy array)"
            if min_beta is not None or max_beta is not None:
                warnings.warn("min_beta and max_beta will be ignored as beta_range is provided")
        else:
            if dataloader is not None:
                # Calculate default beta (inverse temperature) range.
                optimal_beta = optimal_temperature(dataloader)
                if min_beta is None:
                    min_beta = 1e-2 * optimal_beta
                if max_beta is None:
                    max_beta = 1e3 * optimal_beta
            else:
                if min_beta is None or max_beta is None:
                    raise ValueError("min_beta and max_beta must be provided if dataloader is not provided.")
            beta_range = np.power(10, np.linspace(np.log10(min_beta), np.log10(max_beta), beta_samples))

        
        assert min(beta_range) > 0, "All beta values must be greater than 0"
        assert min(epsilon_range) > 0, "All epsilon values must be greater than 0"
        if max(epsilon_range) > 1e-2:
            warnings.warn("Epsilon values greater than 1e-2 typically lead to instability in the sampling process. Consider reducing epsilon to between 1e-6 and 1e-2.")
        
        return cls(epsilon_range=epsilon_range, 
                   beta_range=beta_range, 
                   llc_estimator=llc_estimator,
                   llc_estimator_kwargs = llc_estimator_kwargs)

class EpsilonBetaAnalyzer:
    def __init__(self):
        self.sweep_config = None
        self.plotting_config = None
        self.sweep_df = None
        self.fig = None

    def configure_sweep(self,
                        llc_estimator: Callable,
                        llc_estimator_kwargs: dict,
                        min_epsilon: Optional[float] = 1e-6, 
                        max_epsilon: Optional[float] = 1e-2, 
                        epsilon_samples: float = 8, 
                        epsilon_range: Optional[List[float]] = None,
                        min_beta: Optional[float] = None, 
                        max_beta: Optional[float] = None, 
                        beta_samples: float = 8,
                        beta_range: Optional[List[float]] = None,
                        dataloader: Optional[torch.utils.data.DataLoader] = None) -> None:
        """
        Configure the sampling parameters for the LLC analysis.
        """

        self.sweep_config = SweepConfig.setup(llc_estimator, llc_estimator_kwargs,
                                              min_beta, max_beta, beta_samples, beta_range, min_epsilon, max_epsilon, epsilon_samples, epsilon_range)

    def sweep(self, add_to_existing = False) -> None:
        """
        Sweeps the local learning coefficient using the given llc_estimator function and its associated arguments.
        Results are stored in self.sweep_df.
        """
        assert self.sweep_config is not None, "Sweep configuration is not set. Please call configure_sweep() first."

        epsilon_range = self.sweep_config.epsilon_range
        beta_range = self.sweep_config.beta_range
        llc_estimator = self.sweep_config.llc_estimator
        llc_estimator_kwargs = self.sweep_config.llc_estimator_kwargs

        if "device" in llc_estimator_kwargs:
            if torch.cuda.is_available() and (llc_estimator_kwargs["device"] == "cpu" or torch.device(llc_estimator_kwargs["device"]).type == "cpu"):
                warnings.warn("CUDA is available but not being used. Consider setting device='cuda' for faster computation.")

        all_sweep_stats = []
        with tqdm(total=len(epsilon_range) * len(beta_range)) as pbar:
            for epsilon in epsilon_range:
                for beta in beta_range:
                    try:
                        sweep_stats = llc_estimator(
                            epsilon = epsilon,
                            beta = beta,
                            **llc_estimator_kwargs
                        )
                        sweep_stats = dict(sweep_stats, epsilon=epsilon, beta=beta)
                        all_sweep_stats.append(sweep_stats)
                    except Exception as e:
                        warnings.warn(f"Error encountered for epsilon={epsilon}, beta={beta}. Skipping. Warning: {e}")
                    pbar.update(1)

        sweep_df = pd.DataFrame(all_sweep_stats)
        # If there's only one sweep, there'll only be one trace, so we need to add an extra dimension
        sweep_df["llc/trace"] = sweep_df["llc/trace"].apply(lambda x: x if len(x.shape) == 2 else x[np.newaxis, :])
        if add_to_existing:
            if self.sweep_df is not None:
                self.sweep_df = pd.concat([self.sweep_df, sweep_df], ignore_index=True)
            else:
                self.sweep_df = sweep_df
        else:
            self.sweep_df = sweep_df

    def plot(self, true_lambda: Optional[float] = None,
        num_last_steps_to_average: int = 50,
        color: Optional[str] = None,
        slider: Optional[str] = None,
        slider_plane: Optional[str] = False,
        **kwargs) -> go.Figure:
        """
        Plots the local learning coefficient sweep.
        """
        plot_config = {
                "title": "Local learning coefficient vs. epsilon and beta",
                "z": "llc/final",
                "log_y": True,
                "log_x": True,
                "log_z": True
                }
    
        assert self.sweep_df is not None, "No data to plot. Please call sample() first."

        sweep_df = self.sweep_df.copy()
        # Calculate additional statistics
        sweep_df["llc/std_over_mean"] = sweep_df["llc/trace"].apply(lambda x: x[:, -num_last_steps_to_average:].std() / x[:, -num_last_steps_to_average:].mean())
        sweep_df["llc/final"] = sweep_df["llc/trace"].apply(lambda x: x[:, -num_last_steps_to_average].mean())
        
        if true_lambda is not None:
            sweep_df["true_lambda"] = sweep_df["llc/trace"].apply(lambda x: true_lambda)
            sweep_df["log_true_lambda"] = sweep_df["true_lambda"].apply(lambda x: np.log10(np.abs(x)))
            sweep_df["log_lambda_hat"] = sweep_df["llc/final"].apply(lambda x: np.log10(np.abs(x)))
            sweep_df["log_delta_lambda"] = sweep_df["log_lambda_hat"] - sweep_df["log_true_lambda"]           
            sweep_df["lambda_delta"] = sweep_df["llc/final"] - sweep_df["true_lambda"]
            sweep_df["log_lambda_delta"] = sweep_df["lambda_delta"].apply(lambda x: np.log10(np.abs(x)) * np.sign(x))
            if color is None:
                color = "log_lambda_delta" # default color when true_lambda is provided
                plot_config["range_color"] = [-4, 4]

        if color is None:
            color = "llc/std_over_mean" # default color when true_lambda is not provided
        
        if color == "llc/std_over_mean":
            plot_config["range_color"] = [0, 0.15]

        # Add any additional kwargs to the plot_config
        plot_config.update(kwargs)

        if slider is None:
            fig = px.scatter_3d(sweep_df, x="epsilon", y="beta", color=color, **plot_config)

            if true_lambda is not None:
                # Grid (to easily plot horizontal planes)
                epsilon_range = self.sweep_config.epsilon_range
                beta_range = self.sweep_config.beta_range
                broadcast_grid = np.ones((len(epsilon_range), len(beta_range)))
                # Place a horizontal plane at height true_lambda
                plane = go.Surface(
                    x=epsilon_range,
                    y=beta_range,
                    z=true_lambda * broadcast_grid,
                    opacity=0.4,
                    surfacecolor=true_lambda * broadcast_grid,
                    showscale=False,
                    name=f"RLCT={true_lambda}",)
                fig.add_trace(plane)

        else:
            # Create base figure.
            fig = None

            # Determine fixed ranges for axes and colorbar
            x_range = [sweep_df['epsilon'].min() / 2, sweep_df['epsilon'].max() / 1.5] # Add some margin
            y_range = [sweep_df['beta'].min(), sweep_df['beta'].max()]
            z_range = [max(1e-2, sweep_df['llc/final'].min()), sweep_df['llc/final'].max()]
            color_range = [sweep_df[color].min(), sweep_df[color].max()]
            plot_config["range_color"] = color_range
            plot_config["range_z"] = z_range
            plot_config["range_x"] = x_range
            plot_config["range_y"] = y_range

            unique_slider_vals = sweep_df[slider].unique()
            # Add traces for each unique slider value
            for slider_val in unique_slider_vals:
                df_filtered = sweep_df[sweep_df[slider] == slider_val]
                plot_ = px.scatter_3d(df_filtered, x="epsilon", y="beta", color=color, **plot_config)
                if fig is None:
                    fig = plot_
                else: 
                    trace = plot_.data[0]
                    fig.add_trace(trace)

                # Grid (to easily plot horizontal planes)
                epsilon_range = self.sweep_config.epsilon_range
                beta_range = self.sweep_config.beta_range
                broadcast_grid = np.ones((len(epsilon_range), len(beta_range)))

                if slider_plane:
                    # Place a horizontal plane with height = slider_val
                    plane = go.Surface(
                        x=epsilon_range,
                        y=beta_range,
                        z=slider_val * broadcast_grid,
                        opacity=0.4,
                        surfacecolor=slider_val * broadcast_grid,
                        showscale=False,
                        name=f"{slider}={slider_val}",)
                    fig.add_trace(plane)

            # Slider
            steps = []
            for i, slider_val in enumerate(unique_slider_vals):
                step = dict(
                    method="update",
                    args=[{"visible": [False] * len(fig.data)},
                        {"title": f"Local learning coefficient vs. epsilon and beta ({slider} = {slider_val})"}],
                    label=str(slider_val)
                )

                if slider_plane:
                    step["args"][0]["visible"][2*i] = True  # Toggle i'th scatter trace to "visible"
                    step["args"][0]["visible"][2*i+1] = True  # Toggle i'th plane trace to "visible"
                else:
                    step["args"][0]["visible"][i] = True
                steps.append(step)

            sliders = [dict(
                active=0,
                currentvalue={"prefix": f"{slider}: "},
                pad={"t": 50},
                steps=steps
            )]

            # Axes and layout
            fig.update_layout(
                scene=dict(
                    xaxis_title='epsilon',
                    yaxis_title='beta',
                    zaxis_title='lambdahat',
                    xaxis_type="log",
                    yaxis_type="log",
                    zaxis_type="log",
                    xaxis_range=np.log10(x_range),
                    yaxis_range=np.log10(y_range),
                    zaxis_range=np.log10(z_range),
                    aspectmode='manual',
                    aspectratio=dict(x=0.7, y=1.2, z=1),
                ),
                sliders=sliders,
                title="Local learning coefficient vs. epsilon and beta"
            )

        self.fig = fig
        return fig

### Define the LLC estimator function.
Note: The local learning coefficient estimator function expected by EpsilonBetaAnalyzer must have the following signature:
```python
def estimator(epsilon: float, beta: float, **kwargs) -> dict
```
- Where kwargs correspond to `llc_estimator_kwargs` that are passed in when EpsilonBetaVisualizer.configure_sweep() is called.
- The return value must be a dict with a `"llc/trace"` key corresponding to a numpy array of shape `(num_chains, num_draws)`
- Additional keys can represent other values of interest (e.g. acceptance rates, true LLC.)

See below for an example that uses DevInterp's `estimate_learning_coeff_with_summary` to estimate the local learning coefficient for an arbitrary python model.

In [3]:
# A function wrapper for estimate_learning_coeff_with_summary
# Note: The estimator function expected by EpsilonBetaAnalyzer must have the following signature:
# def estimator(epsilon: float, beta: float, **kwargs) -> dict
# where kwargs are the arguments to estimate_learning_coeff_with_summary
# The return value must be a dict with a "llc/trace" key corresponding to a numpy array of shape (num_chains, num_draws)
# Additional keys can represent other values of interest (e.g. acceptance rates, true LLC.)

def estimate_llc_given_model(model: torch.nn.Module, 
                            loader: torch.utils.data.DataLoader, 
                            criterion: typing.Callable, 
                            epsilon: float,
                            beta: float,
                            sampling_method: Type[torch.optim.Optimizer] = SGLD, 
                            localization: float = 100.0, 
                            num_chains: int = 5, 
                            num_draws: int = 300, 
                            num_burnin_steps: int = 0, 
                            num_steps_bw_draws: int = 1, 
                            device: torch.device = torch.device("cpu"), 
                            online: bool = True, 
                            verbose: bool = False):

    sweep_stats = estimate_learning_coeff_with_summary(
        model,
        loader=loader,
        criterion=criterion,
        sampling_method=sampling_method,
        optimizer_kwargs=dict(lr=epsilon, localization=localization, temperature=beta),
        num_chains=num_chains,  # How many independent chains to run
        num_draws=num_draws,  # How many samples to draw per chain
        num_burnin_steps=num_burnin_steps,  # How many samples to discard at the beginning of each chain
        num_steps_bw_draws=num_steps_bw_draws,  # How many steps to take between each sample
        device=device,
        online=online,
        verbose = verbose
    )

    sweep_stats["llc/trace"] = np.array(sweep_stats["llc/trace"])
    return sweep_stats

### Running the analyzer:
Methods:
- configure_sweep: Sets up the config for the following sweep.
- sweep: Runs a sweep over beta (inverse temperature) and epsilon (learning rate), using the provided llc_estimator function to calculate the loss traces.
- plot: Plots $\hat{\lambda}$ over epsilon and beta with various options.

In [5]:
analyzer = EpsilonBetaAnalyzer()
analyzer.configure_sweep(llc_estimator=estimate_llc_given_model,
                        llc_estimator_kwargs=dict(model=model, loader=loader, criterion=transformers_cross_entropy,
                                                    device = DEVICE),
                        min_epsilon = 1e-6, max_epsilon = 1e-2, epsilon_samples = 8,
                        min_beta = None, max_beta = None, beta_samples = 8,
                        dataloader = loader)  # Automatically find a beta range from the optimal beta
analyzer.sweep()

TypeError: EpsilonBetaAnalyzer.configure_sweep() got an unexpected keyword argument 'dataloader'

In [22]:
analyzer.plot()

In [82]:
# Plotting with options:
# All plotly express options are supported.
analyzer.plot(template="plotly_dark", true_lambda=400, num_last_steps_to_average=50)

In [129]:
# Want to plot slices over a particular subset of the data?
analyzer.plot(template="plotly_dark", num_last_steps_to_average=50, slider="epsilon")

### Visualizing an Epsilon-Beta Sweep for a Deep Linear Network (DLN)

Credit for DLN code goes to Edmundlth. [Source notebook](https://colab.research.google.com/github/edmundlth/validating_lambdahat/blob/dev/DLN_lambdahat.ipynb)

### DLN Setup:

In [6]:
import haiku as hk
import jax
import jax.numpy as jnp
import jax.tree_util as jtree

import numpy as np
import optax
from typing import Sequence, NamedTuple
import json

import matplotlib.pyplot as plt
import plotly.graph_objects as go

import itertools

In [7]:
# Define the DLN model
class DeepLinearNetwork(hk.Module):
    def __init__(self, layer_widths: Sequence[int], name: str = None, with_bias=False):
        super().__init__(name=name)
        self.layer_widths = layer_widths
        self.with_bias = with_bias

    def __call__(self, x):
        for width in self.layer_widths:
            x = hk.Linear(width, with_bias=self.with_bias)(x)
        return x

# Function to initialize and apply the DLN model
def forward_fn(x, layer_widths):
    net = DeepLinearNetwork(layer_widths)
    return net(x)

# Create a Haiku-transformed version of the model
def create_model(layer_widths):
    model = hk.without_apply_rng(hk.transform(lambda x: forward_fn(x, layer_widths)))
    return model


def generate_training_data(true_param, model, input_dim, num_samples):
    # Generate random inputs
    inputs = np.random.uniform(-10, 10, size=(num_samples, input_dim))

    # Apply the true model to generate outputs
    true_outputs = model.apply(true_param, inputs)

    return inputs, true_outputs

def mse_loss(param, model, inputs, targets):
    predictions = model.apply(param, inputs)
    return jnp.mean((predictions - targets) ** 2)


def create_minibatches(inputs, targets, batch_size, shuffle=True):
    assert len(inputs) == len(targets)
    if shuffle:
        indices = np.random.permutation(len(inputs))
    else:
        indices = np.arange(len(inputs))

    for start_idx in range(0, len(inputs) - batch_size + 1, batch_size):
        excerpt = indices[start_idx:start_idx + batch_size]
        yield inputs[excerpt], targets[excerpt]

In [8]:
class SGLDConfig(NamedTuple):
  epsilon: float
  gamma: float
  num_steps: int

def generate_rngkey_tree(key_or_seed, tree_or_treedef):
    rngseq = hk.PRNGSequence(key_or_seed)
    return jtree.tree_map(lambda _: next(rngseq), tree_or_treedef)

def optim_sgld(epsilon, rngkey_or_seed):
    @jax.jit
    def sgld_delta(g, rngkey):
        eta = jax.random.normal(rngkey, shape=g.shape) * jnp.sqrt(epsilon)
        return -epsilon * g / 2 + eta

    def init_fn(_):
        return rngkey_or_seed

    @jax.jit
    def update_fn(grads, state):
        rngkey, new_rngkey = jax.random.split(state)
        rngkey_tree = generate_rngkey_tree(rngkey, grads)
        updates = jax.tree_map(sgld_delta, grads, rngkey_tree)
        return updates, new_rngkey
    return optax.GradientTransformation(init_fn, update_fn)


def create_local_logposterior(avgnegloglikelihood_fn, num_training_data, w_init, gamma, itemp):
    def helper(x, y):
        return jnp.sum((x - y)**2)

    def _logprior_fn(w):
        sqnorm = jax.tree_util.tree_map(helper, w, w_init)
        return jax.tree_util.tree_reduce(lambda a,b: a + b, sqnorm)

    def logprob(w, x, y):
        loglike = -num_training_data * avgnegloglikelihood_fn(w, x, y)
        logprior = -gamma / 2 * _logprior_fn(w)
        return itemp * loglike + logprior
    return logprob


In [9]:
def true_dln_learning_coefficient(true_rank, layer_widths, input_dim, verbose=False):
    M_list = np.array([input_dim] + list(layer_widths)) - true_rank
    indices = brute_force_search_subset(M_list, early_return=verbose)
    M_subset = M_list[indices]
    if verbose:
        print(f"M_list: {M_list}, indices: {indices}, M_subset: {M_subset}")
    M_subset_sum = np.sum(M_subset)
    ell = len(M_subset) - 1
    M = np.ceil(M_subset_sum / ell)
    a = M_subset_sum - (M - 1) * ell
    output_dim = layer_widths[-1]

    term1 = (-true_rank**2 + true_rank * (output_dim + input_dim)) / 2
    term2 = a * (ell - a) / (4 * ell)
    term3 = -ell * (ell - 1) / 4 * (M_subset_sum / ell)**2
    term4 = 1 / 2 * np.sum([M_subset[i] * M_subset[j] for i in range(ell + 1) for j in range(i + 1, ell + 1)])
    learning_coefficient = term1 + term2 + term3 + term4
    return learning_coefficient

def _condition(indices, intlist, verbose=False):
    intlist = np.array(intlist)
    ell = len(indices) - 1
    subset = intlist[indices]
    complement = intlist[[i for i in range(len(intlist)) if i not in indices]]
    has_complement = len(complement) > 0
    # print(indices, subset, complement)
    if has_complement and not (np.max(subset) < np.min(complement)):
        if verbose: print(f"max(subset) = {np.max(subset)}, min(complement) = {np.min(complement)}")
        return False
    if not (np.sum(subset) >= ell * np.max(subset)):
        if verbose: print(f"sum(subset) = {sum(subset)}, ell * max(subset) = {ell * np.max(subset)}")
        return False
    if has_complement and not (np.sum(subset) < ell * np.min(complement)):
        if verbose: print(f"sum(subset) = {sum(subset)}, ell * min(complement) = {ell * np.min(complement)}")
        return False
    return True


def generate_indices_subsets(length):
    indices = list(range(length))
    for size in range(1, length + 1):
        for subset in itertools.combinations(indices, size):
            subset = np.array(subset)
            yield subset


def brute_force_search_subset(intlist, early_return=False):
    candidates = []
    for indices in generate_indices_subsets(len(intlist)):
        if _condition(indices, intlist):
            if early_return:
                return indices
            candidates.append(indices)
    if len(candidates) == 0:
        raise RuntimeError("No candidates")
    if len(candidates) > 1:
        print("More than one candidate")
    return candidates[0]


def to_float_or_list(x):
    if isinstance(x, (float, int)):
        return float(x)
    elif isinstance(x, (list, tuple)):
        return [float(el) for el in x]
    elif hasattr(x, "tolist"):  # For JAX or numpy arrays
        return x.tolist()
    else:
        raise ValueError(f"Unsupported type {type(x)}")

def to_json_friendly_tree(tree):
    return jtree.tree_map(to_float_or_list, tree)


def reduce_matrix_rank(matrix, reduction):
    """
    Reduce the rank of the matrix by 'reduction' amount.

    :param matrix: Input matrix.
    :param reduction: The amount by which the rank should be reduced.
    :return: A matrix similar to the input but with reduced rank.
    """
    U, S, Vh = np.linalg.svd(matrix, full_matrices=False)

    # Reduce the number of non-zero singular values by 'reduction'
    new_rank = max(len(S) - reduction, 0)
    S[new_rank:] = 0

    # Reconstruct the matrix with the reduced number of singular values
    reduced_matrix = np.dot(U * S, Vh)
    return reduced_matrix

def rand_reduce_matrix_rank(matrix):
    r = np.linalg.matrix_rank(matrix)
    reduction = np.random.randint(0, max(1, r))
    return reduce_matrix_rank(matrix, reduction)


### Randomly generate a DLN, calculate its true lambda, and sweep over epsilon and beta

In [10]:
num_training_data = 10000  # Number of training samples

itemp = 1 / np.log(num_training_data)
beta_range = np.power(10, np.linspace(-4, 4, 12)) * itemp
# betas: from beta* /10 to beta* * 10^6
epsilon_range = np.power(10, np.linspace(-6, -1, 12)) # epsilons: from 1e-6 to 1e-2
networks_to_generate = 10
num_epochs = 100
learning_rate = 1e-2
optimizer = optax.sgd(learning_rate)
rngkey = jax.random.PRNGKey(42)
num_epochs = 100
batch_size = 200

# Generate a random network
num_layer = np.random.randint(2, 8)
layer_widths = list(np.random.randint(5, 30, size=num_layer))
input_dim = np.random.randint(5, 20)

2024-07-10 12:54:33.989864: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.5.82). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [137]:
def generate_random_DLN(rngkey, input_dim, layer_widths):
    # Random true parameters
    model = create_model(layer_widths)
    dummy_input = jnp.zeros((1, input_dim))
    rngkey, subkey = jax.random.split(rngkey)
    # true_param = model.init(rngkey, dummy_input)
    # true_param = jtree.tree_map(lambda x: x * 0.0, model.init(rngkey, dummy_input)) # zero true parameter
    true_param = jtree.tree_map(lambda x: rand_reduce_matrix_rank(x) if np.random.rand() > 0.5 else x, model.init(rngkey, dummy_input)) # randomly reduce rank of random matrices
    x_train, y_train = generate_training_data(true_param, model, input_dim, num_training_data)
    loss_fn = jax.jit(lambda param, inputs, targets: mse_loss(param, model, inputs, targets))

    # Training the network
    rngkey, subkey = jax.random.split(rngkey)
    param = model.init(rngkey, jnp.zeros((1, input_dim)))
    opt_state = optimizer.init(param)
    grad_fn = jax.jit(jax.grad(loss_fn, argnums=0))
    sgd_step_count = 0

    for epoch in range(num_epochs):
        for x_batch, y_batch in create_minibatches(x_train, y_train, batch_size=batch_size):
            grads = grad_fn(param, x_batch, y_batch)
            updates, opt_state = optimizer.update(grads, opt_state)
            param = optax.apply_updates(param, updates)

    # Calculate true lambda for the generated network.
    true_matrix = jnp.linalg.multi_dot(
    [true_param[f'deep_linear_network/linear{loc}']['w'] for loc in [''] + [f'_{i}' for i in range(1, len(layer_widths))]]
    )
    true_rank = jnp.linalg.matrix_rank(true_matrix)
    true_lambda = true_dln_learning_coefficient(true_rank, layer_widths, input_dim, verbose=False)
    return model, true_param, true_lambda

model, true_param, true_lambda = generate_random_DLN(rngkey, input_dim, layer_widths)

### Define the LLC estimator function:

In [139]:
def estimate_llc_given_dln(model: hk.Transformed,
                           true_param: hk.Params,
                            epsilon: float,
                            beta: float,
                            gamma: float = 10.0,
                            num_steps: int = 1000,
                            rngkey = jax.random.PRNGKey(42)):

    sgld_config = SGLDConfig(
        epsilon=epsilon,
        gamma=gamma,
        num_steps=num_steps,
    )
    param_init = true_param

    loss_fn = jax.jit(lambda param, inputs, targets: mse_loss(param, model, inputs, targets))
    local_logprob = create_local_logposterior(
            avgnegloglikelihood_fn=loss_fn,
            num_training_data=num_training_data,
            w_init=param_init,
            gamma=sgld_config.gamma,
            itemp=beta,
        )
    sgld_grad_fn = jax.jit(jax.value_and_grad(lambda w, x, y: -local_logprob(w, x, y), argnums=0))

    rngkey, subkey = jax.random.split(rngkey)
    sgldoptim = optim_sgld(sgld_config.epsilon, rngkey)
    samples = []
    nlls = []
    opt_state = sgldoptim.init(param_init)
    param = param_init

    t = 0
    while t < sgld_config.num_steps:
        for x_batch, y_batch in create_minibatches(x_train, y_train, batch_size=batch_size):
            nll, grads = sgld_grad_fn(param, x_batch, y_batch)
            nlls.append(float(nll))
            updates, opt_state = sgldoptim.update(grads, opt_state)
            param = optax.apply_updates(param, updates)
            samples.append(param)
            t += 1

    init_loss = loss_fn(param_init, x_train, y_train)
    loss_trace = [loss_fn(p, x_train, y_train) for p in samples]
    lambdahat = (np.mean(loss_trace) - init_loss) * num_training_data * beta
    llc_trace = [(i - init_loss) * num_training_data * beta for i in loss_trace]

    sweep_stats = {"lambdahat": lambdahat,
                    "llc/trace": llc_trace,}

    sweep_stats["llc/trace"] = np.array(sweep_stats["llc/trace"])
    return sweep_stats

In [140]:
import warnings
warnings.filterwarnings("ignore")

DLNanalyzer = EpsilonBetaAnalyzer()
DLNanalyzer.configure_sweep(llc_estimator=estimate_llc_given_dln,
                        llc_estimator_kwargs=dict(model = model, true_param = true_param,
                                                  gamma=10.0, num_steps=1000),
                        min_epsilon = 1e-6, max_epsilon = 1e-2, epsilon_samples = 8,
                        min_beta = 1e-6, max_beta = 1e2, beta_samples = 8)
DLNanalyzer.sweep()

100%|██████████| 64/64 [03:35<00:00,  3.37s/it]


In [141]:
DLNanalyzer.plot(true_lambda = true_lambda, num_last_steps_to_average = 50)

In [58]:
# To plot the convergence metric as color:
DLNanalyzer.plot(true_lambda = true_lambda, num_last_steps_to_average = 50, color="llc/std_over_mean")

### Sweep over randomly-generated DLNs with different true lambdas:

In [None]:
DLNanalyzer = EpsilonBetaAnalyzer()

for _ in range(10):
    rngkey = jax.random.PRNGKey(np.random.randint(0, 10000))
    model, true_param, true_lambda = generate_random_DLN(rngkey, input_dim, layer_widths)
    DLNanalyzer.configure_sweep(llc_estimator=estimate_llc_given_dln,
                            llc_estimator_kwargs=dict(gamma=10.0, num_steps=2000),
                            min_epsilon = 1e-6, max_epsilon = 1e-2, epsilon_samples = 8,
                            min_beta = 1e-5, max_beta = 1e2, beta_samples = 8)
    DLNanalyzer.sweep(add_to_existing=True)

In [107]:
import pandas as pd
import plotly.express as px
df = pd.DataFrame(all_sweep_stats)
df["llc/std_over_mean"] = df["llc/trace"].apply(lambda x: x[:, -20:].std() / x[:, -20:].mean())
df["llc/final"] = df["llc/trace"].apply(lambda x: x[:, -10].mean())
px.scatter_3d(df, x="epsilon", y="beta", z="llc/final", color="llc/std_over_mean", log_y=True, log_x=True, log_z=True, 
              title="Local learning coefficient vs. epsilon and beta",
              # Set max for color
              range_color=[0, 0.15])

In [108]:
fig.write_html("epsilon_beta_sweep.html")

In [None]:
import seaborn as sns
for chain in df.iloc[-4]["llc/trace"]:
    sns.lineplot(data=chain)
print(df.iloc[-4]["llc/std_over_mean"])

In [103]:
import pandas as pd
import plotly.express as px
df = pd.DataFrame(all_sweep_stats)
df["llc/std_over_mean"] = df["llc/trace"].apply(lambda x: x[:, -20:].std() / x[:, -20:].mean())
df["llc/final"] = df["llc/trace"].apply(lambda x: x[:, -10].mean())
px.scatter_3d(df, x="epsilon", y="beta", z="llc/final", color="llc/std_over_mean", log_y=True, log_x=True, log_z=True, 
              title="Local learning coefficient vs. epsilon and beta",
              # Set max for color
              range_color=[0, 0.15])

In [106]:
fig.write_html("epsilon_beta_sweep.html")

In [76]:
import plotly.graph_objects as go
import pandas as pd
df = pd.DataFrame(all_sweep_stats)
# 3d contour plot
df["llc/final"] = df["llc/trace"].apply(lambda x: x[:, -10].mean())
# Log scale
df["llc/log_final"] = df["llc/final"].apply(lambda x: np.log10(x))
df["log_epsilon"] = df["epsilon"].apply(lambda x: np.log10(x))
df["log_beta"] = df["beta"].apply(lambda x: np.log10(x))

fig = go.Figure(data=[go.Surface(
    x=df["log_epsilon"],
    y=df["log_beta"],
    z=df["llc/log_final"].values.reshape(len(beta_range), len(epsilon_range)).T,
    colorscale='Viridis',
    opacity=0.6,
    contours=dict(z=dict(show=True, usecolormap=True, highlightcolor="limegreen", project=dict(z=True)))
)])
fig.update_layout(scene = dict(
                    xaxis_title='epsilon',
                    yaxis_title='beta',
                    zaxis_title='llc/final',    
                    xaxis = dict(nticks=8, range=[-6,0],),
                    yaxis = dict(nticks=8, range=[0,4],),
                    zaxis = dict(nticks=4, range=[0,4],),
                    ))

fig.show()

In [33]:
list(sweep_stats.values())[0]["llc/trace"][:, -5:].mean()

8.317402