Kurzgesagt Star Shader Development

Stellar eruptions

In [14]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.patches import Circle
from matplotlib.colors import LinearSegmentedColormap
from IPython.display import HTML

In [None]:
class ExplosionRing:
    def __init__(self, alpha = 2, beta = 1, max_radius = 1.0, explosion_center = (0, 0)):
        self.alpha = alpha
        self.beta = beta
        self.max_radius = max_radius
        self.explosion_center = explosion_center

        explosion_colours = ["#500000", "#960000", "#FF1E00", "#FF6400", "#FFC832", "#FFF0C8", "#FFFFFF"]

        positions = [0, 0.15, 0.25, 0.35, 0.75, 0.95, 1.0]

        self.cmap = LinearSegmentedColormap.from_list("explosion", list(zip(positions, explosion_colours)))

        external_colors = ["#001020", "#003050", "#0060A0", "#40A0E0", "#80C0FF", "#C0E0FF", "#FFFFFF"]
        self.external_cmap = LinearSegmentedColormap.from_list("external_displacement", list(zip(positions, external_colors)))

    def log_time(self, time):
        return np.log(1 + time * 9) / np.log(10)

    def get_radii(self, time):
        log_time = self.log_time(time)
        blast_radius = self.max_radius * (log_time ** self.alpha)
        falloff_radius = self.max_radius * (log_time ** self.beta)

        return blast_radius, falloff_radius

    def create_combined_rings(self, size = 512, time = 0.5):
        ring_mask, gradient_values = self.create_ring_mask(size, time)

        external_mask, gradient_external = self.create_external_ring(size, time)

        explosion_image = np.zeros((size, size, 4))

        if np.any(ring_mask):
            colours = self.cmap(gradient_values[ring_mask])

        if np.any(external_mask):
            external_colours = self.external_cmap(gradient_external[external_mask])
            external_colours[:, 3] *= 0.6   # Reduce opacity
            
            for i in range(4):
                explosion_image[external_mask, i] = (
                    explosion_image[external_mask, i] * (1 - external_colours[:, 3]) + external_colours[:, i] * external_colours[:, 3]
                )

        return explosion_image, ring_mask, external_mask

    def create_ring_mask(self, size = 512, time = 0.5):
        abscissa = np.linspace(-1.5, 1.5, size)
        ordinate = np.linspace(-1.5, 1.5, size)
        xAxis, yAxis = np.meshgrid(abscissa, ordinate)
        circumference = np.sqrt(xAxis ** 2 + yAxis ** 2)
        blast_radius, falloff_radius = self.get_radii(time)
        ring_mask = (circumference >= blast_radius) & (circumference <= falloff_radius)

        gradient_values = np.zeros_like(circumference)
        ring_region = ring_mask

        if np.any(ring_region):
            ring_distances = (circumference[ring_region] - blast_radius) / (falloff_radius - blast_radius) if falloff_radius > blast_radius else 0

            gradient_values[ring_region] = ring_distances

        return ring_mask, gradient_values

    def plot_static(self, time = 0.5, size = 512, figsize = (8, 8), show_external = True):
        if show_external:
            explosion_image, ring_mask, external_mask = self.create_combined_rings(size, time)
        else:
            ring_mask, gradient_values = self.create_ring_mask(size, time)
            explosion_image = np.zeros((size, size, 4))

            if np.any(ring_mask):
                colours = self.cmap(gradient_values[ring_mask])
                explosion_image[ring_mask] = colours

        fig, ax = plt.subplots(1, 1, figsize = figsize, facecolor = "black")

        ax.set_facecolor("black")
        ax.imshow(explosion_image, extent = [-1.5, 1.5, -1.5, 1.5], interpolation = "bilinear", alpha = 0.9)

        blast_radius, falloff_radius = self.get_radii(time)

        if blast_radius > 0:
            circle_blast = Circle(self.explosion_center, blast_radius, fill = False, color = "#FF4500", linewidth = 2)

            ax.add_patch(circle_blast)

        if falloff_radius > 0:
            circle_falloff = Circle(self.explosion_center, falloff_radius, fill = False, color = "#FFAA00", linewidth = 2)

            ax.add_patch(circle_falloff)

        ax.set_xlim(-1.5, 1.5)
        ax.set_ylim(-1.5, 1.5)
        ax.set_aspect("equal")
        ax.axis("off")

        title = f"Stellar eruption - time = {time: 0.2f}"

        ax.set_title(title, color="white", fontsize = 14)

        plt.tight_layout()

        return fig, ax