In [None]:
import numpy as np
import plotly.graph_objects as go

# Parameters for the spinning black holes
mass = 1.0
spin = 3.4
grid_size = 23
speed_of_light = 299792458

# Positions of the two spinning black holes with parallel axes
position1 = np.array([-1.0, 0.0, 0.0])
position2 = np.array([1.0, 0.0, 0.0])

x = np.linspace(-4, 4, grid_size)
y = np.linspace(-4, 4, grid_size)
z = np.linspace(-4, 4, grid_size)
X, Y, Z = np.meshgrid(x, y, z)

fig = go.Figure()

num_frames = 50
t_values = np.linspace(0, 2 * np.pi, num_frames)
delay_factor = 0.5

# Calculate the gravitational wave vectors at each point in the grid
def calculate_gravitational_wave_vector(x, y, z, t, position):
    r = np.sqrt((x - position[0])**2 + y**2 + z**2)
    theta = np.arctan2(y, x - position[0])
    phi = np.arctan2(z, np.sqrt((x - position[0])**2 + y**2))
    
    amplitude = np.sin(theta + spin * r - t)
    direction_x = -np.sin(phi) * np.cos(theta) * amplitude
    direction_y = -np.sin(phi) * np.sin(theta) * amplitude
    direction_z = np.cos(phi) * amplitude
    
    return direction_x, direction_y, direction_z

# Create a 3D scatter plot for the gravitational wave vectors
fig = go.Figure(data=go.Cone(
    x=X.flatten(),
    y=Y.flatten(),
    z=Z.flatten(),
    u=np.zeros_like(X.flatten()),
    v=np.zeros_like(Y.flatten()),
    w=np.zeros_like(Z.flatten()),
    sizemode="absolute",
    sizeref=0.1,  # Adjust cone size for visibility
    anchor="tail",
    colorscale='Viridis',  # Use a uniform color scale
    showscale=False,  # Hide the color scale
))

# Add markers for both black holes
fig.add_trace(go.Scatter3d(x=[position1[0], position2[0]], y=[position1[1], position2[1]], z=[position1[2], position2[2]], mode='markers',
                           marker=dict(color='black', size=5, opacity=1)))

# Add the axis of rotation for the spinning black holes
axis_length = 2.5
fig.add_trace(go.Scatter3d(x=[position1[0], position1[0]], y=[position1[1], position1[1]], z=[position1[2] - axis_length, position1[2] + axis_length], mode='lines',
                           line=dict(color='red', width=1)))
fig.add_trace(go.Scatter3d(x=[position2[0], position2[0]], y=[position2[1], position2[1]], z=[position2[2] - axis_length, position2[2] + axis_length], mode='lines',
                           line=dict(color='blue', width=1)))

# Set plot title and axis labels
fig.update_layout(title='Gravitational Wave Vectors and Axis of Rotation around Two Spinning Black Holes with Parallel Axes',
                  scene=dict(
                      xaxis_title='X', yaxis_title='Y', zaxis_title='Z',
                      aspectratio=dict(x=1, y=1, z=1),
                  ))

frames = []

for i, t in enumerate(t_values):
    U1, V1, W1 = calculate_gravitational_wave_vector(X, Y, Z, t - (delay_factor * i / num_frames) * 2 * np.pi * mass / speed_of_light, position1)
    U2, V2, W2 = calculate_gravitational_wave_vector(X, Y, Z, t - (delay_factor * i / num_frames) * 2 * np.pi * mass / speed_of_light, position2)
    
    # Calculate opacity based on propagation delay
    opacity = (i / num_frames) * 0.8 + 0.2
    
    # Sum the contributions from both black holes
    U = U1 + U2
    V = V1 + V2
    W = W1 + W2
    
    frame_data = go.Cone(
        x=X.flatten(),
        y=Y.flatten(),
        z=Z.flatten(),
        u=U.flatten(),
        v=V.flatten(),
        w=W.flatten(),
        sizemode="absolute",
        sizeref=0.95,  # Adjust cone size for visibility
        anchor="tail",
        colorscale='edge',  # Use a uniform color scale
        showscale=True,  # Hide the color scale
        opacity=opacity,  # Apply opacity based on propagation delay
    )
    frame = go.Frame(data=[frame_data])
    frames.append(frame)

fig.frames = frames

fig.update_layout(
    updatemenus=[
        dict(
            type='buttons',
            buttons=[
                dict(
                    label='Play',
                    method='animate',
                    args=[
                        None,
                        {
                            'frame': {'duration': 50, 'redraw': True},
                            'fromcurrent': True,
                            'transition': {'duration': 0}
                        }
                    ]
                ),
                dict(
                    label='Pause',
                    method='animate',
                    args=[
                        [None],
                        {
                            'frame': {'duration': 0, 'redraw': False},
                            'mode': 'immediate',
                            'transition': {'duration': 0}
                        }
                    ]
                )
            ],
            showactive=False,
            x=0.1,
            y=0,
            xanchor='right',
            yanchor='top'
        )
    ]
)

# Display the plot
fig.show()
