In [85]:
import jax
import jax.numpy as jnp
import jax.random as jrandom

import typing

import matplotlib.pyplot as plt

import Project_library

In [86]:
Project_library.test()

This is a test function from the core module.


In [None]:
# Define the user class

class User(typing.NamedTuple):
    """A user class that represents a user with a unique ID and a list of items."""
    latitude: float
    longitude: float
    id: int
    throughput: float = 0.0 # Consider making a secondary class that is only for all QoS. You kinda just need to add more information to the state of the user.
    position: jax.typing.ArrayLike = jnp.array([0.0, 0.0, 0.0])

    def __repr__(self):
        return f"User(id={self.id} \n lat={self.latitude} \n lon={self.longitude})"
    
# Define the function to create users

@jax.jit
def create_user(id, lat_range, longi_range, key : jax.typing.ArrayLike) -> User:
    """Creates a list of users with random latitude and longitude."""
    # Generate a random key for the user
    key, subkey = jrandom.split(key, 2)

    latitude = jrandom.uniform(key, minval=lat_range[0], maxval=lat_range[1])
    key = jrandom.split(key, 1)[0]
    # Split the key for the next random number generation
    longitude = jrandom.uniform(key, minval=longi_range[0], maxval=longi_range[1])

    position = Project_library.spherical_to_cartesian(6378,latitude, longitude)

    return User(
        latitude=latitude,
        longitude=longitude,
        position=position,
        id=id
    )

# Define the function to create multiple users
create_users = jax.vmap(create_user, in_axes=(0, None, None, 0))


In [88]:
import time

start = time.time()
for i in range(20):
    create_user(i, (0, 10), (0, 10), jrandom.PRNGKey(0))

stop = time.time()

print("Time taken to create 10 users:", stop - start)

start = time.time()
lats = jnp.array([0, 10])
long = jnp.array([0, 10])
ids = jnp.arange(20)
jrandom_key = jrandom.split(jrandom.PRNGKey(0), 20)
# Create 10 users with random latitude and longitude
users = create_users(ids, lats, long, jrandom_key)



#start = time.time()
#create_users(ids, lats, longs, jrandom.PRNGKey(0))
#stop = time.time()

#start = time.time()
#users_list_test = create_users(ids, lats, longs, jrandom.PRNGKey(0))
#stop = time.time()

#print("Time taken to create 10 users with vmap:", stop - start)


#print(users_list_test)

Time taken to create 10 users: 0.36763691902160645


In [89]:
print(users[0])
print(users)

[2.8271997  5.788324   0.68850994 9.536983   8.066345   5.116377
 1.5063298  2.7780986  9.560162   1.6895592  7.120054   7.125561
 1.8494904  1.6308415  2.6688766  4.2351856  4.116026   1.3831294
 5.8395185  8.0540085 ]
User(id=[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19] 
 lat=[2.8271997  5.788324   0.68850994 9.536983   8.066345   5.116377
 1.5063298  2.7780986  9.560162   1.6895592  7.120054   7.125561
 1.8494904  1.6308415  2.6688766  4.2351856  4.116026   1.3831294
 5.8395185  8.0540085 ] 
 lon=[0.71745753 6.251626   4.858401   9.478201   6.5973616  7.7300167
 8.439302   4.218855   0.8808291  8.056244   7.55505    9.911102
 4.568754   9.737341   9.891549   4.0398192  4.074993   0.63726425
 2.099477   0.07335663])


In [90]:
class square_cell(typing.NamedTuple):
    """A class representing a square cell."""
    # Define the fields of the named tuple
    lat: float  # Latitude of the cell center
    longi: float  # Longitude of the cell center
    lat_width : typing.List[float]  # Latitude range of the cell
    longi_width : typing.List[float]  # Longitude range of the cell
    density: float  # Density of users in the cell
    id: int  # ID of the cell
    users_amount : int = 0  # Number of users in the cell
    users: User = 0  # List of users in the cell



Number_of_zones_latitude = 1
Number_of_zones_longitude = 2

# Define the latitude and longitude ranges for the grid
lat_range = (jnp.deg2rad(-10), jnp.deg2rad(10))
longi_range = (jnp.deg2rad(-10), jnp.deg2rad(10))

print(jnp.deg2rad(-10), jnp.deg2rad(10))
print(jnp.deg2rad(-10), jnp.deg2rad(10))

# Generate a grid of latitude and longitude points
point_grid_of_cells = Project_library.generate_latitude_longitude_points( Number_of_zones_latitude, Number_of_zones_longitude,lat_range, longi_range)

# Add an extra dimension according to the density of the cell
density_mesh = jnp.ones((Number_of_zones_latitude, Number_of_zones_longitude)) * 5

# Add an extra dimension to the density mesh to match the shape of the point grid
cell_density_mesh = jnp.stack((*point_grid_of_cells, density_mesh), axis=-1)

