In [4]:
import duckdb
import matplotlib.pyplot as plt
import json
from create_database import (
    add_initial_tables,
    add_comments_to_comments_tables,
    create_lookup_table,
    create_subreddit_tables,
    create_threads_table,
    sort_key,
)
from stats import (
    create_row_counts_table,
    get_depth_distribution,
    get_thread_score_distribution,
    get_subreddit_distribution,
    table_stats,
    calculate_weighted_average,
    get_thread_lengths,
)
from filter_database import make_threads_unique, filter_threads

# db = "../data_scc/database_subset10.db"
# con = duckdb.connect(db)
saved_stats = "../data/saved_stats.json"

In [5]:
# Define more specific groupings
category_groups = {
    "depth": {
        "thread_types": [
            "depth_distribution_threads",
            "depth_distribution_threads_viral",
            "depth_distribution_threads_non_viral",
        ],
        "author_counts": [
            "depth_distribution_threads",
            "depth_distribution_threads_2_authors",
            "depth_distribution_threads_3_authors",
            "depth_distribution_threads_4_authors",
            "depth_distribution_threads_5_authors",
        ],
        "subreddits": [
            "depth_distribution_threads",
            "depth_distribution_AskReddit_threads",
            "depth_distribution_memes_threads",
            "depth_distribution_distantsocializing_threads",
            "depth_distribution_ACTrade_threads",
            "depth_distribution_RedditSessions_threads",
        ],
    },
    "author_distribution": {
        # Similar groupings for author_distribution stats
        "thread_types": [
            "author_distribution_threads",
            "author_distribution_threads_viral",
            "author_distribution_threads_non_viral",
        ],
        "subreddits": [
            "author_distribution_threads",
            "author_distribution_AskReddit_threads",
            "author_distribution_memes_threads",
            "author_distribution_distantsocializing_threads",
            "author_distribution_ACTrade_threads",
            "author_distribution_RedditSessions_threads",
        ],
    },
    "lengths": {
        # Similar groupings for lengths stats
        "thread_types": [
            "thread_lengths_threads",
            "thread_lengths_threads_viral",
            "thread_lengths_threads_non_viral",
        ],
        "author_counts": [
            "thread_lengths_threads",
            "thread_lengths_threads_2_authors",
            "thread_lengths_threads_3_authors",
            "thread_lengths_threads_4_authors",
            "thread_lengths_threads_5_authors",
        ],
        "subreddits": [
            "thread_lengths_threads",
            "thread_lengths_AskReddit_threads",
            "thread_lengths_memes_threads",
            "thread_lengths_distantsocializing_threads",
            "thread_lengths_ACTrade_threads",
            "thread_lengths_RedditSessions_threads",
        ],
        "lookup": [
            "thread_lengths_threads",
            "thread_lengths_lookup_table",
            "thread_lengths_AskReddit_lookup",
            "thread_lengths_memes_lookup",
            "thread_lengths_distantsocializing_lookup",
            "thread_lengths_ACTrade_lookup",
            "thread_lengths_RedditSessions_lookup",
        ],
    },
    "widths": {
        "lookup": [
            "thread_widths_lookup_table",
            "thread_widths_AskReddit_lookup",
            "thread_widths_memes_lookup",
            "thread_widths_distantsocializing_lookup",
            "thread_widths_ACTrade_lookup",
            "thread_widths_RedditSessions_lookup",
        ],
    },
    "score": {
        # Similar groupings for score stats
        "thread_types": [
            "thread_score_distribution_threads",
            "thread_score_distribution_threads_viral",
            "thread_score_distribution_threads_non_viral",
        ],
        "author_counts": [
            "thread_score_distribution_threads",
            "thread_score_distribution_threads_2_authors",
            "thread_score_distribution_threads_3_authors",
            "thread_score_distribution_threads_4_authors",
            "thread_score_distribution_threads_5_authors",
        ],
        "subreddits": [
            "thread_score_distribution_threads",
            "thread_score_distribution_AskReddit_threads",
            "thread_score_distribution_memes_threads",
            "thread_score_distribution_distantsocializing_threads",
            "thread_score_distribution_ACTrade_threads",
            "thread_score_distribution_RedditSessions_threads",
        ],
    },
}

In [6]:
import json
import matplotlib.pyplot as plt
import os
import numpy as np
from collections import defaultdict

# Make sure the plots directory exists
os.makedirs("../plots", exist_ok=True)

