In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from tqdm.notebook import tqdm
import itertools
import importlib

import adaptive_algos as aa
import helper_funcs as hf
import ObjectiveFunction as of

In [None]:
class EulerMaruyama:

    """
    Defines an SDE problem to solve via Euler-Maruyama
    """

    def __init__(self, mu, sigma, dW):
        self.mu = mu
        self.sigma = sigma
        self.dW = dW

    def run_simulation(self, t0, tf, n, dim):
        """
        Return the result of one full simulation.
        """
        dt = float(tf - t0) / n
        ts = np.arange(t0, tf + dt, dt)

        x = torch.zeros(dim)
        xs = torch.zeros(n+1, dim)
        xs[0] = x

        for i in range(1, n):
            t = ts[i]
            x = xs[i - 1]
            xs[i] = x + self.mu(x, t) * dt + self.sigma(x, t) * self.dW(dt)

        return ts, xs

In [74]:
def quadratic_well(x, width, depth):
    a = 2*depth/width**2
    return torch.sum(a*x**2 - depth)

In [305]:
def brownian_simulation(func, lr, sigma, max_iterations, dim, plot=False, start=None):
    """
    Initialise at origin, calc escape time. Use Euler-Maruyama.
    """

    escape_time = max_iterations
    xs = []
    x = torch.zeros(dim) if start is None else torch.tensor(start, dtype=torch.float64)
    x.requires_grad = True
    dW = np.random.normal(loc=0, scale=1, size=(max_iterations, dim))
    dW = torch.tensor(dW)

    for i in range(max_iterations):
        potential = func(x)
        potential.backward()

        if plot:
            xs.append(x.detach().clone().numpy())

        if potential >= -depth/np.sqrt(2):
            escape_time = i
            break

        grad = x.grad
        x = -lr*grad + sigma * dW[i]
        x.requires_grad = True
    
    if plot:
        return xs, escape_time
    else:
        return escape_time
    

In [320]:
# Visualise brownian simulation
width = 100
depth = 10
func = lambda x: quadratic_well(x, width=width, depth=depth)
xs, escape_time = brownian_simulation(func, 1/(depth*width**2), width/4, 10000, 1, plot=True, start=0)
xs = [float(x) for x in xs]

bounds = 10
X = torch.linspace(-bounds/2, bounds/2, 100)
Z = [func(_) for _ in X]

fig = make_subplots(specs=[[{"secondary_y": True}]])
fig.add_trace(
    go.Scatter(x=X, y=Z, line=dict(color='black')),
    secondary_y=False
)

hist_fig = px.histogram(xs, opacity=0.2)
fig.add_trace(
    hist_fig.data[0],
    secondary_y=True,
)
fig.update_layout(title=escape_time)


In [282]:
########################## Fix Width and Vary Depth ##########################
width = 10
depths = np.arange(1, 100, 10)
n_exps, max_iterations = 1000, 10000
dim = 1
mean_escape_times = []

# Set step-size/effective potential strength + noise magnitude
lr_func = lambda width, depth: 1/depth
sigma_func = lambda width, depth: width/4

for depth in depths:

    lr = lr_func(width, depth)
    sigma = sigma_func(width, depth)
    func = lambda x: quadratic_well(x, width=width, depth=depth)
    escape_times = []

    for exp in tqdm(range(n_exps)):
        escape_time = brownian_simulation(func, lr, sigma, 
                                          max_iterations=max_iterations, dim=dim)
        escape_times.append(escape_time)
    
    mean_escape_time = np.mean(escape_times)
    mean_escape_times.append(mean_escape_time)

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

In [285]:
px.scatter(x=depths, y=mean_escape_times)

In [2]:
class PointDataset(Dataset):
    def __init__(self, seed, dim, size, bounds):
        np.random.seed(seed)
        self.data = [torch.tensor(np.random.rand(dim)*bounds - 0.5*bounds) for c in range(size)]
        self.data = torch.vstack(self.data)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

def cyclic_loader(dataloader):
    while True:
        for data in dataloader:
            yield data

In [6]:
# Plotting code
def plot_func(bounds, res, model, pointdata):

    xlim, res = (-bounds/2 - 1, bounds/2 + 1), 0.01
    X = np.arange(xlim[0], xlim[1], res)
    Z = [float(model.forward(data=pointdata.data, X=torch.Tensor([_]))) for _ in X]

    fig = make_subplots(specs=[[{"secondary_y": True}]])
    fig.add_trace(
        go.Scatter(x=X, y=Z, line=dict(color='orange')),
        secondary_y=False
    )

    fig.update_layout(
        width=800,
        height=800,
        showlegend=False,
        xaxis=dict(showline=True, mirror=True, linewidth=2, linecolor='black'),
        yaxis=dict(showline=True, mirror=True, linewidth=2, linecolor='black'),
        plot_bgcolor='rgba(0,0,0,0)',
        yaxis_title='Objective Function',
        xaxis_title='Parameter'
    )

    return fig

In [7]:
importlib.reload(of)
importlib.reload(aa)
# Shared parameters
well_width = 6
outer, inner = well_width+1, well_width-1
seed = 22 # determines function
n_runs = 1000
max_iterations = 100
dim = 1
size = 400
signal_size = 50

batchsize = 10
start = [1.0]
lr = 5

bounds = 40
res = 0.01
dataset_params= {'seed': seed, 'dim':dim, 'size':size, 'bounds':bounds}
pointdata = PointDataset(**dataset_params)

