In [2]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from tqdm.notebook import tqdm
import itertools
import importlib

import adaptive_algos as aa
import helper_funcs as hf
import ObjectiveFunction as of

In [11]:
def run_experiment(model, opt_alg, start, bounds, opt_params, n_epochs, repeat=False):
    m = model(start=start, bounds=bounds)
    
    if opt_alg is torch.optim.Adam or opt_alg is torch.optim.SGD:
            opt_problem = of.OptimisationProblem(
                m,
                opt_alg(params=m.parameters(), **opt_params),
                n_epochs = n_epochs
            )
    else:
        opt_problem = of.OptimisationProblem(
            m,
            opt_alg(params=m.parameters(), func=m, **opt_params),
            n_epochs = n_epochs
        )

    if repeat: 
        losses, params, preds = opt_problem.run(logs=False)
    else:
        losses, params, preds = opt_problem.run()

    if not repeat:
        fig1 = opt_problem.visualise((-bounds, bounds), (-bounds, bounds), 0.1, render="contour")
        fig2 = hf.create_density_plot(params)
        if hasattr(opt_problem.opt, 'alpha_record'):
            fig3 = px.scatter(opt_problem.opt.alpha_record)
            return [fig1, fig2, fig3]
        
        else:
            return [fig1, fig2]
    else:
        return params[-1]

In [7]:
importlib.reload(aa)
importlib.reload(of)
lr = 0.05
bounds = 10 
height = 1.0
width = bounds/20
batch_size = 10
max_iterations = 1000
scale_annealer = lambda progress: 1 - progress
# def scale_annealer(progress):
#     if progress < 0.33:
#         return 1
#     elif progress < 0.66:
#         return 0.8
#     else:
#         return 0.01

opt_params={'lr': lr, 'height': height, 'width': width, 'n_epochs': max_iterations, 
            'scale_annealer': scale_annealer}
figs = run_experiment(of.AlpineN1, aa.SGD_TC, start=[2.5,2], bounds=bounds,
                      opt_params={'lr': lr, 'height': height, 'width': width, 
                      'n_epochs': max_iterations, 'scale_annealer': scale_annealer},
                      n_epochs=max_iterations
                      )
figs.extend(run_experiment(of.Ackley, aa.SGD_TC, start=[2.5,2], bounds=bounds,
                      opt_params={'lr': lr, 'height': height, 'width': width, 
                      'n_epochs': max_iterations, 'scale_annealer': scale_annealer},
                      n_epochs=max_iterations
                      )
            )
figs.extend(run_experiment(of.Rosenbrock, aa.SGD_TC, start=[2.5,2], bounds=bounds,
                      opt_params={'lr': lr, 'height': height, 'width': width, 
                      'n_epochs': max_iterations, 'scale_annealer': scale_annealer},
                      n_epochs=max_iterations
                      )
            )
hf.figures_to_html(figs, 'SGD_TC_testfinal.html')

0
100
200
300
400
500
600
700
800
900
0
100
200
300
400
500
600
700
800
900
0
100
200
300
400
500
600
700
800
900



divide by zero encountered in log



In [None]:
# compare with SGD, momentum, ADAM, SGD_TC
# percentage of runs finding global minimum, quantitiative comparisons
# time complexity comparison
# highlight the fact that it does not get trapped
# compare escape dynamics of SGD_TC with metadynamics

In [14]:
importlib.reload(aa)
importlib.reload(of)
n_exps = 1000
final_params = []
max_iterations = 1000

lr = 0.05
bounds = 10 
height = 1.0
width = bounds/20
batch_size = 10
scale_annealer = lambda progress: 1 - progress
opt_params={'lr': lr, 'height': height, 'width': width, 'n_epochs': max_iterations, 
            'scale_annealer': scale_annealer}

for n in tqdm(range(n_exps)):
    final_param = run_experiment(of.AlpineN1, aa.SGD_TC, start=[2.5,2], bounds=bounds,
                      opt_params=opt_params, n_epochs=max_iterations, repeat=True
                      )
    final_params.append(final_param)    

  0%|          | 0/1000 [00:00<?, ?it/s]