In [1]:
%load_ext autoreload
%autoreload 2
import ObjectiveFunction as of
import plotly.graph_objects as go
import plotly.figure_factory as ff
import plotly.express as px
from plotly.subplots import make_subplots
import numpy as np
import helper_funcs as hf
from math import sqrt
from scipy.stats import uniform, norm
from torch.optim.optimizer import Optimizer, required
import torch
from typing import List, Optional

In [2]:
def cart2polar(x):
    x = np.array(x)
    return np.arctan2(x[1], x[0])

def powlaw_samp(x_min, alpha, size=1):
    """
    Samples from powerlaw dist with min value x_min.
    """
    r = np.random.random(size=size)
    return  x_min * (1 - r) ** (1 / (1-alpha))

    # https://stats.stackexchange.com/questions/173242/random-sample-from-power-law-distribution
    # https://arxiv.org/pdf/0706.1062.pdf

In [3]:
class SGD_TC2_Width(Optimizer):

    def __init__(self,
                 params,
                 func: float = required,
                 lr: float = required,
                 height: float = required,
                 width: float = required,
                 momentum: float = 0
    ):

        self.height = height
        if height != 1:
            print("Warning: given height is not compatible with a counting scheme.")
        self.width_denom = -0.5*(1/width)**2    # how large our count regions are
        self.step_count = 0
        self.alpha = 2.5

        self.func = func
        defaults = dict(lr=lr,
                        momentum=momentum)
        super().__init__(params, defaults)
    
    def _metric(self, pred):

        if not self.state:
            return pred
        
        history = self.state['history'][0:-1]
        last_ph = self.state['history'][-1]
        
        Vbias = 0

        for ph in history:
            v = last_ph - ph
            Vbias += torch.exp(self.width_denom * torch.dot(v, v.T))
        
        Vbias = self.height * Vbias

        # update detachment of tensors
        self.state['history'][-1] = self.state['history'][-1].detach().clone()

        # Adapt alpha - no phase preference
        self.adapt_alpha(Vbias, pred)
        
        return pred
    
    def adapt_alpha(self, Vbias, pred):

        p = float(Vbias/self.step_count)
        
        if p > 1:
            raise ValueError("""Should not have this much bias. 
                            Probably irregular walker behaviour - check the trajectory plots.""")

        self.alpha = 2.5 + p

    @torch.no_grad()
    def step(self, closure=None):

        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                
                grad = p.grad
                
                if 'history' not in self.state:
                    self.state['history'] = [p]
                else:
                    self.state['history'].append(p)

                if 'momentum_buffer' not in self.state:
                    self.state['momentum_buffer'] = grad.detach().clone()
                else:
                    self.state['momentum_buffer'].mul_(group['momentum']).add_(grad, alpha=1)
                
                mom_grad = self.state['momentum_buffer']
                p.add_(mom_grad, alpha=-group['lr'])

                # Levy Flight noise
                levy_r = float(powlaw_samp(x_min=group['lr']*0.01, alpha=self.alpha)) * torch.norm(grad)
                theta = float(uniform.rvs(loc=0, scale=2*np.pi))
                dir = np.array([np.cos(theta), np.sin(theta)])
                levy_noise = levy_r * torch.Tensor(dir)
                p.add_(levy_noise, alpha=-group['lr'])
                print(self.alpha, self.step_count)
                
                # Periodic Boundary Conditions (INBUILT TO OPTIMISER TO ENSURE COORDINATES ARE BOUNDED)
                p = self.func.apply_period(p)

        self.step_count += 1

        return loss