Put libraries here

In [1]:
# !pip install dgl

In [2]:
import torch, os, pickle, sys,torch.nn.init as init, dgl
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
from collections import Counter
import dgl.function as fn
from dgl.nn.functional import edge_softmax

# from GAT import GAT
%load_ext autoreload
%autoreload 2

In [3]:
script_path = os.path.abspath("utils\\constans.py")  # Replace __file__ with the path to your script if in a notebook

# Determine the project directory by moving up two levels (adjust as needed)
project_directory = os.path.dirname(os.path.dirname(script_path))

print("Script Path:", script_path)
print("Project Directory:", project_directory)

Script Path: C:\Users\edayo\Downloads\4y2t\THSST-2\ug_thesis\ER_GAT\utils\constans.py
Project Directory: C:\Users\edayo\Downloads\4y2t\THSST-2\ug_thesis\ER_GAT


Code related to GAT

In [4]:

class DialogueGraphDataLoader(DataLoader):
    def __init__(self, node_features_list, edge_index_list, batch_size=1, shuffle=False):
        graph_dataset = DialogueGraphDataset(node_features_list, edge_index_list)
        super().__init__(graph_dataset, batch_size, shuffle, collate_fn=dialogue_graph_collate_fn)

class DialogueGraphDataset(Dataset):
    def __init__(self, node_features_list, edge_index_list):
        self.node_features_list = node_features_list
        self.edge_index_list = edge_index_list

    def __len__(self):
        return len(self.edge_index_list)

    def __getitem__(self, idx):
        return self.node_features_list[idx], self.edge_index_list[idx]

def dialogue_graph_collate_fn(batch):
    node_features_list, edge_index_list = zip(*batch)
    
    node_features_list_combined = []
    num_nodes_seen = 0

    for node_features, edge_index in zip(node_features_list, edge_index_list):
        # Assuming node_features is a tuple (text_embeddings, speakers_list)
        text_embeddings, speakers_list = node_features
        combined_features = (text_embeddings, speakers_list)

        node_features_list_combined.append(combined_features)

        # Translate the range of edge_index
        edge_index_list.append(edge_index + num_nodes_seen)
        num_nodes_seen += len(text_embeddings)

    # Merge the dialogue graphs into a single graph with multiple connected components
    node_features_combined = [torch.cat(features, 1) for features in zip(*node_features_list_combined)]
    edge_index = torch.cat(edge_index_list, 1)

    return node_features_combined, edge_index

# get one hot encoding to normalize values
def get_ohe(edge_types):
    # Number of classes
    num_classes = torch.max(edge_types) + 1

    # Convert to one-hot encoding
    one_hot_encoding = F.one_hot(edge_types, num_classes).float()

    return one_hot_encoding

# print(edge_feats)

In [46]:
class GATLayerWithEdgeType(nn.Module):
    def __init__(self, num_in_features_per_head, num_out_features_per_head, num_heads, num_edge_types):
        super(GATLayerWithEdgeType, self).__init__()
        self.num_in_features_per_head = num_in_features_per_head
        self.num_out_features_per_head = num_out_features_per_head
        self.num_heads = num_heads
        self.num_edge_types = num_edge_types

        # Linear projection for node features
        torch.manual_seed(42)
        self.linear_proj = nn.Linear(self.num_in_features_per_head, self.num_heads * self.num_out_features_per_head)
        
        # Edge type embeddings
        torch.manual_seed(42)  # Set your desired seed value
        self.edge_type_embedding = nn.Embedding(self.num_edge_types, self.num_heads)
        
    def forward(self, input_data, edge_type):
        node_features, edge_indices = input_data

        # Linear projection for node features
        print("node_features.shape: ",node_features.shape, " edge_indices: ", edge_indices.shape)
        print("edge_type.shape: ",  edge_type.shape)
        h_linear = self.linear_proj(node_features.view(-1, self.num_in_features_per_head))
        print("h_linear.shape after linear_proj of node_features: ",h_linear.shape)
        h_linear = h_linear.view(-1, self.num_heads, self.num_out_features_per_head)
        print("h_linear.shape after view: ",h_linear.shape)
        # Transpose dimensions of h_linear to match edge_type_embedding's shape
        h_linear = h_linear.permute(0, 2, 1)
        print("h_linear.shape after permuting dimension: ",h_linear.shape)

        # Edge type embedding
        edge_type_embedding = self.edge_type_embedding(edge_type).transpose(0, 1)
        print("edge_type_embedding.shape after transpose: ",edge_type_embedding.shape)

        # Perform matrix multiplication
        attention_scores = torch.matmul(h_linear, edge_type_embedding).squeeze(-1)
        print("attention_scores..shape after matmul h_linear and edge_type_emb: ",attention_scores.shape)

        # Softmax to get attention coefficients
        attention_coefficients = F.softmax(attention_scores, dim=-1)

         # Weighted sum of neighbor node representations
        print("attention_coefficients.shape after softmax: ",attention_coefficients.shape)
