In [565]:
import torch
from torch import Tensor
import random
import math
from typing import List, Tuple, Union, Iterable
from tqdm.notebook import tqdm

## Config

In [566]:
State = Tensor # vehicle, vuln, feature
StateBatch = Tensor
DefenderAction = Tensor # vehicles
DefenderActionBatch = Tensor
AttackerAction = Tensor # vehicles
AttackerActionBatch = Tensor
Reward = Tensor
RewardBatch = Tensor
Terminal = Tensor
TerminalBatch = Tensor

MAX_VEHICLES = 10
MAX_VULNS = 3
MAX_ATTACK= 2
MEMORY_SIZE = 10000
MEMORY_WARMUP = 2000
BATCH_SIZE = 100
LEARNING_RATE = 0.001
DEVICE = torch.cuda.is_available() and torch.device('cuda') or torch.device('cpu')

In [567]:
def get_random_starting_state() -> State:
    prob_dist = torch.distributions.Normal(
        loc=torch.as_tensor(0.5, dtype=torch.float32),
        scale=torch.as_tensor(0.25, dtype=torch.float32),
    )
    sev_dist = torch.distributions.Normal(
        loc=torch.as_tensor(2, dtype=torch.float32),
        scale=torch.as_tensor(1, dtype=torch.float32),
    )
    state = torch.zeros((MAX_VEHICLES, MAX_VULNS, 4), dtype=torch.float32)
    for i in range(MAX_VEHICLES):
        for j in range(random.randint(0, MAX_VULNS)):
            state[i,j,0] = float(prob_dist.sample().clamp(0.05,1)) # prob
            state[i,j,1] = int(sev_dist.sample().clamp(1,5)) ** 2 # sev
            state[i,j,2] = 0 # compromised
            state[i,j,3] = 0 # membership

    return state

In [568]:
def get_empty_state() -> State:
    return torch.zeros((MAX_VEHICLES, MAX_VULNS, 4), dtype=torch.float32)

In [569]:
def batch_states(states: Union[Tuple[State, ...], List[State]]) -> StateBatch:
    return torch.stack(states)

In [570]:
def get_random_state_batch(num_batches: int) -> StateBatch:
    return batch_states([get_random_starting_state() for _ in range(num_batches)])

In [571]:
def get_attacker_actions(states: StateBatch) -> AttackerActionBatch:
    priority = (states[:,:,:,0] * states[:,:,:,1] * (1-states[:,:,:,2])).sum(dim=-1)
    # find indices of vehicles to attack
    attack = priority.topk(MAX_ATTACK).indices
    # return mask of vehicles to attack
    return torch.zeros((states.shape[0], MAX_VEHICLES), dtype=torch.float32).scatter_(1, attack, 1)

In [572]:
def apply_attacker_actions(states: StateBatch, actions: AttackerActionBatch) -> StateBatch:
    # batch size must be the same
    assert states.shape[0] == actions.shape[0]

    # create a copy so we don't modify the original
    states = states.clone()

    # roll probability for each vulnerability
    probs = torch.rand((states.shape[0], MAX_VEHICLES, MAX_VULNS), dtype=torch.float32)

    # only keep vulns for vehicles that are being attacked
    for i in range(states.shape[0]):
        probs[i, actions[i]!=1, :] = 0 

    # set the vulnerability compromised flag to 1 for each successful attack
    states[:,:,:,2] += (probs > 1-states[:,:,:,0]).float()
    states[:,:,:,2] = states[:,:,:,2].clamp(0, 1)
    return states

In [573]:
def get_random_defender_actions(states: StateBatch) -> DefenderActionBatch:
    return (torch.rand((states.shape[0], MAX_VEHICLES), dtype=torch.float32) > 0.5).float()

In [574]:
def apply_defender_actions(states: StateBatch, actions: DefenderActionBatch) -> StateBatch:
    # batch size must be the same
    assert states.shape[0] == actions.shape[0]

    # create a copy so we don't modify the original
    states = states.clone()

    # set the membership flag to 1 for each vuln in each vehicle that is chosen
    states[:,:,:,3] = 0
    for i in range(states.shape[0]):
        states[i,actions[i,:]==1,:,3] = 1
    return states

