# DataSet PyG

by Alejandro Fernández Sánchez

## Initialization

In [32]:
from neo4j.graph import Node
from neo4j import GraphDatabase, basic_auth
from dotenv import load_dotenv
import numpy as np
import os
from typing import Optional
import torch

from torch_geometric.data import HeteroData

In [33]:
load_dotenv()

True

In [34]:
DB_HOST = os.getenv("NEO4J_HOST")
DB_PORT = os.getenv("NEO4J_PORT")
DB_USER = os.getenv("NEO4J_USER")
DB_PASS = os.getenv("NEO4J_PASS")
LAST_FM_API_KEY = os.getenv("LAST_FM_API_KEY")

# .env validation
assert DB_HOST is not None and \
    DB_PORT is not None and \
    DB_USER is not None and \
    DB_PASS is not None, \
    "INVALID .env"

In [35]:
driver = GraphDatabase.driver(f"bolt://{DB_HOST}:{DB_PORT}", auth=basic_auth(DB_USER, DB_PASS))
driver

<neo4j._sync.driver.BoltDriver at 0x7fc31efadcd0>

In [36]:
!mkdir -p ds

In [37]:
if not torch.cuda.is_available():
    raise Exception("No cuda found!")

In [38]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

## Dataset creation

### Helper class

In [39]:
class IdMap:
    """Helper class to store a mapping from strings to ids."""
    def __init__(self):
        self.str_to_id = dict()
        self.id_to_str = list()

    def __len__(self):
        """Return number of terms stored in the IdMap"""
        assert len(self.str_to_id) == len(self.id_to_str)
        return len(self.id_to_str)

    def _get_str(self, i):
        """Returns the string corresponding to a given id (`i`)."""
        return self.id_to_str[i]

    def _get_id(self, s):
        """Returns the id corresponding to a string (`s`). 
        If `s` is not in the IdMap yet, then assigns a new id and returns the new id.
        """
        if s in self.str_to_id:
            idx = self.str_to_id[s]
        else:
            idx = len(self.str_to_id)
            self.str_to_id[s] = idx
            self.id_to_str.append(s)
        assert len(self.str_to_id) == len(self.id_to_str)  
        return idx

    def __getitem__(self, key):
        """If `key` is a integer, use _get_str; 
           If `key` is a string, use _get_id;"""
        if type(key) is int:
            return self._get_str(key)
        elif type(key) is str:
            return self._get_id(key)
        else:
            raise TypeError("Type not supported")

artist_map = IdMap()
track_map = IdMap()
tag_map = IdMap()

### Helper vars

In [54]:
def get_x_count(x: str):
    with driver.session() as session:
        query = f"MATCH {x} return COUNT(*) AS c"
        return session.run(query).data()[0]["c"]

artist_count = get_x_count("(:Artist)")
track_count = get_x_count("(:Track)")
tag_count = get_x_count("(:Tag)")

print("artist_count:", artist_count)
print("track_count:", track_count)
print("tag_count:", tag_count)

print()

worked_in_count = get_x_count("()-[:WORKED_IN]->()")
worked_by_count = get_x_count("()-[:WORKED_BY]->()")
collab_with_count = get_x_count("()-[:COLLAB_WITH]->()")
musically_related_to_count = get_x_count("()-[:MUSICALLY_RELATED_TO]->()")
personally_related_to_count = get_x_count("()-[:PERSONALLY_RELATED_TO]->()")
linked_to_count = get_x_count("()-[:LINKED_TO]->()")
last_fm_match_count = get_x_count("()-[:LAST_FM_MATCH]->()")
artist_has_tag_count = get_x_count("(:Artist)-[:HAS_TAG]->()")
artist_tags_count = get_x_count("()-[:TAGS]->(:Artist)")
track_has_tag_count = get_x_count("(:Track)-[:HAS_TAG]->()")
track_tags_count = get_x_count("()-[:TAGS]->(:Track)")

print("worked_in_count:", worked_in_count)
print("worked_by_count:", worked_by_count)
print("collab_with_count:", collab_with_count)
print("musically_related_to_count:", musically_related_to_count)
print("personally_related_to_count:", personally_related_to_count)
print("linked_to_count:", linked_to_count)
print("last_fm_match_count:", last_fm_match_count)
print("artist_has_tag_count:", artist_has_tag_count)
print("artist_tags_count:", artist_tags_count)
print("track_has_tag_count:", track_has_tag_count)
print("track_tags_count:", track_tags_count)

artist_count: 1489250
track_count: 24324100
tag_count: 23

worked_in_count: 27661673
worked_by_count: 27661673
collab_with_count: 2463052
musically_related_to_count: 373262
personally_related_to_count: 26720
linked_to_count: 23128
last_fm_match_count: 154865250
artist_has_tag_count: 2410207
artist_tags_count: 2410207
track_has_tag_count: 4030735
track_tags_count: 4030735