#       the one below is for edges
        updated_representation = torch.matmul(attention_coefficients.transpose(1, 2), h_linear).mean(dim=2)
#         the one below is for attention heads
#         updated_representation = torch.matmul(attention_coefficients.transpose(1, 2), h_linear).sum(dim=1)
        print("updated_representation.shape after matmul of trasposed attn_coef and h_linear and sum at dim=1: ",updated_representation.shape)

        return updated_representation, attention_coefficients
    
class GATWithEdgeType(nn.Module):
    def __init__(self, num_of_layers, num_heads_per_layer, num_features_per_layer, num_edge_types):
        super(GATWithEdgeType, self).__init__()

        self.gat_net = nn.ModuleList()

        for layer in range(num_of_layers):
            num_in_features = num_heads_per_layer[layer - 1] * num_features_per_layer[layer - 1] if layer > 0 else num_features_per_layer[0]
            num_out_features = num_heads_per_layer[layer] * num_features_per_layer[layer]
            self.gat_net.append(GATLayerWithEdgeType(num_in_features, num_out_features, num_heads_per_layer[layer], num_edge_types))

    def forward(self, node_features, edge_indices, edge_types):
        h = node_features

        attention_scores = []

        for layer in self.gat_net:
            h, attention_coefficients = layer((h, edge_indices), edge_types)
            attention_scores.append(attention_coefficients)

        return h, attention_scores

