In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import pykerr
from scipy.constants import c, pi, G
from scipy.signal import correlate
from astropy.constants import M_sun

In [2]:
if not os.path.exists("results"):
	os.makedirs("results")
if not os.path.exists("plots"):
	os.makedirs("plots")

In [3]:
def save_gif_PIL(outfile, files, fps=5, loop=0):
    "Helper function for saving GIFs"
    imgs = [Image.open(file) for file in files]
    imgs[0].save(fp=outfile, format='GIF', append_images=imgs[1:], save_all=True, duration=int(1000/fps), loop=loop)

def plot_waveforms(time, pinn_waveform, leaver_waveform, l, m, match_score):
    plt.figure()
    plt.title(f"a = {a}      (l, m, n) = ({l}, {m}, 0)")
    plt.plot(time, pinn_waveform, label="PINN", linewidth=1.5)
    plt.plot(time, leaver_waveform, linestyle="--", label="Leaver", linewidth=1.5)
    plt.text(0.015,0.9,"Training step: %i"%(i+1),fontsize="xx-large",color="k")

    legend = plt.legend(
        loc="upper right",
        bbox_to_anchor=(1.0, 1.0),
        frameon=True,
        title=f"Match: {match_score:.6f}",
        title_fontproperties={"weight": "bold", "size": "medium"}
    )

    # Improve legend box aesthetics
    frame = legend.get_frame()
    frame.set_edgecolor("black")
    frame.set_linewidth(1.2)

### Future Modifications

Need to implement Chebyshev polynomials for sampling

In [4]:
x = torch.linspace(0, 1, 100).view(-1, 1).requires_grad_()
u = torch.linspace(-1, 1, 100).view(-1, 1).requires_grad_()
time = np.linspace(0, 0.1, 100)

# f_, tau_ = pykerr.qnmfreq(200, 0, 2, 0, 0), pykerr.qnmtau(200, 0, 2, 0, 0)
f_ = ((c**3 / (4 * pi * G * 200 * M_sun.value)) * 0.74734)
tau_ = - (((2 * G * 200 * M_sun.value) / c**3) * (1 / -0.17793))

leaver_waveform = np.exp(-time/tau_) * np.cos(2 * np.pi * f_ * time)
leaver_norm = (leaver_waveform - np.mean(leaver_waveform)) / np.std(leaver_waveform)

\begin{align}
F_0(a, x, m, s, \omega, r_+, A) &= -a^4 x^2 \omega^2 - 2 a^3 m x^2 \omega + a^2 \bigg(-A x^2 + x^2 \Big(4 (r_+ + 1)^\omega{}^2 + 2j (r_+ + 2) \omega + 2j s (\omega + 1j) - 2\Big) \nonumber \\
&\quad + x \omega^2 - \omega^2\bigg) + 2 a m \Big(r_+ x^2 (2 \omega + 1j) - x (\omega + 1j) - \omega\Big) \nonumber \\
&\quad + A (x - 1) - 1j r_+ (2 \omega + 1j) \Big(x^2 (s - 2 1j \omega + 1) - 2 (s + 1) x + 2 1j \omega\Big) \nonumber \\
&\quad + (s + 1) (x - 2j \omega).
\end{align}

\begin{align}
F_1(a, x, m, s, \omega, r_+) &= 2 a^4 x^4 (x - 1j \omega) - 2j a^3 m x^4 + a^2 x^2 \Big(2 r_+ x^2 (-1 + 2j \omega) - (s + 3) x^2 + 2 x (s + 1j \omega + 2) - 4j \omega\Big) \nonumber \\
&\quad + 2j a m (x - 1) x^2 + (x - 1) \Big(2 r_+ x^2 (1 - 2j \omega) + (s + 1) x^2 - 2 (s + 1) x + 2j \omega\Big).
\end{align}

\begin{align}
F_2(a, x) &= a^4 x^6 - 2 a^2 (x - 1) x^4 + (x - 1)^2 x^2.
\end{align}

