<a href="https://colab.research.google.com/github/Strojove-uceni/2024-final-letadylka-prochazka-belohlavek/blob/main/architectures.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Architectures
This file contains the description of the main architectures used within this project. Below we provide a detailed description of them.

First we describe models that were used for Q-values predictions:

    - DGN
    - DQN
    - Comm_net,
    
then we go over the methods that were used to aggregate hidden graph representations:

    - SUM
    - GCN

and lastly we describe the **NetMon** class that was originally provided by the authors.




## Q-values prediction
something about..


### Classical MLP


In [None]:
class MLP(nn.Module):
    """
    This is the underlying module for all used models within this work.
    """

    def __init__(self, in_features, mlp_units, activation_fn, activation_on_output = True):
        super(MLP, self).__init__()

        self.activation = activation_fn
        self.dropout = nn.Dropout(0.3)


        self.linear_layers = nn.ModuleList() # Storage for L layers
        previous_units = in_features

        # Transform units into a list
        if isinstance(mlp_units, int):
            mlp_units = [mlp_units]

        # Create a chain of layers
        for units in mlp_units:
            self.linear_layers.append(nn.Linear(previous_units, units))
            previous_units = units

        self.out_features = previous_units
        self.activation_on_ouput = activation_on_output

    # Forward pass
    def forward(self, x):

        # Inter layers
        for module in self.linear_layers[:-1]:
            x = module(x)
            if self.activation is not None:
                x = self.activation(x)
            x = self.dropout(x)

        # Pass through the last layer
        x = self.linear_layers[-1](x)
        if self.activation_on_ouput:
            x = self.activation(x)
            x = self.dropout(x)

        return x

### Attention model

In [None]:
class AttModel(nn.Module):
    """
        Basic attention model with with masking and scaling.
    """

    def __init__(self, in_features, k_features, v_features, out_features, num_heads, activation_fn, vkq_activation_fn):
        super(AttModel, self).__init__()


        self.k_features = k_features
        self.v_features = v_features
        self.num_heads = num_heads      # Number of attention heads

        self.fc_v = nn.Linear(in_features, v_features * num_heads)  # Transforming input features into Values for attention
        self.fc_k = nn.Linear(in_features, k_features * num_heads)  # Transforming input features into Keys for attention
        self.fc_q = nn.Linear(in_features, k_features * num_heads)  # Transforming input values into Queries for attention

        self.fc_out = nn.Linear(v_features * num_heads, out_features)   # Transforms the outputs from all attention heads into output dimension

        self.activation = activation_fn
        self.vkq_activation = vkq_activation_fn     # Activation function that can be applied into Values, Keys, Queries


        """
        Defining the scaling factor for attention as 1/ sqrt(d_k), this is the same as the publishing paper "Attention is All You Need".
        This is done for the purpose of reducing the gradient so it does not become too large. Later you will see that without it, the dot product
        would grow too large without the scaling
        """
        self.attention_scale = 1 / (k_features **0.5)

        self.dropout = nn.Dropout(0.1)

    # Forward pass
    def forward(self, x, mask):
        batch_size, num_agents = x.shape[0], x.shape[1]

        """
        The code below does the following:
            - a linear mapping is applied on the inputs to obtain Values, Keys, Queries
            - the Values, Keys, Queries are then reshaped to separate the different attention heads of the model
            :reshape: will result in (batch_size, num_agents, num_heads, features_per_head)

        Visual representation:
            Input x
            |
            [Linear Layers] -> V, Q, K
            |
            [Optional Activation] (vkq_activation_fn)
            |
            [Reshape for Multi-Head]
            |
            [Transpose for Heads]
            |
            [Compute Attention Weights (Dot Product, Scale, Mask, Softmax)]
            |
            [Apply Attention to Values]
            |
            [Skip Connection]
            |
            [Transpose and Concatenate Heads]
            |
            [Final Linear Layer and Activation]
            |
            Output
        """

        v = self.fc_v(x).view(batch_size, num_agents, self.num_heads, self.v_features)
        q = self.fc_q(x).view(batch_size, num_agents, self.num_heads, self.k_features)
        k = self.fc_k(x).view(batch_size, num_agents, self.num_heads, self.k_features)

        if self.vkq_activation is not None:
            v = self.vkq_activation(v)
            q = self.vkq_activation(q)
            k = self.vkq_activation(k)

        # We rearrange the tensors to shape (batch_size, num_heads, num_agents, features_per_head)
        # This is done so we can perform batch multiplication over the batch size and heads
        q, k, v = q.transpose(1,2), k.transpose(1,2), v.transpose(1,2)

        # Add head axis (we are keeping the same mask for all attention heads)
        mask = mask.unsqueeze(1)    # (batch_size, 1, num_agents, num_agents) (1,1,20,20)

        """
        The attention is calculated as a dot product of all queries with all keys,
            while scaling it with the attention scale so it does not explode.
            - q is of shape             (batch_size, num_heads, num_agents, features_per_head)
            - k transposed is of shape  (batch_size, num_heads, features_per_head, num_agents)
            - the multiplication result is of shape (batch_size, num_heads, num_agents, num_agents)
        :masked_fill sets positions where mask == 0 to a large negative value - removes them from the attention computation practically
        """

        att_weights = torch.matmul(q, k.transpose(2, 3)) * self.attention_scale
        att = att_weights.masked_fill(mask==0, -1e9)
        att = F.softmax(att, dim=-1)    # Softmax is applied along the last dimension to obtain normalized attention probabilities
        att = self.dropout(att)

        # Now we combine the Values with respect to the attention we just computed
        """
            - att is of shape (batch_size, num_heads, num_agents, num_agents)
            - v is of shape (batch_size, num_heads, num_agents, v_features)
            - the multiplication result is of shape (batch_size, num_heads, num_agents, v_features)
        """
        out = torch.matmul(att, v)

        # We add a skip connection
        out  = torch.add(out, v)    # This additionally promotes gradient flow and mitigates vanishing gradient

        # Now "remove" the transpose and concatenate all heads together
        """
            - out is of shape (batch_size, num_heads, num_agents, v_features)
            - out after transpose is of shape (batch_size, num_agents, num_heads, v_features)
            - contiguous() ensures that the tensor is stored in a contiguous chunk of memory so that the reshape for view can happen
            - view is used to reshape the tensor to (batch_size, num_agents, v_features), therefore, we flatten the last two dimensions
                into a single one (num_heads * v_features)
            - final out is of shape  (batch_size, num_agents, num_heads * v_features)
        """

        out = out.transpose(1,2).contiguous().view(batch_size, num_agents, -1)
        out = self.activation(self.fc_out(out)) # Linear map into a desired feature dimension
        out = self.dropout(out)

        return out, att_weights

