In [None]:
# Analysis the basic statistics of reply network
import pandas as pd
import numpy as np
import json
import networkx as nx
import networkx as nx
from datetime import datetime
from collections import defaultdict
import matplotlib.pyplot as plt
from tqdm.auto import tqdm


def generate_network(nodes):
    """
    Generate a network graph from a list of nodes with 'id', 'in_reply_to_id', and metadata.

    :param nodes: List of dictionaries containing:
                  - 'id': Unique identifier for the post
                  - 'in_reply_to_id': ID of the parent post this replies to
                  - Metadata such as username, like counts, etc.
    :return: A NetworkX DiGraph with nodes and edges.
    """
    G = nx.DiGraph()  # Directed graph to represent reply relationships

    for node in tqdm(nodes):
        # Add the current node with metadata
        G.add_node(
            int(node["id"]),
            # username=node['account']['username'],
            # display_name=node['account']['display_name'],
            # verified=node['account']['verified'],
            # followers_count=node['account']['followers_count'],
            # statuses_count=node['account']['statuses_count'],
            # replies_count=node['replies_count'],
            # favourites_count=node['favourites_count'],
            # reblogs_count=node['reblogs_count'],
            # visibility=node['visibility'],
            # content=node['content'],
            # created_at=node['created_at'],
            # language=node['language'],
            # sensitive=node['sensitive'],
            # mentions=[mention['username'] for mention in node.get('mentions', [])],
        )

        # If it replies to another node, add an edge
        if node["in_reply_to_id"]:
            G.add_node(
                int(node["in_reply_to_id"])
            )  # Add the parent node if it doesn't exist
            G.add_edge(int(node["in_reply_to_id"]), int(node["id"]))

    return G


def build_network(graph, thread, parent_id=None):
    post = thread["post"]
    post_id = post["uri"]  # Use the URI as a unique identifier
    graph.add_node(
        post_id, text=post["record"]["text"], author=post["author"]["displayName"]
    )

    if parent_id:  # If there is a parent, add an edge
        graph.add_edge(parent_id, post_id)

    # Process replies recursively
    for reply in thread.get("replies", []):
        build_network(graph, reply, post_id)


with open("../data/bsky_threads.json") as f:
    bsky = json.load(f)
with open("../data/ts_threads.json") as f:
    ts = json.load(f)
bsky_network = nx.DiGraph()
error_count = 0
# Build the graph
for thread in tqdm(bsky):
    try:
        build_network(bsky_network, thread["thread"])
    except:
        error_count += 1


# Generate the network graph
ts_network = generate_network(ts)

In [8]:
from statistics import mean


def calculate_cascade_statistics(graph):
    """
    Calculate cascade statistics for each tree (connected component) in a directed graph.


    Parameters:
        graph (nx.DiGraph): A directed graph representing the network.

    Returns:
        list: A list of dictionaries with cascade statistics for each tree.
    """
    if not isinstance(graph, nx.DiGraph):
        raise ValueError("The input graph must be a directed graph (DiGraph).")

    # Find all connected components (trees) in the directed graph
    components = list(nx.weakly_connected_components(graph))

    cascade_stats = []
    for i, component in enumerate(components):
        # Extract the subgraph for this component
        tree = graph.subgraph(component)

        # Calculate size
        size = len(tree.nodes)

        # Calculate depth
        roots = [node for node in tree.nodes if tree.in_degree(node) == 0]
        if len(roots) > 1:
            raise ValueError("Multiple roots found in the tree.")
        depth = 0
        if roots:
            for root in roots:
                depths = nx.single_source_shortest_path_length(tree, root).values()
                depth = max(depth, *depths)

        # Calculate maximum breadth
        breadth_levels = defaultdict(int)
        for node in tree.nodes:
            try:
                level = nx.shortest_path_length(tree, roots[0], node)
                breadth_levels[level] += 1
            except nx.NetworkXNoPath:
                continue  # Node is unreachable

        max_breadth = max(breadth_levels.values()) if breadth_levels else 0

        # Calculate structural virality
        if len(tree.nodes) > 1:
            shortest_paths = nx.shortest_path_length(tree)
            virality = mean(
                [mean(lengths.values()) for _, lengths in shortest_paths if lengths]
            )
        else:
            virality = 0  # Single node has no structural virality

        cascade_stats.append(
            {
                "root_id": roots[0],
                "size": size,
                "depth": depth,
                "max_breadth": max_breadth,
                "structural_virality": virality,
            }
        )

    return cascade_stats


