Put libraries here

In [9]:
import torch, os, pickle
from torch.utils.data import Dataset, DataLoader

Code related to GAT

In [4]:

class DialogueGraphDataLoader(DataLoader):
    def __init__(self, node_features_list, edge_index_list, batch_size=1, shuffle=False):
        graph_dataset = DialogueGraphDataset(node_features_list, edge_index_list)
        super().__init__(graph_dataset, batch_size, shuffle, collate_fn=dialogue_graph_collate_fn)

class DialogueGraphDataset(Dataset):
    def __init__(self, node_features_list, edge_index_list):
        self.node_features_list = node_features_list
        self.edge_index_list = edge_index_list

    def __len__(self):
        return len(self.edge_index_list)

    def __getitem__(self, idx):
        return self.node_features_list[idx], self.edge_index_list[idx]

def dialogue_graph_collate_fn(batch):
    node_features_list, edge_index_list = zip(*batch)
    
    node_features_list_combined = []
    num_nodes_seen = 0

    for node_features, edge_index in zip(node_features_list, edge_index_list):
        # Assuming node_features is a tuple (text_embeddings, speakers_list)
        text_embeddings, speakers_list = node_features
        combined_features = (text_embeddings, speakers_list)

        node_features_list_combined.append(combined_features)

        # Translate the range of edge_index
        edge_index_list.append(edge_index + num_nodes_seen)
        num_nodes_seen += len(text_embeddings)

    # Merge the dialogue graphs into a single graph with multiple connected components
    node_features_combined = [torch.cat(features, 1) for features in zip(*node_features_list_combined)]
    edge_index = torch.cat(edge_index_list, 1)

    return node_features_combined, edge_index


Some methods

In [90]:
def create_node_pairs_dict(start_idx, end_idx):
    # Initialize an empty list to store pairs
    list_node_i = []
    list_node_j = []
#     node_pairs_dict = {}
    end_idx = end_idx - start_idx
    start_idx = 0
    for i in range(start_idx, end_idx+1):
        val = 3
        while(val >= 0):
            target_idx = i-val
#                 print(target_idx)
            if target_idx >= 0:
                list_node_i.append(i)
                list_node_j.append(target_idx)
#                 node_pairs_dict[i] = target_idx
            val = val-1
    
    return [list_node_i, list_node_j]

def create_adjacency_list(node_pairs):
    adjacency_list_dict = {}

    # Iterate through pairs of nodes
    for i in range(0, len(node_pairs[0])):
        source_node, target_node = node_pairs[0][i], node_pairs[1][i]

#         # Add source node to target node's neighbors
#         if target_node not in adjacency_list_dict:
#             adjacency_list_dict[target_node] = [source_node]
#         else:
#             adjacency_list_dict[target_node].append(source_node)

        # Add target node to source node's neighbors
        if source_node not in adjacency_list_dict:
            adjacency_list_dict[source_node] = [target_node]
        else:
            adjacency_list_dict[source_node].append(target_node)

    return adjacency_list_dict
# print(ranges[:1])

def get_all_adjacency_list(ranges):
    all_adjacency_list = []
    for range_pair in ranges:
        start_idx, end_idx = range_pair
        output = create_node_pairs_dict(start_idx, end_idx)
#         print(output)

        output = create_adjacency_list(output)
        all_adjacency_list.append(output)
    return all_adjacency_list

In [48]:
checkFile = os.path.isfile("data/dump/speaker_encoder.pkl")
encoded_speaker_list = []
if checkFile is False:
    print("Run first the prototype_context_encoder to generate this file")
else:
    file = open('data/dump/speaker_encoder.pkl', "rb")
    encoded_speaker_list, ranges = pickle.load(file)
    file.close()

In [93]:
checkFile = os.path.isfile("data/dump/all_adjacency_list.pkl")
adjacency_list = []
if checkFile is False:
    adjacency_list = get_all_adjacency_list(ranges)
else:
    file = open('data/dump/all_adjacency_list.pkl', "rb")
    adjacency_list = pickle.load(file)
    file.close()

In [92]:
# adjacency_list[:2]

[{0: [0],
  1: [0, 1],
  2: [0, 1, 2],
  3: [0, 1, 2, 3],
  4: [1, 2, 3, 4],
  5: [2, 3, 4, 5],
  6: [3, 4, 5, 6],
  7: [4, 5, 6, 7],
  8: [5, 6, 7, 8],
  9: [6, 7, 8, 9],
  10: [7, 8, 9, 10],
  11: [8, 9, 10, 11],
  12: [9, 10, 11, 12],
  13: [10, 11, 12, 13]},
 {0: [0],
  1: [0, 1],
  2: [0, 1, 2],
  3: [0, 1, 2, 3],
  4: [1, 2, 3, 4],
  5: [2, 3, 4, 5],
  6: [3, 4, 5, 6]}]

In [96]:
len(adjacency_list)

2160

In [97]:
file_path = 'embed/updated_representation_list.pkl'

# Load the list from the file using pickle
with open(file_path, 'rb') as file:
    updated_representations = pickle.load(file)

In [103]:
print(updated_representations[0].shape)
print(updated_representations[0])

torch.Size([14, 300])
tensor([[-2.8721e-01,  5.8134e-01, -1.3142e-01,  ...,  1.8101e-02,
         -4.6824e-04,  1.9901e-02],
        [-1.6920e-01,  1.8220e-01, -1.2245e-01,  ...,  1.3620e-02,
         -2.0732e-03,  8.3473e-03],
        [-8.1502e-02,  7.7161e-02, -6.6144e-02,  ...,  1.3882e-02,
          3.4588e-03, -1.4834e-03],
        ...,
        [-4.1162e-03,  2.6335e-02,  2.8706e-02,  ..., -1.6475e-01,
         -1.3978e-01,  2.8344e-02],
        [-1.7579e-02,  1.8380e-02,  3.3130e-02,  ..., -2.5659e-01,
         -2.2489e-01,  1.5857e-02],
        [-2.9680e-02,  8.5039e-03,  3.3814e-02,  ..., -3.8804e-01,
         -2.8153e-01,  1.1250e-03]], requires_grad=True)
