In [1]:
%load_ext autoreload 
%autoreload 2

In [2]:
import sys
sys.tracebacklimit = 0

import numpy as np
import networkx as nx 
import matplotlib.pyplot as plt

In [3]:
from pettingzoo.test import parallel_api_test
from solution.custom_gym import CustomGymEnviornment
from solution.trainer import *
from solution.policy_net import *

In [4]:
from core.agent import *
from core.world import * 
from core.render import * 
from core.skill import * 
from core.models import *
from core.message import *

In [5]:
from sar.sar_agent import *
from sar.sar_world import *
from sar.sar_env_params import *

belief_initializer = SARBeliefInitializer(BELIEF_DIMS)
def initialize_swarm(world : BaseWorld):
    swarm = [SARAgent() for i in range(SWARM_SIZE)]
    for agent in swarm:
        agent.set_utility(SARUtilityFunction())
        world.add_agent(agent)
    swarm = initialize_positions_randomly(world, swarm)
    swarm = belief_initializer.initialize_beliefs(swarm)

In [6]:

from sar.urban_gen import * 
from sar.victims import * 
from sar.sar_comm import * 

terrain_generator = UrbanTerrainMapGenerator(padding = MAX_VISIBILITY)
victim_generator = VictimGenerator(padding = MAX_VISIBILITY)
def initialize_terrain(world : BaseWorld):
    terrain_map, population_map = terrain_generator.generate(world._dims)
    map_collection : BaseMapCollection = BaseMapCollection()
    map_collection.add_map("Terrain", terrain_map)
    map_collection.add_map("Population", population_map)

    victim_generator.set_density_map(population_map)
    victim_map = victim_generator.generate(world._dims)

    map_collection.add_map("Victims", victim_map)
    return map_collection


In [7]:
from sar.energy import EnergyModel
from sar.victims import VictimModel
from solution.sar_action_interpreter import *
from solution.encoder_net import *
from models.complex_model import * 

world = SARWorld(dims = WORLD_DIMS,
              swarm_initializer= initialize_swarm,
              generation_pipeline=initialize_terrain
              )
world.add_model("energy_model", EnergyModel())
world.add_model("victim_model", VictimModel())
world.reset()

In [8]:
policy_net= PolicyNet(1, 7, 4)
target_net=  PolicyNet(1, 7, 4)
encoder_model = Encoder()
decoder_model = Decoder()

complex_model = ComplexModel(
    policy_net= policy_net, 
    encoder_net = encoder_model, 
    decoder_net = decoder_model
)

comms_protocol = SARCommunicationProtocol(encoder_model, decoder_model)
action_interpreter = SARActionInterpreter(BELIEF_DIMS)

custom_gym : CustomGymEnviornment = CustomGymEnviornment(world, action_interpreter, comms_protocol)

In [9]:
custom_gym.reset(42)

