# Data exploration

## Imports and loading

In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import plotly.express as px
import time
import ast
from collections import defaultdict
from typing import Optional, Tuple, Dict, Any
from plotly.graph_objects import Figure
import textwrap

In [None]:
df = pd.read_parquet('arxiv_cleaned.parquet', engine='pyarrow')
category_df = pd.read_parquet('categories.parquet', engine='pyarrow')

## Categories

In [None]:
# all_categories = df['categories'].str.split().explode()
# category_classes = []
# subcategory_data = []

# for category in all_categories:
#     if '.' in category:
#         # Split into category class and subcategory
#         category_class, subcategory = category.split('.', 1)
#     else:
#         # If no dot, the entire category is the category class
#         category_class, subcategory = category, None
    
#     category_classes.append(category_class)
#     subcategory_data.append({'category_class': category_class, 'subcategory': subcategory})


# subcategory_df = pd.DataFrame(subcategory_data)

# #  Plot for category class counts
# category_class_counts = pd.Series(category_classes).value_counts()
# plt.figure(figsize=(10, 6))
# sns.barplot(x=category_class_counts.index, y=category_class_counts.values, palette="viridis")
# plt.title("Counts of Category Classes")
# plt.xlabel("Category Class")
# plt.ylabel("Count")
# plt.xticks(rotation=45)
# plt.show()

# # Plots for Subcategories
# # Count the number of records per subcategory within each category class
# subcategory_counts = subcategory_df.groupby(['category_class', 'subcategory']).size().reset_index(name='count')

# # Create facet-wrapped bar plots
# g = sns.FacetGrid(subcategory_counts, col="category_class", col_wrap=4, sharex=False, sharey=False, height=4)
# g.map(sns.barplot, "subcategory", "count", order=None, palette="viridis")
# g.set_titles("{col_name}")
# g.set_xticklabels(rotation=45)
# g.set_axis_labels("Subcategory", "Count")
# plt.subplots_adjust(top=0.9)
# g.fig.suptitle("Counts of Records per Subcategory")
# plt.show()

In [None]:
# Plotly Category Plot
# --- 1. Flatten Data: Extract Timestamps and Category Classes ---
all_timestamps = []
all_category_classes = []
errors_parsing_versions = 0

print("\nStarting data flattening and extraction...")
# Use itertuples for performance
for row in df.itertuples(index=False, name='Record'):
    category_str = getattr(row, 'categories', None)
    versions_data = getattr(row, 'versions', None)
    versions_list = None # Reset for each row

    # --- MODIFIED SECTION TO HANDLE NUMPY ARRAY ---
    if isinstance(versions_data, str) and versions_data.startswith('[') and versions_data.endswith(']'):
        # Safely parse 'versions' if it's a string representation of a list
        try:
            versions_list = ast.literal_eval(versions_data)
            if not isinstance(versions_list, list):
                 versions_list = None
                 errors_parsing_versions += 1
        except (ValueError, SyntaxError, MemoryError):
            versions_list = None
            errors_parsing_versions += 1
    elif isinstance(versions_data, list):
        # It's already a Python list
        versions_list = versions_data
    elif isinstance(versions_data, np.ndarray): # <--- ADDED CHECK FOR NUMPY ARRAY
        # Convert NumPy array of objects (dicts) to a Python list
        versions_list = versions_data.tolist()
    # --- END MODIFIED SECTION ---


    # Process if versions and categories are valid
    if versions_list and isinstance(category_str, str) and category_str.strip():
        categories = category_str.split()
        for category in categories:
            category = category.strip()
            if not category: continue

            # Extract main category class (before first '.')
            category_class = category.split('.', 1)[0] if '.' in category else category

            # Extract timestamps for this category class
            for version_info in versions_list:
                # Check if version_info is a dict (it should be now)
                if isinstance(version_info, dict) and 'created' in version_info:
                    timestamp_str = version_info.get('created')
                    if isinstance(timestamp_str, str) and timestamp_str:
                        all_timestamps.append(timestamp_str)
                        all_category_classes.append(category_class)

if errors_parsing_versions > 0:
    print(f"Note: Encountered {errors_parsing_versions} errors parsing 'versions' column strings during flattening.")

# --- Create the initial flattened DataFrame ---
if not all_timestamps:
    # This error should not happen now, but keep check just in case
    print("\nERROR: No valid version timestamps combined with category classes found after flattening.")
    exit()

exploded_df = pd.DataFrame({
    'category_class': all_category_classes,
    'timestamp_str': all_timestamps
})
print(f"Flattened DataFrame shape: {exploded_df.shape}") # Should have rows now
del all_timestamps, all_category_classes

# --- The rest of the script (Steps 2-5) remains the same ---
# (Paste the code for steps 2-5 from the cleaned-up version here)

# --- 2. Convert Timestamps and Clean ---
print("Converting timestamps and cleaning...")
exploded_df['timestamp'] = pd.to_datetime(exploded_df['timestamp_str'], errors='coerce')
exploded_df.drop(columns=['timestamp_str'], inplace=True) # Drop the original string column
exploded_df.dropna(subset=['timestamp'], inplace=True) # Remove rows where conversion failed

# Store the processed flattened DataFrame
exploded_df_processed = exploded_df
print(f"Processed flattened DataFrame shape: {exploded_df_processed.shape}")

if exploded_df_processed.empty:
    print("ERROR: DataFrame is empty after timestamp conversion. Cannot proceed.")
    exit()

# --- 3. Aggregate Counts per Time Period ---
time_freq = 'Y' # Or 'Q', 'Y', 'W'
print(f"Aggregating counts by frequency: {time_freq}...")
# Set index temporarily for Grouper
category_counts_over_time = exploded_df_processed.set_index('timestamp').groupby([
    pd.Grouper(freq=time_freq),
    'category_class'
]).size()
del exploded_df_processed # Free up memory

# --- 4. Calculate Percentages ---
print("Calculating percentages...")
total_counts_over_time = category_counts_over_time.groupby(level=0).sum()
percentage_over_time = category_counts_over_time.div(total_counts_over_time, level=0, fill_value=0) * 100
percentage_df = percentage_over_time.reset_index(name='percentage')
del category_counts_over_time, total_counts_over_time # Free up memory

