In [26]:
import time
import scipy
import numpy as np
import pandas as pd
import winsound

import networkx as nx
import matplotlib

from sklearn.preprocessing import LabelEncoder
import torch

from torch_geometric.data import HeteroData
import torch_geometric.transforms as T
from torch_geometric.nn import SAGEConv, GATConv, Linear, to_hetero
import torch_geometric.nn.functional as F

from HeteroDataFunctions import Encoder, add_types, complete_graph, flatten_lol, node_cat_dict, midi_type

print(scipy.__version__)
print(matplotlib.__version__)
print(nx.__version__)

1.7.3
3.6.2
2.8.4


In [27]:
# Complete Dataset
G = complete_graph(".\slac\embeddings\\all")

loading edgelists...
- notes.edgelist
- program.edgelist
- tempo.edgelist
- time.signature.edgelist
Nodes: 93553
Edges: 786635


In [28]:
nodes = pd.DataFrame((list(G.nodes)), columns=['name'])
edges = pd.DataFrame(np.array(list(G.edges)), columns=['source', 'target'])

In [29]:
node_categories = node_cat_dict(nodes)
node_categories.keys()

node_cat_dict took 0.17 secs to run


dict_keys(['note_group', 'pitch', 'program', 'MIDI', 'duration', 'velocity', 'time_sig', 'tempo'])

In [30]:
%%script false --no- raise -error

nodes_df_complete, edges_df_complete = add_types(nodes, edges, node_categories)

winsound.Beep(400, 700)

nodes_df_complete.to_csv('nodes_complete.csv')
edges_df_complete.to_csv('edges_complete.csv')

Couldn't find program: 'false'


In [140]:
nodes_df_complete = pd.read_csv('.\slac\Contents of Slac\\nodes_complete.csv')
edges_df_complete = pd.read_csv('.\slac\Contents of Slac\edges_complete.csv')
print('Done')

Done


In [141]:
edge_types = ["MIDI__has__tempo",
                   "MIDI__in__time_sig",
                   "MIDI__has__program",
                   "MIDI__has__note_group",
                   "note_group__has__velocity",
                   "note_group__has__duration",
                   "note_group__contains__pitch"]

In [142]:
full_categories = node_categories.copy()
full_categories['node_types'] = list(node_categories.keys())
full_categories['edge_types'] = edge_types  # Dictionary containing every string that may be found in our Dataframes
names_list_full = flatten_lol(full_categories.values())

len(names_list_full) == len(set(names_list_full))

True

In [143]:
encoder = Encoder(names_list_full, n_labels=5)

encoder.decode_value(5)

'g1601074'

In [144]:
node_types = set(nodes_df_complete['node_type'])
node_types

{'MIDI',
 'duration',
 'note_group',
 'pitch',
 'program',
 'tempo',
 'time_sig',
 'velocity'}

In [145]:
input_node_dict = {node_type: {'x': encoder.
                    encode_nodes(nodes_df_complete.
                    loc[nodes_df_complete['node_type'] == node_type, ['name']])}
                    for node_type in node_types}

encode_nodes took 0.01 secs to run
encode_nodes took 0.00 secs to run
encode_nodes took 0.00 secs to run
encode_nodes took 0.00 secs to run
encode_nodes took 0.00 secs to run
encode_nodes took 0.02 secs to run
encode_nodes took 3.14 secs to run
encode_nodes took 0.00 secs to run


In [146]:
edge_types

['MIDI__has__tempo',
 'MIDI__in__time_sig',
 'MIDI__has__program',
 'MIDI__has__note_group',
 'note_group__has__velocity',
 'note_group__has__duration',
 'note_group__contains__pitch']

In [147]:
edges_df_complete.loc[edges_df_complete['edge_type'] == 'MIDI__has__tempo', ['source', 'target']]