width_of_latitude = jnp.abs(lat_range[1] - lat_range[0]) / (Number_of_zones_latitude*2)
width_of_longitude = jnp.abs(longi_range[1] - longi_range[0]) / (Number_of_zones_longitude*2)

print("Width of latitude:", width_of_latitude)
print("Width of longitude:", width_of_longitude)
print("Density mesh shape:", cell_density_mesh)


-0.17453292 0.17453292
-0.17453292 0.17453292
Width of latitude: 0.17453292
Width of longitude: 0.08726646
Density mesh shape: [[[ 0.         -0.08726646  5.        ]
  [ 0.          0.08726647  5.        ]]]


In [91]:
cells = []

key = jrandom.split(jrandom.PRNGKey(0), 5)
# Populate each cell with users
for i in range(cell_density_mesh.shape[0]):
    for j in range(cell_density_mesh.shape[1]):
        print(cell_density_mesh[i, j])
        long_borders = [cell_density_mesh[i, j, 1]-width_of_longitude, cell_density_mesh[i, j, 1]+width_of_longitude]
        lat_borders = [cell_density_mesh[i, j, 0]-width_of_latitude, cell_density_mesh[i, j, 0]+width_of_latitude]
        density = cell_density_mesh[i, j, 2]
        cell_id = i * cell_density_mesh.shape[1] + j
        print(long_borders, lat_borders, density, cell_id)
        # Calculate area of the cell
        cell_area = Project_library.calculate_area_on_sphere(6378,lat_borders, long_borders)
        print("Area of cell:", cell_area)
        print(key)
        amount_of_users = jax.random.poisson(key[0], density * cell_area)
        print(amount_of_users)
        key = jrandom.split(key[2], amount_of_users.item())
        
        # Create users for the cell
        users = create_users(jnp.arange(amount_of_users.item()), lat_borders, long_borders, key)
        # Create the cell
        cells.append(square_cell(
            lat=cell_density_mesh[i, j, 0],
            longi=cell_density_mesh[i, j, 1],
            lat_width=lat_borders,
            longi_width=long_borders,
            density=density,
            id=cell_id,
            users_amount=amount_of_users.item(),
            users=users
        ))
        

[ 0.         -0.08726646  5.        ]
[Array(-0.17453292, dtype=float32), Array(0., dtype=float32)] [Array(-0.17453292, dtype=float32), Array(0.17453292, dtype=float32)] 5.0 0
Area of cell: [2465736.2]
[[1797259609 2579123966]
 [ 928981903 3453687069]
 [4146024105 2718843009]
 [2467461003 3840466878]
 [2285895361  433833334]]
[12323492]
[0.         0.08726647 5.        ]
[Array(7.450581e-09, dtype=float32), Array(0.17453292, dtype=float32)] [Array(-0.17453292, dtype=float32), Array(0.17453292, dtype=float32)] 5.0 1
Area of cell: [2465736.2]
[[2799984767 1105366846]
 [3777617834  145086855]
 [ 915694800 1641710144]
 ...
 [1411171344 1532011838]
 [ 719547843 2037078956]
 [1079085352 2481610848]]
[12330025]


In [None]:
# Calculate distance between users and a satellite.
satellite_position = jnp.array([6730+600, 0.0, 0.0])

@jax.jit
def calculate_distance(x_user, y_user, z_user, satellite_position: jax.typing.ArrayLike) -> float:
    """Calculate the distance between a user and a satellite."""
    # Calculate the distance using the Euclidean formula
    return jnp.sqrt((satellite_position[0] - x_user) ** 2 + (satellite_position[1] - y_user) ** 2 + (satellite_position[2] - z_user) ** 2)

# Calculate the elevation angle of the satellite
@jax.jit
def calculate_elevation(x_user, y_user, z_user, satellite_position : jax.typing.ArrayLike) -> float:
    user_pos = jnp.asarray((x_user, y_user, z_user))
    elev = jnp.arcsin(jnp.dot(satellite_position-user_pos, user_pos/jnp.linalg.norm(user_pos))/jnp.linalg.norm(satellite_position-user_pos))
    return elev

# Calculate the distance between users and the satellite
calculate_distances = jax.vmap(calculate_distance, in_axes=(0,0,0, None))
calculate_elevations = jax.vmap(calculate_elevation, in_axes=(0,0,0, None))

In [101]:
calculate_distance(users.position[0][0], users.position[1][0], users.position[2][0], satellite_position)
jnp.rad2deg(calculate_elevation(users.position[0][0], users.position[1][0], users.position[2][0], satellite_position))

Array(87.798325, dtype=float32)

In [108]:
calculate_distances(users.position[0], users.position[1], users.position[2], satellite_position).block_until_ready()
calculate_elevations(users.position[0], users.position[1], users.position[2], satellite_position).block_until_ready()

Array([1.5323699, 1.3952194, 1.3879647, ..., 1.4723645, 1.4073749,
       1.3908669], dtype=float32)