In [9]:
bsky_stats = calculate_cascade_statistics(bsky_network)
ts_stats = calculate_cascade_statistics(ts_network)

In [None]:
import matplotlib.pyplot as plt
import numpy as np


def plot_ccdfs(cascade_stats, stat_names):
    """
    Plot grouped CCDFs for multiple statistics in a single figure with a logarithmic y-axis.

    Parameters:
        cascade_stats (list): List of dictionaries containing cascade statistics.
        stat_names (list): List of keys for the statistics to plot (e.g., ['size', 'depth']).
    """
    # Create subplots
    n_stats = len(stat_names)
    fig, axes = plt.subplots(1, n_stats, figsize=(5 * n_stats, 5), sharey=True)

    for i, stat_name in enumerate(stat_names):
        # Extract the values for the specified statistic
        values = [stat[stat_name] for stat in cascade_stats]

        # Sort the values
        values = np.array(sorted(values))

        # Compute CCDF
        ccdf = 1 - np.arange(1, len(values) + 1) / len(values)

        # Plot CCDF
        ax = axes[i] if n_stats > 1 else axes
        ax.step(values, ccdf, where="post")
        ax.set_title(f"CCDF of {stat_name.capitalize()}", fontsize=14)
        ax.set_xlabel(stat_name.capitalize(), fontsize=12)
        ax.grid(True, linestyle="--", alpha=0.7)

        # Set logarithmic y-axis
        ax.set_yscale("log")
        ax.set_yticks(
            [
                1,
                0.1,
                0.01,
                0.001,
                0.0001,
            ]
        )
        ax.set_yticklabels(["100%", "10%", "1%", "0.1%", "0.01%"])
        ax.get_yaxis().set_major_formatter(plt.ScalarFormatter())

    # Set shared y-axis label
    fig.supylabel("CCDF (Log Scale)", fontsize=12)

    plt.tight_layout()
    plt.show()


plot_ccdfs(bsky_stats, ["size", "depth", "max_breadth", "structural_virality"])

In [None]:
plot_ccdfs(ts_stats, ["size", "depth", "max_breadth", "structural_virality"])

In [34]:
# topics
bsky_topics = pd.read_csv("../data/bsky_df_id_topic.csv")
ts_topics = pd.read_csv("../data/ts_df_id_topic.csv")

In [None]:
bsky_topics

In [41]:
bsky_stats = (
    pd.DataFrame(bsky_stats)
    .merge(bsky_topics, left_on="root_id", right_on="id", how="left")
    .drop(columns="id")
)
ts_stats = (
    pd.DataFrame(ts_stats)
    .merge(ts_topics, left_on="root_id", right_on="id", how="left")
    .drop(columns="id")
)

In [None]:
ts_stats

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np


# Example function to calculate cascade size (can be extended for other stats)
def calculate_cascade_stats(df, group_col, stat_col):
    """
    Calculate cascade statistics (e.g., size) by grouping the dataframe.

    Parameters:
        df (pd.DataFrame): DataFrame with data.
        group_col (str): Column to group by (e.g., 'root_id').
        stat_col (str): Column to calculate the statistics on (e.g., 'id').

    Returns:
        pd.DataFrame: Cascade statistics with size and associated topic labels.
    """
    # Group by cascade root_id and calculate size
    grouped = (
        df.groupby(group_col)
        .agg(
            {
                stat_col: "count",  # Count posts in each cascade
                "topic_label": "first",  # Keep the topic label
            }
        )
        .reset_index()
    )
    grouped.rename(columns={stat_col: "size"}, inplace=True)
    return grouped


