In [1]:
%load_ext autoreload 
%autoreload 2

In [2]:
import sys
sys.tracebacklimit = 0

import numpy as np
import networkx as nx 
import matplotlib.pyplot as plt

In [3]:
from pettingzoo.test import parallel_api_test
from solution.custom_gym import CustomGymEnviornment
from solution.trainer import *
from solution.policy_net import *

In [4]:
from core.agent import *
from core.world import * 
from core.render import * 
from core.skill import * 
from core.models import *
from core.message import *

In [5]:
from sar.sar_agent import *
from sar.sar_world import *
from sar.sar_env_params import *

belief_initializer = SARBeliefInitializer(BELIEF_DIMS)
def initialize_swarm(world : BaseWorld):
    swarm = [SARAgent() for i in range(SWARM_SIZE)]
    for agent in swarm:
        agent.set_utility(SARUtilityFunction())
        world.add_agent(agent)
    swarm = initialize_positions_randomly(world, swarm)
    swarm = belief_initializer.initialize_beliefs(swarm)

In [6]:

from sar.urban_gen import * 
from sar.victims import * 
from sar.sar_comm import * 

terrain_generator = UrbanTerrainMapGenerator(padding = MAX_VISIBILITY)
victim_generator = VictimGenerator(padding = MAX_VISIBILITY)
def initialize_terrain(world : BaseWorld):
    terrain_map, population_map = terrain_generator.generate(world._dims)
    map_collection : BaseMapCollection = BaseMapCollection()
    map_collection.add_map("Terrain", terrain_map)
    map_collection.add_map("Population", population_map)

    victim_generator.set_density_map(population_map)
    victim_map = victim_generator.generate(world._dims)

    map_collection.add_map("Victims", victim_map)
    return map_collection


In [7]:
from sar.energy import EnergyModel
from sar.victims import VictimModel
from solution.sar_action_interpreter import *
from solution.encoder_net import *
from models.complex_model import * 

world = SARWorld(dims = WORLD_DIMS,
              swarm_initializer= initialize_swarm,
              generation_pipeline=initialize_terrain
              )
world.add_model("energy_model", EnergyModel())
world.add_model("victim_model", VictimModel())
world.reset()

In [8]:
policy_net= PolicyNet(1, 7, 4)
target_net=  PolicyNet(1, 7, 4)
encoder_model = Encoder()
decoder_model = Decoder()

complex_model = ComplexModel(
    policy_net= policy_net, 
    encoder_net = encoder_model, 
    decoder_net = decoder_model
)

comms_protocol = SARCommunicationProtocol(encoder_model, decoder_model)
action_interpreter = SARActionInterpreter(BELIEF_DIMS)

custom_gym : CustomGymEnviornment = CustomGymEnviornment(world, action_interpreter, comms_protocol)

In [9]:
custom_gym.reset(42)