\begin{align}
G_0(a, u, m, s, \omega, A) &= 4 a^2 (u^2 - 1) \omega^2 - 4 a (u^2 - 1) \omega \Big((u - 1) |m - s| + (u + 1) |m + s| + 2 (s + 1) u\Big) \nonumber \\
&\quad + 4 \Big(A (u^2 - 1) + m^2 + 2 m s u + s \big((s + 1) u^2 - 1\big)\Big) \nonumber \\
&\quad - 2 (u^2 - 1) |m + s| - 2 (u^2 - 1) |m - s| (|m + s| + 1) \nonumber \\
&\quad - (u - 1)^2 |m - s|^2 - (u + 1)^2 |m + s|^2.
\end{align}

\begin{align}
G_1(a, u, m, s, \omega) &= -8 a (u^2 - 1)^2 \omega - 4 (u^2 - 1) \Big((u - 1) |m - s| + (u + 1) |m + s| + 2 u\Big).
\end{align}

\begin{align}
G_2(u) &= -4 (u^2 - 1)^2.
\end{align}


In [5]:
def F_0(a, x, m, s, omega, r_plus, A):
    return -a**4 * x**2 * omega**2 - 2 * a**3 * m * x**2 * omega + a**2 * (-A * x**2 + x**2 * (4 * (r_plus + 1) *  omega**2 + 2j * (r_plus + 2) * omega + 2j * s * (omega + 1j) - 2) + x * omega**2 - omega**2) + 2 * a * m * (r_plus * x**2 * (2 * omega + 1j) - x * (omega + 1j) - omega) + A * (x - 1) - 1j * r_plus * (2 * omega + 1j) * (x**2 * (s - 2 * 1j * omega + 1) - 2 * (s + 1) * x + 2 * 1j * omega) + (s + 1) * (x - 2j * omega)

def F_1(a, x, m, s, omega, r_plus):
    return 2 * a**4 * x**4 * (x - 1j * omega) - 2j * a**3 * m * x**4 + a**2 * x**2 * (2 * r_plus * x**2 * (-1 +2j * omega) - (s + 3) * x**2 + 2 * x * (s + 1j * omega + 2) - 4j * omega) + 2j * a * m * (x - 1) * x**2 + (x - 1) * (2 * r_plus * x**2 * (1 - 2j * omega) + (s + 1) * x**2 - 2 * (s + 1) * x + 2j * omega)

def F_2(a, x):
    return a**4 * x**6 - 2 * a**2 * (x - 1) * x**4 + (x - 1)**2 * x**2

def G_0(a, u, m, s, omega, A):
    return 4 * a**2 * (u**2 - 1) * omega**2 - 4 * a * (u**2 - 1) * omega * ((u - 1) * torch.abs(torch.tensor(m - s)) + (u + 1) * torch.abs(torch.tensor(m + s)) + 2 * (s + 1) * u) + 4 * (A * (u**2 - 1) + m**2 + 2 * m * s * u + s * ((s + 1) * u**2 - 1)) - 2 * (u**2 - 1) * torch.abs(torch.tensor(m + s)) - 2 * (u**2 - 1) * torch.abs(torch.tensor(m - s)) * (torch.abs(torch.tensor(m + s)) + 1) - (u - 1)**2 * (torch.abs(torch.tensor(m - s)))**2 - (u + 1)**2 * (torch.abs(torch.tensor(m + s)))**2

def G_1(a, u, m, s, omega):
    return -8 * a * (u**2 - 1)**2 * omega - 4 * (u**2 - 1) * ((u - 1) * torch.abs(torch.tensor(m - s)) + (u + 1) * torch.abs(torch.tensor(m + s)) + 2 * u)

def G_2(u):
    return -4 * (u**2 - 1)**2

In [6]:
l = 2

In [7]:
class QNM_radial(nn.Module):
    def __init__(self, input_size = 1, hidden_size = 200, output_size = 2):
        super(QNM_radial, self).__init__()
        self.layer = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, output_size)
        )

        torch.manual_seed(42)
        for m in self.layer.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean = 0, std = 0.05)
                nn.init.constant_(m.bias, val = 0.0)

    def forward(self, x):
        output = (torch.exp(x - 1) - 1) * self.layer(x) + 1
        return output
    