# --- 5. Plot Top N Categories with Plotly ---
print("Generating Plotly figure...")
if not percentage_df.empty:
    top_n = 30 # Number of top categories to display
    # Calculate average percentage per category to find the top ones
    avg_percentage = percentage_df.groupby('category_class')['percentage'].mean()
    top_categories = avg_percentage.nlargest(top_n).index.tolist()
    # Filter the percentage DataFrame for plotting
    plot_df = percentage_df[percentage_df['category_class'].isin(top_categories)]

    if not plot_df.empty:
        print(f"Plotting top {len(top_categories)} categories...")
        fig = px.line(plot_df,
                      x='timestamp',
                      y='percentage',
                      color='category_class',
                      title=f'Top {len(top_categories)} Category Class Distribution Over Time (%)',
                      labels={'timestamp': f'Time Period ({time_freq})', 'percentage': 'Percentage of Records (%)'},
                      markers=False) # Set markers=True to show points

        fig.update_layout(yaxis_tickformat='.1f%') # Format Y axis as percentage
        fig.update_traces(hovertemplate='<b>%{fullData.name}</b><br>Period: %{x}<br>Percentage: %{y:.2f}%<extra></extra>')

        fig.show() # Display the plot
    else:
        print(f"Warning: No data remained after filtering for top {top_n} categories.")
else:
    print("Warning: Percentage data was empty. No plot generated.")

In [None]:
# --- 1. Flatten Data: Extract Timestamps and Category Classes ---
all_timestamps = []
all_category_classes = []
errors_parsing_versions = 0

# Prepare Category Lookup
category_lookup = defaultdict(lambda: None)
for _, row in category_df.iterrows():
    area_val = str(row['Area']) if pd.notna(row['Area']) else None
    subarea_val = str(row['SubArea']) if pd.notna(row['SubArea']) else None
    category_lookup[row['Code']] = {
        'Domain': row['Domain'],
        'Area': area_val,
        'SubArea': subarea_val
    }

print("\nStarting data flattening and extraction...")
# Use itertuples for performance
for row in df.itertuples(index=False, name='Record'):
    category_str = getattr(row, 'categories', None)
    versions_data = getattr(row, 'versions', None)
    versions_list = None  # Reset for each row

    # --- Modified Section to Handle NumPy Array ---
    if isinstance(versions_data, str) and versions_data.startswith('[') and versions_data.endswith(']'):
        # Safely parse 'versions' if it's a string representation of a list
        try:
            versions_list = ast.literal_eval(versions_data)
            if not isinstance(versions_list, list):
                versions_list = None
                errors_parsing_versions += 1
        except (ValueError, SyntaxError, MemoryError):
            versions_list = None
            errors_parsing_versions += 1
    elif isinstance(versions_data, list):
        # It's already a Python list
        versions_list = versions_data
    elif isinstance(versions_data, np.ndarray):  # <--- Added Check for NumPy Array
        # Convert NumPy array of objects (dicts) to a Python list
        versions_list = versions_data.tolist()
    # --- End Modified Section ---

    # Process if versions and categories are valid
    if versions_list and isinstance(category_str, str) and category_str.strip():
        categories = category_str.split()
        for code in categories:
            code = code.strip()
            if not code:
                continue

            # Map category code to hierarchical name (e.g., Area or SubArea)
            cat_info = category_lookup[code]
            if cat_info:
                # Use Area or SubArea as the category class name
                category_class = cat_info.get('Area') or cat_info.get('SubArea') or code
            else:
                # Fallback to the code itself if no mapping is found
                category_class = code

            # Extract timestamps for this category class
            for version_info in versions_list:
                # Check if version_info is a dict (it should be now)
                if isinstance(version_info, dict) and 'created' in version_info:
                    timestamp_str = version_info.get('created')
                    if isinstance(timestamp_str, str) and timestamp_str:
                        all_timestamps.append(timestamp_str)
                        all_category_classes.append(category_class)

if errors_parsing_versions > 0:
    print(f"Note: Encountered {errors_parsing_versions} errors parsing 'versions' column strings during flattening.")

# --- Create the initial flattened DataFrame ---
if not all_timestamps:
    # This error should not happen now, but keep check just in case
    print("\nERROR: No valid version timestamps combined with category classes found after flattening.")
    exit()

exploded_df = pd.DataFrame({
    'category_class': all_category_classes,
    'timestamp_str': all_timestamps
})
print(f"Flattened DataFrame shape: {exploded_df.shape}")  # Should have rows now
del all_timestamps, all_category_classes

# --- 2. Convert Timestamps and Clean ---
print("Converting timestamps and cleaning...")
exploded_df['timestamp'] = pd.to_datetime(exploded_df['timestamp_str'], errors='coerce')
exploded_df.drop(columns=['timestamp_str'], inplace=True)  # Drop the original string column
exploded_df.dropna(subset=['timestamp'], inplace=True)  # Remove rows where conversion failed

# Store the processed flattened DataFrame
exploded_df_processed = exploded_df
print(f"Processed flattened DataFrame shape: {exploded_df_processed.shape}")

if exploded_df_processed.empty:
    print("ERROR: DataFrame is empty after timestamp conversion. Cannot proceed.")
    exit()

# --- 3. Aggregate Counts per Time Period ---
time_freq = 'Y'  # Or 'Q', 'Y', 'W'
print(f"Aggregating counts by frequency: {time_freq}...")
# Set index temporarily for Grouper
category_counts_over_time = exploded_df_processed.set_index('timestamp').groupby([
    pd.Grouper(freq=time_freq),
    'category_class'
]).size()
del exploded_df_processed  # Free up memory

# --- 4. Calculate Percentages ---
print("Calculating percentages...")
total_counts_over_time = category_counts_over_time.groupby(level=0).sum()
percentage_over_time = category_counts_over_time.div(total_counts_over_time, level=0, fill_value=0) * 100
percentage_df = percentage_over_time.reset_index(name='percentage')
del category_counts_over_time, total_counts_over_time  # Free up memory

