In [1]:
from agents import BasicAttackerAgent, WolpertingerDefenderAgent
from evaluation import Evaluator
from game import GameConfig
from models import StateShapeData
from vehicles import JsonVehicleProvider, Vehicle, Vulnerability

In [2]:
vehicle_provider=JsonVehicleProvider("../subgame/python/solutions.json")
game_config=GameConfig(
    max_vehicles=30,
    cycle_every=3,
    cycle_num=5,
    cycle_allow_platoon=False
)

attacker=BasicAttackerAgent(1)
defender=WolpertingerDefenderAgent(
    state_shape_data=StateShapeData(
        num_vehicles=game_config.max_vehicles,
        num_vehicle_features=Vehicle.get_shape()[0],
        num_vulns=vehicle_provider.max_vulns,
        num_vuln_features=Vulnerability.get_shape()[0]
    )
)

engine = Evaluator(
    vehicle_provider=vehicle_provider,
    game_config=game_config,
    num_rounds=1000
)
engine.reset()



#### State shape expectations

In [3]:
from models import StateTensorBatch
import torch

batch_size = 5
state_batch = engine.game.state.as_tensors(defender.state_shape_data)
state_batch = StateTensorBatch(
    vulnerabilities=torch.cat([state_batch.vulnerabilities] * batch_size),
    vehicles=torch.cat([state_batch.vehicles] * batch_size),
)
shape = state_batch.vulnerabilities.shape
print("vulnerabilities", shape)
assert shape[0] == batch_size
assert shape[1] == game_config.max_vehicles
assert shape[2] == defender.state_shape_data.num_vulns
assert shape[3] == defender.state_shape_data.num_vuln_features

shape = state_batch.vehicles.shape
print("vehicles", state_batch.vehicles.shape)
assert shape[0] == batch_size
assert shape[1] == game_config.max_vehicles
assert shape[2] == defender.state_shape_data.num_vehicle_features

vulnerabilities torch.Size([5, 30, 7, 4])
vehicles torch.Size([5, 30, 2])


#### Proto actions shape expectations

In [4]:
proto_action_batch = defender.actor(state_batch)
shape = proto_action_batch.members.shape
print(shape)
assert len(shape) == 3
assert shape[0] == batch_size
assert shape[1] == 1
assert shape[2] == defender.state_shape_data.num_vehicles

shape = proto_action_batch.monitor.shape
print(shape)
assert len(shape) == 3
assert shape[0] == batch_size
assert shape[1] == 1
assert shape[2] == defender.state_shape_data.num_vehicles

torch.Size([5, 1, 30])
torch.Size([5, 1, 30])


#### Critic can accept unconverted proto-output from actor

In [5]:
print("state", state_batch.vehicles.shape, state_batch.vulnerabilities.shape)
print("actions", proto_action_batch.members.shape, proto_action_batch.monitor.shape)
q_values = defender.critic(state_batch, proto_action_batch)
shape = q_values.shape
print(shape)
assert len(shape) == 2
assert shape[0] == batch_size
assert shape[1] == 1

state torch.Size([5, 30, 2]) torch.Size([5, 30, 7, 4])
actions torch.Size([5, 1, 30]) torch.Size([5, 1, 30])
torch.Size([5, 1])




#### Proto action collapse expectations

In [6]:
actions = defender.collapse_proto_actions(proto_action_batch)
shape = actions.members.shape
print("members", shape)
assert shape[0] == batch_size
assert shape[1] > 1 # should propose multiple actions for each proto-action
assert shape[2] == defender.state_shape_data.num_vehicles
shape = actions.monitor.shape
print("monitor", shape)
assert shape[0] == batch_size
assert shape[1] > 1 # should propose multiple actions for each proto-action
assert shape[2] == defender.state_shape_data.num_vehicles

members torch.Size([5, 5, 30])
monitor torch.Size([5, 5, 30])


#### Critic shape expectations

In [7]:
print("state", state_batch.vehicles.shape, state_batch.vulnerabilities.shape)
print("actions", actions.members.shape, actions.monitor.shape)
q_values = defender.critic(state_batch, actions)
shape = q_values.shape
print(shape)

assert len(shape) == 2
assert shape[0] == batch_size
assert shape[1] == actions.members.shape[1]

print(q_values.argmax(dim=1))

state torch.Size([5, 30, 2]) torch.Size([5, 30, 7, 4])
actions torch.Size([5, 5, 30]) torch.Size([5, 5, 30])
torch.Size([5, 5])
tensor([3, 3, 3, 3, 4])
