In [1]:
import json
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly
import numpy as np

In [2]:
def generate_pieces():
    '''Defines pieces with unique letter/number ID. To be chained with populate_board and make_move_dict'''
    
    nums = []
    for x in range(1,17):
        nums.append(str(x))

    pawns = []
    rooks  = []
    knights = []
    bishops = []
    kings = []
    queens = []

    for x in range(16):
        pawns.append('p'+nums[x])

    for x in range(4):
        rooks.append('r'+nums[x])
        knights.append('n'+nums[x])
        bishops.append('s'+nums[x])

    for x in range(2):
        kings.append('k'+nums[x])
        queens.append('q'+nums[x])
    
    return pawns, rooks, knights, bishops, kings, queens
    

In [3]:
def generate_piece_key():
    '''For df replacement'''
    
    pawns, rooks, knights, bishops, kings, queens = generate_pieces()
    piece_list = pawns + rooks + knights + bishops + kings + queens + ['Pawns'] + ['Rooks'] + ['Knights'] + ['Bishops'] + ['All Captures']
    piece_key = dict(zip(piece_list, range(37)))
    
    return piece_key

In [4]:
def piece_labels():
    '''For use with Sankey plot'''
    
    labels = [0] * 36
    x = ['H', 'G', 'F', 'E', 'D', 'C', 'B', 'A']
    
    for i in range(8,16):
        labels[i] = x[i-8] + ' Pawn'

    x.reverse()

    for i in range(8):
        labels[i] = x[i] + ' Pawn'

    labels[16:20] = ['A Rook', 'H Rook', 'H Rook', 'A Rook']
    labels[20:24] = ['B Knight', 'G Knight', 'G Knight', 'B Knight']
    labels[24:28] = ['Dark Bishop', 'Light Bishop', 'Light Bishop', 'Dark Bishop']
    labels[28:30] = ['King', 'King']
    labels[30:32] = ['Queen', 'Queen']
    labels[32:37] = ['Pawns','Rooks','Knights','Bishops','All Captures']
    
    return tuple(labels)


In [5]:
def df_prep(df, filter_list):
    '''Basic filtering and cleaning function to prep for Sankey and Heatmap figure generation. Returns df'''
    
    stack_df = df.filter(items=filter_list, axis=0)
    stack_df = pd.DataFrame(stack_df.dropna(axis=1).stack())
    
    stack_df.reset_index(inplace=True)
    
    cols = dict(zip(['level_0', 'level_1', 0],['Source', 'Target', 'Value']))
    stack_df.rename(columns = cols, inplace=True)
    
    return stack_df


In [6]:
def generate_sankey_plot(df, filter_list):
    '''Takes in DataFrame generated from capture queries, filters by color then cleans & preps for Sankey visuals.
    Returns Plotly object.'''
    
    
    labels = piece_labels()
    piece_key = generate_piece_key()
    
    sankey_df = df_prep(df, filter_list)
    
    sankey_df.replace(to_replace=piece_key, inplace=True)
    
    sankey_df = pd.DataFrame(sankey_df.groupby('Source').sum())
    sankey_df.reset_index(inplace = True)
    
    new_cols = dict(zip(sankey_df.columns, ['Target', 'Source', 'Value']))
    sankey_df.rename(columns = new_cols, inplace=True)
    
    sankey_df.loc[sankey_df['Target'] < 16, 'Source'] = piece_key['Pawns']
    sankey_df.loc[(sankey_df['Target'] < 20) & (sankey_df['Target'] > 15), 'Source'] = piece_key['Rooks']
    sankey_df.loc[(sankey_df['Target'] < 24) & (sankey_df['Target'] > 19), 'Source'] = piece_key['Knights']
    sankey_df.loc[(sankey_df['Target'] < 28) & (sankey_df['Target'] > 23), 'Source'] = piece_key['Bishops']
    sankey_df.loc[(sankey_df['Target'] < 32) & (sankey_df['Target'] > 27), 'Source'] = piece_key['All Captures']
    
    for source in range(32, 36):
        total = sankey_df[sankey_df['Source'] == source]['Value'].sum()
        record = {'Target':source, 'Source':36, 'Value':total}
        sankey_df = sankey_df.append(record, ignore_index=True)
        
    source = list(sankey_df['Source'])
    target = list(sankey_df['Target'])
    value = list(sankey_df['Value'])
    
    link = (dict(source=source, 
             target=target, 
             value=value))
    
    node = dict(label = labels)
    data = go.Sankey(link=link, node=node)
    #fig = go.Figure(data)
    
    return data