Unnamed: 0,source,target
936,Blues_-_Modern-Albert_King_-_Born_Under_A_Bad_...,11
571698,Blues_-_Modern-B_B_King_-_How_Blue_Can_You_Get,6
658833,Blues_-_Modern-B_B_King_-_Rock_Me_Baby,9
662826,Blues_-_Modern-B_B_King_-_The_Thrill_Is_Gone,9
666395,Blues_-_Modern-Buddy_Guy_-_Don't_Answer_the_Door,5
...,...,...
786125,Rock_-_Metal-Rage_Against_the_Machine_-_Bulls_...,8
786237,Rock_-_Metal-Rage_Against_the_Machine_-_Gueril...,11
786347,Rock_-_Metal-Rage_Against_the_Machine_-_Killin...,12
786441,Rock_-_Metal-Rage_Against_the_Machine_-_Know_Y...,8


In [148]:
input_edge_dict = {edge_type: {'edge_index': encoder.encode_edges(edges_df_complete.loc[
                    edges_df_complete['edge_type'] == edge_type, ['source', 'target']])} for edge_type in edge_types}

encode_edges took 0.02 secs to run
encode_edges took 0.02 secs to run
encode_edges took 0.09 secs to run
encode_edges took 8.36 secs to run
encode_edges took 7.25 secs to run
encode_edges took 5.60 secs to run
encode_edges took 26.60 secs to run


In [149]:
for key in input_edge_dict.keys():
    dim1 = input_edge_dict[key]['edge_index'].shape[0]
    input_edge_dict[key]['edge_index'] = input_edge_dict[key]['edge_index'].reshape(-1, dim1)

In [150]:
# Extract the label of each Midi.
midi_val = nodes_df_complete.loc[nodes_df_complete['node_type'] == 'MIDI', ['name']].values
midi_class = [midi_type(s[0]) for s in midi_val]

lb = LabelEncoder()
y = torch.from_numpy(lb.fit_transform(midi_class))

lb.classes_

array(['Blues', 'Classical', 'Jazz', 'Rap', 'Rock'], dtype='<U9')

In [151]:
input_node_dict['MIDI']['y'] = y

In [152]:
H = HeteroData(input_node_dict, **input_edge_dict)
H['MIDI', 'has', 'tempo']