# Plot CCDFs grouped by topic
def plot_grouped_ccdfs(cascade_stats, stat_col, group_col):
    """
    Plot grouped CCDFs for a given statistic, grouped by topic labels.

    Parameters:
        cascade_stats (pd.DataFrame): DataFrame with cascade statistics.
        stat_col (str): The statistic column to compute the CCDF for (e.g., 'size').
        group_col (str): The column representing the group (e.g., 'topic_label').
    """
    unique_topics = cascade_stats[group_col].unique()
    plt.figure(figsize=(10, 6))

    for topic in unique_topics:
        # Filter data by topic
        topic_data = cascade_stats[cascade_stats[group_col] == topic]
        values = topic_data[stat_col].values

        # Compute CCDF
        values = np.sort(values)
        ccdf = 1 - np.arange(1, len(values) + 1) / len(values)

        # Plot CCDF
        plt.step(values, ccdf, where="post", label=topic)

    plt.xscale("log")
    plt.yscale("log")
    plt.xlabel("Cascade Size (Log Scale)", fontsize=12)
    plt.ylabel("CCDF (Log Scale)", fontsize=12)
    plt.title("CCDF by Topic", fontsize=14)
    plt.grid(True, linestyle="--", alpha=0.7)
    plt.legend(title="Topic Label", fontsize=10)
    plt.tight_layout()
    plt.show()


# Example Usage
if __name__ == "__main__":
    # Load your DataFrame (example)
    data = bsky_stats

    # Calculate cascade statistics
    cascade_stats = calculate_cascade_stats(
        data, group_col="topic_label", stat_col="size"
    )

    # Plot CCDFs grouped by topic
    plot_grouped_ccdfs(cascade_stats, stat_col="size", group_col="topic_label")


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd


def plot_stats_by_topic(cascade_stats, stats_to_plot, topic_col):
    """
    Plot statistics grouped by topics using different colors.

    Parameters:
        cascade_stats (pd.DataFrame): DataFrame with cascade statistics and topics.
        stats_to_plot (list): List of statistic columns to plot (e.g., ['size', 'depth']).
        topic_col (str): Column representing topics (e.g., 'topic_label').
    """
    unique_topics = cascade_stats[topic_col].unique()
    colors = plt.cm.tab10(np.linspace(0, 1, len(unique_topics)))

    plt.figure(figsize=(12, 8))

    for stat in stats_to_plot:
        plt.figure(figsize=(8, 6))
        for topic, color in zip(unique_topics, colors):
            topic_data = cascade_stats[cascade_stats[topic_col] == topic]
            values = topic_data[stat].values

            # Compute CCDF
            values = np.sort(values)
            ccdf = 1 - np.arange(1, len(values) + 1) / len(values)

            # Plot CCDF for the topic
            plt.step(values, ccdf, where="post", label=f"{topic}", color=color)

        plt.xscale("log")
        plt.yscale("log")
        plt.xlabel(f"{stat.capitalize()} (Log Scale)", fontsize=12)
        plt.ylabel("CCDF (Log Scale)", fontsize=12)
        plt.title(f"CCDF of {stat.capitalize()} by Topic", fontsize=14)
        plt.grid(True, linestyle="--", alpha=0.7)
        plt.legend(title="Topic", fontsize=10, bbox_to_anchor=(1.05, 1))
        plt.tight_layout()
        plt.show()


In [None]:
topic_stats

# Start building the network for the overall analysis

In [54]:
# import repost and following data
import json
import pandas as pd
import numpy as np
from datetime import datetime
from cascade_analysis import InformationCascadeGraph


In [55]:
with open("../data/bsky_reposts.json") as f:
    bsky_repost = json.load(f)

with open("../data/bsky_follows.json") as f:
    bsky_follow = json.load(f)

In [16]:
with open("../data/ts_threads.json") as f:
    ts_repost = json.load(f)

In [67]:
from importlib import reload
import cascade_analysis

reload(cascade_analysis)


<module 'cascade_analysis' from '/home/maolee/projects/information-diffusion/src/cascade_analysis.py'>

In [68]:
from collections import defaultdict
from itertools import chain

original_list = bsky_follow

# Use a defaultdict to store sets of DIDs.
merged = defaultdict(set)

# chain.from_iterable(...) flattens out the "dict.items()" across the list
for key, records in chain.from_iterable(item.items() for item in original_list):
    # 'records' is the list of dicts. We update the set with the "did" values.
    merged[key].update(r["did"] for r in records)

# Convert to a regular dict if desired:
merged_dict = dict(merged)


In [69]:
cascade_graph = cascade_analysis.InformationCascadeGraph(
    bsky_repost, merged_dict, platform="bsky"
)

In [70]:
reposts_graph = cascade_graph.build_repost_graph()

Building Repost Graph:   0%|          | 0/195616 [00:00<?, ?it/s]

In [71]:
reposts_graph.number_of_nodes()

379582

In [72]:
reply_graph = cascade_graph.build_reply_graph()


Building Reply Graph:   0%|          | 0/195616 [00:00<?, ?it/s]

