In [None]:
import sys
import os
import pandas as pd
import glob

# Visualization packages
import altair as alt

# Append system path
sys.path = [p for p in sys.path if not p.endswith('../..')]  # Cleans duplicated '../..'
sys.path.insert(0, '../')  # This adds `src` to the path

%load_ext autoreload
%autoreload 2

# Append system path
sys.path = [p for p in sys.path if not p.endswith('../..')]  # Cleans duplicated '../..'
sys.path.insert(0, '../')  # This adds `src` to the path

In [None]:
def load_data(pattern):
    """ Load data from CSV files matching the pattern. """
    file_paths = glob.glob(pattern)
    data = []
    for file in file_paths:
        df = pd.read_csv(file, index_col=0, parse_dates=True)
        # Extract dataset and subset from filename
        parts = os.path.basename(file).split('_')
        df['Dataset'], df['Domain Type'] = parts[0], parts[1]
        df['Dataset'] = df['Dataset'].map({'c4': "C4", 'rf': 'RefinedWeb', 'dolma': 'Dolma'})
        data.append(df)
    
    concat_df = pd.concat(data).reset_index().rename(columns={'index': 'Date'})
    return concat_df

def prepare_data_for_plot1(data):
    """ Prepare data for the first plot. """
    data = data[data['Domain Type'] == 'all']
    # data[['head_frac', 'combined_tokens']] *= 100  # Multiply values by 100
    return data.reset_index().melt(id_vars=['Date', 'Dataset'], value_vars=['Head', 'Full Corpus'],
                                var_name='Token Sample', value_name='Percent')

def prepare_data_for_plot2(data):
    # data['subset'] = data['subset'].map(mapping).fillna(data['subset'])
    relevant_subsets = {
        "Academic": "Acad", 
        "News": "News", 
        'E': "E-comm", 
        'Encyclopedia': "Encyc", 
        'Government': "Gov", 
        "Organization": "Org/Pers", 
        'Social Media': "Socials/Forum",
    }
    data_numeric = data.drop(columns=['Dataset'])  # Drop the 'dataset' column
    # data_numeric[['head_frac', 'combined_tokens']] *= 100  # Multiply values by 100
    grouped = data_numeric.groupby(['Date', 'Domain Type']).mean().reset_index()
    grouped = grouped[grouped["Domain Type"].isin(relevant_subsets.keys())]
    grouped["Domain Type"] = grouped['Domain Type'].map(relevant_subsets).fillna(grouped['Domain Type'])
    return grouped.melt(id_vars=['Date', 'Domain Type'], value_vars=['Head', 'Full Corpus'],
                        var_name='Token Sample', value_name='Percent')

def forecast_region(data, forecast_startdate, height, period_col="Date"):
    forecast_startdate = pd.to_datetime(forecast_startdate)
    shading = alt.Chart(
        pd.DataFrame({"start": [forecast_startdate], "end": [data[period_col].max()]})
    ).mark_rect(
        opacity=0.1,
        color="gray"
    ).encode(
        x=alt.X("start:T", title=""),
        x2="end:T"
    )

    forecast_rule = alt.Chart(
        pd.DataFrame({"period": [forecast_startdate]})
    ).mark_rule(
        color="gray"
    ).encode(
        x="period:T"
    )

    # Add a label in the middle of the forecasted region
    shading_text = alt.Chart(
        pd.DataFrame({"date": [forecast_startdate + (data[period_col].max() - forecast_startdate) / 2], "text": ["Forecast"]})
    ).mark_text(
        align="center",
        baseline="middle",
        dx=0,
        dy=height - 20,
        color="black",
        fontWeight="bold"
    ).encode(
        x="date:T",
        y=alt.value(0),
        text="text:N"
    )

    return shading + forecast_rule + shading_text

    
def temporal_corpus_estimation_plot(
    data, title, x_title, y_title, font_style, font_size,
    y_max=0.5,
    forecast_startdate="2022",
    show_legend=True,
    height=400,
    width=800,
):
    """ Create and return an Altair plot. """

    colorLegend = alt.Legend(orient='none', title='Dataset',
        labelFont=font_style, labelFontSize=font_size,
        titleFont=font_style, titleFontSize=font_size, direction='horizontal',
        legendX=0, legendY=230) if show_legend else None
    strokeDashLegend = alt.Legend(orient='none', title='Token Sample',
        labelFont=font_style, labelFontSize=font_size,
        titleFont=font_style, titleFontSize=font_size, direction='horizontal',
        legendX=300, legendY=230) if show_legend else None
    
    chart = alt.Chart(data).mark_line().encode(
        x=alt.X('Date:T', title=x_title, axis=alt.Axis(format='%Y', tickCount={"interval": "year", "step": 1})),  # Yearly labels, data by month
        y=alt.Y('Percent:Q', title=y_title, scale=alt.Scale(domain=[0, y_max]), axis=alt.Axis(format="%", orient='right')),
        # color='Dataset:N',
        # strokeDash='Token Sample:N',
        color=alt.Color('Dataset:N', legend=colorLegend),  # Position for color legend
        strokeDash=alt.StrokeDash('Token Sample:N', legend=strokeDashLegend))  # Position for strokeDash legend
    
    ################################################################
    # SHADE FORECASTED DATA REGIONS
    # Add a shaded region for forecasted data, if needed
    ################################################################
    if forecast_startdate:
        chart = chart + forecast_region(data, forecast_startdate, height)

    chart = chart.properties(
        width=width,
        height=height
    ).configure_axis(
        labelFontSize=font_size,
        titleFontSize=font_size,
        labelFont=font_style,
        titleFont=font_style,
        grid=False  # Remove gridlines
    ).configure_legend(
        labelFont=font_style,
        labelFontSize=font_size,
        titleFont=font_style,
        titleFontSize=font_size
    )
    return chart
        