# --- 5. Plot Top N Categories with Plotly ---
print("Generating Plotly figure...")
if not percentage_df.empty:
    top_n = 30  # Number of top categories to display
    # Calculate average percentage per category to find the top ones
    avg_percentage = percentage_df.groupby('category_class')['percentage'].mean()
    top_categories = avg_percentage.nlargest(top_n).index.tolist()
    # Filter the percentage DataFrame for plotting
    plot_df = percentage_df[percentage_df['category_class'].isin(top_categories)]

    if not plot_df.empty:
        print(f"Plotting top {len(top_categories)} categories...")
        fig = px.line(plot_df,
                      x='timestamp',
                      y='percentage',
                      color='category_class',
                      title=f'Top {len(top_categories)} Category Class Distribution Over Time (%)',
                      labels={'timestamp': f'Time Period ({time_freq})', 'percentage': 'Percentage of Records (%)'},
                      markers=False)  # Set markers=True to show points

        fig.update_layout(yaxis_tickformat='.1f%')  # Format Y axis as percentage
        fig.update_traces(hovertemplate='<b>%{fullData.name}</b><br>Period: %{x}<br>Percentage: %{y:.2f}%<extra></extra>')

        fig.show()  # Display the plot
    else:
        print(f"Warning: No data remained after filtering for top {top_n} categories.")
else:
    print("Warning: Percentage data was empty. No plot generated.")


In [None]:
def plot_category_distribution(df, category_df,
                               level_to_plot,
                               filter_domain,
                               filter_area=None,
                               time_freq='Y',
                               top_n=10,
                               start_year: Optional[int] = None):
    """
    Generates a Plotly time-series line chart showing the percentage distribution
    of specified category levels within a domain/area over time.

    Args:
        df (pd.DataFrame): DataFrame with 'categories' and 'versions' columns.
        category_df (pd.DataFrame): DataFrame with category hierarchy
                                     (columns: Domain, Area, SubArea, Code, Description).
        level_to_plot (str): The hierarchical level to plot lines for ('Area' or 'SubArea').
        filter_domain (str): The top-level Domain to filter data by (e.g., "Physics", "Computer Science").
        filter_area (str, optional): If level_to_plot is 'SubArea', specify the Area within
                                     the Domain to filter by (e.g., "Astrophysics"). Defaults to None.
        time_freq (str, optional): Time frequency for aggregation ('M', 'Q', 'Y', 'W'). Defaults to 'M'.
        top_n (int, optional): Number of top categories at the specified level to display. Defaults to 10.
        start_year (int, optional): The earliest year to include in the plot (inclusive).
                                    Defaults to None (no start year filter).

    Returns:
        plotly.graph_objects.Figure: The generated Plotly figure object, or None if plotting is not possible.
    """

    grouping_col = level_to_plot.lower() # 'area' or 'subarea'

    # --- 1. Prepare Category Lookup ---
    category_lookup = defaultdict(lambda: None)
    for _, row in category_df.iterrows():
        area_val = str(row['Area']) if pd.notna(row['Area']) else None
        subarea_val = str(row['SubArea']) if pd.notna(row['SubArea']) else None
        category_lookup[row['Code']] = {
            'Domain': row['Domain'],
            'Area': area_val,
            'SubArea': subarea_val
        }

    # --- 2. Flatten Data & Filter ---
    all_timestamps = []
    grouping_values = [] # 'Area' or 'SubArea' names/codes
    errors_parsing_versions = 0

    for row in df.itertuples(index=False, name='Record'):
        category_str = getattr(row, 'categories', None)
        versions_data = getattr(row, 'versions', None)
        versions_list = None

        # Safely parse/convert 'versions'
        if isinstance(versions_data, str) and versions_data.startswith('[') and versions_data.endswith(']'):
            try:
                versions_list = ast.literal_eval(versions_data)
                if not isinstance(versions_list, list): versions_list = None; errors_parsing_versions += 1
            except: versions_list = None; errors_parsing_versions += 1
        elif isinstance(versions_data, list): versions_list = versions_data
        elif isinstance(versions_data, np.ndarray): versions_list = versions_data.tolist()

        if versions_list and isinstance(category_str, str) and category_str.strip():
            categories = category_str.split()
            for code in categories:
                code = code.strip()
                if not code: continue
                cat_info = category_lookup[code]

                # Filtering Logic
                if cat_info and cat_info['Domain'] == filter_domain:
                    # Apply Area filter only if specified
                    if filter_area and cat_info['Area'] != filter_area: continue

                    # Get the value for the level we are plotting (Area or SubArea)
                    # Use the code itself as a fallback if the Area/SubArea name is missing in category_df
                    value_to_append = cat_info.get(level_to_plot)
                    if value_to_append is None or pd.isna(value_to_append):
                        value_to_append = code # Fallback to code

                    # Extract timestamps for matching categories
                    for version_info in versions_list:
                        if isinstance(version_info, dict) and 'created' in version_info:
                            timestamp_str = version_info.get('created')
                            if isinstance(timestamp_str, str) and timestamp_str:
                                all_timestamps.append(timestamp_str)
                                grouping_values.append(value_to_append)

    # --- 3. Create Flattened DataFrame & Convert Timestamps ---
    exploded_df = pd.DataFrame({
        grouping_col: grouping_values,
        'timestamp_str': all_timestamps
    })
    del all_timestamps, grouping_values # Free memory

    # Transform timestamps to datetime
    exploded_df['timestamp'] = pd.to_datetime(exploded_df['timestamp_str'], errors='coerce', utc=True)

    exploded_df.drop(columns=['timestamp_str'], inplace=True)
    print(f"Shape before dropping NaT timestamps: {exploded_df.shape}")
    initial_rows = len(exploded_df)
    exploded_df.dropna(subset=['timestamp'], inplace=True) # Remove rows where conversion failed
    dropped_rows = initial_rows - len(exploded_df)
    if dropped_rows > 0:
        print(f"Dropped {dropped_rows} rows due to NaT timestamps.")
    print(f"Shape AFTER dropping NaT timestamps: {exploded_df.shape}") # Check if empty

    exploded_df_processed = exploded_df # Assign before checking if empty

    # --- Apply Start Year Filter (if specified) ---
    if start_year is not None:
        exploded_df_processed = exploded_df_processed[exploded_df_processed['timestamp'].dt.year >= start_year]

    # --- 4. Aggregate & Calculate Percentages ---
    counts_over_time = exploded_df_processed.set_index('timestamp').groupby([
        pd.Grouper(freq=time_freq),
        grouping_col
    ]).size()
    del exploded_df_processed # Free memory

    # Total counts *within the filtered scope* for each time period
    total_counts_over_time = counts_over_time.groupby(level=0).sum()
    # Handle potential division by zero if a period has zero total counts
    percentage_over_time = counts_over_time.div(total_counts_over_time.replace(0, np.nan), level=0).fillna(0) * 100
    percentage_df = percentage_over_time.reset_index(name='percentage')
    del counts_over_time, total_counts_over_time # Free memory

    # --- 5. Plot Top N with Plotly ---
    if not percentage_df.empty:
        # Calculate average percentage to find the top N
        avg_percentage = percentage_df.groupby(grouping_col)['percentage'].mean()
        top_categories = avg_percentage.nlargest(top_n).index.tolist()
        plot_df = percentage_df[percentage_df[grouping_col].isin(top_categories)]

        if not plot_df.empty:
            # Dynamically set title and labels
            plot_title = f'Top {len(top_categories)} {level_to_plot} Distribution within {filter_domain}'
            if filter_area:
                plot_title += f' ({filter_area})'
            # --- Added start_year to title ---
            if start_year is not None:
                 plot_title += f' (from {start_year})'
            plot_title += ' Over Time'


            y_label = f'Percentage within {filter_domain}{f" ({filter_area})" if filter_area else ""} [%]'
            color_label = level_to_plot # Label for the legend

            print(f"Plotting top {len(top_categories)} categories...")
            fig = px.line(plot_df,
                          x='timestamp',
                          y='percentage',
                          color=grouping_col, # Dynamic color based on level
                          title=plot_title,
                          labels={'timestamp': 'Time Period',
                                  'percentage': y_label,
                                  grouping_col: color_label}, # Dynamic legend title
                          markers=False)

            fig.update_layout(yaxis_tickformat='.1f%')
            fig.update_traces(hovertemplate=f'<b>%{{fullData.name}}</b><br>Period: %{{x|%Y-%m-%d}}<br>Percentage: %{{y:.2f}}%<extra></extra>') # Format date in hover
            return fig # Return the figure object
        else:
            return None
    else:
        return None

