# Data Exploration

In [None]:
from fax.config import get_cfg_defaults
from fax.utils.constants import CHARACTER_ID_TO_NAME, STAGE_ID_TO_NAME
import pandas as pd
import numpy as np
from PIL import Image
import sqlite3
import matplotlib.pyplot as plt; plt.style.use('seaborn-v0_8')
import matplotlib.image as mpimg
from matplotlib.offsetbox import OffsetImage, AnnotationBbox

In [None]:
cfg = get_cfg_defaults()

def average_color(image_paths):
    """Compute the average color of a list of image paths."""
    avg_color = np.zeros(3)
    for path in image_paths:
        img = Image.open(path).convert('RGB')
        img_array = np.array(img)
        avg_color += img_array.mean(axis=(0, 1))
    avg_color /= len(image_paths)
    rgb = tuple(avg_color.astype(int))
    return '#%02x%02x%02x' % rgb

In [None]:
conn = sqlite3.connect(cfg.paths.sql)

query_0f = """
SELECT char, COUNT(*) as count
FROM (
    SELECT p1char AS char FROM replays WHERE bucket = 0
    UNION ALL
    SELECT p2char AS char FROM replays WHERE bucket = 0
)
GROUP BY char
ORDER BY count DESC;
"""

# map IDs → names
df_0fc = pd.read_sql_query(query_0f, conn)
df_0fc["char"] = df_0fc["char"].map(CHARACTER_ID_TO_NAME)
df_0fc = df_0fc.sort_values(by="count", ascending=False).reset_index(drop=True)

# collapse into top 10 + "others"
top_n = 10
n_collapsed = len(df_0fc) - top_n
df_0fc = pd.concat(
    [
        df_0fc.head(top_n),
        pd.DataFrame([{"char": f"others ({n_collapsed})", "count": df_0fc["count"].iloc[top_n:].sum()}]),
    ],
    ignore_index=True,
)

# add percentage and color columns
df_0fc['popularity'] = df_0fc['count'] / df_0fc['count'].sum() * 100
df_0fc['color'] = df_0fc['char'].map(lambda char: average_color([f'../../data/icons/chars/{char}.png']) if "others" not in char else '#cccccc')

In [None]:
fig, ax = plt.subplots()
ax.bar(df_0fc['char'], df_0fc['popularity'], color=df_0fc['color'], edgecolor='black', linewidth=1)
ax.set_xticks(df_0fc['char'])
ax.set_xticklabels(df_0fc['char'].map(lambda x: x.lower()), rotation=45, ha='right')
ax.set_title("Character distribution in XvX dataset")

# add icons above bars
for i, row in df_0fc.iterrows():
    if "others" in row["char"]:
        continue
    img = mpimg.imread(f'../../data/icons/chars/{row["char"]}.png')
    imagebox = OffsetImage(img, zoom=1)  # adjust zoom for size
    ab = AnnotationBbox(
        imagebox,
        (i, row['popularity'] + 1),  # (x, y) position in data coords
        frameon=False,
        box_alignment=(0.5, 0)    # center align above bar
    )
    ax.add_artist(ab)

ax.set_ylim(0, df_0fc['popularity'].max() + 5) # add some space for icons
ax.set_ylabel("Popularity (%)")
fig.tight_layout()
fig.savefig("../../figs/char-dist-xvx.png", dpi=500)

In [None]:
query_1f = """
SELECT opponent AS char, COUNT(*) AS count
FROM (
    SELECT CASE 
             WHEN p1char = 2 THEN p2char   -- Fox id = 2 in your CHARACTER_NAME_TO_ID
             WHEN p2char = 2 THEN p1char
           END AS opponent
    FROM replays
    WHERE bucket = 1
      AND (p1char = 2 OR p2char = 2)
) AS sub
GROUP BY opponent
ORDER BY count DESC;
"""

# map IDs → names
df_1fc = pd.read_sql_query(query_1f, conn)
df_1fc["char"] = df_1fc["char"].map(CHARACTER_ID_TO_NAME)
df_1fc = df_1fc.sort_values(by="count", ascending=False).reset_index(drop=True)

# collapse into top 10 + "others"
top_n = 10
n_collapsed = len(df_1fc) - top_n
df_1fc = pd.concat(
    [
        df_1fc.head(top_n),
        pd.DataFrame([{"char": f"others ({n_collapsed})", "count": df_1fc["count"].iloc[top_n:].sum()}]),
    ],
    ignore_index=True,
)