class EGATConv(nn.Module):
    r"""
    
    Description
    -----------
    Apply Graph Attention Layer over input graph. EGAT is an extension
    of regular `Graph Attention Network <https://arxiv.org/pdf/1710.10903.pdf>`__ 
    handling edge features, detailed description is available in
    `Rossmann-Toolbox <https://pubmed.ncbi.nlm.nih.gov/34571541/>`__ (see supplementary data).
     The difference appears in the method how unnormalized attention scores :math:`e_{ij}`
     are obtain:
        
    .. math::
        e_{ij} &= \vec{F} (f_{ij}^{\prime})

        f_{ij}^{\prim} &= \mathrm{LeakyReLU}\left(A [ h_{i} \| f_{ij} \| h_{j}]\right)

    where :math:`f_{ij}^{\prim}` are edge features, :math:`\mathrm{A}` is weight matrix and 
    :math: `\vec{F}` is weight vector. After that resulting node features 
    :math:`h_{i}^{\prim}` are updated in the same way as in regular GAT. 
   
    Parameters
    ----------
    in_node_feats : int
        Input node feature size :math:`h_{i}`.
    in_edge_feats : int
        Input edge feature size :math:`f_{ij}`.
    out_node_feats : int
        Output nodes feature size.
    out_edge_feats : int
        Output edge feature size.
    num_heads : int
        Number of attention heads.
    bias : bool, optional
        If True, learns a bias term. Defaults: ``True``.
        
    Examples
    ----------
    >>> import dgl
    >>> import torch as th
    >>> from dgl.nn import EGATConv
    >>> 
    >>> num_nodes, num_edges = 8, 30
    >>>#define connections
    >>> u, v = th.randint(num_nodes, num_edges), th.randint(num_nodes, num_edges) 
    >>> graph = dgl.graph((u,v))    

    >>> node_feats = th.rand((num_nodes, 20)) 
    >>> edge_feats = th.rand((num_edges, 12))
    >>> egat = EGATConv(in_node_feats=20,
                          in_edge_feats=12,
                          out_node_feats=15,
                          out_edge_feats=10,
                          num_heads=3)
    >>> #forward pass                    
    >>> new_node_feats, new_edge_feats = egat(graph, node_feats, edge_feats)
    >>> new_node_feats.shape, new_edge_feats.shape
    ((8, 3, 12), (30, 3, 10))
    """
    def __init__(self,
                 in_node_feats,
                 in_edge_feats,
                 out_node_feats,
                 out_edge_feats,
                 num_heads,
                 bias=True,
                 **kw_args):
        
        super().__init__()
        self._num_heads = num_heads
        self._out_node_feats = out_node_feats
        self._out_edge_feats = out_edge_feats
        
        self.fc_node = nn.Linear(in_node_feats, out_node_feats * num_heads, bias=bias)
        self.fc_ni = nn.Linear(in_node_feats, out_edge_feats * num_heads, bias=False)
        self.fc_fij = nn.Linear(in_edge_feats, out_edge_feats * num_heads, bias=False)
        self.fc_nj = nn.Linear(in_node_feats, out_edge_feats * num_heads, bias=False)
        
        # Attention parameter
        self.attn = nn.Parameter(torch.FloatTensor(size=(1, num_heads, out_edge_feats)))
        
        if bias:
            self.bias = nn.Parameter(torch.FloatTensor(size=(num_heads * out_edge_feats,)))
        else:
            self.register_buffer('bias', None)
        
        self.reset_parameters()

    def reset_parameters(self):
        """
        Reinitialize learnable parameters.
        """
        torch.manual_seed(42)  # You can use any integer value as the seed
        gain = init.calculate_gain('relu')
        init.xavier_normal_(self.fc_node.weight, gain=gain)
        init.xavier_normal_(self.fc_ni.weight, gain=gain)
        init.xavier_normal_(self.fc_fij.weight, gain=gain)
        init.xavier_normal_(self.fc_nj.weight, gain=gain)
        init.xavier_normal_(self.attn, gain=gain)
        
        if self.bias is not None:
            nn.init.constant_(self.bias, 0)

    def forward(self, graph, nfeats, efeats, get_attention=False):
        r"""
        Compute new node and edge features.

        Parameters
        ----------
        graph : DGLGraph
            The graph.
        nfeats : torch.Tensor
            The input node feature of shape :math:`(*, D_{in})`
            where:
                :math:`D_{in}` is size of input node feature,
                :math:`*` is the number of nodes.
        efeats: torch.Tensor
             The input edge feature of shape :math:`(*, F_{in})`
             where:
                 :math:`F_{in}` is size of input node feauture,
                 :math:`*` is the number of edges.
        get_attention : bool, optional
                Whether to return the attention values. Default to False.
            
        Returns
        -------
        pair of torch.Tensor
            node output features followed by edge output features
            The node output feature of shape :math:`(*, H, D_{out})` 
            The edge output feature of shape :math:`(*, H, F_{out})`
            where:
                :math:`H` is the number of heads,
                :math:`D_{out}` is size of output node feature,
                :math:`F_{out}` is size of output edge feature.            
        """
        
        with graph.local_scope():
            # TODO allow node src and dst feats
            graph.edata['f'] = efeats
            graph.ndata['h'] = nfeats
            # calc edge attention
            # same trick way as in dgl.nn.pytorch.GATConv, but also includes edge feats
            # https://github.com/dmlc/dgl/blob/master/python/dgl/nn/pytorch/conv/gatconv.py#L297
            f_ni = self.fc_ni(nfeats)
            f_nj = self.fc_nj(nfeats)
            f_fij = self.fc_fij(efeats)
            graph.srcdata.update({'f_ni' : f_ni})
            graph.dstdata.update({'f_nj' : f_nj})
            #graph.edata.update({'f_fij' : f_fij})
            # add ni, nj factors
            graph.apply_edges(fn.u_add_v('f_ni', 'f_nj', 'f_tmp'))
            # add fij to node factor
            f_out = graph.edata.pop('f_tmp') + f_fij 
            if self.bias is not None:
                f_out+= self.bias
            f_out = nn.functional.leaky_relu(f_out)
            f_out = f_out.view(-1, self._num_heads, self._out_edge_feats)
            # compute attention factor
            e = (f_out * self.attn).sum(dim=-1).unsqueeze(-1)
            graph.edata['a'] = edge_softmax(graph, e)
            graph.ndata['h_out'] = self.fc_node(nfeats).view(-1, self._num_heads, self._out_node_feats)
            # calc weighted sum 
            graph.update_all(fn.u_mul_e('h_out', 'a', 'm'),
                            fn.sum('m', 'h_out'))

            h_out = graph.ndata['h_out'].view(-1, self._num_heads, self._out_node_feats)
            if get_attention:
                return h_out, f_out, graph.edata.pop('a')
            else:
                return h_out, f_out
    