extra_points = torch.vstack([torch.tensor(np.random.rand(dim)*3 - outer) for c in range(signal_size)])
pointdata.data = torch.vstack([pointdata.data, extra_points])
extra_points = torch.vstack([torch.tensor(np.random.rand(dim)*3 - inner) for c in range(signal_size)])
pointdata.data = torch.vstack([pointdata.data, extra_points])
extra_points = torch.vstack([torch.tensor(np.random.rand(dim)*3 + outer) for c in range(signal_size)])
pointdata.data = torch.vstack([pointdata.data, extra_points])
extra_points = torch.vstack([torch.tensor(np.random.rand(dim)*3 + inner) for c in range(signal_size)])
pointdata.data = torch.vstack([pointdata.data, extra_points])
model_params = {'start':start, 'bounds':bounds, 'well_width': well_width}
model = of.AdjustableWell(**model_params)

plot_func(bounds, res, model, pointdata)

In [133]:
######################################## SGD ########################################
title = "Escape Time: SGD with Function Seed {}".format(seed)
opt_params = {'lr': lr}
model_params = {'start':start, 'bounds':bounds, 'well_width': well_width}
escape_times = []

for i in tqdm(range(n_runs)):

    # Shuffle false gives all the same results as expected
    torch.manual_seed(seed+i)
    dataloader = torch.utils.data.DataLoader(dataset=pointdata, batch_size=batchsize, shuffle=True)
    train_cdl = cyclic_loader(dataloader)
    model = of.AdjustableWell(**model_params)
    opt = torch.optim.SGD(params=model.parameters(), **opt_params)
    escaped = False

    for j in range(max_iterations):
        batch = next(train_cdl)
        opt.zero_grad()
        loss = model(data=batch)
        loss.backward()
        opt.step()

        curr_param = list(model.parameters())[0]
        if torch.abs(curr_param) > well_width: 
            escape_times.append(j+1)
            escaped = True
            break
    if not escaped:
        escape_times.append(max_iterations+1)

fig1 = px.histogram(escape_times, nbins=max_iterations+2)
fig1.show()

  0%|          | 0/1000 [00:00<?, ?it/s]

In [130]:
########################################## SGD_TC #######################################
importlib.reload(aa)

title = "SGD_TC with Function Seed {}".format(seed)
scale_annealer = lambda progress: 1
opt_params={'lr': lr, 'height': 1.0, 'width': bounds/10,
            'scale_annealer': scale_annealer, 'n_epochs': max_iterations}
model_params = {'start':start, 'bounds':bounds, 'well_width': well_width}
escape_times = []

for i in tqdm(range(n_runs)):

    # Shuffle false gives all the same results as expected
    torch.manual_seed(seed+i)
    dataloader = torch.utils.data.DataLoader(dataset=pointdata, batch_size=batchsize, shuffle=True)
    train_cdl = cyclic_loader(dataloader)
    model = of.AdjustableWell(**model_params)
    opt = aa.SGD_TC(params=model.parameters(), **opt_params)
    escaped = False

    for j in range(max_iterations):
        batch = next(train_cdl)
        opt.zero_grad()
        loss = model(data=batch)
        loss.backward()
        opt.step()
        curr_param = list(model.parameters())[0]
        if torch.abs(curr_param) > well_width:
            escape_times.append(j+1)
            escaped = True
            break
    if not escaped:
        escape_times.append(max_iterations+1)
fig2 = px.histogram(escape_times, nbins=max_iterations+2)
fig2.show()

  0%|          | 0/1000 [00:00<?, ?it/s]

In [13]:
# generalise and compute average
def escape_time_analysis(opter, opt_params, repeat=True):
    
    model_params = {'start':start, 'bounds':bounds, 'well_width': well_width}
    escape_times = []

    pbar = range(n_runs) if repeat else tqdm(range(n_runs))

    for i in pbar:

        # Shuffle false gives all the same results as expected
        torch.manual_seed(seed+i)
        dataloader = torch.utils.data.DataLoader(dataset=pointdata, batch_size=batchsize, shuffle=True)
        train_cdl = cyclic_loader(dataloader)
        model = of.AdjustableWell(**model_params)
        opt = opter(params=model.parameters(), **opt_params)
        escaped = False

        for j in range(max_iterations):
            batch = next(train_cdl)
            opt.zero_grad()
            loss = model(data=batch)
            loss.backward()
            opt.step()

            curr_param = list(model.parameters())[0]
            if torch.abs(curr_param) > well_width: 
                escape_times.append(j+1)
                escaped = True
                break
        if not escaped:
            escape_times.append(max_iterations+1)
    if repeat:
        return np.mean(escape_times)
    else:
        fig = px.histogram(escape_times, nbins=max_iterations+2)
        fig.update_layout(
            showlegend=False
        )
        return fig

In [14]:
scale_annealer = lambda progress: 1

opt_params={'lr': lr, 'height': 1.0, 'width': bounds/10,
            'scale_annealer': scale_annealer, 'n_epochs': max_iterations}
escape_time_analysis(aa.SGD_TC, opt_params, repeat=False)

  0%|          | 0/1000 [00:00<?, ?it/s]

In [15]:
opt_params={'lr': lr}
escape_time_analysis(torch.optim.SGD, opt_params, repeat=False)

  0%|          | 0/1000 [00:00<?, ?it/s]