In [None]:
plot_category_distribution(df, category_df, level_to_plot='Area', filter_domain='Computer Science', time_freq='Y', top_n=5, start_year=2020)

In [None]:
plot_category_distribution(df, category_df, level_to_plot='Area', filter_domain='Quantitative Biology', time_freq='Y', top_n=5, start_year=2020)

In [None]:
plot_category_distribution(df, category_df, level_to_plot='Area', filter_domain='Economics', time_freq='Y', top_n=5, start_year=2020)

In [None]:
plot_category_distribution(df, category_df, level_to_plot='Area', filter_domain='Electrical Engineering and Systems Science', time_freq='Y', top_n=5, start_year=2020)

In [None]:
plot_category_distribution(df, category_df, level_to_plot='Area', filter_domain='Mathematics', time_freq='Y', top_n=10, start_year=2020)

In [None]:
plot_category_distribution(df, category_df, level_to_plot='Area', filter_domain='Physics', time_freq='Y', top_n=5, start_year=2020)

In [None]:
plot_category_distribution(df, category_df, level_to_plot='Area', filter_domain='Statistics', time_freq='Y', top_n=5, start_year=2020)

In [None]:
plot_category_distribution(df, category_df, level_to_plot='SubArea', filter_domain='Physics', filter_area='Astrophysics', time_freq='Y', top_n=5, start_year=2020)

In [None]:
plot_category_distribution(df, category_df, level_to_plot='SubArea', filter_domain='Physics', filter_area='Physics', time_freq='Y', top_n=5, start_year=2020)