In [55]:
# 1                        -> begin_date
# 2                        -> end_date
# 3                        -> ended
# 4,5,6,7,8                -> gender1-5
# 9                        -> popularity_scaled
# 10,11,12,13,14,15        -> type1-6
artist_attr_count = 15

# 1   -> popularity_scaled
# 2   -> year
# 3,4 -> sem1-2
track_attr_count = 4

In [64]:
# TODO: Check if everything fits into memory or if I have to process everything separately

# Nodes
artist_seen = set()
artist_x = torch.empty((artist_count, artist_attr_count + 1), dtype=torch.float16)

track_seen = set()
track_x = torch.empty((track_count, track_attr_count + 1), dtype=torch.float16)

tag_seen = set()
tag_x = torch.ones((tag_count, 1), dtype=torch.float16)

# Relationship indexes
worked_in_index, worked_in_helper_idx = torch.empty((worked_in_count, 2), dtype=torch.int32), 0
worked_by_index, worked_by_helper_idx = torch.empty((worked_by_count, 2), dtype=torch.int32), 0
collab_with_index, collab_with_helper_idx = torch.empty((collab_with_count, 2), dtype=torch.int32), 0
musically_related_to_index, musically_related_to_helper_idx = torch.empty((musically_related_to_count, 2), dtype=torch.int32), 0
personally_related_to_index, personally_related_to_helper_idx = torch.empty((personally_related_to_count, 2), dtype=torch.int32), 0
linked_to_index, linked_to_helper_idx = torch.empty((linked_to_count, 2), dtype=torch.int32), 0
last_fm_match_index, last_fm_match_helper_idx = torch.empty((last_fm_match_count, 2), dtype=torch.int32), 0
artist_has_tag_index, artist_has_tag_helper_idx = torch.empty((artist_has_tag_count, 2), dtype=torch.int32), 0
artist_tags_index, artist_tags_helper_idx = torch.empty((artist_tags_count, 2), dtype=torch.int32), 0
track_has_tag_index, track_has_tag_helper_idx = torch.empty((track_has_tag_count, 2), dtype=torch.int32), 0
track_tags_index, track_tags_helper_idx = torch.empty((track_tags_count, 2), dtype=torch.int32), 0

# Relationship attributes
collab_with_attributes = torch.empty((collab_with_count, 1), dtype=torch.float16)
musically_related_to_attributes = torch.empty((musically_related_to_count, 1), dtype=torch.float16)
personally_related_to_attributes = torch.empty((personally_related_to_count, 1), dtype=torch.float16)
linked_to_attributes = torch.empty((linked_to_count, 1), dtype=torch.float16)
last_fm_match_attributes = torch.empty((last_fm_match_count, 1), dtype=torch.float16)

In [65]:
def add_relationship_to_dataset(
    id0: int,
    id1: int,
    rel_index: int,
    rel_tensor: torch.Tensor,
    attribute: Optional[float] = None,
    relationship_attributes: Optional[torch.Tensor] = None
):
    rel_tensor[rel_index, 0] = id0
    rel_tensor[rel_index, 1] = id1
    if relationship_attributes is not None:
        relationship_attributes[rel_index] = attribute

In [66]:
def add_node_to_dataset(node: Node):

    def add_node(
        id: int,
        seen: set[int],
        attributes: list,
        attributes_tensor: torch.Tensor
    ):
        if id in seen:
            return
        seen.add(id)
        for j, attribute in enumerate(attributes):
            print(attribute)
            attributes_tensor[id, j] = attribute

    match node, "Artist" in node.labels, "Track" in node.labels, "Tag" in node.labels:

        # Artist
        case node, True, False, False:
            add_node(
                artist_map[node["main_id"]],
                artist_seen,
                [
                    node.get("begin_date", -1),
                    node.get("end_date", -1),
                    node["ended"],
                    *[node[f"gender_{j}"] for j in range(1, 6)],
                    node["popularity_scaled"],
                    *[node[f"type_{j}"] for j in range(1, 7)],
                    1
                ],
                artist_x
            )

        # Track
        case node, False, True, False:
            add_node(
                track_map[node["id"]],
                track_seen,
                [
                    node["popularity_scaled"],
                    node["year"],
                    node["month"] <= 6,
                    node["month"] > 6,
                    1
                ],
                track_x
            )
        
        # Tag
        case node, False, False, True:
            add_node(
                tag_map[node["id"]],
                tag_seen,
                [1],
                tag_x
            )
        
        case _: raise NotImplementedError(f"Unknown node: {node}")

### Main algo

In [67]:

