In [None]:
import cupy as cp
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
from scipy import optimize

from cupyx.scipy.special import ellipkinc


class GPU_SchwarzschildRayTracer:
    def __init__(self, beta=0.0, bh_mass=1.0):
        # Imagen de fondo
        self.stars_img = np.array(Image.open("Imágenes/Background_1.jpg").convert("RGB"))
        self.stars_height, self.stars_width = self.stars_img.shape[:2]

        self.image_size_x = self.stars_width
        self.image_size_y = self.stars_height

        # Parámetros físicos
        self.M = bh_mass
        self.rs = 2 * bh_mass

        self.r_start = 100 * self.rs
        self.r_max = self.r_start * 1.1

        self.b_max_x = 20 * self.rs / self.M
        self.b_max_y = self.b_max_x / self.image_size_x * self.image_size_y
        self.b_max = np.sqrt(self.b_max_x**2 + self.b_max_y**2)

        self.beta = beta

        # Precálculo de rm(b) en CPU
        self._precompute_rm_table()

    def _precompute_rm_table(self, N=1000):
        b_vals = np.linspace(self.rs * 1.01, self.b_max, N)
        rm_vals = []

        for b in tqdm(b_vals, desc="Precomputando rm(b)", unit="val"):
            def rmin(R): return 1 / b**2 - 1 / R**2 + self.rs / R**3

            try:
                sol = optimize.root_scalar(rmin, bracket=[self.rs * 1.01, self.r_max], method='brentq')
                rm_vals.append(sol.root)
            except:
                rm_vals.append(self.rs)

        self.b_vals_cpu = b_vals
        self.rm_vals_cpu = np.array(rm_vals)
        self.rm_interp_gpu = cp.interp  # usaremos cp.interp directamente en GPU

    def deflexion_gpu(self, rm):
        s = cp.sqrt((rm - self.rs) * (rm + 3 * self.rs))
        m = (s - rm + 3 * self.rs) / (2 * s)
        arg = cp.sqrt(2 * s / (3 * rm - 3 * self.rs + s))
        arg = cp.clip(arg, -1.0, 1.0)
        varphi = cp.arcsin(arg)

        alpha = (4 * cp.sqrt(rm / s) * ellipkinc(varphi, m)) % (2 * cp.pi) - cp.pi
        return alpha

    def spherical_to_pixel(self, delta, alpha):
        x = ((delta / (2 * np.pi)) + 0.5)
        y = ((alpha - np.pi / 2) / np.pi + 0.5)

        x_pix = (x * self.stars_width).astype(cp.int32) % self.stars_width
        y_pix = (y * self.stars_height).astype(cp.int32) % self.stars_height
        return x_pix, y_pix

    def render(self, save_path="../Figuras/Sombra_GPU.png"):
        # Generar malla de rayos
        i = cp.arange(self.image_size_x)
        j = cp.arange(self.image_size_y)
        ii, jj = cp.meshgrid(i, j, indexing='ij')

        x = (ii - self.image_size_x / 2) / (self.image_size_x / 2) * self.b_max_x
        y = (jj - self.image_size_y / 2) / (self.image_size_y / 2) * self.b_max_y

        b = cp.sqrt(x**2 + y**2)
        theta_view = cp.arctan2(y, x)

        # Filtrar b > rs (los que pueden llegar)
        mask = b > self.rs
        b_valid = cp.where(mask, b, self.rs + 1e-5)

        # Interpolar rm en GPU
        b_vals_gpu = cp.asarray(self.b_vals_cpu)
        rm_vals_gpu = cp.asarray(self.rm_vals_cpu)
        rm = cp.interp(b_valid, b_vals_gpu, rm_vals_gpu)

        # Calcular deflexión
        alpha = self.deflexion_gpu(rm)

        # Calcular dirección aparente
        phi = b_valid / self.b_max * cp.pi
        theta = theta_view

        x_sph = cp.sin(theta) * cp.cos(phi)
        y_sph = cp.sin(theta) * cp.sin(phi)
        z_sph = cp.cos(theta)

        x_rot = x_sph
        y_rot = z_sph
        z_rot = -y_sph

        delta = cp.arcsin(z_rot)
        alpha_view = (cp.arctan2(y_rot, x_rot) - self.beta) % (2 * cp.pi)

        # Convertir a píxeles
        x_pix, y_pix = self.spherical_to_pixel(delta, alpha_view)

        # Samplear colores
        colors = self.stars_img[x_pix.get(), y_pix.get()]  # imagen está en CPU

        # Aplicar máscara para poner negro donde b <= rs
        colors[~mask.get()] = [0, 0, 0]

        # Guardar imagen
        final_img = colors.reshape((self.image_size_x, self.image_size_y, 3))
        plt.imshow(final_img.astype(np.uint8))
        plt.axis('off')
        plt.tight_layout()
        plt.savefig(save_path, dpi=300)
        plt.close()

        print(f"Imagen guardada en: {save_path}")


rt = GPU_SchwarzschildRayTracer(beta=0.0, bh_mass=1.0)
rt.render()
