In [4]:
import math
import numpy as np

import torch

In [14]:
class Branin(object):
    def __init__(self):
        self.n_vertices = np.array([21, 21])
        self.n_factors = len(self.n_vertices)
        self.suggested_init = torch.Tensor(self.n_vertices).long().unsqueeze(0) / 2
        for _ in range(1, 2):
            random_init = torch.cat([torch.randint(0, int(elm), (1, 1)) for elm in self.n_vertices], dim=1)
            self.suggested_init = torch.cat([self.suggested_init, random_init], dim=0)
        self.adjacency_mat = []
        self.fourier_freq = []
        self.fourier_basis = []
        for i in range(len(self.n_vertices)):
            n_v = self.n_vertices[i]
            adjmat = torch.diag(torch.ones(n_v - 1), -1) + torch.diag(torch.ones(n_v - 1), 1)
            adjmat *= (n_v - 1.0)
            self.adjacency_mat.append(adjmat)
            degmat = torch.sum(adjmat, dim=0)
            laplacian = (torch.diag(degmat) - adjmat)
            
            # TEST
            print(laplacian.size())
            
            eigval, eigvec = torch.linalg.eigh(laplacian)
            self.fourier_freq.append(eigval)
            self.fourier_basis.append(eigvec)

    def evaluate(self, x_g):
        flat = x_g.dim() == 1
        if flat:
            x_g = x_g.view(1, -1)
        ndim = x_g.size(1)
        assert ndim == len(self.n_vertices)
        n_repeat = int(ndim / 2)
        n_dummy = int(ndim % 2)

        x_e = torch.ones(x_g.size())
        for d in range(len(self.n_vertices)):
            x_e[:, d] = torch.linspace(-1, 1, int(self.n_vertices[d]))[x_g[:, d].long()]

        shift = torch.cat([torch.FloatTensor([2.5, 7.5]).repeat(n_repeat), torch.zeros(n_dummy)])

        x_e = x_e * 7.5 + shift

        a = 1
        b = 5.1 / (4 * math.pi ** 2)
        c = 5.0 / math.pi
        r = 6
        s = 10
        t = 1.0 / (8 * math.pi)
        output = 0
        for i in range(n_repeat):
            output += a * (x_e[:, 2 * i + 1] - b * x_e[:, 2 * i] ** 2 + c * x_e[:, 2 * i] - r) ** 2 \
                      + s * (1 - t) * torch.cos(x_e[:, 2 * i]) + s
        output /= float(n_repeat)
        
        if flat:
            return output.squeeze(0)
        else:
            return output

In [15]:
b = Branin()

torch.Size([21, 21])
torch.Size([21, 21])


In [13]:
b.suggested_init

tensor([[10.5000, 10.5000],
        [ 0.0000, 12.0000]])