In [None]:
def _prepare_filtered_data(
    df: pd.DataFrame,
    category_df: pd.DataFrame,
    start_date_str: str = "2020-01-01",
    end_date_str: str = "2025-12-31"
) -> Optional[pd.DataFrame]:
    """
    Internal helper function to flatten data, convert timestamps, and filter
    by a specific date range.

    Args:
        df (pd.DataFrame): Original DataFrame with 'categories', 'versions'.
        category_df (pd.DataFrame): DataFrame with category hierarchy.
        start_date_str (str): Start date string (YYYY-MM-DD).
        end_date_str (str): End date string (YYYY-MM-DD).

    Returns:
        Optional[pd.DataFrame]: Filtered DataFrame with 'timestamp', 'domain',
                                 'area', 'subarea' columns, or None if processing fails.
    """
    start_time_prep = time.time()
    print(f"--- Starting Data Preparation ({start_date_str} to {end_date_str}) ---")

    start_date = pd.Timestamp(start_date_str, tz='UTC')
    end_date = pd.Timestamp(end_date_str, tz='UTC')

    # --- 1. Prepare Category Lookup ---
    category_lookup = defaultdict(lambda: None)
    for _, row in category_df.iterrows():
        area_val = str(row['Area']) if pd.notna(row['Area']) else None
        subarea_val = str(row['SubArea']) if pd.notna(row['SubArea']) else None
        category_lookup[row['Code']] = {
            'Domain': row['Domain'],
            'Area': area_val,
            'SubArea': subarea_val,
            'Code': row['Code']
        }
    # print("Category lookup created.") # Less verbose for helper

    # --- 2. Flatten Data ---
    all_data = []
    errors_parsing_versions = 0
    # print("Starting data flattening...") # Less verbose

    for row in df.itertuples(index=False, name='Record'):
        category_str = getattr(row, 'categories', None)
        versions_data = getattr(row, 'versions', None)
        versions_list = None

        # Safely parse/convert 'versions'
        if isinstance(versions_data, str) and versions_data.startswith('[') and versions_data.endswith(']'):
            try:
                versions_list = ast.literal_eval(versions_data)
                if not isinstance(versions_list, list): versions_list = None; errors_parsing_versions += 1
            except: versions_list = None; errors_parsing_versions += 1
        elif isinstance(versions_data, list): versions_list = versions_data
        elif isinstance(versions_data, np.ndarray): versions_list = versions_data.tolist()

        if versions_list and isinstance(category_str, str) and category_str.strip():
            categories = category_str.split()
            for code in categories:
                code = code.strip()
                if not code: continue
                cat_info = category_lookup[code]

                if cat_info:
                    domain = cat_info['Domain']
                    area = cat_info['Area'] if cat_info['Area'] else f"Code:{cat_info['Code']}"
                    subarea = cat_info['SubArea'] if cat_info['SubArea'] else f"Code:{cat_info['Code']}"

                    for version_info in versions_list:
                        if isinstance(version_info, dict) and 'created' in version_info:
                            timestamp_str = version_info.get('created')
                            if isinstance(timestamp_str, str) and timestamp_str:
                                all_data.append({
                                    'timestamp_str': timestamp_str,
                                    'domain': domain,
                                    'area': area,
                                    'subarea': subarea
                                })

    if errors_parsing_versions > 0: print(f"Note: Encountered {errors_parsing_versions} errors parsing 'versions'.")
    if not all_data: print(f"ERROR: No category/version data found during flattening."); return None

    # --- 3. Create DataFrame & Convert/Filter Timestamps ---
    flat_df = pd.DataFrame(all_data)
    print(f"Flattened DataFrame shape: {flat_df.shape}")
    # del all_data

    # print("Converting timestamps...")
    flat_df['timestamp'] = pd.to_datetime(flat_df['timestamp_str'], errors='coerce', utc=True)

    failed_timestamps_mask = flat_df['timestamp'].isna()
    num_failed = failed_timestamps_mask.sum()
    if num_failed > 0: print(f"Warning: {num_failed} timestamps failed conversion.")

    flat_df.drop(columns=['timestamp_str'], inplace=True)
    initial_rows = len(flat_df)
    flat_df.dropna(subset=['timestamp'], inplace=True)
    dropped_rows = initial_rows - len(flat_df)
    if dropped_rows > 0: print(f"Dropped {dropped_rows} rows due to NaT timestamps.") 

    if flat_df.empty:
        print("ERROR: DataFrame is empty after timestamp conversion. Cannot proceed.")
        return None

    # --- Apply Date Range Filter ---
    # print(f"Filtering for dates between {start_date.date()} and {end_date.date()}...") 
    initial_rows_before_date_filter = len(flat_df)
    flat_df = flat_df[(flat_df['timestamp'] >= start_date) & (flat_df['timestamp'] <= end_date)]
    rows_after_date_filter = len(flat_df)
    removed_rows = initial_rows_before_date_filter - rows_after_date_filter
    if removed_rows > 0: print(f"Removed {removed_rows} rows outside the date range.") 

    if flat_df.empty:
        print(f"ERROR: DataFrame is empty after applying date filter ({start_date.date()} to {end_date.date()}). Cannot proceed.")
        return None

    print(f"Data preparation complete. Filtered data shape: {flat_df.shape}. Time: {time.time() - start_time_prep:.2f}s")
    return all_data, flat_df


In [None]:
all_data, flat_df = _prepare_filtered_data(df, category_df, "2015-01-01", "2025-12-31")

In [None]:
def plot_top_10_domains(
    flat_df: pd.DataFrame
) -> Optional[plt.Figure]:
    """
    Generates a bar chart showing the overall Top 10 Domains
    based on total counts between Jan 1, 2020, and Dec 31, 2025.

    Args:
        flat_df (pd.DataFrame): DataFrame containing at least a 'domain' column.
                                Assumed to represent data within the target date range.

    Returns:
        Optional[matplotlib.figure.Figure]: The generated Matplotlib figure object,
                                             or None if plotting is not possible (e.g., empty data).
    """

    # Aggregate counts per Domain
    domain_counts = flat_df['domain'].value_counts().reset_index()
    domain_counts.columns = ['domain', 'count']

    # Get the top 10 domains
    top_10_domains = domain_counts.nlargest(10, 'count')

    if top_10_domains.empty:
        print("No domain data found after aggregation. Cannot generate plot.")
        return None

    num_domains = len(top_10_domains)

    # Create the plot using seaborn.barplot
    plt.figure(figsize=(12, 7))  # Adjust figure size for better label readability
    ax = sns.barplot(
        data=top_10_domains,
        x='domain',
        y='count',
        palette="colorblind",
        order=top_10_domains['domain'].tolist()  # Ensure bars are ordered by count descending
    )

    # Customize plot
    ax.set_title(f'Overall Top {num_domains} Domains by Count (2020-2025)', fontsize=16, pad=20)
    ax.set_xlabel("Domain", fontsize=12)
    ax.set_ylabel("Total Count", fontsize=12)
    ax.tick_params(axis='x', rotation=45, labelsize=10)  # Rotate labels for better fit
    ax.tick_params(axis='y', labelsize=10)

    # Add line breaks to x-axis labels
    # Define the maximum width of a label line
    max_label_width = 20  # Adjust this value as needed for desired wrapping
    # Get current labels, wrap them, and set them back
    labels = [item.get_text() for item in ax.get_xticklabels()]
    wrapped_labels = [textwrap.fill(label, width=max_label_width) for label in labels]
    ax.set_xticklabels(wrapped_labels)
    # Adjust rotation mode for potentially multi-line rotated labels
    plt.setp(ax.get_xticklabels(), rotation_mode="anchor", ha="right")

    # Add count labels to bars
    for container in ax.containers:
        ax.bar_label(container, fmt='%d', label_type='edge', fontsize=9, padding=3)

    # Adjust y-axis limits for padding above labels
    ax.margins(y=0.1)

    plt.tight_layout()  # Adjust layout

    # Return the figure object associated with the axes
    fig = ax.get_figure()
    return fig

