In [3]:
import os
import sys

# Add project root to Python path
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

In [None]:
import ast
import numpy as np
import networkx as nx
from tqdm import tqdm

from utils.gcloud_utilities import *
from utils.metadata import *
from utils.preprocessing_utilities import (
    import_operating_nodes,
    expand_parameters_col_and_format,
)

In [None]:
year = "2023"

# Load data
bucket, nodes = import_operating_nodes(year)
endUse_nodes = pull_from_gcs_csv(
    bucket, GCLOUD_PREPROCESSED_DIR + BENCHMARK_PREPROCESSED_DIR + "endUse_nodes.csv"
)

edges = pull_from_gcs_csv(
    bucket,
    GCLOUD_PREPROCESSED_DIR
    + BENCHMARK_PREPROCESSED_DIR
    + BENCHMARK_EDGES_DIR
    + BENCHMARK_EDGES_FILE,
)

edges["properties"] = edges["properties"].astype(str).apply(ast.literal_eval)
dict_df = pd.json_normalize(edges["properties"])
edges = edges.drop(columns=["properties"]).join(dict_df)

nodes_df = pd.concat((nodes, endUse_nodes))

nodes_df["type"] = (
    nodes_df["mine_type"]
    .fillna(nodes_df["process_type"])
    .fillna(nodes_df["product_type"])
)
nodes_df["country"] = nodes_df["country"].fillna(nodes_df["region"])
nodes_df["company"] = nodes_df["company"].fillna(nodes_df["operator_short_clean"])
# nodes_df = nodes_df.dropna(subset=['type'])[['node_id', 'type', year]]

stages_dict = {
    "mining": ["Brine", "Spodumene", "Mica", "Pegmatite"],
    "carbonate": ["Lithium Carbonate"],
    "hydroxide": ["Lithium Hydroxide"],
    "cathode": [
        "NCM mid nickel",
        "LFP",
        "4V Ni or Mn based",
        "NCA",
        "NCM high nickel",
        "LCO",
        "NCM low nickel",
        "5V Mn based",
    ],
    "battery": [
        "Cylindrical",
        "Pouch",
        "Cylindrical, Pouch",
        "Pouch, Prismatic",
        "Prismatic",
        "Cylindrical, Prismatic",
        "Cylindrical, Pouch, Prismatic",
    ],
    "end_use": ["EV", "ESS", "Portable"],
}

nodes_df["stage"] = nodes_df["type"].map(
    {item: cat for cat, items in stages_dict.items() for item in items}
)

inputs = edges.merge(
    nodes_df[["node_id", "type", "stage", "country", "company"]],
    left_on=["source", "edge_type"],
    right_on=["node_id", "type"],
    how="left",
)
outputs = inputs[
    [
        "stage",
        "type",
        "source",
        "target",
        "2023_volume",
        "edge_type",
        "edge_destination",
        "country",
        "company",
    ]
].merge(
    nodes_df[["node_id", "stage", "type", "country", "company"]],
    left_on="target",
    right_on="node_id",
    how="left",
    suffixes=("_source", "_target"),
)
outputs = outputs[
    (outputs["edge_destination"].isna())
    | (outputs["edge_destination"] == outputs["type_target"])
]

