In [1]:
%run GameEnv.ipynb

from torch.nn import Module, Conv1d, Conv2d, Linear, ReLU, LayerNorm, Sequential, MultiheadAttention
from torch.nn.functional import pad
from torch import cat
from torch import tensor
import math

In [2]:
class SAPnet(Module):
    def __init__(self, env):
        super().__init__()
        
        n_species = env.n_species
        n_foods = env.n_foods
        n_status = env.n_status  # 17
        n_team_pet_data = 7      # atk, hp, lvl, exp, pos, temp_atk, temp_hp
        n_shop_pet_data = 4      # atk, hp, cost, is_frozen
        n_foods_data = 2          # cost, is_frozen
        self.n_data = env.n_data       # gold, lives, wins, turn, lost_last_battle, n_actions_left
        n_actions = env.n_actions
        
        # Team(5x90) -> Conv2d(20) -> RELU -> Conv1d(50) -> RELU -> (5x50)
        team_pet_length = n_team_pet_data + n_species + n_status
        team_nchannels1 = 20
        team_nchannels2 = 50
        team_kernel_size = (3,)
        
        self.conv2d = Conv2d(1, team_nchannels1, (3, team_pet_length), padding =(1,0))
        self.relu1 = ReLU()
        self.conv1d = Conv1d(team_nchannels1, team_nchannels2, (3,), padding=1)
        self.team_norm = LayerNorm(Team.max_size)
        
        # Shop(7x90) -> MultiSelfAttention(92,4) -> Resid+Norm -> Linear(92) -> ReLU -> Linear(32) -> Norm -> (7x32)
        shop_nheads = 4
        shop_input_size = n_shop_pet_data + n_species
        shop_npad = shop_nheads - (shop_input_size % shop_nheads)
        shop_input_size += shop_npad
        shop_output_size = 32
        
        self.shop_pad = (0, shop_npad)
        self.attention1 = MultiheadAttention(shop_input_size, shop_nheads)
        self.attention2 = MultiheadAttention(shop_input_size, shop_nheads)
        self.shop_norm1 = LayerNorm(shop_input_size)
        self.shop_norm2 = LayerNorm(shop_input_size)
        self.ff1 = Sequential(Linear(shop_input_size, 2 * shop_input_size),
                              ReLU(),
                              Linear(2 * shop_input_size, shop_output_size))
        self.shop_norm3 = LayerNorm(shop_output_size)
        
        # Food(3x18) -> MultiSelfAttention(20, 4) -> Resid+Norm -> Linear(40) -> ReLU -> Linear(16) -> Norm -> (3x16)
        foods_nheads = 4
        foods_input_size = n_foods_data + n_foods
        foods_npad = foods_nheads - (foods_input_size % foods_nheads)
        foods_input_size += foods_npad
        foods_output_size = 16
        
        self.foods_pad = (0, foods_npad)
        self.attention3 = MultiheadAttention(foods_input_size, foods_nheads)
        self.foods_norm1 = LayerNorm(foods_input_size)
        self.ff2 = Sequential(Linear(foods_input_size, 2 * foods_input_size),
                              ReLU(),
                              Linear(2 * foods_input_size, foods_output_size))
        self.foods_norm2 = LayerNorm(foods_output_size)
        
        # Data(5) -> Norm -> (5)
        self.data_norm = LayerNorm(self.n_data)
        
        # All(144) -> Linear(144) -> ReLU -> Linear(144) -> ReLU -> Linear(132) -> Mask -> SoftMax -> 122
        final_input_size = Team.max_size*team_nchannels2 + \
                            Shop.max_size*shop_output_size + \
                            env.max_n_foods*foods_output_size + \
                            self.n_data
        final_layer1_size = 400
        final_layer2_size = 200
        final_output_size = n_actions
        
        self.final_ff = Sequential(Linear(final_input_size, final_layer1_size),
                                   ReLU(),
                                   Linear(final_layer1_size, final_layer2_size),
                                   ReLU(),
                                   Linear(final_layer2_size, final_output_size),
                                   # mask?
                                   # softmax?
                                  )
        

    def forward(self, team, shop, foods, data):
        """
        param team: np.array with dim (5, x) or (B, 5, x)
        param shop: np.array with dim (7, x) or (B, 7, x)
        param foods: np.array with dim (3, x) or (B, 3, x)
        param shop: np.array with dim (4,) or (B, 4)
        
        output: tensor(122,) of estimated qvalues for each action
        """
        # convert non-batch input into batch input
        if team.ndim == 3 or shop.ndim == 3 or foods.ndim == 3 or data.ndim == 2:
            assert team.ndim == shop.ndim == foods.ndim == 3 and data.ndim == 2, \
            f'batched input must have ndim (3, 3, 3, 2), got ({team.ndim}, {shop.ndim}, {foods.ndim}, {data.ndim})'
            is_batch = True
        else:
            assert team.ndim == shop.ndim == foods.ndim == 2 and data.ndim == 1, \
            f'non-batched input must have ndim (2, 2, 2, 1), got ({team.ndim}, {shop.ndim}, {foods.ndim}, {data.ndim})'
            is_batch = False
            team = np.expand_dims(team, 0)
            shop = np.expand_dims(shop, 0)
            foods = np.expand_dims(foods, 0)
            data = np.expand_dims(data, 0)
            
        assert team.shape[1] == Team.max_size, f'expected {Team.max_size} but got {team.shape[1]}'
        assert shop.shape[1] == Shop.max_size, f'expected {Shop.max_size} but got {shop.shape[1]}'
        assert foods.shape[1] == env.max_n_foods, f'expected {env.max_n_foods} but got {foods.shape[1]}'
        assert data.shape[1] == self.n_data, f'expected {self.n_data} but got {data.shape[1]}'
            
        # actual net
        team = tensor(np.expand_dims(team, -3))
        x1 = self.conv2d(team).squeeze(-1)
        x1 = self.relu1(x1)
        x1 = self.conv1d(x1)
        x1 = self.team_norm(x1)

        shop = pad(tensor(shop), self.shop_pad)
        x2 = self.shop_norm1(self.attention1(*[shop]*3)[0] + shop)
        x2 = self.shop_norm2(self.attention2(*[x2]*3)[0] + x2)
        x2 = self.shop_norm3(self.ff1(x2))

        foods = pad(tensor(foods), self.foods_pad)
        x3 = self.foods_norm1(self.attention3(*[foods]*3)[0] + foods)
        x3 = self.foods_norm2(self.ff2(x3))

        data = tensor(data)
        x4 = self.data_norm(data)

        # 250 + 224 + 48 + 6 = 528
        full = cat([x.flatten(1) for x in [x1, x2, x3, x4]], 1)
        if not is_batch:
            full = full.flatten()
        
        x = self.final_ff(full)
        
        return x, is_batch
        
            