class QNM_angular(nn.Module):
    def __init__(self, input_size = 1, hidden_size = 200, output_size = 2):
        super(QNM_angular, self).__init__()
        self.layer = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, output_size)
        )

        torch.manual_seed(42)
        for m in self.layer.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, mean = 0, std = 0.05)
                nn.init.constant_(m.bias, val = 0.0)
    
    def forward(self, u):
        output = (torch.exp(u + 1) - 1) * self.layer(u) + 1
        return output
    
omega = nn.Parameter(torch.tensor([0.7, -0.1], requires_grad=True)) 
A = nn.Parameter(torch.tensor([l*(l+1) - 2.0, 0.0]), requires_grad=True)

In [8]:
net1 = QNM_radial()
net2 = QNM_angular()

# opt_net1 = torch.compile(net1)
# opt_net2 = torch.compile(net2)

# optimizer = optim.Adam(list(net1.parameters()) + list(net2.parameters()) + [omega, A], lr=0.005)
optimizer = optim.Adam(list(net1.parameters()) + list(net2.parameters()) + [omega, A], lr=0.005)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.999)

## Parameters for Schwarzschild BH case (with gravitational perturbation)

- a = 0 (spinless)
- s = -2
- m = 0
- $r_+ = \frac{1 + \sqrt{1 - 4 a^2}}{2}$

In [9]:
a, s, m= 0.0, -2, 0
r_plus = (1 + np.sqrt(1 - 4 * a**2)) / 2

In [10]:
def loss(weight):
    f_r, f_i = net1(x)[:, 0].unsqueeze(dim = 1), net1(x)[:, 1].unsqueeze(dim = 1)
    g_r, g_i = net2(u)[:, 0].unsqueeze(dim = 1), net2(u)[:, 1].unsqueeze(dim = 1)
    
    f_r_x = torch.autograd.grad(f_r, x, grad_outputs=torch.ones_like(f_r), create_graph=True)[0]
    f_r_xx = torch.autograd.grad(f_r_x, x, grad_outputs=torch.ones_like(f_r_x), create_graph=True)[0]

    f_i_x = torch.autograd.grad(f_i, x, grad_outputs=torch.ones_like(f_i), create_graph=True)[0]
    f_i_xx = torch.autograd.grad(f_i_x, x, grad_outputs=torch.ones_like(f_i_x), create_graph=True)[0]

    g_r_u = torch.autograd.grad(g_r, u, grad_outputs=torch.ones_like(g_r), create_graph=True)[0]
    g_r_uu = torch.autograd.grad(g_r_u, u, grad_outputs=torch.ones_like(g_r_u), create_graph=True)[0]

    g_i_u = torch.autograd.grad(g_i, u, grad_outputs=torch.ones_like(g_i), create_graph=True)[0]
    g_i_uu = torch.autograd.grad(g_i_u, u, grad_outputs=torch.ones_like(g_i_u), create_graph=True)[0]

    F_0_ = F_0(a = a, x = x, m = m, s = s, omega = torch.complex(omega[0], omega[1]), r_plus = r_plus, A = torch.complex(A[0], A[1]))
    F_1_ = F_1(a = a, x = x, m = m, s = s, omega = torch.complex(omega[0], omega[1]), r_plus = r_plus)
    F_2_ = F_2(a = a, x = x)

    G_0_ = G_0(a = a, u = u, m = m, s = s, omega = torch.complex(omega[0], omega[1]), A = torch.complex(A[0], A[1]))
    G_1_ = G_1(a = a, u = u, m = m, s = s, omega = torch.complex(omega[0], omega[1]))
    G_2_ = G_2(u = u)

    L_F_  = torch.abs(torch.complex(F_2_ * f_r_xx + F_1_.real * f_r_x - F_1_.imag * f_i_x + F_0_.real * f_r - F_0_.imag * f_i, F_2_ * f_i_xx + F_1_.real * f_i_x + F_1_.imag * f_r_x + F_0_.real * f_i + F_0_.imag * f_r))
    L_G_ = torch.abs(torch.complex(G_2_ * g_r_uu + G_1_.real * g_r_u - G_1_.imag * g_i_u + G_0_.real * g_r - G_0_.imag * g_i, G_2_ * g_i_uu + G_1_.real * g_i_u + G_1_.imag * g_r_u + G_0_.real * g_i + G_0_.imag * g_r))

    return weight * torch.mean(L_F_) + torch.mean(L_G_)