fig_top10 = plot_top_10_domains(flat_df)

plt.show()

In [None]:
def plot_top_10_areas(
    flat_df: pd.DataFrame
) -> Optional[plt.Figure]:
    """
    Generates a bar chart showing the overall Top 10 Areas
    based on total counts between Jan 1, 2020, and Dec 31, 2025.

    Args:
        flat_df (pd.DataFrame): DataFrame containing at least an 'subarea' column.
                                Assumed to represent data within the target date range.

    Returns:
        Optional[matplotlib.figure.Figure]: The generated Matplotlib figure object,
                                             or None if plotting is not possible (e.g., empty data).
    """


    # Aggregate counts per Area
    area_counts = flat_df['subarea'].value_counts().reset_index()
    area_counts.columns = ['subarea', 'count']

    # Get the top 10 areas
    top_10_areas = area_counts.nlargest(10, 'count')

    if top_10_areas.empty:
        print("No area data found after aggregation. Cannot generate plot.")
        return None

    num_areas = len(top_10_areas)

    # Create the plot using seaborn.barplot [[1]] [[5]]
    plt.figure(figsize=(12, 7)) # Adjust figure size for better label readability
    ax = sns.barplot(
        data=top_10_areas,
        x='subarea',
        y='count',
        palette="colorblind",
        order=top_10_areas['subarea'].tolist() # Ensure bars are ordered by count descending
    )

    # Customize plot
    ax.set_title(f'Overall Top {num_areas} Areas by Count (2020-2025)', fontsize=16, pad=20)
    ax.set_xlabel("Area", fontsize=12)
    ax.set_ylabel("Total Count", fontsize=12)
    ax.tick_params(axis='x', rotation=45, labelsize=10) # Rotate labels for better fit [[6]]
    ax.tick_params(axis='y', labelsize=10)

    # Add line breaks to x-axis labels
    # Define the maximum width of a label line
    max_label_width = 20 # Adjust this value as needed for desired wrapping
    # Get current labels, wrap them, and set them back
    labels = [item.get_text() for item in ax.get_xticklabels()]
    wrapped_labels = [textwrap.fill(label, width=max_label_width) for label in labels]
    ax.set_xticklabels(wrapped_labels)
    # Adjust rotation mode for potentially multi-line rotated labels
    plt.setp(ax.get_xticklabels(), rotation_mode="anchor", ha="right")

    # Add count labels to bars
    for container in ax.containers:
        ax.bar_label(container, fmt='%d', label_type='edge', fontsize=9, padding=3)

    # Adjust y-axis limits for padding above labels
    ax.margins(y=0.1)

    plt.tight_layout() # Adjust layout

    # Return the figure object associated with the axes
    fig = ax.get_figure()
    return fig

fig_top10 = plot_top_10_areas(flat_df)

plt.show()


In [None]:
def plot_top_10_areas_stacked_by_domain(
    flat_df: pd.DataFrame
) -> Optional[plt.Figure]:
    """
    Generates a static stacked bar chart showing the overall Top 10 Areas
    based on total counts between Jan 1, 2020, and Dec 31, 2025,
    with bars stacked by the 'Domain' column.

    Args:
        flat_df (pd.DataFrame): DataFrame containing at least 'area' and 'Domain' columns.
                                Assumed to represent data within the target date range.

    Returns:
        Optional[matplotlib.figure.Figure]: The generated Matplotlib figure object,
                                             or None if plotting is not possible.
    """

    # --- Data Preparation ---

    # 1. Calculate total counts per Area to find the top 10
    total_area_counts = flat_df['subarea'].value_counts()
    top_10_area_names = total_area_counts.nlargest(10).index.tolist()

    if not top_10_area_names:
        print("No area data found after aggregation. Cannot generate plot.")
        return None

    # 2. Filter the original DataFrame to include only data for the top 10 areas
    top_10_df = flat_df[flat_df['subarea'].isin(top_10_area_names)]

    # 3. Aggregate counts grouped by Area and Domain for the filtered data
    # Use size() and unstack() to pivot Domains into columns [[9]]
    stacked_data = top_10_df.groupby(['subarea', 'domain']).size().unstack(fill_value=0)

    # 4. Reorder the stacked_data index based on the total counts (descending)
    # This ensures the bars are plotted in the order of total magnitude
    stacked_data = stacked_data.loc[top_10_area_names]

    num_areas = len(stacked_data)
    print(f"Plotting Top {num_areas} Areas overall, stacked by Domain...")

    # --- Plotting ---

    # Create the plot
    fig, ax = plt.subplots(figsize=(14, 8)) 
    stacked_data.plot(
        kind='bar',
        stacked=True,
        ax=ax,
        colormap='tab20b'
    )

    # --- Customization ---
    ax.set_title(f'Overall Top {num_areas} Areas by Count (2020-2025), Stacked by Domain', fontsize=16, pad=20)
    ax.set_xlabel("Area", fontsize=12)
    ax.set_ylabel("Total Count", fontsize=12)
    ax.tick_params(axis='x', rotation=45, labelsize=10) # Rotate labels
    ax.tick_params(axis='y', labelsize=10)

    # Add line breaks to x-axis labels
    max_label_width = 20
    labels = [item.get_text() for item in ax.get_xticklabels()]
    wrapped_labels = [textwrap.fill(label, width=max_label_width) for label in labels]
    ax.set_xticklabels(wrapped_labels)
    plt.setp(ax.get_xticklabels(), rotation_mode="anchor", ha="right") # Adjust rotation for wrapped labels

    # Add total count labels on top of the stacked bars
    # Calculate total height for each bar
    totals = stacked_data.sum(axis=1)
    for i, total in enumerate(totals):
        ax.text(i, total + (ax.get_ylim()[1] * 0.01), f'{total}', ha='center', va='bottom', fontsize=9) # Adjust position slightly above bar

    # Adjust y-axis limits for padding above labels
    ax.margins(y=0.1) # Increase top margin

    # Add Legend
    ax.legend(title='Domain', bbox_to_anchor=(1.02, 1), loc='upper left') # Place legend outside plot area

    plt.tight_layout(rect=[0, 0, 0.9, 1]) # Adjust layout to make space for legend

    return fig