In [64]:
reply_graph.number_of_edges()

116224

In [10]:
combined_graph = cascade_graph.build_combined_graph()

Merging:   0%|          | 0/116224 [00:00<?, ?it/s]

The node is not in repost graph
The node is not in repost graph
The node is not in repost graph
The node is not in repost graph
The node is not in repost graph
Step 2.1: Merged 23138 reply edges into repost edges out of 59813 total reply edges


Merging:   0%|          | 0/183966 [00:00<?, ?it/s]

In [73]:
combined_graph = cascade_graph.build_combined_graph()

Merging:   0%|          | 0/116224 [00:00<?, ?it/s]

Step 2.1: Merged 23138 reply edges into repost edges out of 59813 total reply edges
Step 2.1: 5 nodes not in repost graph


Merging:   0%|          | 0/183966 [00:00<?, ?it/s]

In [11]:
combined_graph.number_of_nodes()

379587

In [74]:
stats = cascade_graph.calculate_statistics()

Calculating Tree Statistics:   0%|          | 0/79397 [00:00<?, ?it/s]

Calculating Tree Statistics:   0%|          | 0/195616 [00:00<?, ?it/s]

Calculating Tree Statistics:   0%|          | 0/79397 [00:00<?, ?it/s]

In [82]:
import pandas as pd
# build the dataframe, column and row switch

combined_stats_df = pd.DataFrame(stats["combined_graph"]).T
repost_stats_df = pd.DataFrame(stats["repost_graph"]).T
reply_stats_df = pd.DataFrame(stats["reply_graph"]).T

In [76]:
# import the topic data
bsky_topics = pd.read_csv("../data/bsky_df_id_topic.csv")


In [83]:
# find root id for each repost id
import networkx as nx


def find_root(G, child):
    parent = list(G.predecessors(child))
    if len(parent) == 0:
        return child
    else:
        return find_root(G, parent[0])


for repost_id in repost_stats_df.index:
    if combined_graph.in_degree(repost_id) == 0:
        repost_stats_df.loc[repost_id, "root_id"] = repost_id
    else:
        repost_stats_df.loc[repost_id, "root_id"] = find_root(combined_graph, repost_id)

In [77]:
repost_stats.reset_index(inplace=True)

In [85]:
repost_original = repost_stats_df.merge(
    bsky_topics, left_on="root_id", right_on="id", how="left"
).drop(columns="id")

In [87]:
repost_original.to_csv("../data/bsky_repost_stat.csv", index=False)

In [88]:
reply_stats_df.reset_index(inplace=True)
reply_original = reply_stats_df.merge(
    bsky_topics, left_on="index", right_on="id", how="left"
).drop(columns="id")


In [90]:
reply_original.to_csv("../data/bsky_reply_stats.csv", index=False)

In [91]:
combined_stats_df.reset_index(inplace=True)
combined_original = combined_stats_df.merge(
    bsky_topics, left_on="index", right_on="id", how="left"
).drop(columns="id")

combined_original.to_csv("../data/bsky_combined_stats.csv", index=False)

# Truthsocial

In [2]:
# Load data
import json

with open("../data/ts_threads_withReblogs.json") as f:
    ts_repost = json.load(f)

with open("../data/ts_user_following_map.json") as f:
    ts_follow = json.load(f)

In [3]:
import importlib
import cascade_analysis

importlib.reload(cascade_analysis)

cascade_graph = cascade_analysis.InformationCascadeGraph(ts_repost, ts_follow)


In [4]:
reply_graph = cascade_graph.build_reply_graph()

Building Reply Graph:   0%|          | 0/1369696 [00:00<?, ?it/s]

In [5]:
repost_graph = cascade_graph.build_repost_graph()
print(repost_graph.number_of_nodes())

Building Repost Graph:   0%|          | 0/1369696 [00:00<?, ?it/s]

5217599


In [6]:
combined_graph = cascade_graph.build_combined_graph()

Merging:   0%|          | 0/1325878 [00:00<?, ?it/s]

Step 2.1: Merged 191575 reply edges into repost edges out of 747571 total reply edges
Step 2.1: 0 nodes not in repost graph


Merging:   0%|          | 0/3847903 [00:00<?, ?it/s]

In [16]:
%env NX_CUGRAPH_AUTOCONFIG=True
import networkx as nx
from tqdm.auto import tqdm
from collections import defaultdict


