In [4]:
import numpy as np
import torch
import torch.nn as nn
from typing import List, Tuple

In [5]:
class Q_Value_NN(nn.Module):
    """
    Value function approximator
    """

    def __init__(self, num_locs: int, max_cap: int) -> None:
        super().__init__()

        self.num_locs = num_locs
        self.max_cap = max_cap

        # num_embeddings 指定了嵌入层能够表示的离散值的数量
        # embedding_dim 表示嵌入的维度，即每个离散值将被映射到多少维的实数向量
        # padding_idx 是用于指定填充值的索引，当输入序列中出现填充值时，嵌入层会将其映射为全零向量
        self.location_embed = nn.Embedding(
            num_embeddings=num_locs + 1, embedding_dim=100, padding_idx=0
        )

        self.lstm = nn.LSTM(
            input_size=max_cap * 2 + 1,
            hidden_size=200,
            num_layers=1,
            bidirectional=True,
            batch_first=True,
        )

        self.time_embed = nn.Linear(in_features=1, out_features=100)
        self.state_embed = nn.Sequential(
            nn.Linear(in_features=200 * 2 + 100 + 2, out_features=300),
            nn.ReLU(),
            nn.Linear(in_features=300, out_features=300),
            nn.ReLU(),
        )
        self.output_layer = nn.Linear(300, 1)

    def forward(self, data_inputs: List[Tuple[torch.Tensor]]) -> List[torch.Tensor]:

        # path_location_input (num_agents, max_cap*2+1)
        # delay_input (num_agents, max_cap*2+1)
        # current_time_input (num_agents,)
        # other_agents_input (num_agents,)
        # num_requests_input (num_agents,)

        outputs: List[torch.Tensor] = []

        for _, data in enumerate(data_inputs):
            (
                path_location_input,
                delay_input,
                current_time_input,
                other_agents_input,
                num_requests_input,
            ) = data

            path_location_embed = self.location_embed(
                path_location_input
            )  # (num_actions*num_agents, max_cap*2+1, embedding_dim)
            delay_masked = torch.masked_fill(
                delay_input, delay_input == -1, 0
            )  # (num_actions*num_agents, max_cap*2+1)
            path_input = torch.concatenate(
                [path_location_embed, delay_masked], dim=-1
            )  # (num_actions*num_agents, max_cap*2+1, embedding_dim+1)

            path_embed: torch.Tensor = self.lstm(
                path_input
            )  # (num_actions*num_agents, max_cap*2+1, hidden_size*2)

            current_time_embed: torch.Tensor = self.time_embed(
                current_time_input
            )  # (100)
            current_time_embed = (
                current_time_embed.unsqueeze(0)
                .unsqueeze(1)
                .repeat(len(path_embed[0]), len(path_embed[1]), 1)
            )  # (num_actions*num_agents, max_cap*2+1, 100)

            other_agents_input: torch.Tensor = (
                other_agents_input.unsqueeze(-1)
                .unsqueeze(-1)
                .repeat(1, len(path_embed[1]), 1)
            )
            # (num_actions*num_agents, max_cap*2+1, 1)

            num_requests_input: torch.Tensor = (
                num_requests_input.unsqueeze(-1)
                .unsqueeze(-1)
                .repeat(1, len(path_embed[1]), 1)
            )
            # (num_actions*num_agents, max_cap*2+1, 1)

            state_embed_input = torch.concatenate(
                [
                    path_embed,
                    current_time_embed,
                    other_agents_input,
                    num_requests_input,
                ],
                dim=-1,
            ).view(-1, len(path_embed[1]) * (len(path_embed[2]) + 102))
            # (num_actions*num_agents, max_cap*2+1, hidden_size*2+100+1+1)
            state_embed: torch.Tensor = self.state_embed(state_embed_input)
            # (num_actions*num_agents, max_cap*2+1, 300)

            outputs.append(torch.sum(self.output_layer(state_embed), dim=1))
            # (num_actions*num_agents, max_cap*2+1, 1)
            # (num_actions*num_agents, 1)

        return outputs

In [36]:
net1 = Q_Value_NN(5,6)
net2 = Q_Value_NN(5,6).load_state_dict(net1.state_dict())


In [41]:
a = torch.tensor([1,2,3,4])
a.unsqueeze(-1)

tensor([[1],
        [2],
        [3],
        [4]])

In [37]:
def fun(model1, model2):
    params = model1.state_dict()
    print(params)
    print(model2.state_dict())

In [38]:
fun(net1,net2)

OrderedDict([('location_embed.weight', tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,

AttributeError: '_IncompatibleKeys' object has no attribute 'state_dict'