In [1]:
import jax
import jax.numpy as jnp
import jax.random as jrandom

import typing

import matplotlib.pyplot as plt

import Project_library

In [2]:
Project_library.test()

This is a test function from the core module.


True

In [3]:
# Define the user class

class User(typing.NamedTuple):
    """A user class that represents a user with a unique ID and a list of items."""
    latitude: float
    longitude: float
    id: int
    throughput: float = 0.0 # Consider making a secondary class that is only for all QoS. You kinda just need to add more information to the state of the user.
    position: jax.typing.ArrayLike = jnp.array([0.0, 0.0, 0.0])

    def __repr__(self):
        return f"User(id={self.id} \n lat={self.latitude} \n lon={self.longitude})"
    
# Define the function to create users

@jax.jit
def create_user(id, lat_range, longi_range, key : jax.typing.ArrayLike) -> User:
    """Creates a list of users with random latitude and longitude."""
    # Generate a random key for the user
    key, subkey = jrandom.split(key, 2)

    latitude = jrandom.uniform(key, minval=lat_range[0], maxval=lat_range[1])
    key = jrandom.split(key, 1)[0]
    # Split the key for the next random number generation
    longitude = jrandom.uniform(key, minval=longi_range[0], maxval=longi_range[1])

    position = Project_library.spherical_to_cartesian(latitude, longitude)

    return User(
        latitude=latitude,
        longitude=longitude,
        position=position,
        id=id
    )

# Define the function to create multiple users
create_users = jax.vmap(create_user, in_axes=(0, 0, 0, None))


In [4]:
import time

start = time.time()
for i in range(20):
    create_user(i, (0, 10), (0, 10), jrandom.PRNGKey(0))

stop = time.time()

print("Time taken to create 10 users:", stop - start)

start = time.time()
lats = jnp.ones((20, 2))
longs = jnp.ones((20, 2))

for i in range(10):
    lats.at[i].set([0, 10])
    longs.at[i].set([0, 10])

ids = jnp.arange(20)

start = time.time()
create_users(ids, lats, longs, jrandom.PRNGKey(0))
stop = time.time()

start = time.time()
users_list_test = create_users(ids, lats, longs, jrandom.PRNGKey(0))
stop = time.time()

print("Time taken to create 10 users with vmap:", stop - start)


print(users_list_test)

Time taken to create 10 users: 0.8307781219482422
Time taken to create 10 users with vmap: 0.0029926300048828125
User(id=[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19] 
 lat=[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.] 
 lon=[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.])


In [None]:
class square_cell(typing.NamedTuple):
    """A class representing a square cell."""
    # Define the fields of the named tuple
    lat: float  # Latitude of the cell center
    longi: float  # Longitude of the cell center
    density: float  # Density of users in the cell
    id: int  # ID of the cell
    users: typing.List[User]  # List of users in the cell

Number_of_zones_latitude = 10
Number_of_zones_longitude = 10

# Define the latitude and longitude ranges for the grid
lat_range = (jnp.deg2rad(-10), jnp.deg2rad(10))
longi_range = (jnp.deg2rad(-10), jnp.deg2rad(10))

# Generate a grid of latitude and longitude points
density_grid = Project_library.generate_latitude_longitude_points( Number_of_zones_latitude, Number_of_zones_longitude,lat_range, longi_range)



def poisson_draw_users(density, lattitude_range, longitude_range, key) -> int:
    """Draws users from a Poisson distribution."""
    # Generate a random key for the user
    key, subkey = jrandom.split(key, 2)

    # Calculate the number of users to draw based on the density
    num_users = jnp.random.poisson(density)

    return num_users




SyntaxError: expected ':' (1410344310.py, line 1)