In [None]:
files = []

for i in range(2001):
    optimizer.zero_grad()
    loss_ = loss(10)
    loss_.backward()
    # print(omega.grad, A.grad)
    optimizer.step()
    scheduler.step()

    if i < 50:
        f = ((c**3 / (4 * pi * G * 200 * M_sun.value)) * omega[0].detach().cpu().numpy())
        tau = - (((2 * G * 200 * M_sun.value) / c**3) * (1 / omega[1].detach().cpu().numpy()))

        pinn_waveform = np.exp(-time/tau) * np.cos(2 * np.pi * f * time)
        pinn_norm = (pinn_waveform - np.mean(pinn_waveform)) / np.std(pinn_waveform)

        corr = correlate(pinn_norm, leaver_norm, mode='valid')
        match_score = np.max(corr) / len(pinn_waveform) 

        plot_waveforms(time, pinn_waveform, leaver_waveform, l, m, match_score)

        file = "plots/%.8i.png"%(i+1)
        plt.savefig(file, bbox_inches='tight', pad_inches=0.1, dpi=100, facecolor="white")
        files.append(file)
        plt.close()

    if i > 50 and i % 20 == 0:
        f = ((c**3 / (4 * pi * G * 200 * M_sun.value)) * omega[0].detach().cpu().numpy())
        tau = - (((2 * G * 200 * M_sun.value) / c**3) * (1 / omega[1].detach().cpu().numpy()))

        pinn_waveform = np.exp(-time/tau) * np.cos(2 * np.pi * f * time)
        pinn_norm = (pinn_waveform - np.mean(pinn_waveform)) / np.std(pinn_waveform)

        corr = correlate(pinn_norm, leaver_norm, mode='valid')
        match_score = np.max(corr) / len(pinn_waveform) 

        plot_waveforms(time, pinn_waveform, leaver_waveform, l, m, match_score)

        file = "plots/%.8i.png"%(i+1)
        plt.savefig(file, bbox_inches='tight', pad_inches=0.1, dpi=100, facecolor="white")
        files.append(file)
        plt.close()

    if i % 100 == 0:
        print(f"Epoch {i} | Loss: {loss_} | Omega: {torch.complex(omega[0], omega[1])}")

save_gif_PIL(f"results/QNM_pinn_a_{a}.gif", files, fps=30, loop=0)

Epoch 0 | Loss: 16.320016860961914 | Omega: (0.6949999928474426-0.10500000417232513j)
Epoch 100 | Loss: 0.2617379128932953 | Omega: (0.7640773057937622-0.1843014359474182j)
Epoch 200 | Loss: 0.16968567669391632 | Omega: (0.7590205669403076-0.18332539498806j)
Epoch 300 | Loss: 0.36453676223754883 | Omega: (0.7498103380203247-0.18139898777008057j)
Epoch 400 | Loss: 0.20284844934940338 | Omega: (0.7489073276519775-0.17660242319107056j)
Epoch 500 | Loss: 0.12225915491580963 | Omega: (0.7501217126846313-0.17724467813968658j)
Epoch 600 | Loss: 0.11630988866090775 | Omega: (0.749554455280304-0.17870664596557617j)
Epoch 700 | Loss: 0.12814973294734955 | Omega: (0.7476443648338318-0.17680999636650085j)
Epoch 800 | Loss: 0.09600942581892014 | Omega: (0.7489701509475708-0.17823877930641174j)
Epoch 900 | Loss: 0.07832834124565125 | Omega: (0.7471605539321899-0.17749987542629242j)
Epoch 1000 | Loss: 0.07464103400707245 | Omega: (0.748440682888031-0.1783640831708908j)
Epoch 1100 | Loss: 0.0754716768

In [12]:
torch.complex(omega[0], omega[1])

tensor(0.7459-0.1781j, grad_fn=<ComplexBackward0>)