In [1]:
import torch
import torch.nn as nn
from config import *
from gnn_layer import GraphAttentionLayer


configs = Config()

class CustomGAT(nn.Module):
    def __init__(self, configs, name):
        super(CustomGAT, self).__init__()
        self.name = name
        in_dim = configs.feat_dim
        self.gnn_dims = [in_dim] + [int(dim) for dim in configs.gnn_dims.strip().split(',')]

        self.gnn_layers = len(self.gnn_dims) - 1
        self.att_heads = [int(att_head) for att_head in configs.att_heads.strip().split(',')]
        
        if name == "v_custom_gat":
            self.name_list = configs.name_list_v_gat.strip().split(',')
        
        self.gnn_layer_stack = nn.ModuleDict()
        for i in range(self.gnn_layers):
            in_dim = self.gnn_dims[i] * self.att_heads[i - 1] if i != 0 else self.gnn_dims[i]
            self.gnn_layer_stack.update({
                "{}".format(self.name_list[i]): GraphAttentionLayer(self.att_heads[i], in_dim, self.gnn_dims[i + 1], configs.dp, name=self.name_list[i])
            })

    def forward(self, feat_in, adj=None):
        for i, gnn_layer in enumerate(self.gnn_layer_stack):
            feat_in = gnn_layer(feat_in, adj)
        return feat_in

model = CustomGAT(configs, "v_custom_gat")

for (name, module) in model.named_modules():
    print(name, '\n', module)

  from .autonotebook import tqdm as notebook_tqdm


 
 CustomGAT(
  (gnn_layer_stack): ModuleDict(
    (v_clause-level_gat): GraphAttentionLayer (4096 -> 256)
    (v_sentence-level_gat): GraphAttentionLayer (256 -> 32)
    (v_text-level_gat): GraphAttentionLayer (32 -> 4)
  )
)
gnn_layer_stack 
 ModuleDict(
  (v_clause-level_gat): GraphAttentionLayer (4096 -> 256)
  (v_sentence-level_gat): GraphAttentionLayer (256 -> 32)
  (v_text-level_gat): GraphAttentionLayer (32 -> 4)
)
gnn_layer_stack.v_clause-level_gat 
 GraphAttentionLayer (4096 -> 256)
gnn_layer_stack.v_clause-level_gat.H 
 Linear(in_features=4096, out_features=4096, bias=True)
gnn_layer_stack.v_sentence-level_gat 
 GraphAttentionLayer (256 -> 32)
gnn_layer_stack.v_sentence-level_gat.H 
 Linear(in_features=256, out_features=256, bias=True)
gnn_layer_stack.v_text-level_gat 
 GraphAttentionLayer (32 -> 4)
gnn_layer_stack.v_text-level_gat.H 
 Linear(in_features=32, out_features=32, bias=True)