In [7]:
def generate_captures_heatmap(df, filter_list):
    '''Takes in DataFrame generated from capture queries, filters by color then cleans & preps for Sankey visuals.
    Returns Plotly object.'''
    
    pawns, rooks, knights, bishops, kings, queens = generate_pieces()
    
    x= ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']
    y = ['1', '2', '3', '4', '5', '6', '7', '8']
    
    if 'k1' in filter_list:
        pieces = ([rooks[0]] + [knights[0]] + [bishops[0]] + [queens[0]] +
                  [kings[0]] + [bishops[1]] + [knights[1]] + [rooks[1]] + pawns[:8]
                 )
    if 'k2' in filter_list:
        pieces = ([rooks[2]] + [knights[2]] + [bishops[2]] + [kings[1]] + 
                  [queens[1]] + [bishops[3]] + [knights[3]] + [rooks[3]] + pawns[8:]
                 )
        x.reverse()
        y.reverse()
    
    
    heatmap = df_prep(df, filter_list)
    heatmap = pd.DataFrame(heatmap.groupby('Source').sum())
    
    
    hm_data = []
    for p in pieces:
        hm_data.append(heatmap['Value'][p])
    hm_data = [hm_data[:8]]+[hm_data[8:]]
    hm_data += [[1]*8]*6

    fig = go.Heatmap(z=hm_data, x=x, y=y, xgap=1, ygap=1)

    return fig

In [8]:
def generate_death_heatmap(df, filter_list):
    '''Generates piece by piece heatmaps of where pieces go to die.'''
    
    pawns, rooks, knights, bishops, kings, queens = generate_pieces()
    
    x= ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H']
    y = ['1', '2', '3', '4', '5', '6', '7', '8']
    
    if 'k1' in filter_list:
        pieces = (pawns[:8] + [rooks[0]] + [knights[0]] + [bishops[0]] + [queens[0]] +
                  [kings[0]] + [bishops[1]] + [knights[1]] + [rooks[1]]
                 )
    if 'k2' in filter_list:
        pieces = (pawns[8:] + [rooks[2]] + [knights[2]] + [bishops[2]] + [kings[1]] + 
                  [queens[1]] + [bishops[3]] + [knights[3]] + [rooks[3]]
                 )
        x.reverse()
        y.reverse()
        df = df.reindex(index=df.index[::-1])
    
    figures = {}
    for piece in pieces:
        caps = df[piece].sum()
        array = np.array(df[piece]).reshape((8,8))
        data = np.divide(array, caps, where=array!=0)
        if 'k' in piece:
            data = np.zeros((8,8))
            figures[piece] = go.Heatmap(z=data, x=x, y=y, xgap=1, ygap=1, colorscale='Blues', showscale=False, zmin=0, zmax=1)
        else:    
            figures[piece] = go.Heatmap(z=data, x=x, y=y, xgap=1, ygap=1, colorscale='Blues', showscale=False)
    
    row1 = []
    for pawn in pieces[:8]:
        row1.append(figures[pawn])
        
    row2 = []
    for br_piece in pieces[8:]:
        row2.append(figures[br_piece])
    
    fig = make_subplots(rows=2, cols = 8)
    
    fig.add_traces(row1, rows=1, cols=list(range(1,9)))
    fig.add_traces(row2, rows=2, cols=list(range(1,9)))
    fig.update_layout(height=800, width = 2400, 
                      paper_bgcolor='#494952',
                      margin=dict(t=150),
                      title=dict(
                        text='Death of a Chess Piece', 
                        font=dict(
                            family='Droid Sans', 
                            size=36, 
                            color='white'), 
                        x=0.5, y=0.95),
                      font = dict(
                          family = 'Droid Sans',
                          color='white'
                      )
                     )
    
    return fig