({1: {'Belief': array([0., 0., 0., 0., 0.]),
   'Victims': array([[1., 0., 1., 0., 1., 1., 1.],
          [0., 0., 1., 0., 0., 1., 1.],
          [0., 0., 1., 1., 0., 1., 0.],
          [0., 1., 1., 0., 1., 0., 1.],
          [1., 1., 1., 1., 1., 1., 1.],
          [1., 0., 1., 1., 1., 1., 1.],
          [1., 0., 1., 0., 1., 1., 1.]]),
   'Energy': 100.0},
  2: {'Belief': array([0., 0., 0., 0., 0.]),
   'Victims': array([[0., 1., 1., 1., 0., 1., 1.],
          [0., 0., 0., 1., 1., 1., 0.],
          [1., 0., 1., 0., 0., 1., 0.],
          [0., 1., 0., 0., 0., 1., 1.],
          [1., 0., 1., 0., 0., 1., 0.],
          [0., 0., 1., 1., 0., 1., 1.],
          [0., 0., 0., 1., 0., 1., 0.]]),
   'Energy': 100.0},
  3: {'Belief': array([0., 0., 0., 0., 0.]),
   'Victims': array([[1., 1., 1., 1., 1., 0., 1.],
          [0., 1., 1., 0., 1., 1., 1.],
          [1., 0., 0., 1., 1., 1., 1.],
          [1., 1., 0., 0., 0., 0., 0.],
          [0., 1., 1., 1., 0., 0., 1.],
          [1., 1., 0., 0.,

### Testing

In [10]:
parallel_api_test(custom_gym, num_cycles=1_000)
custom_gym.reset()

Passed Parallel API test


({1: {'Belief': array([0., 0., 0., 0., 0.]),
   'Victims': array([[1., 0., 1., 0., 1., 1., 1.],
          [0., 0., 1., 0., 0., 1., 1.],
          [0., 0., 1., 1., 0., 1., 0.],
          [0., 1., 1., 0., 1., 0., 1.],
          [1., 1., 1., 1., 1., 1., 1.],
          [1., 0., 1., 1., 1., 1., 1.],
          [1., 0., 1., 0., 1., 1., 1.]]),
   'Energy': 100.0},
  2: {'Belief': array([0., 0., 0., 0., 0.]),
   'Victims': array([[0., 1., 1., 1., 0., 1., 1.],
          [0., 0., 0., 1., 1., 1., 0.],
          [1., 0., 1., 0., 0., 1., 0.],
          [0., 1., 0., 0., 0., 1., 1.],
          [1., 0., 1., 0., 0., 1., 0.],
          [0., 0., 1., 1., 0., 1., 1.],
          [0., 0., 0., 1., 0., 1., 0.]]),
   'Energy': 100.0},
  3: {'Belief': array([0., 0., 0., 0., 0.]),
   'Victims': array([[1., 1., 1., 1., 1., 0., 1.],
          [0., 1., 1., 0., 1., 1., 1.],
          [1., 0., 0., 1., 1., 1., 1.],
          [1., 1., 0., 0., 0., 0., 0.],
          [0., 1., 1., 1., 0., 0., 1.],
          [1., 1., 0., 0.,

# Training

In [11]:

from models.base import * 
from models.idqn import * 
from solution.policy_net import PolicyNet
import matplotlib.pyplot as plt

In [None]:
model = IDQN(env = custom_gym,
             feature_extractor= feature_extractor,
             target_net= target_net,
             model= complex_model,
             batch_size=1024,
             device = "cuda"
             )

In [None]:
rewards = train_loop(custom_gym, model, games=10, optimization_passes = 1, seed=42)

Training on thesis.


Training Progress:   0%|          | 0/10 [00:00<?, ?it/s]

Average loss 8.960081766419112
Model has been saved.

Starting evaluation on thesis (num_games=1)


Training Progress:  10%|█         | 1/10 [02:35<23:22, 155.88s/it]

Avg reward: 371.784  std: 272.68312992189306
Avg reward per agent, per game:  {1: 0.0, 2: 557.0, 3: 296.0, 4: 98.0, 5: 99.0, 6: 759.0, 7: 197.0, 8: 388.0, 9: 99.0, 10: 579.0, 11: 666.0, 12: 296.0, 13: 290.0, 14: 160.0, 15: 99.0, 16: 99.0, 17: 1018.0, 18: 97.0, 19: 99.0, 20: 577.0, 21: 99.0, 22: 193.0, 23: 196.0, 24: 381.0, 25: 383.0, 26: 171.0, 27: 1095.0, 28: 581.0, 29: 189.0, 30: 1117.0, 31: 384.0, 32: 390.0, 33: 764.0, 34: 393.0, 35: 735.0, 36: 853.0, 37: 288.0, 38: 932.0, 39: 389.0, 40: 359.0, 41: 544.0, 42: 0.0, 43: 579.0, 44: 371.0, 45: 198.0, 46: 650.0, 47: 294.0, 48: 477.0, 49: 177.0, 50: 468.0, 51: 0.0, 52: 98.0, 53: 757.0, 54: 349.0, 55: 197.0, 56: 1005.0, 57: 473.0, 58: 197.0, 59: 190.0, 60: 292.0, 61: 446.0, 62: 748.0, 63: 553.0, 64: 737.0, 65: 393.0, 66: 99.0, 67: 189.0, 68: 292.0, 69: 486.0, 70: 0.0, 71: 190.0, 72: 99.0, 73: 193.0, 74: 198.0, 75: 198.0, 76: 261.0, 77: 387.0, 78: 184.0, 79: 0.0, 80: 0.0, 81: 461.0, 82: 655.0, 83: 387.0, 84: 570.0, 85: 485.0, 86: 76.0, 87: 

Training Progress:  20%|██        | 2/10 [04:26<17:16, 129.50s/it]

Avg reward: 363.474  std: 273.3506819526887
Avg reward per agent, per game:  {1: 0.0, 2: 665.0, 3: 696.0, 4: 99.0, 5: 99.0, 6: 761.0, 7: 480.0, 8: 562.0, 9: 662.0, 10: 847.0, 11: 198.0, 12: 292.0, 13: 99.0, 14: 190.0, 15: 93.0, 16: 99.0, 17: 484.0, 18: 98.0, 19: 0.0, 20: 296.0, 21: 197.0, 22: 0.0, 23: 289.0, 24: 0.0, 25: 292.0, 26: 753.0, 27: 674.0, 28: 295.0, 29: 95.0, 30: 666.0, 31: 657.0, 32: 1514.0, 33: 457.0, 34: 660.0, 35: 295.0, 36: 728.0, 37: 291.0, 38: 485.0, 39: 575.0, 40: 196.0, 41: 196.0, 42: 197.0, 43: 659.0, 44: 99.0, 45: 198.0, 46: 92.0, 47: 288.0, 48: 285.0, 49: 196.0, 50: 563.0, 51: 197.0, 52: 0.0, 53: 484.0, 54: 94.0, 55: 293.0, 56: 468.0, 57: 197.0, 58: 293.0, 59: 98.0, 60: 393.0, 61: 192.0, 62: 485.0, 63: 577.0, 64: 754.0, 65: 393.0, 66: 385.0, 67: 0.0, 68: 99.0, 69: 858.0, 70: 98.0, 71: 570.0, 72: 99.0, 73: 0.0, 74: 198.0, 75: 198.0, 76: 195.0, 77: 291.0, 78: 99.0, 79: 191.0, 80: 0.0, 81: 275.0, 82: 197.0, 83: 578.0, 84: 475.0, 85: 857.0, 86: 0.0, 87: 196.0, 88: 99

Training Progress:  30%|███       | 3/10 [07:19<17:24, 149.22s/it]

Avg reward: 413.398  std: 284.1667461122078
Avg reward per agent, per game:  {1: 0.0, 2: 479.0, 3: 628.0, 4: 388.0, 5: 290.0, 6: 742.0, 7: 296.0, 8: 296.0, 9: 825.0, 10: 673.0, 11: 758.0, 12: 485.0, 13: 484.0, 14: 96.0, 15: 99.0, 16: 0.0, 17: 194.0, 18: 0.0, 19: 666.0, 20: 296.0, 21: 99.0, 22: 0.0, 23: 99.0, 24: 556.0, 25: 387.0, 26: 579.0, 27: 851.0, 28: 580.0, 29: 193.0, 30: 1021.0, 31: 476.0, 32: 1102.0, 33: 489.0, 34: 1206.0, 35: 198.0, 36: 936.0, 37: 98.0, 38: 579.0, 39: 1111.0, 40: 654.0, 41: 1363.0, 42: 664.0, 43: 584.0, 44: 382.0, 45: 198.0, 46: 0.0, 47: 0.0, 48: 278.0, 49: 480.0, 50: 574.0, 51: 578.0, 52: 441.0, 53: 863.0, 54: 97.0, 55: 99.0, 56: 286.0, 57: 99.0, 58: 674.0, 59: 99.0, 60: 670.0, 61: 388.0, 62: 937.0, 63: 385.0, 64: 389.0, 65: 391.0, 66: 373.0, 67: 484.0, 68: 197.0, 69: 745.0, 70: 368.0, 71: 390.0, 72: 488.0, 73: 0.0, 74: 198.0, 75: 584.0, 76: 479.0, 77: 393.0, 78: 99.0, 79: 376.0, 80: 0.0, 81: 197.0, 82: 197.0, 83: 677.0, 84: 572.0, 85: 767.0, 86: 0.0, 87: 198.

Training Progress:  40%|████      | 4/10 [10:03<15:29, 154.96s/it]

Avg reward: 425.954  std: 285.41055671435845
Avg reward per agent, per game:  {1: 283.0, 2: 761.0, 3: 264.0, 4: 578.0, 5: 198.0, 6: 296.0, 7: 578.0, 8: 489.0, 9: 912.0, 10: 665.0, 11: 770.0, 12: 485.0, 13: 392.0, 14: 96.0, 15: 99.0, 16: 196.0, 17: 194.0, 18: 98.0, 19: 483.0, 20: 391.0, 21: 197.0, 22: 0.0, 23: 196.0, 24: 379.0, 25: 387.0, 26: 386.0, 27: 929.0, 28: 736.0, 29: 287.0, 30: 843.0, 31: 294.0, 32: 1132.0, 33: 296.0, 34: 1465.0, 35: 923.0, 36: 486.0, 37: 575.0, 38: 944.0, 39: 1006.0, 40: 1202.0, 41: 1030.0, 42: 743.0, 43: 678.0, 44: 384.0, 45: 669.0, 46: 475.0, 47: 294.0, 48: 189.0, 49: 567.0, 50: 483.0, 51: 570.0, 52: 370.0, 53: 668.0, 54: 97.0, 55: 197.0, 56: 635.0, 57: 99.0, 58: 766.0, 59: 193.0, 60: 671.0, 61: 485.0, 62: 487.0, 63: 295.0, 64: 485.0, 65: 391.0, 66: 1042.0, 67: 484.0, 68: 740.0, 69: 860.0, 70: 99.0, 71: 843.0, 72: 673.0, 73: 0.0, 74: 198.0, 75: 863.0, 76: 387.0, 77: 581.0, 78: 99.0, 79: 376.0, 80: 0.0, 81: 197.0, 82: 99.0, 83: 763.0, 84: 474.0, 85: 489.0, 86:

Training Progress:  50%|█████     | 5/10 [12:35<12:49, 153.81s/it]

Avg reward: 409.039  std: 276.68851707109206
Avg reward per agent, per game:  {1: 289.0, 2: 469.0, 3: 99.0, 4: 480.0, 5: 198.0, 6: 393.0, 7: 390.0, 8: 296.0, 9: 512.0, 10: 1130.0, 11: 198.0, 12: 581.0, 13: 579.0, 14: 97.0, 15: 99.0, 16: 196.0, 17: 196.0, 18: 0.0, 19: 384.0, 20: 296.0, 21: 287.0, 22: 0.0, 23: 99.0, 24: 381.0, 25: 99.0, 26: 855.0, 27: 484.0, 28: 576.0, 29: 285.0, 30: 1456.0, 31: 483.0, 32: 1193.0, 33: 296.0, 34: 1646.0, 35: 642.0, 36: 393.0, 37: 390.0, 38: 1289.0, 39: 372.0, 40: 574.0, 41: 863.0, 42: 578.0, 43: 954.0, 44: 293.0, 45: 393.0, 46: 383.0, 47: 294.0, 48: 99.0, 49: 475.0, 50: 385.0, 51: 578.0, 52: 294.0, 53: 771.0, 54: 98.0, 55: 197.0, 56: 99.0, 57: 574.0, 58: 386.0, 59: 444.0, 60: 944.0, 61: 541.0, 62: 579.0, 63: 390.0, 64: 393.0, 65: 391.0, 66: 566.0, 67: 390.0, 68: 293.0, 69: 1041.0, 70: 291.0, 71: 753.0, 72: 393.0, 73: 0.0, 74: 198.0, 75: 863.0, 76: 387.0, 77: 575.0, 78: 287.0, 79: 98.0, 80: 197.0, 81: 294.0, 82: 99.0, 83: 1035.0, 84: 293.0, 85: 674.0, 86: 

In [None]:
plt.plot(rewards)