<h3>Methods definition

In [6]:
def create_node_pairs_list(start_idx, end_idx):
    # Initialize an empty list to store pairs
    list_node_i = []
    list_node_j = []
#     node_pairs_dict = {}
    end_idx = end_idx - start_idx
    start_idx = 0
    for i in range(start_idx, end_idx+1):
        val = 0
        while (val <= 3)  and (i+val <= end_idx):
            target_idx = i+val
#                 print(target_idx)
            if target_idx >= 0:
                list_node_i.append(i)
                list_node_j.append(target_idx)
#                 node_pairs_dict[i] = target_idx
            val = val+1
    
    return [list_node_i, list_node_j]

def create_adjacency_dict(node_pairs):
    adjacency_list_dict = {}

    # Iterate through pairs of nodes
    for i in range(0, len(node_pairs[0])):
        source_node, target_node = node_pairs[0][i], node_pairs[1][i]

#         # Add source node to target node's neighbors
#         if target_node not in adjacency_list_dict:
#             adjacency_list_dict[target_node] = [source_node]
#         else:
#             adjacency_list_dict[target_node].append(source_node)

        # Add target node to source node's neighbors
        if source_node not in adjacency_list_dict:
            adjacency_list_dict[source_node] = [target_node]
        else:
            adjacency_list_dict[source_node].append(target_node)

    return adjacency_list_dict
# print(ranges[:1])

def get_all_adjacency_list(ranges, key=0):
    all_adjacency_list = []
    for range_pair in ranges:
        start_idx, end_idx = range_pair
        
        if key == 0:
            output = create_node_pairs_list(start_idx, end_idx)
            output = create_adjacency_dict(output)
        elif key == 1:
            output = create_node_pairs_list(start_idx, end_idx)
            output = torch.tensor(output)
        else:
            print("N/A")
        all_adjacency_list.append(output)
    return all_adjacency_list

def get_all_edge_type_list(edge_indices, encoded_speaker_list):
    dialogs_len = len(edge_indices)
    whole_edge_type_list = []
    
    for i in range(dialogs_len): #2140 dialogs
        dialog_nodes_pairs = edge_indices[i]
        dialog_speakers = list(encoded_speaker_list[i])
        dialog_len = len(dialog_nodes_pairs.keys())
        edge_type_list = []
#         print(i, " th dialogue")
#         print(i, dialog_speakers)
        for j in range(dialog_len): #num utterances
            src_node = dialog_nodes_pairs[j] # j = key = src node
            node_i_idx = j
            win_len = len(src_node)
            for k in range(win_len):
                node_j_idx = src_node[k] # k = value = targ node
                # edge_types = torch.tensor([0, 1, 2]) 
                # 0: cur-self, 1: past-self, 2: past-other/past-inter
                                
                if node_i_idx == node_j_idx:
                    edge_type_list.append(0)