# Define a professional color palette with good distinguishability
PROFESSIONAL_COLORS = [
    "#1f77b4",  # muted blue
    "#ff7f0e",  # safety orange
    "#2ca02c",  # cooked asparagus green
    "#d62728",  # brick red
    "#9467bd",  # muted purple
    "#8c564b",  # chestnut brown
    "#e377c2",  # raspberry yogurt pink
    "#7f7f7f",  # middle gray
    "#bcbd22",  # curry yellow-green
    "#17becf",  # blue-teal
]

# Load data
with open(saved_stats, "r") as f:
    data = json.load(f)


def process_numerical_xaxis(x, y, key):
    """Process numerical x-axis values to ensure complete range"""
    try:
        # Convert x values to integers
        x_int = [int(k) for k in x]

        # For score data, filter to desired range
        if "score" in key:
            filtered = [(xi, yi) for xi, yi in zip(x_int, y) if -50 <= xi <= 200]
            x_int = [xi for xi, _ in filtered]
            y = [yi for _, yi in filtered]

        # Create complete range
        min_x, max_x = min(x_int), max(x_int)
        full_range = range(min_x, max_x + 1)

        # Create new y values with zeros for missing x values
        value_dict = dict(zip(x_int, y))
        new_y = [value_dict.get(xi, 0) for xi in full_range]

        return [str(xi) for xi in full_range], new_y
    except ValueError:
        return x, y