# add percentage and color columns
df_1fc['popularity'] = df_1fc['count'] / df_1fc['count'].sum() * 100
df_1fc['color'] = df_1fc['char'].map(lambda char: average_color([f'../../data/icons/chars/{char}.png']) if "others" not in char else '#cccccc')

In [None]:
fig, ax = plt.subplots()
ax.bar(df_1fc['char'], df_1fc['popularity'], color=df_1fc['color'], edgecolor='black', linewidth=1)
ax.set_xticks(df_1fc['char'])
ax.set_xticklabels(df_1fc['char'].map(lambda x: x.lower()), rotation=45, ha='right')
ax.set_title("Character distribution in XvF / FvX dataset")

# add icons above bars
for i, row in df_1fc.iterrows():
    if "others" in row["char"]:
        continue
    img = mpimg.imread(f'../../data/icons/chars/{row["char"]}.png')
    imagebox = OffsetImage(img, zoom=1)  # adjust zoom for size
    ab = AnnotationBbox(
        imagebox,
        (i, row['popularity'] + 1),  # (x, y) position in data coords
        frameon=False,
        box_alignment=(0.5, 0)    # center align above bar
    )
    ax.add_artist(ab)

ax.set_ylim(0, df_1fc['popularity'].max() + 5) # add some space for icons
ax.set_ylabel("Popularity (%)")
fig.tight_layout()
fig.savefig("../../figs/char-dist-xvf.png", dpi=500)

In [None]:
query = """
SELECT bucket, stage, COUNT(*) as count
FROM replays
GROUP BY bucket, stage
ORDER BY bucket, stage;
"""

df_stage = pd.read_sql_query(query, conn)

# pivot into wide form
df_stage_pivot = df_stage.pivot(index="bucket", columns="stage", values="count").fillna(0).astype(int)
df_stage_pivot.columns = df_stage_pivot.columns.map(STAGE_ID_TO_NAME)
df_stage_pivot = df_stage_pivot.reset_index()
df_stage_pivot

In [None]:
query = """
SELECT bucket, stage, COUNT(*) as count
FROM replays
GROUP BY bucket, stage
ORDER BY bucket, stage;
"""

df_stage = pd.read_sql_query(query, conn)
df_stage["stage"] = df_stage["stage"].map(STAGE_ID_TO_NAME)
df_stage["color"] = df_stage["stage"].map(lambda stage: average_color([f'../../data/icons/stages/{stage}.png']))
df_stage["popularity"] = df_stage["count"] / df_stage["count"].sum() * (300)  # we have three buckets

# make one df per bucket
buckets = {0: "XvX", 1: "XvF / FvX", 2: "FvF"}

fig, axes = plt.subplots(1, 3, figsize=(12, 6), sharey=True)

for i, (bucket, title) in enumerate(buckets.items()):
    ax = axes[i]
    df_b = df_stage[df_stage["bucket"] == bucket]
    ax.bar(df_b["stage"], df_b["popularity"], color=df_b["color"], edgecolor='black', linewidth=1)
    ax.set_title(title)
    ax.set_xticks(df_b["stage"])
    ax.set_xticklabels(df_b["stage"].map(lambda x: x.replace('_', ' ').lower()), rotation=45, ha='right')

fig.axes[0].set_ylabel("Popularity (%)", fontsize=10)
fig.suptitle("Stage distribution per dataset", fontsize=16)
fig.tight_layout()
fig.savefig("../../figs/stage-dist.png", dpi=500)

In [None]:
query = """
SELECT AVG(
    CASE
        WHEN p1char = 2 AND winner = 1 THEN 1
        WHEN p2char = 2 AND winner = 2 THEN 1
        ELSE 0
    END
) AS fox_winrate
FROM replays
WHERE bucket = 1
  AND (p1char = 2 OR p2char = 2);
"""
print(f"Fox winrate in XvF / FvX dataset: {conn.execute(query).fetchone()[0]:.2%}")

query = """
SELECT AVG(CASE WHEN winner = 1 THEN 1 ELSE 0 END) AS p1_winrate
FROM replays
WHERE bucket = 2;
"""
print(f"P1 winrate in FvF dataset: {conn.execute(query).fetchone()[0]:.2%}")

query = """
SELECT AVG(CASE WHEN winner = 1 THEN 1 ELSE 0 END) AS p1_winrate
FROM replays
WHERE bucket = 0;
"""
print(f"P1 winrate in XvX dataset: {conn.execute(query).fetchone()[0]:.2%}")