#                     print("This is 0 ", node_i_idx, node_j_idx)
                else:
                    if dialog_speakers[node_i_idx] != dialog_speakers[node_j_idx]:
                        edge_type_list.append(1)
#                         print("This is 1 ", node_i_idx, node_j_idx)
                    else:
                        edge_type_list.append(2)
#                         print("This is 2 ", node_i_idx, node_j_idx)
        whole_edge_type_list.append(torch.tensor(edge_type_list).to(torch.int64))  
                    
    return whole_edge_type_list

In [7]:
# print(edge_indices[0][0][3])
# len(edge_indices[0].keys())
# list(encoded_speaker_list[1])

In [8]:
# assume this is working
# edge_indices = get_all_adjacency_list(ranges)
# edge_types = get_all_edge_type_list(edge_indices, encoded_speaker_list)
# edge_indices = get_all_adjacency_list(ranges, key=1)

In [9]:
# print((edge_types[:10]))
# edge_indices[:10]
# (updated_representations[0].shape)
# edge_indices[0]
# edge_types[0]

In [10]:
checkFile = os.path.isfile("data/dump/speaker_encoder.pkl")
encoded_speaker_list = []
if checkFile is False:
    print("Run first the prototype_context_encoder to generate this file")
else:
    file = open('data/dump/speaker_encoder.pkl', "rb")
    encoded_speaker_list, ranges = pickle.load(file)
    file.close()

In [11]:
# need update
# checkFile = os.path.isfile("data/dump/all_adjacency_list.pkl")
# adjacency_list = []
# if checkFile is False:
#     adjacency_list = get_all_adjacency_list(ranges)
# else:
#     file = open('data/dump/all_adjacency_list.pkl', "rb")
#     adjacency_list = pickle.load(file)
#     file.close()

In [12]:
# adjacency_list[:2]

In [13]:
# len(adjacency_list)

In [14]:
file_path = 'embed/updated_representation_list.pkl'

# Load the list from the file using pickle
with open(file_path, 'rb') as file:
    updated_representations = pickle.load(file)

In [15]:
# print(updated_representations[0].shape)
# print(updated_representations[0].shape)

<h3> Making Progress...

In [16]:
edge_indices = get_all_adjacency_list(ranges)
sample = edge_indices[0]
edge_types = get_all_edge_type_list(edge_indices, encoded_speaker_list)
edge_indices = get_all_adjacency_list(ranges, key=1)

In [17]:
# sample

In [18]:
updated_representations[0].shape

torch.Size([14, 300])

In [19]:
edge_types[0]

tensor([0, 1, 2, 1, 0, 1, 2, 1, 0, 1, 2, 1, 0, 1, 2, 1, 0, 1, 2, 1, 0, 1, 2, 1,
        0, 1, 2, 1, 0, 1, 2, 1, 0, 1, 2, 1, 0, 1, 2, 1, 0, 1, 2, 1, 0, 1, 2, 0,
        1, 0])

In [47]:
num_in_features = 300
num_out_features = 300
num_heads = 4
num_edge_types = 3

gat_layer = GATLayerWithEdgeType(num_in_features, num_out_features, num_heads, num_edge_types)

i = 0
h_prime, attention_coef = gat_layer((updated_representations[i], edge_indices[i]), edge_types[i])
print(f"dialogue_representation[{i}] shape:", updated_representations[i].shape)
print("Attention coef shape:", attention_coef.shape)
print("h_prime shape:", h_prime.shape)

