In [95]:
import numpy as np
import os
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
from matplotlib.ticker import FormatStrFormatter
import imageio
## implementation of the paper  https://arxiv.org/abs/2006.10739

In [20]:
class RandomFourierFeatures:
    def __init__(self,batch_size: int = 1,embedding_size: int = 5,features_size : int  = 1):
        scale = 1.0
        self.batch_size = batch_size
        self.embedding_size = embedding_size
        self.features_size = features_size
        self.W = np.random.randn(features_size,embedding_size)*scale  # [embedding_size] sampled N(0,1)
    def get_random_fourier_embedding(self,x):
        assert x.shape[-1] == self.W.shape[0]
        x_proj = (x @ self.W) * 2 * np.pi # [bs,features,embedding_size]
        x_fourier = np.concatenate([np.sin(x_proj), np.cos(x_proj)], axis=-1)
        return x_fourier

In [121]:
batch_size = 1
features_size = 1
embedding_size = 5
nb_points = 15
x_list = [torch.ones(batch_size,features_size)*i/nb_points for i in range(nb_points)]
fourier_features = RandomFourierFeatures(batch_size,embedding_size,features_size)
hue_values= ["sin"]*embedding_size + ["cos"]*embedding_size

In [122]:
filenames = []
for x in x_list:

    x_fourier = fourier_features.get_random_fourier_embedding(x)

    filename = f"{x[0,0]}.png"
    filenames.append(filename)

    fig,axs = plt.subplots(figsize=(12,5))#
    axs = sns.barplot(x=np.around(np.concatenate([fourier_features.W[0],fourier_features.W[0]]), decimals=3),
                         y=x_fourier[0],
                         hue = hue_values)
    axs.set_ylim([-1,1])

    axs.title.set_text('Fourier Embedding Output')

    plt.subplots_adjust(left=0.25, bottom=0.25)
    axamp = plt.axes([0.1, 0.25, 0.0225, 0.63])
    amp_slider = Slider(
        ax=axamp,
        label="Scalar Input",
        valmin=0,
        valmax=1,
        valinit=x[0,0],
        orientation="vertical",
        facecolor = sns.color_palette()[5]
    )

    plt.savefig(filename)
    plt.close()

In [123]:
with imageio.get_writer('fourier_embedding.gif', mode='I') as writer:
    for filename in filenames:
        image = imageio.imread(filename)
        writer.append_data(image)

In [124]:
for filename in set(filenames):
    os.remove(filename)