In [9]:
def add_piece_img(sub_plot, piece_color='white'):
    '''Adds chess piece images to subplot layout. Specifically formatted to 2400 x 800 size plot'''
    
    w_pawn_img = 'https://upload.wikimedia.org/wikipedia/commons/0/04/Chess_plt60.png'
    w_rook_img = 'https://upload.wikimedia.org/wikipedia/commons/5/5c/Chess_rlt60.png'
    w_queen_img = 'https://upload.wikimedia.org/wikipedia/commons/4/49/Chess_qlt60.png'
    w_knight_img = 'https://upload.wikimedia.org/wikipedia/commons/2/28/Chess_nlt60.png'
    w_bishop_img = 'https://upload.wikimedia.org/wikipedia/commons/9/9b/Chess_blt60.png'
    w_king_img = 'https://upload.wikimedia.org/wikipedia/commons/3/3b/Chess_klt60.png'

    b_pawn_img = 'https://upload.wikimedia.org/wikipedia/commons/c/cd/Chess_pdt60.png'
    b_rook_img = 'https://upload.wikimedia.org/wikipedia/commons/a/a0/Chess_rdt60.png'
    b_queen_img = 'https://upload.wikimedia.org/wikipedia/commons/a/af/Chess_qdt60.png'
    b_knight_img = 'https://upload.wikimedia.org/wikipedia/commons/f/f1/Chess_ndt60.png'
    b_bishop_img = 'https://upload.wikimedia.org/wikipedia/commons/8/81/Chess_bdt60.png'
    b_king_img = 'https://upload.wikimedia.org/wikipedia/commons/e/e3/Chess_kdt60.png'
    
    piece_list=[]
    
    if piece_color=='white':
        piece_list = [w_rook_img, w_knight_img, w_bishop_img, w_queen_img, w_king_img, w_bishop_img, w_knight_img, w_rook_img, w_pawn_img]
        
    if piece_color=='black':
        piece_list = [b_rook_img, b_knight_img, b_bishop_img, b_king_img, b_queen_img, b_bishop_img, b_knight_img, b_rook_img, b_pawn_img]
    
    for num in range(0,8):
        sub_plot.add_layout_image(
                source=piece_list[-1],
                xref="paper", yref="paper",
                x=0.0516+(num*0.128), y=1.03,
                sizex=0.06, sizey=0.06,
                xanchor='center', yanchor='middle'
            )

    for num in range(0,8):
        sub_plot.add_layout_image(
            source = piece_list[num],
            xref='paper', yref='paper',
            x=0.0516+(num*0.128), y=0.457,
            sizex=0.06, sizey=0.06,
            xanchor='center', yanchor='middle'
            )
        
    return None

In [10]:
pawns, rooks, knights, bishops, kings, queens = generate_pieces()

w_filter = rooks[0:2] + knights[0:2] + bishops[0:2] + [queens[0]] + [kings[0]] + pawns[:8]
b_filter = rooks[2:] + knights[2:] + bishops[2:] + [kings[1]] + [queens[1]] + pawns[8:]

In [None]:
overall = pd.read_json('query results\cap_dict.txt')
twentyfive = pd.read_json('query results\cap_dict_elosum_0-2500.txt')
three = pd.read_json('query results\cap_dict_elosum_2500-3000.txt')
thirtyfive = pd.read_json('query results\cap_dict_elosum_3000-3500.txt')
four = pd.read_json('query results\cap_dict_elosum_3500-4000.txt')
ten = pd.read_json('query results\cap_dict_elosum_4000-10000.txt')

overall = generate_captures_heatmap(overall, w_filter)
twentyfive = generate_captures_heatmap(twentyfive, w_filter)
three = generate_captures_heatmap(three, w_filter)
thirtyfive = generate_captures_heatmap(thirtyfive, w_filter)
four = generate_captures_heatmap(four, w_filter)
ten = generate_captures_heatmap(ten, w_filter)


row1 = [overall, twentyfive, three]
row2 = [thirtyfive, four, ten]

fig = make_subplots(rows=2, cols=3, 
                   subplot_titles=('All Games', 'Elo 0-1250', 'Elo 1250-1500', 'Elo 1500-1750', 'Elo 1750-2000', 'Elo 2000+'))


avg = list(overall['z'])
for trace in row1 + row2:
    trace['z']= tuple(np.divide(np.array(list(trace['z'])),np.array(avg), where=np.array(avg)!=0).tolist())

fig.add_traces(row1, rows=1, cols=[1,2,3])
fig.add_traces(row2, rows=2, cols=[1,2,3])

fig.update_traces(zmin=0.5, zmax=1.5, colorscale='RdBu', reversescale=True)
fig.update_layout(height=750, width = 1000, title_text='Opponent Piece Captures Per Game Breakdown')

In [None]:
death_oa = generate_death_heatmap(pd.read_json('query results\death_dict.txt'), w_filter)
death_twentyfive = generate_death_heatmap(pd.read_json('query results\death_dict_elosum_0-2500.txt'), w_filter)
death_three = generate_death_heatmap(pd.read_json('query results\death_dict_elosum_2500-3000.txt'), w_filter)
death_thirtyfive = generate_death_heatmap(pd.read_json('query results\death_dict_elosum_3000-3500.txt'), w_filter)
death_four = generate_death_heatmap(pd.read_json('query results\death_dict_elosum_3500-4000.txt'), w_filter)
death_max = generate_death_heatmap(pd.read_json('query results\death_dict_elosum_4000-10000.txt'), w_filter)

