# DataSet PyG

by Alejandro Fernández Sánchez

## Initialization

In [14]:
from neo4j import GraphDatabase, Driver, basic_auth
from dotenv import load_dotenv
import numpy as np
import os
from typing import Any
import torch

from torch_geometric.data import HeteroData

In [2]:
load_dotenv()

True

In [3]:
DB_HOST = os.getenv("NEO4J_HOST")
DB_PORT = os.getenv("NEO4J_PORT")
DB_USER = os.getenv("NEO4J_USER")
DB_PASS = os.getenv("NEO4J_PASS")
LAST_FM_API_KEY = os.getenv("LAST_FM_API_KEY")

# .env validation
assert DB_HOST is not None and \
    DB_PORT is not None and \
    DB_USER is not None and \
    DB_PASS is not None and \
    LAST_FM_API_KEY is not None, \
    "INVALID .env"

In [4]:
driver = GraphDatabase.driver(f"bolt://{DB_HOST}:{DB_PORT}", auth=basic_auth(DB_USER, DB_PASS))
driver

<neo4j._sync.driver.BoltDriver at 0x7fb54c309040>

In [6]:
!mkdir -p ds

In [17]:
if not torch.cuda.is_available():
    raise Exception("No cuda found!")

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

## Dataset creation

### Helper class

In [54]:
class IdMap:
    """Helper class to store a mapping from strings to ids."""
    def __init__(self):
        self.str_to_id = dict()
        self.id_to_str = list()

    def __len__(self):
        """Return number of terms stored in the IdMap"""
        assert len(self.str_to_id) == len(self.id_to_str)
        return len(self.id_to_str)

    def _get_str(self, i):
        """Returns the string corresponding to a given id (`i`)."""
        return self.id_to_str[i]

    def _get_id(self, s):
        """Returns the id corresponding to a string (`s`). 
        If `s` is not in the IdMap yet, then assigns a new id and returns the new id.
        """
        if s in self.str_to_id:
            idx = self.str_to_id[s]
        else:
            idx = len(self.str_to_id)
            self.str_to_id[s] = idx
            self.id_to_str.append(s)
        assert len(self.str_to_id) == len(self.id_to_str)  
        return idx

    def __getitem__(self, key):
        """If `key` is a integer, use _get_str; 
           If `key` is a string, use _get_id;"""
        if type(key) is int:
            return self._get_str(key)
        elif type(key) is str:
            return self._get_id(key)
        else:
            raise TypeError("Type not supported")

artist_map = IdMap()
release_map = IdMap()
tag_map = IdMap()

### Helper vars

In [None]:
# Nodes
artist_tensor = list()
release_tensor = list()
tag_tensor = list()

# Relationship indexes
worked_in_index = list()
worked_by_index = list()
collab_with_index = list()
musically_related_to_index = list()
personally_related_to_index = list()
linked_to_index = list()
last_fm_match_index = list()
has_tag_index = list()
tags_index = list()

# Relationship attributes
worked_in_attributes = list()
worked_by_attributes = list()
collab_with_attributes = list()
musically_related_to_attributes = list()
personally_related_to_attributes = list()
linked_to_attributes = list()
last_fm_match_attributes = list()
has_tag_attributes = list()
tags_attributes = list()

### Artists

In [57]:
# TODO: How should we handle dates?


query = """
    MATCH (a)-[r]->(b) return a, r, type(r) as r_type, b limit 5
    ;
"""
with driver.session() as session:
    records = session.run(query)  # type: ignore
    for i, record in enumerate(records):
        print(i)
        # print(record["a"]["main_id"])
        # print(record["r_type"])
        # print(record["b"].labels)
        # print()
        

0
1
2
3
4


In [None]:
driver.close()