# Simulación de un agujero negro de Kerr

## No terminado

No está terminado

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os
from scipy.interpolate import interp1d
from tqdm import tqdm
import gc
from torchdiffeq import odeint

class KerrRayTracerBatch:
    def __init__(self, bh_mass=1.0, bh_a=0, Lz=1, device=None):
        self.G = 1.0
        self.c = 1.0
        self.E = 1.0
        self.M = bh_mass
        self.a = bh_a
        self.Lz= Lz
        self.rs = 2 * self.G * self.M / self.c**2
        self.b_max = 15 * self.rs
        self.device = device if device else (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))

    def system(self, tau, y):
        # y: tensor con shape (N, 6) -> [t, r, theta, phi, pr, ptheta]
        t = y[:, 0]
        r = y[:, 1]
        theta = y[:, 2]
        phi = y[:, 3]
        pr = y[:, 4]
        ptheta = y[:, 5]

        M = self.M
        a = self.a
        E = self.E
        Lz = self.Lz
        Q = self.Q

        sin_theta = torch.sin(theta)
        cos_theta = torch.cos(theta)
        sin2 = sin_theta ** 2
        cos2 = cos_theta ** 2

        Delta = r**2 - 2 * M * r + a**2
        Sigma = r**2 + a**2 * cos2

        term1 = r**2 + a**2

        dt_dtau = (term1 * (E * term1 - a * Lz) / Delta - a * (a * E * sin2 - Lz)) / Sigma
        dphi_dtau = (a * (E * term1 - a * Lz) / Delta - (a * E - Lz / sin2)) / Sigma

        dr_dtau = pr / Sigma
        dtheta_dtau = ptheta / Sigma

        R = (E * term1 - a * Lz)**2 - Delta * (r**2 + (Lz - a * E)**2 + Q)
        Theta = Q - cos2 * (a**2 * (1 - E**2) + Lz**2 / sin2)

        # Derivadas simplificadas (puedes refinar estas derivadas)
        dR_dr = 4 * r * E * (E * term1 - a * Lz) - (2 * r - 2 * M) * (r**2 + (Lz - a * E)**2 + Q) - 2 * Delta * r
        dTheta_dtheta = 2 * cos_theta * sin_theta * (a**2 * (1 - E**2) + Lz**2 / sin2)

        dpr_dtau = dR_dr / (2 * Sigma)
        dptheta_dtau = dTheta_dtheta / (2 * Sigma)

        dydtau = torch.stack([dt_dtau, dr_dtau, dtheta_dtau, dphi_dtau, dpr_dtau, dptheta_dtau], dim=1)

        return dydtau


    def integrate(self, y0, t_span):

        y0 = y0.to(self.device)
        t_span = t_span.to(self.device)

        # torchdiffeq espera function con (t, y) donde y shape (N,6)
        # pero en realidad y shape (batch, variables) para batch o solo (variables,) para uno solo.

        # Por compatibilidad, definimos un wrapper:
        def func(t, y):
            # Si y es 1D (un solo punto), expandimos dims para batch de 1
            if y.dim() == 1:
                y = y.unsqueeze(0)
                out = self.system(t, y)
                return out.squeeze(0)
            else:
                return self.system(t, y)

        sol = odeint(func, y0, t_span, method='dopri5')
        return sol
    
    def color(self, gamma_deg, x1, y1, z1):

        rs= self.rs 

        Rmin=4.32*rs
        Rmax=9.16*rs

        A = (np.sqrt(x1**2 + z1**2) > Rmin) & (np.sqrt(x1**2 + z1**2) < Rmax)
        B = np.abs(np.arctan(y1 / np.sqrt(x1**2 + z1**2))) < 0.01

        if np.any(A & B):
            arg=np.where(A & B)[0][0]
            x1=x1[arg]
            y1=y1[arg]
            z1=z1[arg]
                                
            r=np.sqrt(x1**2 + z1**2)
            argumento=(r-Rmin)/(Rmax-Rmin)

            phi= np.arctan2(z1,x1)

            #t=np.sin(4*phi+np.cos(argumento*4* np.pi))

            f=np.abs(np.cos(argumento*14*np.pi))*(1-argumento)#*t
                            
            f_dop=Rmin*(1-0.75*np.cos(phi)*np.cos(gamma_deg/180*np.pi))*self.c*np.sqrt(self.rs/r)*(r/(r-rs))
                                
            Red=int((220*f+30*f_dop+5))
            Green=int(50*f+8*f_dop+5)
            Blue=int(20*f_dop*f)

            return (Red, Green, Blue)
        
        else:
            return (0,0,0)
        
    def color2(self, x, y, gamma_deg):

        if np.abs(gamma_deg)<0.1: #Evitar errores numéricos de gamma=0
            gamma_deg=0.1

        rs= self.rs 

        Rmin=4.32*rs
        Rmax=9.16*rs

        z=y/np.tan(gamma_deg/180*np.pi) # |y1|=0 -> z sin = y cos

        if  (np.sqrt(x**2 + y**2 + z**2) > Rmin) and (np.sqrt(x**2 + y**2+ z**2) < Rmax):
                                
            r=np.sqrt(x**2 + y**2 + z**2)
            argumento=(r-Rmin)/(Rmax-Rmin)

            phi= np.arctan2(np.sqrt(z**2+y**2),x)

            #t=np.sin(4*phi+np.cos(argumento*4* np.pi))

            f=np.abs(np.cos(argumento*14*np.pi))*(1-argumento)#*t
                            
            f_dop=Rmin*(1-0.75*np.cos(phi)*np.cos(gamma_deg/180*np.pi))*self.c*np.sqrt(self.rs/r)*(r/(r-rs))
                                
            Red=int((220*f+30*f_dop+5))
            Green=int(50*f+8*f_dop+5)
            Blue=int(20*f_dop*f)

            return (Red, Green, Blue)
        
        else:
            return (0,0,0)

    def interpolacion(self, tau_sub, r_sub, phi_sub, theta):
        if len(tau_sub) >= 4:  # mínimo 4 puntos para 'cubic'
            interp_r = interp1d(tau_sub, r_sub, kind='cubic')
            interp_phi = interp1d(tau_sub, phi_sub, kind='cubic')
        else:
            interp_r = interp1d(tau_sub, r_sub, kind='linear')
            interp_phi = interp1d(tau_sub, phi_sub, kind='linear')

        tau_fino = np.linspace(tau_sub[0], tau_sub[-1], num=300)

        r_interp = interp_r(tau_fino)
        phi_interp = interp_phi(tau_fino)

        # Calcular coordenadas 3D interpoladas
        x_interp = r_interp * np.cos(theta) * np.sin(phi_interp)
        y_interp = r_interp * np.sin(theta) * np.sin(phi_interp)
        z_interp = r_interp * np.cos(phi_interp)

        return x_interp, y_interp, z_interp
        
    def render(self, N_x=100, N_y=100, gamma_deg=-10, save_path=None):
        ancho, alto = N_x, N_y
        imagen = Image.new('RGB', (ancho, alto), color=(0,0,0))
        pixeles = imagen.load()
        imagen2 = Image.new('RGB', (ancho, alto), color=(0,0,0))
        pixeles2 = imagen2.load()

        # Posición fija de la cámara
        r0 = 25 * self.rs
        theta0 = torch.pi / 2
        phi0 = 0.
        t0 = 0.

        # Creamos grilla 2D normalizada [-1,1] para definir pequeñas desviaciones angulares
        ys = torch.linspace(-1, 0, N_y//2)
        xs = torch.linspace(-1, 1, N_x)
        yv, xv = torch.meshgrid(ys, xs, indexing='ij')

        # Pequeñas desviaciones para ptheta y pr iniciales (momentum)
        # Ajusta escala para que el ángulo cambie un poco
        scale = 0.05
        pr0 = scale * xv.reshape(-1)
        ptheta0 = scale * yv.reshape(-1)

        # Inicializamos t, r, theta, phi iguales para todos
        t0s = torch.full_like(pr0, t0)
        r0s = torch.full_like(pr0, r0)
        theta0s = torch.full_like(pr0, theta0)
        phi0s = torch.full_like(pr0, phi0)

        # Batch inicial shape (N,H*W), donde N es número de rayos
        y0_batch = torch.stack([t0s, r0s, theta0s, phi0s, pr0, ptheta0], dim=1)

        # Mover a GPU si quieres
        y0_batch = y0_batch.to('cuda')


        r_vals = sol_batch[:, :, 0].numpy()
        phi_vals = sol_batch[:, :, 1].numpy()

        del sol_batch
        gc.collect()

        R=self.matriz_rotacion(gamma_deg=gamma_deg)

        idx = 0
        with tqdm(total=ancho//2, desc="Renderizando", unit="px") as pbar:
            for i in range(ancho//2):
                for j in range(alto//2):
                    b = self.b_max * np.sqrt(xv[j,i]**2 + yv[j,i]**2)

                    r_traj = r_vals[:, idx]
                    phi_traj = phi_vals[:, idx]
                        #print(len(r_traj))

                    theta = np.arctan2(yv[j,i], xv[j,i])

                    tau_original = np.linspace(0, len(r_traj)-1, len(r_traj))  # o el vector real de tau si tienes

                    indices_validos = np.where((r_traj > 6) & (r_traj < 20))[0]

                    if len(indices_validos) > 1:

                        tau_sub = tau_original[indices_validos]
                        r_sub = r_traj[indices_validos]
                        phi_sub = phi_traj[indices_validos]

                        x_interp, y_interp, z_interp = self.interpolacion(tau_sub, r_sub, phi_sub, theta)

                        points = np.vstack((x_interp, y_interp, z_interp)).T  # shape (N, 3)
                        rotated_points = points @ R.T  # shape (N, 3)

                        x1 = rotated_points[:, 0]
                        y1 = rotated_points[:, 1]
                        z1 = rotated_points[:, 2]

                        pixeles[i, j] = self.color(gamma_deg, x1, y1, z1) 
                        pixeles[ancho-i-1, j] = self.color(gamma_deg, -x1, y1, z1) 

                        y_interp=-y_interp
                            
                        points = np.vstack((x_interp, y_interp, z_interp)).T  # shape (N, 3)
                        rotated_points = points @ R.T  # shape (N, 3)

                        x1 = rotated_points[:, 0]
                        y1 = rotated_points[:, 1]
                        z1 = rotated_points[:, 2]

                        pixeles[i, alto-j-1] = self.color(gamma_deg, x1, y1, z1)
                        pixeles[ancho-i-1, alto-j-1] = self.color(gamma_deg, -x1, y1, z1)
                    
                    else:
                        pixeles[i, j] = (0, 0, 0)
                        pixeles[ancho-i-1, j] = (0, 0, 0)
                        pixeles[i, alto-j-1] = (0, 0, 0)
                        pixeles[ancho-i-1, alto-j-1]=(0,0,0)

                    x = self.b_max * xv[j,i]
                    y = self.b_max * yv[j,i]

                    pixeles2[i,j]=self.color2(x, y, gamma_deg)
                    pixeles2[ancho-i-1,j]=self.color2(-x, y , gamma_deg)
                    pixeles2[i,alto-j-1]=self.color2(x, -y, gamma_deg)
                    pixeles2[ancho-i-1, alto-j-1]=self.color2(-x, -y, gamma_deg)

                    idx += 1

                pbar.update(1)

        imagen.save("../Multimedia/Imagen_disco_de_acreción.png")
        combined_image = Image.new('RGB', (2 * ancho, alto))
        combined_image.paste(imagen, (0, 0))
        combined_image.paste(imagen2, (ancho, 0))
        combined_image.save("../Multimedia/Imagen_disco_de_acreción_2.png")

    def matriz_rotacion(self, theta_deg=0, phi_deg=0, gamma_deg=0):
        theta = np.deg2rad(theta_deg)
        phi = np.deg2rad(phi_deg)
        gamma= np.deg2rad(gamma_deg)
        Rz = np.array([
            [np.cos(theta), -np.sin(theta), 0],
            [np.sin(theta),  np.cos(theta), 0],
            [0, 0, 1]
        ])
        Ry = np.array([
            [np.cos(phi), 0, np.sin(phi)],
            [0, 1, 0],
            [-np.sin(phi), 0, np.cos(phi)]
        ])
        Rx = np.array([
            [1, 0, 0],
            [0, np.cos(gamma), np.sin(gamma)],
            [0, -np.sin(gamma), np.cos(gamma)]
        ])
        return Rz @ Ry @ Rx

rt_batch = KerrRayTracerBatch(bh_mass=1.0)
rt_batch.render(N_x=500, N_y=500)

KeyboardInterrupt: 