death_list = [death_oa, death_twentyfive, death_three, death_thirtyfive, death_four, death_max]
subplot_titles=('All Games', 'Elo 0-1250', 'Elo 1250-1500', 'Elo 1500-1750', 'Elo 1750-2000', 'Elo 2000+')

index = 0
for figure in death_list:
    add_piece_img(figure, 'white')
    figure.update_layout(annotations = [dict(xref='paper',yref='paper',
                                        xanchor='center', yanchor='top',
                                        x=0.535, y=1.15,
                                        showarrow=False,
                                        text =F'<i>Averaged {subplot_titles[index]}<i>',
                                        font=dict(size=24)
                                           ), 
                                        dict(xref='paper', yref='paper',
                                        x=-0.03, y=-0.14,
                                        xanchor='left',  yanchor='bottom',
                                        text = '<i>~4.7 million games analyzed<br>Source: https://database.lichess.org/standard/lichess_db_standard_rated_2016-01.pgn.bz2<i>',
                                        showarrow=False,
                                        font=dict(
                                        family='Droid Sans', 
                                        size=14, 
                                        color='white')
                                         )])
    
    max_val = figure.data[0]['z'].max()
    figure.update_traces(showscale=True,
                         row=1, col=1, 
                         colorbar=dict(
                         title='Relative Frequency of<br>  Capture on Square',
                         titlefont = dict(size=18),
                         titleside='top',
                         tickvals=[0, max_val/2, max_val],
                         ticktext=['<i>None<i>', '<i>Occaisonal<i>', '<i>Most<i>'],
                         ticks='outside',
                         ticklabelposition='outside bottom',
                         tickfont= dict(size=16)))
    index+=1
    
death_oa.update_layout(annotations = [dict(xref='paper',yref='paper',
                                        x=0.535, y=1.15,
                                        showarrow=False,
                                        text =F'<i>All Games<i>',
                                        font=dict(size=24)
                                           ), 
                                        dict(xref='paper', yref='paper',
                                        x=-0.03, y=-0.14,
                                        xanchor='left',  yanchor='bottom',
                                        text = '<i>~4.7 million games analyzed<br>Source: https://database.lichess.org/standard/lichess_db_standard_rated_2016-01.pgn.bz2<i>',
                                        showarrow=False,
                                        font=dict(
                                        family='Droid Sans', 
                                        size=14, 
                                        color='white')
                                         )])

In [None]:
b_death_oa = generate_death_heatmap(pd.read_json('query results\death_dict.txt'), b_filter)
b_death_twentyfive = generate_death_heatmap(pd.read_json('query results\death_dict_elosum_0-2500.txt'), b_filter)
b_death_three = generate_death_heatmap(pd.read_json('query results\death_dict_elosum_2500-3000.txt'), b_filter)
b_death_thirtyfive = generate_death_heatmap(pd.read_json('query results\death_dict_elosum_3000-3500.txt'), b_filter)
b_death_four = generate_death_heatmap(pd.read_json('query results\death_dict_elosum_3500-4000.txt'), b_filter)
b_death_max = generate_death_heatmap(pd.read_json('query results\death_dict_elosum_4000-10000.txt'), b_filter)

b_death_list = [b_death_oa, b_death_twentyfive, b_death_three, b_death_thirtyfive, b_death_four, b_death_max]
subplot_titles=('All Games', 'Elo 0-1250', 'Elo 1250-1500', 'Elo 1500-1750', 'Elo 1750-2000', 'Elo 2000+')

index = 0
for figure in b_death_list:
    add_piece_img(figure, 'black')
    figure.update_layout(annotations = [dict(xref='paper',yref='paper',
                                        xanchor='center', yanchor='top',
                                        x=0.535, y=1.15,
                                        showarrow=False,
                                        text =F'<i>Averaged {subplot_titles[index]}<i>',
                                        font=dict(size=24)
                                           ), 
                                        dict(xref='paper', yref='paper',
                                        x=-0.03, y=-0.14,
                                        xanchor='left',  yanchor='bottom',
                                        text = '<i>~4.7 million games analyzed<br>Source: https://database.lichess.org/standard/lichess_db_standard_rated_2016-01.pgn.bz2<i>',
                                        showarrow=False,
                                        font=dict(
                                        family='Droid Sans', 
                                        size=14, 
                                        color='white')
                                         )])
    
    max_val = figure.data[0]['z'].max()
    figure.update_traces(showscale=True,
                         row=1, col=1, 
                         colorbar=dict(
                         title='Relative Frequency of<br>  Capture on Square',
                         titlefont = dict(size=18),
                         titleside='top',
                         tickvals=[0, max_val/2, max_val],
                         ticktext=['<i>None<i>', '<i>Occaisonal<i>', '<i>Most<i>'],
                         ticks='outside',
                         ticklabelposition='outside bottom',
                         tickfont= dict(size=16)))
    index+=1
    
    