({1: {'Belief': array([0., 0., 0., 0., 0.]),
   'Victims': array([[1., 0., 1., 0., 1., 1., 1.],
          [0., 0., 1., 0., 0., 1., 1.],
          [0., 0., 1., 1., 0., 1., 0.],
          [0., 1., 1., 0., 1., 0., 1.],
          [1., 1., 1., 1., 1., 1., 1.],
          [1., 0., 1., 1., 1., 1., 1.],
          [1., 0., 1., 0., 1., 1., 1.]]),
   'Energy': 100.0},
  2: {'Belief': array([0., 0., 0., 0., 0.]),
   'Victims': array([[0., 1., 1., 1., 0., 1., 1.],
          [0., 0., 0., 1., 1., 1., 0.],
          [1., 0., 1., 0., 0., 1., 0.],
          [0., 1., 0., 0., 0., 1., 1.],
          [1., 0., 1., 0., 0., 1., 0.],
          [0., 0., 1., 1., 0., 1., 1.],
          [0., 0., 0., 1., 0., 1., 0.]]),
   'Energy': 100.0},
  3: {'Belief': array([0., 0., 0., 0., 0.]),
   'Victims': array([[1., 1., 1., 1., 1., 0., 1.],
          [0., 1., 1., 0., 1., 1., 1.],
          [1., 0., 0., 1., 1., 1., 1.],
          [1., 1., 0., 0., 0., 0., 0.],
          [0., 1., 1., 1., 0., 0., 1.],
          [1., 1., 0., 0.,

### Testing

In [10]:
parallel_api_test(custom_gym, num_cycles=1_000)
custom_gym.reset()

Passed Parallel API test


({1: {'Belief': array([0., 0., 0., 0., 0.]),
   'Victims': array([[1., 0., 1., 0., 1., 1., 1.],
          [0., 0., 1., 0., 0., 1., 1.],
          [0., 0., 1., 1., 0., 1., 0.],
          [0., 1., 1., 0., 1., 0., 1.],
          [1., 1., 1., 1., 1., 1., 1.],
          [1., 0., 1., 1., 1., 1., 1.],
          [1., 0., 1., 0., 1., 1., 1.]]),
   'Energy': 100.0},
  2: {'Belief': array([0., 0., 0., 0., 0.]),
   'Victims': array([[0., 1., 1., 1., 0., 1., 1.],
          [0., 0., 0., 1., 1., 1., 0.],
          [1., 0., 1., 0., 0., 1., 0.],
          [0., 1., 0., 0., 0., 1., 1.],
          [1., 0., 1., 0., 0., 1., 0.],
          [0., 0., 1., 1., 0., 1., 1.],
          [0., 0., 0., 1., 0., 1., 0.]]),
   'Energy': 100.0},
  3: {'Belief': array([0., 0., 0., 0., 0.]),
   'Victims': array([[1., 1., 1., 1., 1., 0., 1.],
          [0., 1., 1., 0., 1., 1., 1.],
          [1., 0., 0., 1., 1., 1., 1.],
          [1., 1., 0., 0., 0., 0., 0.],
          [0., 1., 1., 1., 0., 0., 1.],
          [1., 1., 0., 0.,

# Training

In [11]:

from models.base import * 
from models.idqn import * 
from solution.policy_net import PolicyNet
import matplotlib.pyplot as plt

In [12]:
model = IDQN(env = custom_gym,
             feature_extractor= feature_extractor,
             target_net= target_net,
             model= complex_model,
             batch_size=1024,
             device = "cpu"
             )

In [None]:
rewards = train_loop(custom_gym, model, games=10, optimization_passes = 1, seed=42)

Training on thesis.


Training Progress:   0%|          | 0/10 [00:00<?, ?it/s]

Average loss 9.076769726380705
Average loss 8.383879623189568
Average loss 8.20101273828745
Model has been saved.

Starting evaluation on thesis (num_games=1)


Training Progress:  10%|█         | 1/10 [04:35<41:22, 275.86s/it]

Avg reward: 396.729  std: 261.2850236025785
Avg reward per agent, per game:  {1: 286.0, 2: 834.0, 3: 1354.0, 4: 195.0, 5: 99.0, 6: 464.0, 7: 385.0, 8: 296.0, 9: 1172.0, 10: 584.0, 11: 664.0, 12: 574.0, 13: 579.0, 14: 382.0, 15: 289.0, 16: 196.0, 17: 466.0, 18: 98.0, 19: 471.0, 20: 483.0, 21: 195.0, 22: 182.0, 23: 194.0, 24: 194.0, 25: 385.0, 26: 474.0, 27: 756.0, 28: 661.0, 29: 367.0, 30: 386.0, 31: 1087.0, 32: 582.0, 33: 581.0, 34: 755.0, 35: 748.0, 36: 489.0, 37: 97.0, 38: 481.0, 39: 1200.0, 40: 191.0, 41: 861.0, 42: 476.0, 43: 583.0, 44: 655.0, 45: 197.0, 46: 98.0, 47: 95.0, 48: 371.0, 49: 194.0, 50: 288.0, 51: 0.0, 52: 99.0, 53: 295.0, 54: 189.0, 55: 294.0, 56: 393.0, 57: 196.0, 58: 388.0, 59: 382.0, 60: 570.0, 61: 99.0, 62: 290.0, 63: 482.0, 64: 387.0, 65: 296.0, 66: 659.0, 67: 98.0, 68: 198.0, 69: 769.0, 70: 196.0, 71: 562.0, 72: 576.0, 73: 0.0, 74: 198.0, 75: 198.0, 76: 660.0, 77: 391.0, 78: 99.0, 79: 528.0, 80: 290.0, 81: 543.0, 82: 666.0, 83: 433.0, 84: 198.0, 85: 753.0, 86: 0

Training Progress:  20%|██        | 2/10 [09:31<38:18, 287.30s/it]

Avg reward: 433.129  std: 276.9970620042747
Avg reward per agent, per game:  {1: 383.0, 2: 480.0, 3: 380.0, 4: 578.0, 5: 99.0, 6: 765.0, 7: 393.0, 8: 296.0, 9: 481.0, 10: 672.0, 11: 996.0, 12: 759.0, 13: 670.0, 14: 97.0, 15: 99.0, 16: 196.0, 17: 386.0, 18: 0.0, 19: 481.0, 20: 670.0, 21: 197.0, 22: 372.0, 23: 99.0, 24: 742.0, 25: 385.0, 26: 578.0, 27: 393.0, 28: 767.0, 29: 284.0, 30: 482.0, 31: 664.0, 32: 1023.0, 33: 296.0, 34: 669.0, 35: 658.0, 36: 657.0, 37: 384.0, 38: 390.0, 39: 1117.0, 40: 292.0, 41: 953.0, 42: 294.0, 43: 296.0, 44: 483.0, 45: 296.0, 46: 0.0, 47: 294.0, 48: 382.0, 49: 387.0, 50: 291.0, 51: 471.0, 52: 420.0, 53: 489.0, 54: 0.0, 55: 197.0, 56: 182.0, 57: 387.0, 58: 574.0, 59: 662.0, 60: 1133.0, 61: 389.0, 62: 860.0, 63: 198.0, 64: 487.0, 65: 296.0, 66: 198.0, 67: 387.0, 68: 389.0, 69: 487.0, 70: 475.0, 71: 849.0, 72: 572.0, 73: 0.0, 74: 198.0, 75: 675.0, 76: 573.0, 77: 460.0, 78: 99.0, 79: 0.0, 80: 291.0, 81: 196.0, 82: 672.0, 83: 485.0, 84: 384.0, 85: 769.0, 86: 98.0

In [None]:
plt.plot(rewards)