# --- Execution ---
fig_top10_stacked = plot_top_10_areas_stacked_by_domain(flat_df)

plt.show()

In [None]:
def plot_flop_10_areas_stacked_by_domain(
    flat_df: pd.DataFrame
) -> Optional[plt.Figure]:
    """
    Generates a static stacked bar chart showing the overall Flop 10 Areas
    based on total counts between Jan 1, 2020, and Dec 31, 2025,
    with bars stacked by the 'Domain' column.

    Args:
        flat_df (pd.DataFrame): DataFrame containing at least 'area' and 'Domain' columns.
                                Assumed to represent data within the target date range.

    Returns:
        Optional[matplotlib.figure.Figure]: The generated Matplotlib figure object,
                                             or None if plotting is not possible.
    """

    # --- Data Preparation ---

    # 1. Calculate total counts per Area to find the "flop" 10
    total_area_counts = flat_df['subarea'].value_counts()
    flop_10_area_names = total_area_counts.nsmallest(10).index.tolist()

    flop_10_df = flat_df[flat_df['subarea'].isin(flop_10_area_names)]

    # 3. Aggregate counts grouped by Area and Domain for the filtered data
    # Use size() and unstack() to pivot Domains into columns
    stacked_data = flop_10_df.groupby(['subarea', 'domain']).size().unstack(fill_value=0)

    # 4. Reorder the stacked_data index based on the total counts (descending)
    # This ensures the bars are plotted in the order of total magnitude
    stacked_data = stacked_data.loc[flop_10_area_names]

    num_areas = len(stacked_data)
    print(f"Plotting Top {num_areas} Areas overall, stacked by Domain...")

    # --- Plotting ---

    # Create the plot using pandas plotting
    fig, ax = plt.subplots(figsize=(14, 8)) # Adjust figure size
    stacked_data.plot(
        kind='bar',
        stacked=True,
        ax=ax,
        colormap='tab20b'
    )

    # --- Customization ---
    ax.set_title(f'Overall Least Represented {num_areas} Areas by Count (2020-2025), Stacked by Domain', fontsize=16, pad=20)
    ax.set_xlabel("Area", fontsize=12)
    ax.set_ylabel("Total Count", fontsize=12)
    ax.tick_params(axis='x', rotation=45, labelsize=10) # Rotate labels
    ax.tick_params(axis='y', labelsize=10)

    # Add line breaks to x-axis labels
    max_label_width = 20
    labels = [item.get_text() for item in ax.get_xticklabels()]
    wrapped_labels = [textwrap.fill(label, width=max_label_width) for label in labels]
    ax.set_xticklabels(wrapped_labels)
    plt.setp(ax.get_xticklabels(), rotation_mode="anchor", ha="right") # Adjust rotation for wrapped labels

    # Add total count labels on top of the stacked bars
    # Calculate total height for each bar
    totals = stacked_data.sum(axis=1)
    for i, total in enumerate(totals):
        ax.text(i, total + (ax.get_ylim()[1] * 0.01), f'{total}', ha='center', va='bottom', fontsize=9) # Adjust position slightly above bar

    # Adjust y-axis limits for padding above labels
    ax.margins(y=0.1) # Increase top margin

    # Add Legend
    ax.legend(title='Domain', bbox_to_anchor=(1.02, 1), loc='upper left') # Place legend outside plot area

    plt.tight_layout(rect=[0, 0, 0.9, 1]) # Adjust layout to make space for legend

    return fig

# --- Execution ---
fig_flop10_stacked = plot_flop_10_areas_stacked_by_domain(flat_df)

plt.show()

In [None]:
def plot_top_physics_subareas_from_physics_area(
    flat_df: pd.DataFrame, # Corrected type hint based on usage
) -> Optional[plt.Figure]:
    """
    Generates a static bar chart showing the Top 10 SubAreas within the 'Physics'
    Area based on total counts between Jan 1, 2020, and Dec 31, 2025.
    Adds line breaks to long subarea names on the x-axis.

    Args:
        flat_df (pd.DataFrame): DataFrame containing at least 'area', 'subarea' columns.
                                 Expected to have one row per item to count.

    Returns:
        Optional[matplotlib.figure.Figure]: The generated Matplotlib figure object,
                                             or None if plotting is not possible (e.g., no data).
    """

    # Filter for Physics area
    physics_df = flat_df[flat_df['area'] == 'Physics'].copy()

    if physics_df.empty:
        print("No data found for the 'Physics' area.")
        return None

    # Aggregate counts per SubArea within Physics
    subarea_counts = physics_df.groupby('subarea', observed=False).size().reset_index(name='count')

    # Find the top 10 subareas
    top_10_physics_subareas = subarea_counts.nlargest(10, 'count')

    if top_10_physics_subareas.empty:
        print("No subareas found within the 'Physics' area after filtering.")
        return None

    num_subareas = len(top_10_physics_subareas)

    # Create the plot
    plt.style.use('seaborn-v0_8-whitegrid')
    fig, ax = plt.subplots(figsize=(12, 7)) # Consider adjusting figsize if labels overlap

    sns.barplot(
        data=top_10_physics_subareas,
        x='subarea',
        y='count',
        ax=ax,
        palette='colorblind'
        # order=top_10_physics_subareas['subarea'] # Order is implicitly handled by nlargest
    )

    ax.set_title(f'Top {num_subareas} Physics Areas by Paper Count (2020-2025)', fontsize=16)
    ax.set_xlabel("Physics Area", fontsize=12)
    ax.set_ylabel("Total Count", fontsize=12)
    ax.tick_params(axis='x', rotation=90, labelsize=8) # Keep rotation 90 as requested
    ax.tick_params(axis='y', labelsize=10)

    # Add line breaks to x-axis labels
    # Define the maximum width of a label line
    max_label_width = 20 # Adjust this value as needed for desired wrapping
    # Get current labels, wrap them, and set them back
    labels = [item.get_text() for item in ax.get_xticklabels()]
    wrapped_labels = [textwrap.fill(label, width=max_label_width) for label in labels]
    ax.set_xticklabels(wrapped_labels)
    # Adjust rotation mode for potentially multi-line rotated labels
    plt.setp(ax.get_xticklabels(), rotation_mode="anchor", ha="right")

    for container in ax.containers:
        ax.bar_label(container, fmt='%d', label_type='edge', fontsize=9, padding=3)

    # Adjust y-limit slightly
    ax.margins(y=0.1)

    plt.tight_layout() # Adjust layout AFTER setting new labels

    return fig