node_features.shape:  torch.Size([14, 300])  edge_indices:  torch.Size([2, 50])
edge_type.shape:  torch.Size([50])
h_linear.shape after linear_proj of node_features:  torch.Size([14, 1200])
h_linear.shape after view:  torch.Size([14, 4, 300])
h_linear.shape after permuting dimension:  torch.Size([14, 300, 4])
edge_type_embedding.shape after transpose:  torch.Size([4, 50])
attention_scores..shape after matmul h_linear and edge_type_emb:  torch.Size([14, 300, 50])
attention_coefficients.shape after softmax:  torch.Size([14, 300, 50])
updated_representation.shape after matmul of trasposed attn_coef and h_linear and sum at dim=1:  torch.Size([14, 50])
dialogue_representation[0] shape: torch.Size([14, 300])
Attention coef shape: torch.Size([14, 300, 50])
h_prime shape: torch.Size([14, 50])


In [48]:
print(h_prime.shape)
print(h_prime)

torch.Size([14, 50])
tensor([[-2.6777e-02, -3.8597e-02, -1.1790e-02, -3.8597e-02, -2.6777e-02,
         -3.8597e-02, -1.1790e-02, -3.8597e-02, -2.6777e-02, -3.8597e-02,
         -1.1790e-02, -3.8597e-02, -2.6777e-02, -3.8597e-02, -1.1790e-02,
         -3.8597e-02, -2.6777e-02, -3.8597e-02, -1.1790e-02, -3.8597e-02,
         -2.6777e-02, -3.8597e-02, -1.1790e-02, -3.8597e-02, -2.6777e-02,
         -3.8597e-02, -1.1790e-02, -3.8597e-02, -2.6777e-02, -3.8597e-02,
         -1.1790e-02, -3.8597e-02, -2.6777e-02, -3.8597e-02, -1.1790e-02,
         -3.8597e-02, -2.6777e-02, -3.8597e-02, -1.1790e-02, -3.8597e-02,
         -2.6777e-02, -3.8597e-02, -1.1790e-02, -3.8597e-02, -2.6777e-02,
         -3.8597e-02, -1.1790e-02, -2.6777e-02, -3.8597e-02, -2.6777e-02],
        [-1.3126e-02, -1.8014e-02, -7.0997e-03, -1.8014e-02, -1.3126e-02,
         -1.8014e-02, -7.0997e-03, -1.8014e-02, -1.3126e-02, -1.8014e-02,
         -7.0997e-03, -1.8014e-02, -1.3126e-02, -1.8014e-02, -7.0997e-03,
         -1.8014

In [58]:
# source node
# index represent the edge
target_nodes = edge_indices[0][1].tolist()
print(edge_indices[0][1].tolist())

[0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6, 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10, 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 12, 13, 13]


In [82]:
sample = {}
for i in set(target_nodes):
    sample[i] = []

for target_node, idx in zip(target_nodes, range(len(target_nodes))):
    sample[target_node].append([idx, h_prime[target_node][idx].tolist()])

In [84]:
sample

{0: [[0, -0.026776738464832306]],
 1: [[1, -0.018013518303632736], [4, -0.013125997968018055]],
 2: [[2, -0.0007867645472288132],
  [5, -0.008494682610034943],
  [8, -0.005044706631451845]],
 3: [[3, -0.004851858131587505],
  [6, 0.002407621592283249],
  [9, -0.004851858131587505],
  [12, -0.0015193652361631393]],
 4: [[7, -0.003075276967138052],
  [10, 0.0039338660426437855],
  [13, -0.003075276967138052],
  [16, 0.00017807260155677795]],
 5: [[11, -0.0020052697509527206],
  [14, 0.005349091254174709],
  [17, -0.0020052697509527206],
  [20, 0.0015172697603702545]],
 6: [[15, 0.000844285823404789],
  [18, 0.007752683945000172],
  [21, 0.000844285823404789],
  [24, 0.0041534025222063065]],
 7: [[19, -8.43442976474762e-05],
  [22, 0.007420366629958153],
  [25, -8.43442976474762e-05],
  [28, 0.0035506240092217922]],
 8: [[23, 0.0015317141078412533],
  [26, 0.008486921899020672],
  [29, 0.0015317141078412533],
  [32, 0.004794551990926266]],
 9: [[27, 0.001454437617212534],
  [30, 0.0088685

In [85]:
edge_indices[0]

tensor([[ 0,  0,  0,  0,  1,  1,  1,  1,  2,  2,  2,  2,  3,  3,  3,  3,  4,  4,
          4,  4,  5,  5,  5,  5,  6,  6,  6,  6,  7,  7,  7,  7,  8,  8,  8,  8,
          9,  9,  9,  9, 10, 10, 10, 10, 11, 11, 11, 12, 12, 13],
        [ 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6,  4,  5,
          6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10,  8,  9, 10, 11,
          9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 12, 13, 13]])

The GATLayerWithExpEdgeType class does not work!!

In [22]:
# num_in_features = 300
# num_out_features = 300
# num_heads = 4
# num_edge_types = 3

# gat_exp_layer = GATLayerWithExpEdgeType(num_in_features, num_out_features, num_heads, num_edge_types)

# # torch.manual_seed(42)
# i = 0
# attention_scores, attention_coefficients = gat_exp_layer((updated_representations[i], edge_indices[i]), edge_types[i])
# print(f"dialogue_representation[{i}] shape:", updated_representations[i].shape)
# print("Attention coef shape:", attention_coefficients.shape)
# print("Attention score shape:", attention_score.shape)

In [23]:
# print(attention_coef[:10].shape)
# print(attention_score[:10])


In [24]:
attention_coef[0,:,:].shape

torch.Size([300, 50])

In [25]:
# Assuming input_tensor is your tensor of shape (14, 150, 50)
# input_tensor = torch.rand((14, 150, 50))
input_tensor = attention_coef
# Set the value of k for top-k
k = 5

# Initialize a list to store top-k frequent edges for each node
top_k_frequent_edges_per_node = []

# Loop over each node
for node_index in range(input_tensor.shape[0]):
    # Extract the edges for the current node
    node_edges = input_tensor[node_index]

    # Flatten the tensor to have shape (150, 50)
    flat_tensor = node_edges.view(-1)

    # Find the indices of the top-k influential edges for the current node
    top_k_indices = torch.argsort(flat_tensor, descending=True)[:k]

    # Ensure the top-k indices are within the correct range (0-50)
    top_k_indices = top_k_indices % 50

    # Flatten the top-k indices and convert them to a list
    top_k_flat_list = top_k_indices.view(-1).tolist()

    # Count the occurrences of each edge index
    edge_counts = Counter(top_k_flat_list)

    # Find the top-k most frequent edges for the current node
    top_k_frequent_edges = [edge for edge, count in edge_counts.most_common(k)]

    # Append the result to the list
    top_k_frequent_edges_per_node.append(top_k_frequent_edges)

# Print the results
for node_index, edges in enumerate(top_k_frequent_edges_per_node):
    print("Top {} Frequent Edges for Node {}: {}".format(k, node_index, edges))

Top 5 Frequent Edges for Node 0: [25, 39, 23, 5, 45]
Top 5 Frequent Edges for Node 1: [43, 3, 35, 17, 33]
Top 5 Frequent Edges for Node 2: [3, 17, 15, 13, 11]
Top 5 Frequent Edges for Node 3: [11, 25, 27, 33, 13]
Top 5 Frequent Edges for Node 4: [1, 3, 13, 33, 41]
Top 5 Frequent Edges for Node 5: [11, 45, 3, 19, 1]
Top 5 Frequent Edges for Node 6: [31, 39, 1, 27, 11]
Top 5 Frequent Edges for Node 7: [45, 5, 21, 43, 48]
Top 5 Frequent Edges for Node 8: [37, 13, 15, 11, 48]
Top 5 Frequent Edges for Node 9: [46, 38, 34, 30, 2]
Top 5 Frequent Edges for Node 10: [14, 18, 42, 22, 46]
Top 5 Frequent Edges for Node 11: [34, 46, 14, 26, 2]
Top 5 Frequent Edges for Node 12: [26, 2, 46, 14, 42]
Top 5 Frequent Edges for Node 13: [46, 22, 34, 10, 26]


In [26]:
num_nodes = 8
num_edges = 30
#define connections
u = torch.randint(num_nodes, (num_edges,))
print("u: ", u)
v = torch.randint(num_nodes, (num_edges,)) 
print("v: ", v)
graph = dgl.graph((u,v))    
print(graph)
node_feats = torch.rand((num_nodes, 20)) 
print("node_feats.shape: ", node_feats.shape)
edge_feats = torch.rand((num_edges, 12))
print("edge_feats.shape: ", edge_feats.shape)


u:  tensor([3, 5, 5, 1, 7, 3, 4, 0, 3, 1, 5, 4, 3, 0, 0, 2, 2, 6, 1, 7, 3, 3, 7, 6,
        5, 5, 6, 5, 2, 3])
v:  tensor([6, 3, 7, 0, 2, 4, 2, 6, 4, 0, 6, 1, 3, 0, 3, 5, 1, 1, 0, 1, 4, 1, 3, 3,
        6, 3, 6, 3, 4, 7])
Graph(num_nodes=8, num_edges=30,
      ndata_schemes={}
      edata_schemes={})
node_feats.shape:  torch.Size([8, 20])
edge_feats.shape:  torch.Size([30, 12])


In [27]:
#use as regular torch/dgl layer work similar as GATConv from dgl library
# egat = EGATConv(in_node_feats=num_node_feats,
#                 in_edge_feats=num_edge_feats,
#                 out_node_feats=10,
#                 out_edge_feats=10,
#                 num_heads=3)

In [28]:
# new_node_feats, new_edge_feats = egat(graph, node_feats, edge_feats)
#new_node_feats.shape = (*, num_heads, out_node_feats)
#new_eode_feats.shape = (*, num_heads, out_edge_feats)

In [29]:
i = 0
graph = dgl.graph((edge_indices[i][0],edge_indices[i][1]))    
edge_feats = get_ohe(edge_types[i])
egat = EGATConv(in_node_feats=300,
                in_edge_feats=3,
                out_node_feats=300,
                out_edge_feats=3,
                num_heads=4)
new_node_feats, new_edge_feats = egat(graph, updated_representations[i], edge_feats)

In [51]:
print(new_edge_feats.shape, new_edge_feats.shape)

torch.Size([50, 4, 3]) torch.Size([50, 4, 3])


In [30]:
new_edge_feats

tensor([[[-2.2043e-03,  1.3946e-01,  1.3355e-01],
         [-1.1724e-02, -1.4797e-02, -4.4118e-04],
         [-8.2909e-03, -3.0248e-03, -3.6196e-03],
         [ 7.0107e-02, -4.0372e-03, -2.1249e-03]],

        [[ 3.1490e-01,  5.6027e-01,  3.6082e-01],
         [-4.0730e-03, -6.6862e-03, -7.2717e-04],
         [ 7.3918e-01, -2.0522e-03, -4.5301e-03],
         [ 9.0749e-01, -4.2760e-03, -3.3131e-03]],

        [[ 3.8424e-01, -2.7835e-03,  9.7298e-02],
         [ 5.8915e-01, -7.6024e-03,  2.3478e-01],
         [ 4.3970e-01,  1.5603e-01, -1.1243e-03],
         [ 7.3272e-01,  3.9763e-01, -4.7669e-03]],

        [[ 3.9363e-01,  5.2008e-01,  2.9877e-01],
         [-4.6900e-03, -5.4016e-03, -5.5663e-04],
         [ 8.6497e-01, -2.4066e-03, -4.9672e-03],
         [ 8.2114e-01, -1.8289e-03, -5.7593e-03]],

        [[-2.7848e-03, -4.0790e-05,  5.7089e-02],
         [-1.3146e-02, -1.3630e-02, -3.3268e-03],
         [-7.6839e-03, -2.8498e-03, -6.6084e-04],
         [-3.3714e-05, -2.7149e-04, -1.549