## DGN
what is it, how it works, provide that code

In [None]:
class DGN(nn.Module):
    """

    """

    def __init__(self, in_features, mlp_units, num_actions, num_heads, num_attention_layers, activation_fn, kv_values):
        super(DGN, self).__init__()

        self.encoder = MLP(in_features, mlp_units, activation_fn)
        self.att_layers = nn.ModuleList()
        hidden_features = self.encoder.out_features

        print("In features of DGN: ", in_features)
        print("MLP units are: ", mlp_units)

        for _ in range(num_attention_layers):
            self.att_layers.append(
                AttModel(hidden_features, kv_values, kv_values, hidden_features, num_heads, activation_fn, activation_fn)
                                   )

        self.q_net = Q_Net(hidden_features * (num_attention_layers + 1), num_actions)

        self.att_weights = []

    def forward(self, x, mask):
        """
        Additional comment to the function:
            - each attention layer refines the representation h by focusing on relevant parts of the input
            - by concatenating the representations the feature set for the Q-network is enhanced, consequently making more informed decisions

        """

        h = self.encoder(x)     # Encodes the input featuers, has a shape of (batch_size, num_agents, hidden_features)
        q_input = h     # Initialize the q_input with encoded features
        self.att_weights.clear()    # Ensuring that attention weights from previous forward passes do not accumulate

        for attention_layer in self.att_layers:
            h, att_weight = attention_layer(h, mask)
            self.att_weights.append(att_weight)

            # Concatenation of outputs
            q_input = torch.cat((q_input, h), dim=-1)

        # Final q_input is of shape (batch_size, num_agents, hidden_features * (num_attention_layers +1))
        q = self.q_net(q_input)

        return q    # is of shape (batch_size, num_agents, num_actions)


### DQN

same

In [None]:
class DQN(nn.Module):
    """
    Introduces simple Deep Feed Forward Neural Network( = MLP) as the encoder.
    """

    def __init__(self, in_features, mlp_units, num_actions, activation_fn):
        super(DQN, self).__init__()

        self.encoder = MLP(in_features, mlp_units, activation_fn)   # Encodes incoming features
        self.q_net = Q_Net(self.encoder.out_features, num_actions)  # Outputs Q-values
        self.activation = activation_fn

    def forward(self, x, mask):
        batch, agent, features = x.shape
        h = self.encoder(x)
        q = self.q_net(h)
        return q


### Comm_net
same