query = """
    MATCH (a)-[r]->(b)
    return a, r, type(r) as r_type, b
    limit 5
    ;
"""
with driver.session() as session:
    records = session.run(query)  # type: ignore
    for i, record in enumerate(records):
        
        # Nodes
        add_node_to_dataset(record["a"])
        add_node_to_dataset(record["b"])
        
        # Relationships
        # TODO: Maybe change TAGS and HAS_TAGS to be different depending on the node type (Artist or Track), should make this loop faster
        match record["r_type"], record["a"], "Artist" in record["a"].labels, record["b"], "Artist" in record["b"].labels, record["r"]:

            # ('Artist', 'HAS_TAG', 'Tag')
            case "HAS_TAG", artist, True, tag, _, _:
                add_relationship_to_dataset(
                    artist_map[artist["main_id"]],
                    tag_map[tag["id"]],
                    artist_has_tag_helper_idx,
                    artist_has_tag_index,
                )
                artist_has_tag_helper_idx += 1

            # ('Tag', 'TAGS', 'Artist')
            case "TAGS", tag, _, artist, True, _:
                add_relationship_to_dataset(
                    tag_map[tag["id"]],
                    artist_map[artist["main_id"]],
                    artist_tags_helper_idx,
                    artist_tags_index
                )
                artist_tags_helper_idx += 1

            # ('Track', 'HAS_TAG', 'Tag')
            case "HAS_TAG", track, False, tag, _, _:
                add_relationship_to_dataset(
                    track_map[track["id"]],
                    tag_map[tag["id"]],
                    track_has_tag_helper_idx,
                    track_has_tag_index
                )
                track_has_tag_helper_idx += 1

            # ('Tag', 'TAGS', 'Track')
            case "TAGS", tag, _, track, False, _:
                add_relationship_to_dataset(
                    tag_map[tag["id"]],
                    track_map[track["id"]],
                    track_tags_helper_idx,
                    track_tags_index
                )
                track_tags_helper_idx += 1

            # ('Artist', 'WORKED_IN', 'Track')
            case "WORKED_IN", artist, _, track, _, _:
                add_relationship_to_dataset(
                    artist_map[artist["main_id"]],
                    track_map[track["id"]],
                    worked_in_helper_idx,
                    worked_in_index
                )
                worked_in_helper_idx += 1

            # ('Track', 'WORKED_BY', 'Artist')
            case "WORKED_BY", track, _, artist, _, _:
                add_relationship_to_dataset(
                    track_map[track["id"]],
                    artist_map[artist["main_id"]],
                    worked_by_helper_idx,
                    worked_by_index
                )
                worked_by_helper_idx += 1

            # ('Artist', 'COLLAB_WITH', 'Artist')
            case "COLLAB_WITH", artist0, _, artist1, _, r:
                add_relationship_to_dataset(
                    artist_map[artist0["main_id"]],
                    artist_map[artist1["main_id"]],
                    collab_with_helper_idx,
                    collab_with_index,
                    r["count"],
                    collab_with_attributes
                )
                collab_with_helper_idx += 1

            # ('Artist', 'MUSICALLY_RELATED_TO', 'Artist')
            case "MUSICALLY_RELATED_TO", artist0, _, artist1, _, r:
                add_relationship_to_dataset(
                    artist_map[artist0["main_id"]],
                    artist_map[artist1["main_id"]],
                    musically_related_to_helper_idx,
                    musically_related_to_index,
                    r["count"],
                    musically_related_to_attributes
                )
                musically_related_to_helper_idx += 1

            # ('Artist', 'PERSONALLY_RELATED_TO', 'Artist')
            case "PERSONALLY_RELATED_TO", artist0, _, artist1, _, r:
                add_relationship_to_dataset(
                    artist_map[artist0["main_id"]],
                    artist_map[artist1["main_id"]],
                    personally_related_to_helper_idx,
                    personally_related_to_index,
                    r["count"],
                    personally_related_to_attributes
                )
                personally_related_to_helper_idx += 1

            # ('Artist', 'MUSICALLY_RELATED_TO', 'Artist')
            case "LINKED_TO", artist0, _, artist1, _, r:
                add_relationship_to_dataset(
                    artist_map[artist0["main_id"]],
                    artist_map[artist1["main_id"]],
                    linked_to_helper_idx,
                    linked_to_index,
                    r["count"],
                    linked_to_attributes
                )
                linked_to_helper_idx += 1

            # ('Artist', 'LAST_FM_MATCH', 'Artist')
            case "LAST_FM_MATCH", artist0, _, artist1, _, r:
                add_relationship_to_dataset(
                    artist_map[artist0["main_id"]],
                    artist_map[artist1["main_id"]],
                    last_fm_match_helper_idx,
                    last_fm_match_index,
                    r["weight"],
                    last_fm_match_attributes,
                )
                last_fm_match_helper_idx += 1

            case _: raise NotImplementedError(f"Unknown case in match. Relationship type: {record["r_type"]}")
        

-1
-1
0.0
False
False
False
False
False
None


TypeError: can't assign a NoneType to a torch.HalfTensor

In [68]:
driver.close()