def calculate_tree_statistics(graph):
    if not nx.is_directed_acyclic_graph(graph):
        raise ValueError("Graph must be a directed acyclic graph (DAG).")

    root_nodes = [n for n, d in graph.in_degree() if d == 0]
    tree_statistics = {}

    for root in tqdm(root_nodes, desc="Calculating Tree Statistics"):
        tree_nodes = nx.descendants(graph, root) | {root}
        tree = graph.subgraph(tree_nodes)

        depths = nx.single_source_shortest_path_length(tree, root)
        max_depth = max(depths.values())

        size = tree.number_of_nodes()

        breadth = defaultdict(int)
        for depth in depths.values():
            breadth[depth] += 1
        # Calculate max breadth
        max_breadth = max(breadth.values())

        total_distance = 0
        pair_count = 0
        for node in tree.nodes:
            distances = nx.single_source_shortest_path_length(tree, node)
            total_distance += sum(distances.values())
            pair_count += len(distances) - 1

        structural_virality = total_distance / pair_count if pair_count > 0 else 0

        reach = len(tree.nodes)

        tree_statistics[root] = {
            "max_depth": max_depth,
            "size": size,
            "breadth": max_breadth,
            "structural_virality": structural_virality,
            "reach": reach,
        }

    return tree_statistics


from collections import defaultdict

import networkx as nx
from tqdm.auto import tqdm
import cudf
import cugraph
import networkx as nx
import pandas as pd
from tqdm.auto import tqdm


def calculate_tree_statistics_cugraph(nx_graph):
    """
    Given a NetworkX DiGraph (assumed to be a tree or forest),
    convert it to a cuGraph DiGraph and compute per-root tree statistics.

    Statistics computed for each tree (root):
    - max_depth: maximum distance from the root to any node
    - size: number of nodes in the tree
    - breadth: maximum number of nodes at any distance from the root
    - structural_virality: average shortest-path distance among all node pairs in the tree
    - reach: same as size
    """
    # --- Step 1: Convert the NetworkX graph to a cuGraph graph ---
    # Create an edge list from the NetworkX graph. cuGraph requires a DataFrame with
    # source and destination columns.
    df_edges = nx.to_pandas_edgelist(nx_graph)
    # Make sure the edge list uses the expected column names: 'source' and 'target'
    if "source" not in df_edges.columns or "target" not in df_edges.columns:
        raise ValueError("The edge list must have 'source' and 'target' columns.")

    # Convert the Pandas DataFrame to a cuDF DataFrame.
    cudf_edges = cudf.DataFrame.from_pandas(df_edges)

    # Create a cuGraph DiGraph and load the edge list.
    G_cu = cugraph.Graph(directed=True)
    G_cu.from_cudf_edgelist(cudf_edges, source="source", destination="target")

    # --- Step 2: Identify Root Nodes ---
    # In a tree, root nodes have zero in-degree.
    # We can compute in-degrees by grouping on the 'target' column.
    in_degree_df = (
        cudf_edges.groupby("target")
        .agg({"target": "count"})
        .rename(columns={"target": "in_degree"})
    )
    # Get all unique vertices from both source and target columns.
    all_vertices = pd.concat([df_edges["source"], df_edges["target"]]).unique()
    # Identify roots: vertices that never appear as a target.
    in_degree_set = set(in_degree_df["target"].to_pandas())
    roots = [v for v in all_vertices if v not in in_degree_set]

    tree_statistics = {}
    # --- Step 3: For Each Root, Run BFS and Compute Statistics ---
    for root in tqdm(roots, desc="Calculating Tree Statistics (cuGraph)"):
        # Run BFS from the root; cuGraph returns a cuDF DataFrame with columns:
        # 'vertex', 'distance', and 'predecessor'
        bfs_result = cugraph.bfs(G_cu, root)
        # Convert to Pandas DataFrame for easier (CPU-side) aggregation;
        # if your trees are very large you might want to keep computations on the GPU.
        bfs_pdf = bfs_result.to_pandas()

        # max_depth: the maximum distance encountered
        max_depth = int(bfs_pdf["distance"].max())
        # size (and reach): total number of nodes reached by BFS
        size = len(bfs_pdf)
        # breadth: maximum number of nodes found at the same distance from the root
        breadth_series = bfs_pdf.groupby("distance").size()
        max_breadth = int(breadth_series.max())

        # structural_virality: average distance over all node pairs.
        # The following approach runs a BFS from each node in the tree.
        # (Note: if the trees are large, you might want to use an approximate method.)
        total_distance = 0
        pair_count = 0
        for v in bfs_pdf["vertex"]:
            bfs_v = cugraph.bfs(G_cu, v)
            # Convert the result to Pandas for summing.
            distances = bfs_v.to_pandas()["distance"]
            total_distance += distances.sum()
            # Subtract one so that we don’t count the distance from v to itself
            pair_count += len(distances) - 1
        structural_virality = (
            float(total_distance / pair_count) if pair_count > 0 else 0
        )

        tree_statistics[root] = {
            "max_depth": max_depth,
            "size": size,
            "breadth": max_breadth,
            "structural_virality": structural_virality,
            "reach": size,
        }

    return tree_statistics


