# DataSet PyG

by Alejandro Fernández Sánchez

## Initialization

In [3]:
from neo4j import GraphDatabase, Driver, basic_auth
from dotenv import load_dotenv
import numpy as np
import os
from typing import Optional
import torch

from torch_geometric.data import HeteroData

In [4]:
load_dotenv()

True

In [5]:
DB_HOST = os.getenv("NEO4J_HOST")
DB_PORT = os.getenv("NEO4J_PORT")
DB_USER = os.getenv("NEO4J_USER")
DB_PASS = os.getenv("NEO4J_PASS")
LAST_FM_API_KEY = os.getenv("LAST_FM_API_KEY")

# .env validation
assert DB_HOST is not None and \
    DB_PORT is not None and \
    DB_USER is not None and \
    DB_PASS is not None, \
    "INVALID .env"

In [6]:
driver = GraphDatabase.driver(f"bolt://{DB_HOST}:{DB_PORT}", auth=basic_auth(DB_USER, DB_PASS))
driver

<neo4j._sync.driver.BoltDriver at 0x7f59cd3ac470>

In [7]:
!mkdir -p ds

In [8]:
if not torch.cuda.is_available():
    raise Exception("No cuda found!")

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

## Dataset creation

### Helper class

In [23]:
class IdMap:
    """Helper class to store a mapping from strings to ids."""
    def __init__(self):
        self.str_to_id = dict()
        self.id_to_str = list()

    def __len__(self):
        """Return number of terms stored in the IdMap"""
        assert len(self.str_to_id) == len(self.id_to_str)
        return len(self.id_to_str)

    def _get_str(self, i):
        """Returns the string corresponding to a given id (`i`)."""
        return self.id_to_str[i]

    def _get_id(self, s):
        """Returns the id corresponding to a string (`s`). 
        If `s` is not in the IdMap yet, then assigns a new id and returns the new id.
        """
        if s in self.str_to_id:
            idx = self.str_to_id[s]
        else:
            idx = len(self.str_to_id)
            self.str_to_id[s] = idx
            self.id_to_str.append(s)
        assert len(self.str_to_id) == len(self.id_to_str)  
        return idx

    def __getitem__(self, key):
        """If `key` is a integer, use _get_str; 
           If `key` is a string, use _get_id;"""
        if type(key) is int:
            return self._get_str(key)
        elif type(key) is str:
            return self._get_id(key)
        else:
            raise TypeError("Type not supported")

artist_map = IdMap()
track_map = IdMap()
tag_map = IdMap()

### Helper vars

In [24]:
# Nodes
artist_tensor = list()
artist_seen = set()
track_tensor = list()
track_seen = set()
tag_tensor = list()
tag_seen = set()

# Relationship indexes
worked_in_index = [list(), list()]
worked_by_index = [list(), list()]
collab_with_index = [list(), list()]
musically_related_to_index = [list(), list()]
personally_related_to_index = [list(), list()]
linked_to_index = [list(), list()]
last_fm_match_index = [list(), list()]
artist_has_tag_index = [list(), list()]
artist_tags_index = [list(), list()]
track_has_tag_index = [list(), list()]
track_tags_index = [list(), list()]

# Relationship attributes
last_fm_match_attributes = list()

In [25]:
def add_relationship_to_dataset(
    id0: int,
    id1: int,
    relationship_index: list[list[int], list[int]],
    attributes: Optional[list] = None,
    relationship_attributes: Optional[list] = None
):
    relationship_index[0].append(id0)
    relationship_index[1].append(id1)
    if relationship_attributes is not None:
        relationship_attributes.append(attributes)

### Artists

In [27]:

query = """
    MATCH (a)-[r]->(b)
    return a, r, type(r) as r_type, b
    limit 5
    ;
"""
with driver.session() as session:
    records = session.run(query)  # type: ignore
    for i, record in enumerate(records):
        match record["r_type"], record["a"], "Artist" in record["a"].labels, record["b"], "Artist" in record["b"].labels, record["r"]:

            # ('Artist', 'HAS_TAG', 'Tag')
            case "HAS_TAG", artist, True, tag, _, _:
                add_relationship_to_dataset(
                    artist_map[artist["main_id"]],
                    tag_map[tag["id"]],
                    artist_has_tag_index
                )

            # ('Tag', 'TAGS', 'Artist')
            case "TAGS", tag, _, artist, True, _:
                add_relationship_to_dataset(
                    tag_map[tag["id"]],
                    artist_map[artist["main_id"]],
                    artist_tags_index
                )

            # ('Track', 'HAS_TAG', 'Tag')
            case "HAS_TAG", track, False, tag, _, _:
                add_relationship_to_dataset(
                    track_map[track["id"]],
                    tag_map[tag["id"]],
                    track_has_tag_index
                )

            # ('Tag', 'TAGS', 'Track')
            case "TAGS", tag, _, track, False, _:
                add_relationship_to_dataset(
                    tag_map[tag["id"]],
                    track_map[track["id"]],
                    track_tags_index
                )

            # ('Artist', 'WORKED_IN', 'Track')
            case "WORKED_IN", artist, _, track, _, _:
                add_relationship_to_dataset(
                    artist_map[artist["main_id"]],
                    track_map[track["id"]],
                    worked_in_index
                )

            # ('Track', 'WORKED_BY', 'Artist')
            case "WORKED_BY", track, _, artist, _, _:
                add_relationship_to_dataset(
                    track_map[track["id"]],
                    artist_map[artist["main_id"]],
                    worked_by_index
                )

            # ('Artist', 'COLLAB_WITH', 'Artist')
            case "COLLAB_WITH", artist0, _, artist1, _, _:
                add_relationship_to_dataset(
                    artist_map[artist0["main_id"]],
                    artist_map[artist1["main_id"]],
                    collab_with_index
                )
            
            # TODO: Is there something wrong with the inverse of a relationship being itself?

            # ('Artist', 'MUSICALLY_RELATED_TO', 'Artist')
            case "MUSICALLY_RELATED_TO", artist0, _, artist1, _, _:
                add_relationship_to_dataset(
                    artist_map[artist0["main_id"]],
                    artist_map[artist1["main_id"]],
                    musically_related_to_index
                )

            # ('Artist', 'PERSONALLY_RELATED_TO', 'Artist')
            case "PERSONALLY_RELATED_TO", artist0, _, artist1, _, _:
                add_relationship_to_dataset(
                    artist_map[artist0["main_id"]],
                    artist_map[artist1["main_id"]],
                    personally_related_to_index
                )

            # ('Artist', 'MUSICALLY_RELATED_TO', 'Artist')
            case "LINKED_TO", artist0, _, artist1, _, _:
                add_relationship_to_dataset(
                    artist_map[artist0["main_id"]],
                    artist_map[artist1["main_id"]],
                    linked_to_index
                )

            # ('Artist', 'LAST_FM_MATCH', 'Artist')
            case "LAST_FM_MATCH", artist0, _, artist1, _, attributes:
                add_relationship_to_dataset(
                    artist_map[artist0["main_id"]],
                    artist_map[artist1["main_id"]],
                    last_fm_match_index,
                    [attributes["weight"]],
                    last_fm_match_attributes,
                )

            case _: raise NotImplementedError(f"Unknown case in match. Relationship type: {record["r_type"]}")
        

In [68]:
driver.close()