def temporal_corpus_estimation_by_service_plot(
    data, title, x_title, y_title, font_style, font_size,
    y_max=0.7,
    forecast_startdate=None,
    show_legend=True,
    height=400,
    width=800,
):
    """ Ensure 'Date' is a datetime and data is sorted. """
    data['Date'] = pd.to_datetime(data['Date'])  # Parse 'Date' as datetime if not already
    data = data.sort_values('Date')  # Sort data by 'Date'

    # Optional settings for legends if they are to be displayed
    legend_color = alt.Legend(orient='none', title='Domain Type',
                            labelFont=font_style, labelFontSize=font_size,
                            titleFont=font_style, titleFontSize=font_size, direction='horizontal', # , columns=7
                            legendX=0, legendY=230) if show_legend else None
    
    legend_stroke_dash = None 
    # alt.Legend(orient='none', title='Token Sample',
    #                                 labelFont=font_style, labelFontSize=font_size,
    #                                 titleFont=font_style, titleFontSize=font_size,
    #                                 direction='horizontal', legendX=150, legendY=10) if show_legend else None
    
    # Chart code with conditional legends
    chart = alt.Chart(data).mark_line().encode(
        x=alt.X('Date:T', title=x_title, axis=alt.Axis(format='%Y', tickCount={"interval": "year", "step": 1})),
        y=alt.Y('Percent:Q', title=y_title, scale=alt.Scale(domain=[0, y_max]), axis=alt.Axis(format="%", orient='right')),
        color=alt.Color('Domain Type:N', legend=legend_color),
        strokeDash=alt.StrokeDash('Token Sample:N', legend=legend_stroke_dash)
    )

    ################################################################
    # SHADE FORECASTED DATA REGIONS
    # Add a shaded region for forecasted data, if needed
    ################################################################
    if forecast_startdate:
        chart = chart + forecast_region(data, forecast_startdate, height)

    chart = chart.properties(
        width=width,
        height=height
    ).configure_axis(
        labelFontSize=font_size,
        titleFontSize=font_size,
        labelFont=font_style,
        titleFont=font_style,
        grid=False  # Remove gridlines
    )

    return chart


In [None]:
# Load data
robots_df = load_data('data/domain_estimates/robots/*')

# Prepare data for plots
robots_df_full_plot = prepare_data_for_plot1(robots_df)
robots_df_service_plot = prepare_data_for_plot2(robots_df)
# print(data_plot2["subset"].unique())

# Create plots
robots_corpus_plot = temporal_corpus_estimation_plot(
    robots_df_full_plot, 'Robots: Head Tokens vs Combined Tokens for each Dataset', '', '', 'Times', 13.5,
    y_max=0.5,
    forecast_startdate=None,
    show_legend=False,
    height=200, width=450,
)
robots_services_plot = temporal_corpus_estimation_by_service_plot(
    robots_df_service_plot, 'Robots: Average Percent of Tokens by Subset', '', '', 'Times', 13.5,
    y_max=0.5,
    forecast_startdate=None,
    show_legend=False,
    height=200, width=450,
)

# Display the plots
robots_corpus_plot.display()
robots_services_plot.display()

# Head, Combined
# Robots vs ToS

In [None]:
# Load data
tos_df = load_data('data/domain_estimates/tos/*')

# Prepare data for plots
tos_df_full_plot = prepare_data_for_plot1(tos_df)
tos_df_service_plot = prepare_data_for_plot2(tos_df)
# print(data_plot2["subset"].unique())

# Create plots
tos_corpus_plot = temporal_corpus_estimation_plot(
    tos_df_full_plot, 'Robots: Head Tokens vs Combined Tokens for each Dataset', '', '', 'Times', 13.5,
    y_max=0.7,
    forecast_startdate=None,
    show_legend=True,
    height=200, width=450,
)
tos_services_plot = temporal_corpus_estimation_by_service_plot(
    tos_df_service_plot, 'Robots: Average Percent of Tokens by Subset', '', '', 'Times', 13.5,
    y_max=0.7,
    forecast_startdate=None,
    show_legend=True,
    height=200, width=450,
)
# Display the plots
tos_corpus_plot.display()
tos_services_plot.display()

In [None]:
# head == rand for every single one.
# E-commerce mean shouldn't be that high.

In [None]:
# alt.hconcat(robots_corpus_plot, robots_services_plot).configure_axis(
#     grid=False
# ).configure_view(
#     strokeWidth=0 # Remove the frame around the chart
# ).resolve_legend(
#     color='independent'
# )


In [None]:
robots_df_full_plot[robots_df_full_plot["Date"] == "2024-04-01"]

In [None]:
robots_df_full_plot[robots_df_full_plot["Date"] == "2023-01-01"]

In [None]:
robots_df_service_plot[robots_df_service_plot["Date"] == "2024-04-01"]


In [None]:
robots_df_service_plot[robots_df_service_plot["Date"] == "2023-04-01"]


In [None]:
tos_df_full_plot[tos_df_full_plot["Date"] == "2024-04-01"]

In [None]:
tos_df_full_plot[tos_df_full_plot["Date"] == "2023-04-01"]

In [None]:
print(f"C4 Full Corpus: {43 / 28}")
print(f"Dolma Full Corpus: {52 / 41}")
print(f"RW Full Corpus: {53 / 42}")

In [None]:
tos_df_service_plot[tos_df_service_plot["Date"] == "2024-04-01"]


In [None]:
tos_df_service_plot[tos_df_service_plot["Date"] == "2023-01-01"]