reply_stats = calculate_tree_statistics(reply_graph)
repost_stats = calculate_tree_statistics(repost_graph)
combined_stats = calculate_tree_statistics(combined_graph)

env: NX_CUGRAPH_AUTOCONFIG=True


Calculating Tree Statistics:   0%|          | 0/43818 [00:00<?, ?it/s]

Calculating Tree Statistics:   0%|          | 0/1369696 [00:00<?, ?it/s]

Calculating Tree Statistics:   0%|          | 0/43818 [00:00<?, ?it/s]

In [9]:
reply_stats = (
    cascade_analysis.InformationCascadeGraph.calculate_tree_statistics_cugraph(
        reply_graph
    )
)
repost_stats = (
    cascade_analysis.InformationCascadeGraph.calculate_tree_statistics_cugraph(
        repost_graph
    )
)
combined_stats = (
    cascade_analysis.InformationCascadeGraph.calculate_tree_statistics_cugraph(
        combined_graph
    )
)

AttributeError: module 'cugraph' has no attribute 'DiGraph'

In [None]:
reply_stats = cascade_graph.calculate_tree_statistics(reply_graph)
repost_stats = cascade_graph.calculate_tree_statistics(repost_graph)
combined_stats = cascade_graph.calculate_tree_statistics(combined_graph)

In [17]:
import pandas as pd

reply_stats_df = pd.DataFrame(reply_stats).T
repost_stats_df = pd.DataFrame(repost_stats).T
combined_stats_df = pd.DataFrame(combined_stats).T

In [18]:
# find root id for each repost id
import networkx as nx

repost_root = []


def find_root(G, child):
    parent = list(G.predecessors(child))
    if len(parent) == 0:
        return child
    else:
        return find_root(G, parent[0])


for repost_id in repost_stats_df.index:
    if combined_graph.in_degree(repost_id) == 0:
        repost_stats_df.loc[repost_id, "root_id"] = repost_id
    else:
        repost_stats_df.loc[repost_id, "root_id"] = find_root(combined_graph, repost_id)


In [19]:
ts_topics = pd.read_csv("../data/ts_df_id_topic.csv")


In [20]:
reply_stats_df.reset_index(inplace=True)
reply_stats_df["index"] = reply_stats_df["index"].astype(int)
ts_topics["id"] = ts_topics["id"].astype(int)
reply_stats_df = reply_stats_df.merge(
    ts_topics, left_on="index", right_on="id", how="left"
).drop(columns="id")

In [21]:
reply_stats_df.to_csv("../data/ts_reply_stats.csv", index=False)

In [22]:
repost_stats_df.reset_index(inplace=True)
repost_stats_df["root_id"] = repost_stats_df["root_id"].astype(int)
repost_stats_df["index"] = repost_stats_df["index"].astype(int)
ts_topics["id"] = ts_topics["id"].astype(int)

repost_stats_df_test = repost_stats_df.merge(
    ts_topics, left_on="root_id", right_on="id", how="left"
).drop(columns="id")

In [23]:
repost_stats_df_test.to_csv("../data/ts_repost_stats.csv", index=False)

In [24]:
combined_stats_df.reset_index(inplace=True)
combined_stats_df["index"] = combined_stats_df["index"].astype(int)
ts_topics["id"] = ts_topics["id"].astype(int)
combined_stats_df_output = combined_stats_df.merge(
    ts_topics, left_on="index", right_on="id", how="left"
).drop(columns="id")

In [25]:
combined_stats_df_output.to_csv("../data/ts_combined_stats.csv", index=False)