{'edge_index': tensor([[92762, 93541, 92801, 93549, 92790, 93545, 92888, 93545, 92783, 93542,
         92714, 93537, 92692, 93545, 92886, 93545, 92805, 93542, 92875, 93542,
         92710, 93539, 92766, 93549, 92905, 93535, 92760, 93557, 92730, 93543,
         92785, 93535, 92939, 93540, 92738, 93540, 92774, 93545, 92919, 93541,
         92814, 93540, 92836, 93556, 92928, 93539, 92789, 93554, 92797, 93535,
         92757, 93535, 92731, 93554, 92866, 93545, 92834, 93535, 92773, 93555,
         92898, 93555, 92713, 93555, 92896, 93546, 92907, 93543, 92744, 93535,
         92852, 93546, 92817, 93535, 92846, 93555, 92703, 93545, 92761, 93545,
         92911, 93554, 92860, 93538, 92763, 93555, 92903, 93535, 92858, 93555,
         92767, 93541, 92829, 93535, 92853, 93555, 92862, 93555, 92930, 93555,
         92882, 93555, 92743, 93557, 92802, 93557, 92927, 93555, 92709, 93557,
         92865, 93555, 92804, 93543, 92742, 93543, 92782, 93538, 92818, 93541,
         92741, 93541, 92778, 93555, 

In [153]:
write_from = torch.tensor(np.random.choice(10, 50, replace=True))
write_to = torch.tensor(np.random.choice(20, 50, replace=True))
write = torch.concat((write_from, write_to)).reshape(-1,50).shape
write

torch.Size([2, 50])

In [154]:
input_edge_dict['MIDI__has__tempo']['edge_index']

tensor([[92762, 93541, 92801, 93549, 92790, 93545, 92888, 93545, 92783, 93542,
         92714, 93537, 92692, 93545, 92886, 93545, 92805, 93542, 92875, 93542,
         92710, 93539, 92766, 93549, 92905, 93535, 92760, 93557, 92730, 93543,
         92785, 93535, 92939, 93540, 92738, 93540, 92774, 93545, 92919, 93541,
         92814, 93540, 92836, 93556, 92928, 93539, 92789, 93554, 92797, 93535,
         92757, 93535, 92731, 93554, 92866, 93545, 92834, 93535, 92773, 93555,
         92898, 93555, 92713, 93555, 92896, 93546, 92907, 93543, 92744, 93535,
         92852, 93546, 92817, 93535, 92846, 93555, 92703, 93545, 92761, 93545,
         92911, 93554, 92860, 93538, 92763, 93555, 92903, 93535, 92858, 93555,
         92767, 93541, 92829, 93535, 92853, 93555, 92862, 93555, 92930, 93555,
         92882, 93555, 92743, 93557, 92802, 93557, 92927, 93555, 92709, 93557,
         92865, 93555, 92804, 93543, 92742, 93543, 92782, 93538, 92818, 93541,
         92741, 93541, 92778, 93555, 92807, 93535, 9

In [155]:
H.edge_stores

[{'edge_index': tensor([[92762, 93541, 92801, 93549, 92790, 93545, 92888, 93545, 92783, 93542,
          92714, 93537, 92692, 93545, 92886, 93545, 92805, 93542, 92875, 93542,
          92710, 93539, 92766, 93549, 92905, 93535, 92760, 93557, 92730, 93543,
          92785, 93535, 92939, 93540, 92738, 93540, 92774, 93545, 92919, 93541,
          92814, 93540, 92836, 93556, 92928, 93539, 92789, 93554, 92797, 93535,
          92757, 93535, 92731, 93554, 92866, 93545, 92834, 93535, 92773, 93555,
          92898, 93555, 92713, 93555, 92896, 93546, 92907, 93543, 92744, 93535,
          92852, 93546, 92817, 93535, 92846, 93555, 92703, 93545, 92761, 93545,
          92911, 93554, 92860, 93538, 92763, 93555, 92903, 93535, 92858, 93555,
          92767, 93541, 92829, 93535, 92853, 93555, 92862, 93555, 92930, 93555,
          92882, 93555, 92743, 93557, 92802, 93557, 92927, 93555, 92709, 93557,
          92865, 93555, 92804, 93543, 92742, 93543, 92782, 93538, 92818, 93541,
          92741, 93541, 9

In [156]:
H_und = T.ToUndirected()(H)
H_und = T.NormalizeFeatures()(H_und)

RuntimeError: result type Float can't be cast to the desired output type __int64

# GNN

In [136]:
class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x

model = GNN(hidden_channels=64, out_channels=len(lb.classes_))
model = to_hetero(model, H_und.metadata(), aggr='sum')

In [137]:
optimizer_name = "Adam"
lr = 1e-1
optimizer = getattr(torch.optim, optimizer_name)(model.parameters(), lr=lr)


In [138]:
def train():
    model.train()
    optimizer.zero_grad()
    out = model(H_und.x_dict, H_und.edge_index_dict)
    mask = H_und['MIDI'].train_mask
    loss = F.cross_entropy(out['paper'][mask], H_und['MIDI'].y[mask])
    loss.backward()
    optimizer.step()
    return float(loss)

In [139]:
train()

ValueError: Found indices in 'edge_index' that are larger than 249 (got 93557). Please ensure that all indices in 'edge_index' point to valid indices in the interval [0, 250) in your node feature matrix and try again.

 # Old Implementation

In [None]:
# nodes_ten_ = encoder.encode_nodes(nodes_df_complete)
# edges_ten_ = encoder.encode_edges(edges_df_complete)

# node_type_ = nodes_df_complete.iloc[:, 1]

# Get the source and target indices from the edges tensor
# edge_index = edges_ten_[:, :2]

## Get the edge types from the edges tensor
#edge_type_ = edges_df_complete.iloc[:, 2]

#full_hetero_graph = HeteroData(x=nodes_ten_, node_type=node_type_, edge_index=edge_index, edge_type=edge_type_)