In [3]:
import networkx as nx
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from datetime import datetime
import time
import altair as alt
import math
import csv

In [4]:
input_file = csv.DictReader(open("data/chess/games.csv"))

# all pieces have separate nodes
progressive_id = 0
max_sequences = 2
graph = nx.Graph()

for rownum, row in enumerate(input_file):
    for i, entry in enumerate(row['moves'].split(" ")):
        piece_name = "pawn"
        if entry[0].isupper():
            piece_name = entry[0]
        graph.add_node("n" + str(progressive_id), type=piece_name, layer=i)
        progressive_id += 1

    if rownum >= max_sequences: break

print(list(graph.nodes(data=True)))

[('n0', {'type': 'pawn', 'layer': 0}), ('n1', {'type': 'pawn', 'layer': 1}), ('n2', {'type': 'pawn', 'layer': 2}), ('n3', {'type': 'pawn', 'layer': 3}), ('n4', {'type': 'pawn', 'layer': 4}), ('n5', {'type': 'pawn', 'layer': 5}), ('n6', {'type': 'pawn', 'layer': 6}), ('n7', {'type': 'pawn', 'layer': 7}), ('n8', {'type': 'N', 'layer': 8}), ('n9', {'type': 'B', 'layer': 9}), ('n10', {'type': 'N', 'layer': 10}), ('n11', {'type': 'B', 'layer': 11}), ('n12', {'type': 'B', 'layer': 12}), ('n13', {'type': 'pawn', 'layer': 0}), ('n14', {'type': 'N', 'layer': 1}), ('n15', {'type': 'pawn', 'layer': 2}), ('n16', {'type': 'pawn', 'layer': 3}), ('n17', {'type': 'pawn', 'layer': 4}), ('n18', {'type': 'pawn', 'layer': 5}), ('n19', {'type': 'pawn', 'layer': 6}), ('n20', {'type': 'pawn', 'layer': 7}), ('n21', {'type': 'pawn', 'layer': 8}), ('n22', {'type': 'N', 'layer': 9}), ('n23', {'type': 'Q', 'layer': 10}), ('n24', {'type': 'N', 'layer': 11}), ('n25', {'type': 'Q', 'layer': 12}), ('n26', {'type': 'N

In [242]:
input_file = csv.DictReader(open("data/chess/games.csv"))

def get_piece_name(entry):
    if entry[0] == "N":
        return "knight"
    elif entry[0] == "B":
        return "bishop"
    elif entry[0] == "R":
        return "rook"
    elif entry[0] == "Q":
        return "queen"
    elif entry[0] == "K":
        return "king"
    elif entry[0].isupper():
        return entry[0]
    else:
        return "pawn"

def find_node(piece_name, layer):
    return list(filter(lambda d: d[1]['layer'] == i and d[1]['type'] == piece_name, list(graph.nodes(data=True))))

def find_edge(source, target):
    # if (source == "n4"): print("looking for", source, "->", target, "in", list(graph.edges(data=True)))
    return list(filter(lambda d: d[0] == source and d[1] == target, list(graph.edges(data=True))))

def add_edge(source, target):
    if len(find_edge(source, target)) == 0:
        # print("adding", source, "->", target)
        graph.add_edge(source, target, weight = 1)
        # print(graph.edges(data=True))
    else:
        # print("updating", source, "->", target)
        find_edge(source, target)[0][2]['weight'] += 1


# same piece = same node
progressive_id = 0
max_sequences = 100
max_layers = 10
graph = nx.DiGraph()

for rowcount, row in enumerate(input_file):
    # if rowcount < 82: continue
    # print(' '.join(row['moves'].split(" ")[:max_layers+1]))
    prevnode = None
    
    for i, entry in enumerate(row['moves'].split(" ")):
        piece_name = get_piece_name(entry)

        existing_nodes_same_features = find_node(piece_name, i)
        if len(existing_nodes_same_features) > 0:
            existing_nodes_same_features[0][1]['original_entries'].append(entry)

        if len(existing_nodes_same_features) > 0: 
            if prevnode is not None:
                add_edge(prevnode[0], existing_nodes_same_features[0][0])
        else:
            graph.add_node("n" + str(progressive_id), type=piece_name, layer=i, original_entries=[entry])
            
            if prevnode is not None:
                add_edge(prevnode[0], "n" + str(progressive_id))

            progressive_id += 1

        # print("last node: ", find_node(piece_name, i)[0])
        prevnode = find_node(piece_name, i)[0]

        if i == max_layers: break

    if rowcount >= max_sequences: break

# print(list(graph.nodes(data=True)))
# print(list(graph.edges(data=True)))

In [243]:
import plotly.graph_objects as go

# graph = read_chess_games(chosen_sequences=[1, 2, 5, 10], max_layers=10)
# print(graph)

nodelabelarray = []
nodexposarray = []
nodeyposarray = []

nodelist = list(graph.nodes(data=True))
nodelist.sort(key=lambda x: x[1]['layer'])
# print(nodelist)

for n in nodelist:
    nodelabelarray.append(n[1]['type'])
    nodexposarray.append(n[1]['layer'])

    if n[1]['type'] == "pawn":
        nodeyposarray.append(0)
    else:
        nodeyposarray.append(1)

# print(nodelabelarray)
# print(nodexposarray)

edgesources = []
edgetargets = []
edgeweigths = []
nodelisttmp = [n[0] for n in nodelist]

edgelist = list(graph.edges(data=True))
edgelist.sort(key=lambda x: nodelisttmp.index(x[0]))
# print(edgelist)
# print("nodelisttmp", nodelisttmp)

for e in graph.edges():
    if graph.nodes[e[0]]['layer'] < graph.nodes[e[1]]['layer']:
        edgesources.append(nodelisttmp.index(e[0]))
        edgetargets.append(nodelisttmp.index(e[1]))
    else:
        edgesources.append(nodelisttmp.index(e[1]))
        edgetargets.append(nodelisttmp.index(e[0]))
    edgeweigths.append(graph.edges[e]['weight'])

# print("sources", edgesources)
# print("targets", edgetargets)
# print("weigths", edgeweigths)

fig = go.Figure(go.Sankey(
    arrangement = "snap",
    node = {
        "label": nodelabelarray,
        "x": nodexposarray,
        # "y": nodeyposarray,
        'pad': 10},  # 10 Pixels
    link = {
        "source": edgesources,
        "target": edgetargets,
        "value": edgeweigths
    }))

fig.show()

In [145]:
seed = "hello"

for rowcount, row in enumerate(input_file):
    print(len(row['moves'].split(" ")))

75
135
81
46
70
39
91
137
46
18
28
40
111
54
33
90
90
40
49
50
59
26
23
118
66
120
36
26
62
37
24
47
37
24
9
47
85
39
58
95
2
97
12
30
43
11
72
156
18
77
32
35
85
50
111
65
73
125
76
65
41
59
70
34
69
63
107
58
34
100
80
44
40
105
50
144
81
69
77
158
107
117
44
40
59
110
30
54
69
178
136
124
96
53
39
10
51
18
106
38
59
111
28
101
15
42
118
34
60
66
21
73
159
74
28
112
104
40
44
150
60
37
88
113
55
135
137
97
33
29
55
91
53
49
48
20
87
11
110
24
23
81
34
103
29
22
38
46
23
31
51
87
31
38
43
105
87
36
38
110
48
95
70
42
177
114
57
57
82
29
70
97
99
74
48
77
80
35
33
87
40
54
67
21
14
71
72
62
48
6
69
72
40
29
44
59
115
70
36
83
99
43
89
96
116
78
77
70
27
102
21
33
74
65
81
45
52
51
48
20
50
98
106
19
22
114
69
34
40
108
27
63
31
71
92
47
56
48
124
49
64
29
45
96
92
40
123
129
120
71
97
60
57
38
111
119
33
32
96
92
64
10
85
40
58
26
4
118
11
163
95
37
101
80
78
89
42
8
101
65
60
71
44
20
41
40
27
59
19
50
34
15
52
65
67
31
130
125
76
121
53
116
47
53
76
110
67
17
39
94
37
37
34
41
30
51
