In [1]:
import networkx as nx
import pandas as pd
import numpy as np
%matplotlib inline

# Load csv data into pandas dataframe

In [3]:
csv_path = '/Users/Alexander/Desktop/MagnaTagATune metriclearning/MagnaTagATune/comparisons_final.csv'
df = pd.read_csv(csv_path, '\t')
df

Unnamed: 0,clip1_id,clip2_id,clip3_id,clip1_numvotes,clip2_numvotes,clip3_numvotes,clip1_mp3_path,clip2_mp3_path,clip3_mp3_path
0,42344,52148,53079,3,1,6,8/jacob_heringman-josquin_des_prez_lute_settin...,4/tim_rayborn-the_path_beyond-14-taqsim_ud-175...,9/the_wretch-ambulatory-15-release-146-175.mp3
1,44925,17654,56325,2,0,0,0/american_bach_soloists-j_s__bach__cantatas_v...,8/hybris-the_first_words-04-final_trust-146-17...,9/the_strap_ons-geeking_crime-20-pimps-59-88.mp3
2,25699,2619,15184,0,0,2,b/philharmonia_baroque-beethoven_symphonies_no...,4/jay_kishor-the_color_of_night-01-malkauns-10...,4/seth_carlin-schubert__works_for_solo_fortepi...
3,2308,57629,44657,0,0,2,0/american_bach_soloists-joseph_haydn__masses-...,c/magnatune-classical-24-la_primavera_robert_j...,f/ehren_starks-lines_build_walls-10-tunnel_sys...
4,45324,3858,13497,15,14,6,5/burnshee_thornside-rock_this_moon-11-city_gi...,9/sitar-cd1_the_sowebo_concert-01-raga_maru_bi...,6/electric_frankenstein-conquers_the_world-03-...
...,...,...,...,...,...,...,...,...,...
528,50394,35897,8457,1,0,0,e/joram-moments_of_clarity-13-plenilune-262-29...,6/farallon_recorder_quartet-ludwig_senfl-08-in...,6/mercy_machine-in_your_bed-02-my_joan_of_arc-...
529,12088,8446,34711,3,1,1,1/dac_crowell-sferica-03-chapel_hill_phantom_l...,8/mercy_machine-mercy_machine-02-my_fathers_ha...,c/liquid_zen-oscilloscope-08-autumn_glide-59-8...
530,55772,20576,26478,5,4,0,3/jag-four_strings-19-helena_street_corner_blu...,1/spinecar-passive_aggressive-04-true-262-291.mp3,f/strojovna_07-switch_on__switch_off-06-crysta...
531,27594,40264,26134,0,1,1,f/paul_berget-sl_weiss_on_11_strings-06-linfid...,2/duo_chambure-vihuela_duets_of_valderrabano-0...,1/artemis-undone-06-beside_u-59-88.mp3


# Calculate statistics

In [4]:
print('total number of votes: ', df['clip1_numvotes'].sum() + df['clip2_numvotes'].sum() + df['clip3_numvotes'].sum())
print('average number of votes per triplet: ', df[['clip1_numvotes', 'clip2_numvotes', 'clip3_numvotes']].sum(axis=1).mean())
print(f'number of unique triplets: ', len({tuple(sorted(ids)) for _, ids in df.iloc[:, :3].iterrows()}))

total number of votes:  7650
average number of votes per triplet:  14.352720450281426
number of unique triplets:  346


# Create directed graph from dataframe

In [5]:
def create_graph(df):
    graph = nx.DiGraph()

    for _, row in df.iterrows():
        clip_ids, clip_votes, clip_paths = row[:3].values, row[3:6].values, row[6:].values
        idx = range(len(clip_votes))
        for i in idx:
            votes = clip_votes[i]
            if votes > 0:
                odd_one_out_id = clip_ids[i]
                other_idx = np.setdiff1d(idx, [i])
                node1 = tuple(sorted(clip_ids[other_idx]))
                node2 = tuple(sorted([clip_ids[other_idx][0], odd_one_out_id]))
                node3 = tuple(sorted([clip_ids[other_idx][1], odd_one_out_id]))

                # Find existing edge
                ed = graph.get_edge_data(node1, node2)
                if ed is not None:
                    edge_weight = ed['weight']
                    graph[node1][node2]['weight'] += votes
                else:
                    graph.add_edge(node1, node2, weight=votes)

                # Find existing edge
                ed = graph.get_edge_data(node1, node3)
                if ed is not None:
                    edge_weight = ed['weight']
                    graph[node1][node3]['weight'] += votes
                else:
                    graph.add_edge(node1, node3, weight=votes)
    return graph

graph = create_graph(df)

# Remove length-2 cycles from graph