fig = plot_top_physics_subareas_from_physics_area(flat_df)

plt.show()

In [None]:
def plot_top_areas_per_domain(
    flat_df = flat_df
) -> Optional[plt.Figure]:
    """
    Generates a static, faceted bar chart showing the Top 3 Areas per Domain
    based on total counts between Jan 1, 2020, and Dec 31, 2025.

    Args:
        df (pd.DataFrame): The original DataFrame with 'categories', 'versions'.
        category_df (pd.DataFrame): DataFrame with category hierarchy.

    Returns:
        Optional[matplotlib.figure.Figure]: The generated Matplotlib figure object,
                                             or None if plotting is not possible.
    """

    # Aggregate counts per Domain and Area
    area_counts = flat_df.groupby(['domain', 'subarea'], observed=False).size().reset_index(name='count')

    # Find the top 3 areas within each domain
    top_areas_list = []
    for domain, group in area_counts.groupby('domain', observed=False):
        top_areas_list.append(group.nlargest(3, 'count'))

    top_areas_per_domain = pd.concat(top_areas_list).reset_index(drop=True)

    num_domains = top_areas_per_domain['domain'].nunique()
    print(f"Plotting Top Areas across {num_domains} Domains...")

    # Determine layout for facets
    cols = 3
    rows = (num_domains + cols - 1) // cols

    # Create the plot
    g = sns.catplot(
        data=top_areas_per_domain,
        x='subarea',
        y='count',
        col='domain',
        kind='bar',
        col_wrap=cols,
        sharex=False,
        sharey=False,
        height=4, 
        aspect=1.2,
       #palette=['#FCBA12','#448D76','#AE0D7A']
        palette="colorblind"
    )

    # Customize plot
    g.fig.suptitle('Top 3 Areas per Domain by Count (2020-2025)', y=1.03, fontsize=16)
    g.set_titles("Domain: {col_name}")
    g.set_axis_labels("Area", "Total Count")

    # Add count labels to bars
    for ax in g.axes.flat:
        ax.tick_params(axis='x', rotation=45, labelsize=8)
        for container in ax.containers:
            ax.bar_label(container, fmt='%d', label_type='edge', fontsize=8, padding=2)
        ax.margins(y=0.1)

    plt.tight_layout(rect=[0, 0, 1, 0.97])

    return g.fig

fig = plot_top_areas_per_domain()

plt.show()


In [None]:
def plot_top_physics_subareas_from_physics_domain(
    flat_df: pd.DataFrame, # Corrected type hint based on usage
) -> Optional[plt.Figure]:
    """
    Generates a static bar chart showing the Top 10 SubAreas within the 'Physics'
    Domain based on total counts between Jan 1, 2020, and Dec 31, 2025.
    Adds line breaks to long subarea names on the x-axis.

    Args:
        flat_df (pd.DataFrame): DataFrame containing at least 'domain', 'subarea' columns.
                                 Expected to have one row per item to count.

    Returns:
        Optional[matplotlib.figure.Figure]: The generated Matplotlib figure object,
                                             or None if plotting is not possible (e.g., no data).
    """

    # Filter for Physics domain
    physics_df = flat_df[flat_df['domain'] == 'Physics'].copy()

    if physics_df.empty:
        print("No data found for the 'Physics' domain.")
        return None

    # Aggregate counts per SubArea within Physics
    subarea_counts = physics_df.groupby('subarea', observed=False).size().reset_index(name='count')

    # Find the top 10 subareas
    top_10_physics_subareas = subarea_counts.nlargest(10, 'count')

    if top_10_physics_subareas.empty:
        print("No subareas found within the 'Physics' domain after filtering.")
        return None

    num_subareas = len(top_10_physics_subareas)

    # Create the plot
    plt.style.use('seaborn-v0_8-whitegrid')
    fig, ax = plt.subplots(figsize=(12, 7)) # Consider adjusting figsize if labels overlap

    sns.barplot(
        data=top_10_physics_subareas,
        x='subarea',
        y='count',
        ax=ax,
        palette='colorblind'
        # order=top_10_physics_subareas['subarea'] # Order is implicitly handled by nlargest
    )

    ax.set_title(f'Top {num_subareas} Physics Areas by Paper Count (2020-2025)', fontsize=16)
    ax.set_xlabel("Physics Area", fontsize=12)
    ax.set_ylabel("Total Count", fontsize=12)
    ax.tick_params(axis='x', rotation=90, labelsize=8) # Keep rotation 90 as requested
    ax.tick_params(axis='y', labelsize=10)

    # Add line breaks to x-axis labels
    # Define the maximum width of a label line
    max_label_width = 20 # Adjust this value as needed for desired wrapping
    # Get current labels, wrap them, and set them back
    labels = [item.get_text() for item in ax.get_xticklabels()]
    wrapped_labels = [textwrap.fill(label, width=max_label_width) for label in labels]
    ax.set_xticklabels(wrapped_labels)
    # Adjust rotation mode for potentially multi-line rotated labels
    plt.setp(ax.get_xticklabels(), rotation_mode="anchor", ha="right")

    for container in ax.containers:
        ax.bar_label(container, fmt='%d', label_type='edge', fontsize=9, padding=3)

    # Adjust y-limit slightly
    ax.margins(y=0.1)

    plt.tight_layout() # Adjust layout AFTER setting new labels

    return fig

fig = plot_top_physics_subareas_from_physics_domain(flat_df)

plt.show()