In [1]:
import numpy as np
import numpy.random as npr
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt

In [2]:
def circular_harmonics(L, theta):
    B = [torch.tensor([1 / np.sqrt(2 * np.pi)] * theta.shape[0])]
    for l in range(1,L+1):
        B.append(l * torch.cos(theta) / np.sqrt(np.pi))
        B.append(l * torch.sin(theta) / np.sqrt(np.pi))
    return torch.stack(B).permute(1,0).unsqueeze(2).float()

In [3]:
lmax = 5
model = nn.Sequential(
    nn.Linear(1, 8),
    nn.ReLU(inplace=True),
    nn.Linear(8, 2*lmax + 1)
)

In [4]:
with torch.no_grad():
    w = model(torch.rand(2,1))

In [8]:
Es = list()
for theta in np.linspace(0,2*np.pi, 32):
    B = circular_harmonics(lmax, torch.tensor([theta, theta]))
    Es.append(torch.bmm(w.unsqueeze(1), B).squeeze())
Es = torch.stack(Es)

In [9]:
print(Es.shape)

torch.Size([32, 2])


In [None]:
E = Es[:,0].tolist()

fig, ax = plt.subplots(subplot_kw={'projection': 'polar'})
ax.plot(np.linspace(0,2*np.pi, 32), E)
ax.set_rmax(np.max(E) + 0.2)
ax.set_rticks(np.round(np.linspace(np.min(E),np.max(E), 5), 1))
ax.grid(True)

ax.set_title("Energy", va='bottom')
plt.show()