In [2]:
import jax
import jax.numpy as jnp
import numpy as np
import jax.random as jrandom


import Project_library as pl

In [3]:
@jax.jit
def create_user(id, lat_range, longi_range, key : jax.typing.ArrayLike) -> pl.User:
    """Creates a list of users with random latitude and longitude."""
    # Generate a random key for the user
    key, subkey = jrandom.split(key, 2)

    latitude = jrandom.uniform(key, minval=lat_range[0], maxval=lat_range[1])
    key = jrandom.split(key, 1)[0]
    # Split the key for the next random number generation
    longitude = jrandom.uniform(key, minval=longi_range[0], maxval=longi_range[1])

    position = pl.spherical_to_cartesian(6378,latitude, longitude)

    return pl.User(
        latitude=latitude,
        longitude=longitude,
        position=position,
        id=id
    )


create_users = jax.vmap(create_user, in_axes=(0, None, None, 0))

In [4]:
# Specify dummy values for the evaluation
satellite_position = np.array([600+6378, 0, 0])

# Define users in area below the satellite
lat, long = pl.generate_latitude_longitude_divisions(1,1,[jnp.deg2rad(-3),jnp.deg2rad(3)],
                                          [jnp.deg2rad(-3),jnp.deg2rad(3)])

print(lat,long)

print(pl.calculate_area_on_sphere(6378, lat, long))

# Generate users in the latitude and longitude divisions
key = jax.random.PRNGKey(0)
users_amount = jax.random.poisson(key, 1*pl.calculate_area_on_sphere(6378, lat, long))
user_ids = jnp.arange(users_amount.item())

print(users_amount)
users = []

key = jrandom.split(key, (users_amount.item()))

print(key.shape)

users = create_users(user_ids, lat, long, key)

users

[-0.05235988  0.05235988] [-0.05235988  0.05235988]
[445890.12]
[444903]
(444903, 2)


User(id=[     0      1      2 ... 444900 444901 444902] 
 lat=[-0.02275351  0.00825531 -0.04514982 ... -0.03094647 -0.01136237
 -0.01934456] 
 lon=[-0.04484668  0.013107   -0.00148282 ... -0.00045968 -0.02093906
 -0.00912976])

In [5]:
# Calculate the distance between users and the satellite
distances = pl.calculate_distances(users.position[0], users.position[1], users.position[2], satellite_position)

# Calculate the elevation angles of the users
elevations = pl.calculate_elevations(users.position[0], users.position[1], users.position[2], satellite_position)

print(distances, elevations)
print(users.position[0], users.position[1], users.position[2])

[687.3965  608.8337  671.42163 ... 634.5302  620.6911  616.73566] [1.0353241 1.3923224 1.0822653 ... 1.2235554 1.2996889 1.3263669]
[6369.9375 6377.235  6371.4937 ... 6374.9453 6376.19   6376.541 ] [-285.86224     83.59119     -9.447804  ...   -2.9304507 -133.53098
  -58.217896 ] [-145.10938    52.651756 -287.8677   ... -197.34505   -72.46763
 -123.371925]


In [11]:
# Caculate the path loss
@jax.jit
def calculate_path_loss(distance: float, frequency: float) -> float:
    """Calculate the path loss using the free space path loss formula.
    
    Args:
        distance (float): The distance between the user and the satellite in m.
        frequency (float): The frequency of the signal in MHz.
    """
    # Calculate the path loss using the free space path loss formula
    path_loss = 32.45 + 20 * jnp.log10(frequency) + 20 * jnp.log10(distance*1000)

    return path_loss

# Calculate the gain from the user to the satellite
@jax.jit
def calculate_gain_fixed(elevation:float, frequency:float) -> float:
    """Calculate the gain from the user to the satellite.
    
    Args:
        elevation (float): The elevation angle of the satellite in degrees.
        frequency (float): The frequency of the signal in MHz.
        
    return:
        float: The gain from the user to the satellite in dB.
    """

    user_gain = 0.0
    satellite_gain = 30.0
    # Values from 3gpp TR 38.811 and 38.821

    return satellite_gain, user_gain

def calculate_noise(sky_temperature : float, Bandwidth :float , Kb : float) -> float:
    """Calculate the noise using the formula N = k * T * B.
    
    Args:
        sky_temperature (float): The sky temperature in Kelvin.
        Bandwidth (float): The bandwidth in Hz.
        Kb (float): The Boltzmann constant in J/K.
        
    return:
        float: The noise in dBm.
    """
    # Calculate the noise using the formula N = k * T * B
    noise = 10 * jnp.log10(Kb * sky_temperature * Bandwidth)
    return noise

# Calculate the SNR
@jax.jit
def calculate_snr(power : float, frequency : float, bandwidth : float, distance: float, elevation: float) -> float:
    """Calculate the SNR using the path loss and gain."""
    # Calculate the SNR using the formula SNR = gain - path_loss - nois

    gain_tx, gain_rx = calculate_gain_fixed(elevation, None)
    path_loss = calculate_path_loss(distance, frequency)
    noise = calculate_noise(300, bandwidth, 1.38e-23)

    power_db = 10 * jnp.log10(power)

    snr = power_db + gain_tx + gain_rx - path_loss - noise

    return snr

# Calculate the capacity of each channel with a uniform distribution of bandwidths
@jax.jit
def calculate_capacity(snr: float, bandwidth: float) -> float:
    """Calculate the capacity of a channel using the Shannon formula.
    
    args:
        snr (float): The signal-to-noise ratio in dB.
        bandwidth (float): The bandwidth of the channel in Hz.
    """
    # Calculate the capacity using the Shannon formula
    capacity = bandwidth * jnp.log2(1 + jnp.power(10,snr/10))

    return capacity


bandwidth = 1e6
power = 0.2 # W
snr = calculate_snr(0.2, 2, bandwidth, distances[0], elevations[0])
calculate_capacity(snr, 1e6)


Array(3957885.5, dtype=float32)

In [None]:
calculate_capacities = jax.vmap(calculate_capacity, in_axes=(0, None))
calculate_capacities(
    calculate_snrs(
        calculate_path_loss(distances, 2e9), 
        calculate_gain_fixed(elevations, 2e9), 
        calculate_noise(100, 1e6/len(users), 1.38e-23)
    ), 
    jnp.array(1e6/len(users))
)