In [6]:
def remove_inconsistencies(graph):
    count = 0
    weight_points = 0
    for node in graph:
        to_remove = []
        for (u,v,d) in graph.edges(node, data=True):
            if graph.has_edge(v,u):
                weight = d['weight']
                weight_rev = graph.get_edge_data(v, u)['weight']
                
                # If contradicting edges have equal votes, remove both
                if weight == weight_rev:
                    to_remove.append((u, v))
                    to_remove.append((v, u))
                    count += 2
                    weight_points += weight * 2
                    
                elif weight > weight_rev:
                    to_remove.append((v, u))
                    graph[u][v]['weight'] = weight - weight_rev
                    count += 1
                    weight_points += 2* weight_rev
                elif weight < weight_rev:
                    to_remove.append((v, u))
                    graph[v][u]['weight'] = weight_rev - weight
                    count += 1
                    weight_points += 2 * weight
        graph.remove_edges_from(to_remove)
                    
    print(f'removed {count} inconsistent edges.')
    print(f'removed {weight_points} weight points.') 

remove_inconsistencies(graph)

removed 738 inconsistent edges.
removed 8402 weight points.


# Remove isolated vertices

In [7]:
isolates = list(nx.isolates(graph))
graph.remove_nodes_from(isolates)
print(f'removed {len(isolates)} isolated nodes')

removed 27 isolated nodes


# Calculate number of subgraphs & edges in resulting graph

In [17]:
subgraphs = list(nx.weakly_connected_components(graph))
sg_sizes = set([len(sg) for sg in subgraphs])
referenced_clips = {clip_id for node in graph.nodes() for clip_id in node}
print(f'graph consists of {len(subgraphs)} disjoint subgraphs containing {sg_sizes} vertices each')
print(f'total graph contains {len(graph.edges)} edges/triplet constraints')
print('number of referenced clips: ', len(referenced_clips))

graph consists of 337 disjoint subgraphs containing {3} vertices each
total graph contains 860 edges/triplet constraints
number of referenced clips:  993


# Split into 10 disjoint sets of subgraphs, for k-fold CV

In [11]:
array = np.array(subgraphs)
np.random.shuffle(array)
splits = np.array_split(array, 10)
split_nodes = lambda split: [node for sg in split for node in sg]
split_edges = lambda split: [edges for node in split_nodes(split) for edges in graph.edges(node)]
constraints_per_split = [len(split_edges(split)) for split in splits]
print('average number of constraints per split: ', np.array(constraints_per_split).mean())

average number of constraints per split:  86.0


In [15]:
from sys import getsizeof
getsizeof(array)
splits[0]

array([{(8446, 34711), (8446, 12088), (12088, 34711)},
       {(6717, 30974), (21752, 30974), (6717, 21752)},
       {(26836, 28405), (28405, 44292), (26836, 44292)},
       {(25776, 29873), (21993, 25776), (21993, 29873)},
       {(37428, 55029), (37428, 49015), (49015, 55029)},
       {(5783, 40683), (5783, 50307), (40683, 50307)},
       {(26497, 37286), (37286, 39679), (26497, 39679)},
       {(6959, 19598), (6959, 53758), (19598, 53758)},
       {(39938, 41134), (158, 39938), (158, 41134)},
       {(36147, 50477), (15309, 50477), (15309, 36147)},
       {(9194, 23808), (4975, 23808), (4975, 9194)},
       {(23020, 35492), (35492, 42331), (23020, 42331)},
       {(227, 45936), (5881, 45936), (227, 5881)},
       {(3914, 23780), (3914, 11633), (11633, 23780)},
       {(38341, 39174), (24125, 38341), (24125, 39174)},
       {(4642, 6716), (4642, 16862), (6716, 16862)},
       {(37173, 48228), (25211, 48228), (25211, 37173)},
       {(8725, 37140), (37140, 58843), (8725, 58843)},
    

In [18]:
def triplet_from_edge(edge):
    all_ids = tuple({clip_id for node in edge for clip_id in node})
    odd_one_out_id = np.intersect1d(edge[0], edge[1])
    similar_pair = np.setdiff1d(all_ids, odd_one_out_id)
    return similar_pair, odd_one_out_id

In [22]:
(a, b), odd_one_out = triplet_from_edge(((6959, 19598), (19598, 53758)))
# a, b = similar_pair
b

53758

In [177]:
[triplet_from_edge(edge) for edge in split_edges(splits[0])]

[(array([21950, 37093]), array([43565])),
 (array([21950, 43565]), array([37093])),
 (array([37093, 43565]), array([21950])),
 (array([ 9938, 20536]), array([46254])),
 (array([20536, 46254]), array([9938])),
 (array([ 4316, 46294]), array([9356])),
 (array([ 9356, 46294]), array([4316])),
 (array([ 6848, 10757]), array([29305])),
 (array([10757, 29305]), array([6848])),
 (array([ 6848, 29305]), array([10757])),
 (array([ 2139, 49609]), array([45535])),
 (array([ 2139, 45535]), array([49609])),
 (array([ 7035, 43358]), array([18230])),
 (array([18230, 43358]), array([7035])),
 (array([ 7035, 18230]), array([43358])),
 (array([ 3450, 42896]), array([48198])),
 (array([ 3450, 48198]), array([42896])),
 (array([42896, 48198]), array([3450])),
 (array([21334, 38057]), array([50779])),
 (array([21334, 50779]), array([38057])),
 (array([36572, 43201]), array([52271])),
 (array([43201, 52271]), array([36572])),
 (array([36572, 52271]), array([43201])),
 (array([18221, 58717]), array([37812]))