b_death_oa.update_layout(annotations = [dict(xref='paper',yref='paper',
                                        x=0.535, y=1.15,
                                        showarrow=False,
                                        text =F'<i>All Games<i>',
                                        font=dict(size=24)
                                           ), 
                                        dict(xref='paper', yref='paper',
                                        x=-0.03, y=-0.14,
                                        xanchor='left',  yanchor='bottom',
                                        text = '<i>~4.7 million games analyzed<br>Source: https://database.lichess.org/standard/lichess_db_standard_rated_2016-01.pgn.bz2<i>',
                                        showarrow=False,
                                        font=dict(
                                        family='Droid Sans', 
                                        size=14, 
                                        color='white')
                                         )])


In [111]:
pieces = ([rooks[2]] + [knights[2]] + [bishops[2]] + [kings[1]] + 
                  [queens[1]] + [bishops[3]] + [knights[3]] + [rooks[3]]
                 )

new_fig = go.Figure()
data = []
for entry in b_death_oa.data[8:]:
    data.append(entry['z'].reshape(64))

index = 0
for entry in data:
    entry.sort()
    new_fig.add_scatter(y=entry[::-1].cumsum(), x=list(range(1,65)), name=pieces[index])
    index+=1

new_fig.update_layout(title_text='Cumulative Sum of Capture Frequencies, All Games')

In [112]:
new_fig

In [109]:
pieces = ([rooks[2]] + [knights[2]] + [bishops[2]] + [kings[1]] + 
                  [queens[1]] + [bishops[3]] + [knights[3]] + [rooks[3]]
                 )

color_list = ['black', 'blue', 'maroon', 'brown', 'blueviolet', 'red', 'navy', 'silver']


new_fig = make_subplots(rows=2, cols=3, subplot_titles=('All Games', 'Elo 0-1250', 'Elo 1250-1500', 'Elo 1500-1750', 'Elo 1750-2000', 'Elo 2000+'))
fig_index = 0

for figure in b_death_list:
    scatter_data = []
    
    data = []
    for entry in figure.data[8:]:
        data.append(entry['z'].reshape(64))

    
    for entry in data:
        entry.sort()
        scatter_data.append(entry[::-1].cumsum())
    
    index = 0
    for entry in scatter_data:
        new_fig.append_trace(go.Scatter(x=list(range(1,65)), y=entry, name=pieces[index], line_color=color_list[index]), row=(fig_index//3)+1, col=(fig_index%3)+1)
        index+=1
    fig_index += 1

new_fig.update_layout(title_text='Cumulative Sum of Capture Frequencies')

In [110]:
new_fig.update_traces(showlegend=False)
new_fig.update_traces(showlegend=True,
                         row=1, col=1)

In [85]:
new

Scatter()

In [89]:
help(go.Scatter)

Help on class Scatter in module plotly.graph_objs._scatter:

class Scatter(plotly.basedatatypes.BaseTraceType)
 |  Scatter(arg=None, cliponaxis=None, connectgaps=None, customdata=None, customdatasrc=None, dx=None, dy=None, error_x=None, error_y=None, fill=None, fillcolor=None, groupnorm=None, hoverinfo=None, hoverinfosrc=None, hoverlabel=None, hoveron=None, hovertemplate=None, hovertemplatesrc=None, hovertext=None, hovertextsrc=None, ids=None, idssrc=None, legendgroup=None, line=None, marker=None, meta=None, metasrc=None, mode=None, name=None, opacity=None, orientation=None, r=None, rsrc=None, selected=None, selectedpoints=None, showlegend=None, stackgaps=None, stackgroup=None, stream=None, t=None, text=None, textfont=None, textposition=None, textpositionsrc=None, textsrc=None, texttemplate=None, texttemplatesrc=None, tsrc=None, uid=None, uirevision=None, unselected=None, visible=None, x=None, x0=None, xaxis=None, xcalendar=None, xperiod=None, xperiod0=None, xperiodalignment=None, xsrc