In [None]:
import jax
import jax.numpy as jnp
import numpy as np

import Project_library as pl

In [None]:
# Specify dummy values for the evaluation
satellite_position = np.array([600, 0, 0])

#

In [None]:
# Calculate the distance between users and the satellite
distances = calculate_distances(users.position[0], users.position[1], users.position[2], satellite_position)

# Calculate the elevation angles of the users
elevations = calculate_elevations(users.position[0], users.position[1], users.position[2], satellite_position)

# Caculate the path loss
@jax.jit
def calculate_path_loss(distance: float, frequency: float) -> float:
    """Calculate the path loss using the free space path loss formula."""
    # Calculate the path loss using the free space path loss formula
    path_loss = 32.45 + 20 * jnp.log10(frequency) + 20 * jnp.log10(distance*1000)

    return path_loss

# Calculate the gain from the user to the satellite

@jax.jit
def calculate_gain_fixed(elevation:float, frequency:float) -> float:
    """Calculate the gain from the user to the satellite.
    
    Args:
        elevation (float): The elevation angle of the satellite in degrees.
        frequency (float): The frequency of the signal in MHz.
        
    return:
        float: The gain from the user to the satellite in dB.
    """

    user_gain = 0.0
    satellite_gain = 30.0
    # Values from 3gpp TR 38.811 and 38.821

    return satellite_gain+user_gain

def calculate_noise(sky_temperature, Bandwidth, Kb) -> float:
    """Calculate the noise using the formula N = k * T * B.
    
    Args:
        sky_temperature (float): The sky temperature in Kelvin.
        Bandwidth (float): The bandwidth in Hz.
        Kb (float): The Boltzmann constant in J/K.
        
    return:
        float: The noise in dBm.
    """
    # Calculate the noise using the formula N = k * T * B
    noise = 10 * jnp.log10(Kb * sky_temperature * Bandwidth) + 30

    return noise

# Calculate the SNR
@jax.jit
def calculate_snr(path_loss: float, gain: float, noise: float) -> float:
    """Calculate the SNR using the path loss and gain."""
    # Calculate the SNR using the formula SNR = gain - path_loss - nois

    snr = gain - path_loss - noise

    return snr

# Vectorize the functions
calculate_snrs = jax.vmap(calculate_snr, in_axes=(0,None,None))
calculate_path_losses = jax.vmap(calculate_path_loss, in_axes=(0, None))


# Calculate the SNR for each user
calculate_snrs(
    calculate_path_losses(distances, 28e9), 
    calculate_gain_fixed(None, None), 
    calculate_noise(100, 1e6, 1.38e-23)
)

# Calculate the capacity of each channel with a uniform distribution of bandwidths
@jax.jit
def calculate_capacity(snr: float, bandwidth: float) -> float:
    """Calculate the capacity of a channel using the Shannon formula.
    
    args:
        snr (float): The signal-to-noise ratio in dB.
        bandwidth (float): The bandwidth of the channel in Hz.
    """
    # Calculate the capacity using the Shannon formula
    capacity = bandwidth * jnp.log2(1 + jnp.power(10,snr))

    return capacity

print(calculate_path_loss(600, 2e9))

calculate_capacities = jax.vmap(calculate_capacity, in_axes=(0, None))
calculate_capacities(
    calculate_snrs(
        calculate_path_loss(distances, 28e9), 
        calculate_gain_fixed(elevations, 28e9), 
        calculate_noise(100, 1e6, 1.38e-23)
    ), 
    jnp.array(1e6/len(users))
)