In [575]:
def get_defender_utilities(states: StateBatch) -> RewardBatch:
    # identify which platoons contain compromised vehicles
    compromise_free_platoons = (states[:,:,:,2] * states[:,:,:,3]).sum(dim=[-1,-2]) == 0
    # identify size of each platoon
    members = states[:,:,:,3].max(dim=-1).values.sum(dim=-1)
    # return 0 if platoon is compromised, size of platoon otherwise
    return members * compromise_free_platoons.float()

In [576]:
import torch.nn as nn
import torch.nn.functional as F

In [577]:
class DefenderActor(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.LazyConv2d(out_channels=5, kernel_size=2, stride=1)
        self.fc1 = nn.LazyLinear(256)
        self.fc2 = nn.LazyLinear(128)
        self.fc3 = nn.LazyLinear(MAX_VEHICLES)

    def forward(self, x: StateBatch) -> DefenderActionBatch:
        x = torch.cat((
            F.gelu(self.conv1(x)).flatten(start_dim=1),
            x.flatten(start_dim=1), # skip connection after conv
        ))
        x = F.gelu(self.fc1(x))
        x = F.gelu(self.fc2(x))
        x = torch.tanh(self.fc3(x))
        return x

In [578]:
class DefenderCritic(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.LazyConv2d(out_channels=5, kernel_size=2, stride=1)
        self.fc1 = nn.LazyLinear(256)
        self.fc2 = nn.LazyLinear(128)
        self.fc3 = nn.LazyLinear(1)
    def forward(self, x1: StateBatch, x2: DefenderActionBatch) -> Reward:
        x = torch.hstack((
            F.gelu(self.conv1(x1)).flatten(start_dim=1),
            x1.flatten(start_dim=1), # skip connection after conv
            x2,
        ))
        x = F.gelu(self.fc1(x))
        x = F.gelu(self.fc2(x))
        x = self.fc3(x)
        return x

In [579]:
actor = DefenderActor()
actor_target = DefenderActor()
actor_target.load_state_dict(actor.state_dict())
critic = DefenderCritic()
critic_target = DefenderCritic()
critic_target.load_state_dict(critic.state_dict())

<All keys matched successfully>

In [580]:
actor_optimizer = torch.optim.Adam(actor.parameters(), lr=LEARNING_RATE)
actor_scheduler = torch.optim.lr_scheduler.StepLR(actor_optimizer, step_size=100, gamma=0.9)
critic_optimizer = torch.optim.Adam(critic.parameters(), lr=LEARNING_RATE)
critic_scheduler = torch.optim.lr_scheduler.StepLR(critic_optimizer, step_size=100, gamma=0.9)

In [581]:
from collections import deque
from dataclasses import dataclass

@dataclass
class Transition:
    state: State
    action: DefenderAction
    reward: Reward
    next_state: State
    terminal: Terminal

@dataclass
class TransitionBatch:
    states: StateBatch
    actions: DefenderActionBatch
    rewards: RewardBatch
    next_states: StateBatch
    terminals: TerminalBatch

memory: deque[Transition] = deque(maxlen=MEMORY_SIZE)

In [582]:
def sample_memory(batch_size: int) -> TransitionBatch:
    samples = random.sample(memory, batch_size)
    return TransitionBatch(
        states=torch.stack([s.state for s in samples]),
        actions=torch.stack([s.action for s in samples]),
        rewards=torch.stack([s.reward for s in samples]),
        next_states=torch.stack([s.next_state for s in samples]),
        terminals=torch.stack([s.terminal for s in samples]),
    )

## Training

### Warmup memory

In [583]:
states = get_random_state_batch(BATCH_SIZE)

for _ in tqdm(range(MEMORY_WARMUP // BATCH_SIZE)):
    defender_actions = get_random_defender_actions(states)
    next_states = apply_defender_actions(states, defender_actions)
    attacker_actions = get_attacker_actions(next_states)
    next_states = apply_attacker_actions(next_states, attacker_actions)
    rewards = get_defender_utilities(next_states)
    terminals = rewards == 0

    # track next states as empty if terminal
    next_states[terminals] = get_empty_state()
    
    for i in range(BATCH_SIZE):
        memory.append(Transition(
            state=states[i],
            action=defender_actions[i],
            reward=rewards[i],
            next_state=next_states[i],
            terminal=torch.as_tensor(rewards[i] == 0),
        ))

    # reset environment for terminal states
    states = next_states.clone()
    states[terminals] = get_random_state_batch(int(terminals.sum()))

  0%|          | 0/20 [00:00<?, ?it/s]

In [606]:
sample_memory(1000).rewards.sum()

tensor(1178.)

### Train loop