### Knowledge Graph Creation for Relation Extraction

In [None]:
def create_graphs(dataset):
    graphs = []
    for item in dataset:
        entity_to_id = {}
        node_features, edge_index, edge_attr = [], [], []
        for entity_key in ['subj', 'obj']:
            entity_info = (item[f'{entity_key}_type'], tuple(item['tokens'][item[f'{entity_key}_start']:item[f'{entity_key}_end']+1]))
            if entity_info not in entity_to_id:
                pos_one_hot = torch.tensor(item['stanford_pos_one_hot'][item[f'{entity_key}_start']], dtype=torch.float)
                ner_one_hot = torch.tensor(item['stanford_ner_one_hot'][item[f'{entity_key}_start']], dtype=torch.float)
                node_feature = torch.cat((pos_one_hot, ner_one_hot), dim=0)
                node_features.append(node_feature)
                entity_to_id[entity_info] = len(node_features) - 1
        dependency_path = get_shortest_dependency_path(item['stanford_head'], item['subj_start'], item['obj_start'])
        for token_idx in dependency_path:
            if token_idx not in entity_to_id:
                pos_one_hot = torch.tensor(item['stanford_pos_one_hot'][token_idx], dtype=torch.float)
                ner_one_hot = torch.tensor(item['stanford_ner_one_hot'][token_idx], dtype=torch.float)
                node_feature = torch.cat((pos_one_hot, ner_one_hot), dim=0)
                node_features.append(node_feature)
                entity_to_id[token_idx] = len(node_features) - 1
        subj_id = entity_to_id[(item['subj_type'], tuple(item['tokens'][item['subj_start']:item['subj_end']+1]))]
        obj_id = entity_to_id[(item['obj_type'], tuple(item['tokens'][item['obj_start']:item['obj_end']+1]))]
        edge_index.append([subj_id, obj_id])
        edge_attr.append(relation_to_one_hot(item['relation']))
        node_features = torch.stack(node_features)
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_attr = torch.stack(edge_attr)
        graph_data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_attr)
        graphs.append(graph_data)
    return graphs