In [3]:
import os
import sys
import glob

try:
    sys.path.append(glob.glob('carla/PythonAPI/carla/dist/carla-*%d.%d-%s.egg' % (
        sys.version_info.major,
        sys.version_info.minor,
        'win-amd64' if os.name == 'nt' else 'linux-x86_64'))[0])
except IndexError:
    pass

import carla
import hydra

from omegaconf import DictConfig, OmegaConf
from hydra.core.config_store import ConfigStore

from core.pgm import PGM
from utils.weather import Weather
from schemas.pgm_schema import PGMModel
from schemas.weather_schema import WeatherSchema, SunSchema

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
cs = ConfigStore.instance()
cs.store(group="weather", name="base_weather_model", node=WeatherSchema)
cs.store(group="weather/sun", name="base_sun_model", node=SunSchema)
cs.store(group="model", name="base_pgm_model", node=PGMModel)
with hydra.initialize(version_base=None, config_path="conf"):
    cfg = hydra.compose(
        config_name="config", 
        overrides=[
            "weather=clear",
            "weather/sun=day"
            
            ])
    # print(OmegaConf.to_yaml(cfg.model))

pgm = PGM(cfg.model)
pgm.print_cpd('T')

+----------+------------+---------------+---------------+
| R        | R(NO_RAIN) | R(LIGHT_RAIN) | R(HEAVY_RAIN) |
+----------+------------+---------------+---------------+
| T(LOW)   | 0.1        | 0.4           | 0.9           |
+----------+------------+---------------+---------------+
| T(HEAVY) | 0.9        | 0.6           | 0.1           |
+----------+------------+---------------+---------------+


In [9]:
states = pgm.get_states()
states.Rain.NO_RAIN.name

'NO_RAIN'

In [10]:
variables = pgm.get_variables()
variables.Rain

'R'

In [11]:
res = pgm.predict_dist(
    [variables.Traffic, variables.Speed], 
    evidence={variables.Rain: states.Rain.LIGHT_RAIN.name})
print(res)

Finding Elimination Order: : : 0it [00:00, ?it/s]
0it [00:00, ?it/s]

+----------+---------+------------+
| T        | S       |   phi(T,S) |
| T(LOW)   | S(LOW)  |     0.2400 |
+----------+---------+------------+
| T(LOW)   | S(HIGH) |     0.1600 |
+----------+---------+------------+
| T(HEAVY) | S(LOW)  |     0.3600 |
+----------+---------+------------+
| T(HEAVY) | S(HIGH) |     0.2400 |
+----------+---------+------------+





In [12]:
res = pgm.predict_state(['T', 'S'], evidence={'R': states.Rain.LIGHT_RAIN.name})
res

Finding Elimination Order: : : 0it [00:00, ?it/s]
0it [00:00, ?it/s]


{'T': 'HEAVY', 'S': 'LOW'}