# Process each statistic group
for category, groups in category_groups.items():
    for group_name, keys in groups.items():
        # Check if we have any of these keys in our data
        available_keys = [k for k in keys if k in data]
        if not available_keys:
            continue

        # Create figure for this group
        plt.figure(figsize=(12, 8))

        # To track legend entries
        legend_handles = []
        legend_labels = []

        # Process each statistic in this group
        for i, key in enumerate(available_keys):
            values = data[key]

            # Skip non-dictionary values
            if not isinstance(values, dict):
                continue

            x = list(values.keys())
            y = list(values.values())

            # Process numerical x-axis
            if all(k.lstrip("-").isdigit() for k in x):
                x, y = process_numerical_xaxis(x, y, key)
            else:
                # For non-numerical keys, use original processing
                if "subreddit" in key:
                    x = x[:10]
                    y = y[:10]
                elif "score" in key:
                    pass
                else:
                    x = sorted(x, key=sort_key)[:50]
                    y = [values[k] for k in x]

            # Check if values are numerical
            try:
                y = [float(v) for v in y]
            except (ValueError, TypeError):
                print(f"Skipping {key} - values are not numerical")
                continue

            # If a y is negative, set it to 0
            y = [max(0, v) for v in y]

            # Create line plot with professional color
            color = PROFESSIONAL_COLORS[i % len(PROFESSIONAL_COLORS)]
            (line,) = plt.plot(
                x,
                y,
                color=color,
                marker="o",
                linestyle="-",
                linewidth=2,
                markersize=6,
                alpha=0.8,
                label=key,
            )

            # Add to legend
            legend_handles.append(line)
            legend_labels.append(key.replace("_", " ").title())

        # Customize plot
        plt.title(
            f"{category.capitalize()} - {group_name.replace('_', ' ').capitalize()}",
            fontsize=14,
            pad=20,
        )
        plt.xlabel("Categories", fontsize=12)
        plt.yscale("log")
        plt.ylabel("Values (log scale)", fontsize=12)
        plt.grid(True, which="both", ls="--", alpha=0.3)

        # Set x-ticks based on the data
        if all(k.lstrip("-").isdigit() for k in x):
            # For numerical x-axis, show every nth tick based on range
            x_vals = [int(k) for k in x]
            step = max(1, len(x_vals) // 10)  # Show about 10 ticks
            plt.xticks(
                [x[i] for i in range(0, len(x), step)],
                rotation=45,
                ha="right",
                fontsize=10,
            )
        else:
            # For non-numerical x-axis
            plt.xticks(rotation=45, ha="right", fontsize=10)

        # Add legend
        plt.legend(
            handles=legend_handles,
            labels=legend_labels,
            loc="best",
            fontsize="small",
            framealpha=1,
        )

        plt.tight_layout()

        # Save plot
        safe_group_name = "".join(c if c.isalnum() else "_" for c in group_name)
        output_path = f"../plots/{category}_{safe_group_name}.png"
        plt.savefig(output_path, dpi=300, bbox_inches="tight")
        plt.close()

# Process remaining statistics that don't fit into our groups
for key, values in data.items():
    # Skip if we've already processed this key in our groups
    processed = False
    for category, groups in category_groups.items():
        for group_keys in groups.values():
            if key in group_keys and key not in [
                "depth_distribution_threads",
                "author_distribution_threads",
                "thread_lengths_threads",
                "thread_score_distribution_threads",
            ]:
                processed = True
                break
        if processed:
            break
    if processed:
        continue

    # Handle non-dictionary values
    if not isinstance(values, dict):
        print(f"{key}: {values}")
        continue

    # Handle dictionaries with only one value
    if len(values) == 1 or "variance" in key:
        print(f"{key}: {values}")
        continue

    # Prepare figure for individual stat
    plt.figure(figsize=(10, 6))

    x = list(values.keys())
    y = list(values.values())

    # Process numerical x-axis
    if all(k.lstrip("-").isdigit() for k in x):
        x, y = process_numerical_xaxis(x, y, key)
    else:
        # For non-numerical keys, use original processing
        if "subreddit" in key:
            x = x[:10]
            y = y[:10]
        elif "score" in key:
            pass
        else:
            x = sorted(x, key=sort_key)[:50]
            y = [values[k] for k in x]

    # Check if values are numerical
    try:
        y = [float(v) for v in y]
    except (ValueError, TypeError):
        print(f"Skipping {key} - values are not numerical")
        continue

    # If a y is negative, set it to 0
    y = [max(0, v) for v in y]

    # Create bar plot for individual stats (keeping as bars for single distributions)
    bars = plt.bar(x, y, color=PROFESSIONAL_COLORS[0], alpha=0.7)

    # Customize plot
    plt.title(key.replace("_", " ").title(), fontsize=14)
    plt.xlabel("Categories", fontsize=12)
    plt.ylabel("Values", fontsize=12)
    plt.grid(True, which="both", ls="--", alpha=0.3)

    # Set x-ticks
    if all(k.lstrip("-").isdigit() for k in x):
        step = max(1, len(x) // 10)
        plt.xticks(
            [x[i] for i in range(0, len(x), step)], rotation=45, ha="right", fontsize=10
        )
    else:
        plt.xticks(rotation=45, ha="right", fontsize=10)

    plt.tight_layout()

    # Save plot (using original naming convention)
    safe_key = "".join(c if c.isalnum() else "_" for c in key)
    output_path = f"../plots/{safe_key}.png"
    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    plt.close()

print("All visualizations complete!")

all_widths_lookup_table: {'1': 41974531}
thread_lengths_lookup_table_weighted_average: 1.506064277275699
thread_widths_lookup_table_weighted_average: 0.5930788083154582
all_widths_lookup_table_weighted_average: 0.9999999999976176
number_of_threads_threads: 56245856
depth_distribution_threads_weighted_average: 1.8369052946157853
author_distribution_threads_weighted_average: 1.3962651932946024
thread_score_distribution_threads_weighted_average: 11.325471106686818
thread_score_distribution_threads_median: 2.0
thread_lengths_threads_weighted_average: 1.6613872865411554
number_of_threads_threads_2_authors: 1997988
depth_distribution_threads_2_authors_weighted_average: 3.3153060742538276
thread_score_distribution_threads_2_authors_weighted_average: 14.365053242844048
thread_score_distribution_threads_2_authors_median: 8.0
thread_lengths_threads_2_authors_weighted_average: 5.305259411776219
number_of_threads_threads_3_authors: 66659
depth_distribution_threads_3_authors_weighted_average: 5.892

In [None]:
import json
import matplotlib.pyplot as plt
import os
from collections import defaultdict

# Make sure the plots directory exists
os.makedirs("../plots", exist_ok=True)


# Load data
with open(saved_stats, "r") as f:
    data = json.load(f)

# Process each statistic group
for category, groups in category_groups.items():
    for group_name, keys in groups.items():
        # Check if we have any of these keys in our data
        available_keys = [k for k in keys if k in data]
        if not available_keys:
            continue

        # Create figure for this group
        plt.figure(figsize=(12, 8))

        # Get a color map for different statistics
        colors = plt.cm.tab10.colors

        # To track legend entries
        legend_handles = []
        legend_labels = []

        # Process each statistic in this group
        for i, key in enumerate(available_keys):
            values = data[key]

            # Skip non-dictionary values
            if not isinstance(values, dict):
                continue

            x = list(values.keys())
            y = list(values.values())

            # Use the same x-axis processing logic as in original code
            try:
                sorted_keys = sorted(x, key=lambda k: int(k))
                if "score" in key:
                    # x should be in the range of -50 to 200 and only according y values should be kept
                    x = [k for k in x if int(k) >= -50 and int(k) <= 200]
                    y = [values[k] for k in x]
                else:
                    x = sorted_keys[:50]
                    y = [values[k] for k in x]
            except ValueError:
                pass

            if not all(k.isdigit() for k in x):
                # If key contains "subreddit", only keep the first 10 values
                if "subreddit" in key:
                    x = x[:10]
                    y = y[:10]
                elif "score" in key:
                    pass
                else:
                    # Sort by sort_key
                    x = sorted(x, key=sort_key)[:50]
                    y = [values[k] for k in x]

            # Check if values are numerical
            try:
                y = [float(v) for v in y]
            except (ValueError, TypeError):
                print(f"Skipping {key} - values are not numerical")
                continue

            # If a y is negative, set it to 0
            y = [max(0, v) for v in y]

            # Create bar plot with alpha transparency and color
            color = colors[i % len(colors)]
            bars = plt.bar(x, y, alpha=0.6, color=color, label=key)

            # Add to legend
            legend_handles.append(bars[0])
            legend_labels.append(key)

        # Customize plot
        plt.title(
            f"{category.capitalize()} - {group_name.replace('_', ' ').capitalize()}"
        )
        plt.xlabel("Categories")
        plt.ylabel("Values")
        plt.yscale("log")

        # Set x-ticks based on the first dataset (assuming similar x values)
        # This follows the original logic for x-ticks
        if "score" in category:
            plt.xticks(x[::10], rotation=45, ha="right")
        else:
            plt.xticks(x, rotation=45, ha="right")

        # Add legend
        plt.legend(
            handles=legend_handles, labels=legend_labels, loc="best", fontsize="small"
        )

        plt.tight_layout()

        # Save plot
        safe_group_name = "".join(c if c.isalnum() else "_" for c in group_name)
        output_path = f"../plots/{category}_{safe_group_name}.png"
        plt.savefig(output_path)
        plt.close()

# Process remaining statistics that don't fit into our groups
for key, values in data.items():
    # Skip if we've already processed this key in our groups
    processed = False
    for category, groups in category_groups.items():
        for group_keys in groups.values():
            if key in group_keys:
                processed = True
                break
        if processed:
            break
    if processed:
        continue

    # Handle non-dictionary values
    if not isinstance(values, dict):
        print(f"{key}: {values}")
        continue

    # Handle dictionaries with only one value
    if len(values) == 1:
        print(f"{key}: {values}")
        continue

    # Prepare figure for individual stat
    plt.figure(figsize=(10, 6))
    x = list(values.keys())
    y = list(values.values())

    # Original logic for handling keys
    try:
        sorted_keys = sorted(x, key=lambda k: int(k))
        if "score" in key:
            # x should be in the range of -50 to 200 and only according y values should be kept
            x = [k for k in x if int(k) >= -50 and int(k) <= 200]
            y = [values[k] for k in x]
            pass
        else:
            x = sorted_keys[:50]
            y = [values[k] for k in x]
    except ValueError:
        pass

    if not all(k.isdigit() for k in x):
        # If key contains "subreddit", only keep the first 10 values
        if "subreddit" in key:
            x = x[:10]
            y = y[:10]
        elif "score" in key:
            pass
        else:
            # Sort by sort_key
            x = sorted(x, key=sort_key)[:50]
            y = [values[k] for k in x]

    # Check if values are numerical
    try:
        y = [float(v) for v in y]
    except (ValueError, TypeError):
        print(f"Skipping {key} - values are not numerical")
        continue

    # If a y is negative, set it to 0
    y = [max(0, v) for v in y]

    # Create bar plot
    bars = plt.bar(x, y)

    # Customize plot
    plt.title(key)
    plt.xlabel("Categories")
    plt.ylabel("Values")
    plt.yscale("log")
    plt.Colormap("viridis")

    # if "score" in key, x ticks should only show every 10th value
    if "score" in key:
        plt.xticks(x[::10], rotation=45, ha="right")
    else:
        plt.xticks(x, rotation=45, ha="right")

    plt.tight_layout()

    # Save plot (using original naming convention)
    safe_key = "".join(c if c.isalnum() else "_" for c in key)
    output_path = f"../plots/{safe_key}.png"
    plt.savefig(output_path)
    plt.close()

print("All visualizations complete!")

all_widths_lookup_table: {'1': 41974531}
thread_lengths_lookup_table_weighted_average: 1.506064277275699
thread_widths_lookup_table_weighted_average: 0.5930788083154582
all_widths_lookup_table_weighted_average: 0.9999999999976176
number_of_threads_threads: 56245856
depth_distribution_threads_weighted_average: 1.8369052946157853
author_distribution_threads_weighted_average: 1.3962651932946024
thread_score_distribution_threads_weighted_average: 11.325471106686818
thread_score_distribution_threads_median: 2.0
thread_lengths_threads_weighted_average: 1.6613872865411554
number_of_threads_threads_2_authors: 1997988
depth_distribution_threads_2_authors_weighted_average: 3.3153060742538276
thread_score_distribution_threads_2_authors_weighted_average: 14.365053242844048
thread_score_distribution_threads_2_authors_median: 8.0
thread_lengths_threads_2_authors_weighted_average: 5.305259411776219
number_of_threads_threads_3_authors: 66659
depth_distribution_threads_3_authors_weighted_average: 5.892

In [None]:
with open(saved_stats, "r") as f:
    data = json.load(f)

# Process each statistic
for key, values in data.items():
    if not isinstance(values, dict):
        print(f"{key}: {values}")
        continue
    if len(values) == 1:
        print(f"{key}: {values}")
        continue
    # Prepare figure
    plt.figure(figsize=(10, 6))
    x = list(values.keys())
    y = list(values.values())

    # if keys can be converted to int, sort them
    try:
        sorted_keys = sorted(x, key=lambda k: int(k))
        if "score" in key:
            # x should be in the range of -50 to 200 and only according y values should be kept
            x = [k for k in x if int(k) >= -50 and int(k) <= 200]
            y = [values[k] for k in x]
            pass
        else:
            x = sorted_keys[:50]
            y = [values[k] for k in x]
    except ValueError:
        pass

    if not all(k.isdigit() for k in x):
        # If key contains "subreddit", only keep the first 10 values
        if "subreddit" in key:
            x = x[:10]
            y = y[:10]
        elif "score" in key:
            pass
        else:
            # Sort by sort_key
            x = sorted(x, key=sort_key)[:50]
            y = [values[k] for k in x]

    # Check if values are numerical
    try:
        y = [float(v) for v in y]
    except (ValueError, TypeError):
        print(f"Skipping {key} - values are not numerical")
        continue
    # If a y is negative, set it to 0
    y = [max(0, v) for v in y]
    # Create bar plot
    bars = plt.bar(x, y)

    # Add value labels on top of each bar
    """for bar in bars:
        height = bar.get_height()
        plt.text(
            bar.get_x() + bar.get_width() / 2.0,
            height,
            f"{height:.2f}",
            ha="center",
            va="bottom",
        )"""

    # Customize plot
    plt.title(key)
    plt.xlabel("Categories")
    plt.ylabel("Values")
    plt.yscale("log")
    # if "score" in key, x ticks should only show every 10th value
    if "score" in key:
        plt.xticks(x[::10], rotation=45, ha="right")
    else:
        plt.xticks(x, rotation=45, ha="right")
    plt.tight_layout()

    # Save plot
    safe_key = "".join(c if c.isalnum() else "_" for c in key)
    output_path = f"../plots/{safe_key}.png"
    plt.savefig(output_path)
    plt.close()

all_widths_lookup_table: {'1': 41974531}
thread_lengths_lookup_table_weighted_average: 1.506064277275699
thread_widths_lookup_table_weighted_average: 0.5930788083154582
all_widths_lookup_table_weighted_average: 0.9999999999976176
number_of_threads_threads: 56245856
depth_distribution_threads_weighted_average: 1.8369052946157853
author_distribution_threads_weighted_average: 1.3962651932946024
thread_score_distribution_threads_weighted_average: 11.325471106686818
thread_score_distribution_threads_median: 2.0
thread_lengths_threads_weighted_average: 1.6613872865411554
number_of_threads_threads_2_authors: 1997988
depth_distribution_threads_2_authors_weighted_average: 3.3153060742538276
thread_score_distribution_threads_2_authors_weighted_average: 14.365053242844048
thread_score_distribution_threads_2_authors_median: 8.0
thread_lengths_threads_2_authors_weighted_average: 5.305259411776219
number_of_threads_threads_3_authors: 66659
depth_distribution_threads_3_authors_weighted_average: 5.892