[32m2025-05-10 16:28:58.932[0m | [1mINFO    [0m | [36mutils.gcloud_utilities[0m:[36mfetch_gcs_bucket[0m:[36m16[0m - [1mFetching GCS bucket: lithium-datasets in project: critical-minerals'[0m
[32m2025-05-10 16:29:15.927[0m | [1mINFO    [0m | [36mutils.gcloud_utilities[0m:[36mpull_from_gcs_csv[0m:[36m27[0m - [1mPulling data from preprocessed/benchmark/benchmark_nodes.csv in bucket lithium-datasets[0m
[32m2025-05-10 16:29:17.467[0m | [1mINFO    [0m | [36mutils.gcloud_utilities[0m:[36mpull_from_gcs_csv[0m:[36m27[0m - [1mPulling data from preprocessed/benchmark/endUse_nodes.csv in bucket lithium-datasets[0m
[32m2025-05-10 16:29:17.657[0m | [1mINFO    [0m | [36mutils.gcloud_utilities[0m:[36mpull_from_gcs_csv[0m:[36m27[0m - [1mPulling data from preprocessed/benchmark/edge_creation/benchmark_combined_edges.csv in bucket lithium-datasets[0m


In [None]:
# 2. Keep only rows where stage_target matches next_stage_map[stage_source]
stages = list(stages_dict.keys())
next_stage_map = {stages[i]: stages[i + 1] for i in range(len(stages) - 1)}

real_flows = pd.concat(
    (
        outputs[outputs["stage_source"].map(next_stage_map) == outputs["stage_target"]],
        outputs[
            (outputs["stage_source"] == "mining")
            & (outputs["stage_target"] == "hydroxide")
        ],
        outputs[
            (outputs["stage_source"] == "carbonate")
            & (outputs["stage_target"] == "cathode")
        ],
    )
)

real_flows = real_flows[real_flows["source"] != real_flows["target"]]

cathode_types = stages_dict["cathode"]

# Add node id prefix based on cathode type due to multiple cathode types from some facilities
prefix_values = [str(i) for i in range(10, 56, 5)]  # Define prefix order
prefix_map = dict(zip(cathode_types, prefix_values))  # Create mapping
real_flows["source"] = (
    real_flows["type_source"].map(prefix_map).fillna("")
    + real_flows["source"].astype(str)
).astype(int)

real_flows["target"] = (
    real_flows["type_target"].map(prefix_map).fillna("")
    + real_flows["target"].astype(str)
).astype(int)

In [None]:
def compute_flow_fractions(G, target, vol_attr="vol"):
    """
    Computes the fraction of flow at each node that eventually reaches `target`.
    Assumes the graph is a DAG.
    """
    # Initialize all fractions to zero; set target's fraction to 1.
    fractions = {node: 0 for node in G.nodes()}
    fractions[target] = 1.0

    try:
        # Get a topological ordering (requires a DAG)
        topo_order = list(nx.topological_sort(G))
    except nx.NetworkXUnfeasible:
        raise ValueError("The graph contains cycles. This method assumes a DAG.")

    # Process nodes in reverse topological order (from target upstream)
    topo_order.reverse()
    for node in topo_order:
        # For every predecessor of the current node,
        # add the contribution from the edge from pred -> node.
        for pred in G.predecessors(node):
            # Sum of volumes on all edges leaving 'pred'
            out_edges = list(G.out_edges(pred, data=True))
            total_out = sum(edge_data[vol_attr] for _, _, edge_data in out_edges)
            if total_out > 0:
                # The fraction of pred's flow that goes to this child
                flow_ratio = G[pred][node][vol_attr] / total_out
                fractions[pred] += fractions[node] * flow_ratio
    return fractions

In [None]:
# Build a directed graph from the DataFrame
G = nx.from_pandas_edgelist(
    real_flows,
    source="source",
    target="target",
    edge_attr=year + "_volume",
    create_using=nx.DiGraph(),
)

In [None]:
all_trees = pd.DataFrame()

battery_node_ids = nodes_df[
    [i in list(stages_dict["battery"]) for i in nodes_df["type"]]
]["node_id"].unique()

for battery_node_id in tqdm(battery_node_ids):

    # Compute flow fractions for each node
    fractions = compute_flow_fractions(G, battery_node_id, vol_attr=year + "_volume")

    # Now, adjust each edge's volume to only account for the portion that eventually reaches final_target.
    # For an edge (i -> j), the adjusted volume is: vol(i->j) * fraction[j]
    adjusted_flows = real_flows.copy()
    adjusted_flows["adj_vol"] = adjusted_flows.apply(
        lambda row: row[year + "_volume"] * fractions.get(row["target"], 0), axis=1
    )

    upstream_tree = adjusted_flows[adjusted_flows["adj_vol"] > 0].copy()

    upstream_tree["battery_node_id"] = battery_node_id

    all_trees = pd.concat((all_trees, upstream_tree))

100%|██████████| 247/247 [10:26<00:00,  2.53s/it]


In [None]:
all_trees.to_csv(project_root + "/figures/main_results/upstream_trees.csv", index=False)