In [None]:
class DQNR(nn.Module):
    """
    Recurrent DQN with an lstm cell.
    """

    def __init__(self, in_features, mlp_units, num_actions, activation_fn):
        super(DQNR, self).__init__()
        self.encoder = MLP(in_features, mlp_units, activation_fn)
        self.lstm = nn.LSTMCell(
            input_size=self.encoder.out_features, hidden_size=self.encoder.out_features
        )
        self.state = None
        self.q_net = Q_Net(self.encoder.out_features, num_actions)

    def get_state_len(self):
        return 2 * self.lstm.hidden_size

    def _state_reshape_in(self, batch_size, n_agents):
        """
        Reshapes the state of shape
            (batch_size, n_agents, self.get_state_len())
        to shape
            (2, batch_size * n_agents, hidden_size).

        :param batch_size: the batch size
        :param n_agents: the number of agents
        """
        self.state = (
            self.state.reshape(
                batch_size * n_agents,
                2,
                self.lstm.hidden_size,
            )
            .transpose(0, 1)
            .contiguous()
        )

    def _state_reshape_out(self, batch_size, n_agents):
        """
        Reshapes the state of shape
            (2, batch_size * n_agents, hidden_size)
        to shape
            (batch_size, n_agents, self.get_state_len()).

        :param batch_size: the batch size
        :param n_agents: the number of agents
        """
        self.state = self.state.transpose(0, 1).reshape(batch_size, n_agents, -1)

    def _lstm_forward(self, x, reshape_state=True):
        """
        A single lstm forward pass

        :param x: Cell input
        :param reshape_state: reshape the state to and from (batch_size, n_agents, -1)
        """
        batch_size, n_agents, feature_dim = x.shape
        # combine agent and batch dimension
        x = x.view(batch_size * n_agents, -1)

        if self.state is None:
            lstm_hidden_state, lstm_cell_state = self.lstm(x)
        else:
            if reshape_state:
                self._state_reshape_in(batch_size, n_agents)
            lstm_hidden_state, lstm_cell_state = self.lstm(
                x, (self.state[0], self.state[1])
            )

        self.state = torch.stack((lstm_hidden_state, lstm_cell_state))
        x = lstm_hidden_state

        # undo combine
        x = x.view(batch_size, n_agents, -1)
        if reshape_state:
            self._state_reshape_out(batch_size, n_agents)

        return x

    def forward(self, x, mask):
        h = self.encoder(x)
        h = self._lstm_forward(h)
        return self.q_net(h)


class CommNet(DQNR):
    """

    """

    def __init__(
        self,
        in_features,
        mlp_units,
        num_actions,
        comm_rounds,
        activation_fn,
    ):
        super().__init__(in_features, mlp_units, num_actions, activation_fn)
        assert comm_rounds >= 0
        self.comm_rounds = comm_rounds

    def forward(self, x, mask):
        batch_size, n_agents, feature_dim = x.shape
        h = self.encoder(x)

        # manually reshape state
        if self.state is not None:
            self._state_reshape_in(batch_size, n_agents)

        h = self._lstm_forward(h, reshape_state=False)

        # explicitly exclude self-communication from mask
        mask = mask * ~torch.eye(n_agents, dtype=bool, device=x.device).unsqueeze(0)

        for _ in range(self.comm_rounds):
            # combine hidden state h according to mask
            # first add up hidden states according to mask
            #    h has dimensions (batch, agents, features)
            #    and mask has dimensions (batch, agents, neighbors)
            #    => we have to transpose the mask to aggregate over all neighbors
            c = torch.bmm(h.transpose(1, 2), mask.transpose(1, 2)).transpose(1, 2)
            # then normalize according to number of neighbors per agent
            c = c / torch.clamp(mask.sum(dim=-1).unsqueeze(-1), min=1)

            # skip connection for hidden state and communication
            h = h + c
            # use new hidden state
            self.state[0] = h.view(batch_size * n_agents, -1)

            # pass through forward module
            h = self._lstm_forward(h, reshape_state=False)

        # manually reshape state in the end
        self._state_reshape_out(batch_size, n_agents)
        return self.q_net(h)




## State aggregation

....

### SUM


In [None]:
class SimpleAggregation(nn.Module):
    def __init__(self, agg: str, mask_eye: bool) -> None:
        super().__init__()
        self.agg = agg
        assert self.agg == "mean" or self.agg == "sum"
        self.mask_eye = mask_eye

    def forward(self, node_features, node_adjacency):
        if self.mask_eye:
            node_adjacency = node_adjacency * ~(
                torch.eye(
                    node_adjacency.shape[1],
                    node_adjacency.shape[1],
                    device=node_adjacency.device,
                )
                .repeat(node_adjacency.shape[0], 1, 1)
                .bool()
            )
        feature_sum = torch.bmm(node_adjacency, node_features)
        if self.agg == "sum":
            return feature_sum
        if self.agg == "mean":
            num_neighbors = torch.clamp(node_adjacency.sum(dim=-1), min=1).unsqueeze(-1)
            return feature_sum / num_neighbors


### GCN
GCN is a graph convolutional operator that handles Message Passing phase within the GNN. Implementation is available at [GCN](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GCNConv.html#torch_geometric.nn.conv.GCNConv) within the pytorch_geometric library that specializes on GNNs.

GCN is based on the spectral approximations of convolutional

## NetMon