In [1]:
import os
import sys

# Add project root to Python path
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)


import ast

import numpy as np


from utils.gcloud_utilities import *

from utils.metadata import *

from utils.preprocessing_utilities import (
    import_operating_nodes,
    expand_parameters_col_and_format,
)

In [2]:
year = "2023"

# Load data
bucket, nodes = import_operating_nodes(year)
endUse_nodes = pull_from_gcs_csv(
    bucket, GCLOUD_PREPROCESSED_DIR + BENCHMARK_PREPROCESSED_DIR + "endUse_nodes.csv"
)

edges = pull_from_gcs_csv(
    bucket,
    GCLOUD_PREPROCESSED_DIR
    + BENCHMARK_PREPROCESSED_DIR
    + BENCHMARK_EDGES_DIR
    + BENCHMARK_EDGES_FILE,
)

edges["properties"] = edges["properties"].astype(str).apply(ast.literal_eval)
dict_df = pd.json_normalize(edges["properties"])
edges = edges.drop(columns=["properties"]).join(dict_df)

nodes_df = pd.concat((nodes, endUse_nodes))

nodes_df["type"] = (
    nodes_df["mine_type"]
    .fillna(nodes_df["process_type"])
    .fillna(nodes_df["product_type"])
)
nodes_df = nodes_df.dropna(subset=["type"])[["node_id", "type", year]]

stages_dict = {
    "mining": ["Brine", "Spodumene", "Mica", "Pegmatite"],
    "carbonate": ["Lithium Carbonate"],
    "hydroxide": ["Lithium Hydroxide"],
    "cathode": [
        "NCM mid nickel",
        "LFP",
        "4V Ni or Mn based",
        "NCA",
        "NCM high nickel",
        "LCO",
        "NCM low nickel",
        "5V Mn based",
    ],
    "battery": [
        "Cylindrical",
        "Pouch",
        "Cylindrical, Pouch",
        "Pouch, Prismatic",
        "Prismatic",
        "Cylindrical, Prismatic",
        "Cylindrical, Pouch, Prismatic",
    ],
    "end_use": ["EV", "ESS", "Portable"],
}

nodes_df["stage"] = nodes_df["type"].map(
    {item: cat for cat, items in stages_dict.items() for item in items}
)

nodes_df = nodes_df.dropna(subset=["stage"])

[32m2025-05-22 11:19:37.994[0m | [1mINFO    [0m | [36mutils.gcloud_utilities[0m:[36mfetch_gcs_bucket[0m:[36m16[0m - [1mFetching GCS bucket: lithium-datasets in project: critical-minerals'[0m


[32m2025-05-22 11:19:42.860[0m | [1mINFO    [0m | [36mutils.gcloud_utilities[0m:[36mpull_from_gcs_csv[0m:[36m27[0m - [1mPulling data from preprocessed/benchmark/benchmark_nodes.csv in bucket lithium-datasets[0m
[32m2025-05-22 11:19:46.416[0m | [1mINFO    [0m | [36mutils.gcloud_utilities[0m:[36mpull_from_gcs_csv[0m:[36m27[0m - [1mPulling data from preprocessed/benchmark/endUse_nodes.csv in bucket lithium-datasets[0m
[32m2025-05-22 11:19:46.649[0m | [1mINFO    [0m | [36mutils.gcloud_utilities[0m:[36mpull_from_gcs_csv[0m:[36m27[0m - [1mPulling data from preprocessed/benchmark/edge_creation/benchmark_combined_edges.csv in bucket lithium-datasets[0m


In [3]:
inputs = edges.merge(
    nodes_df[["node_id", "type", "stage"]],
    left_on=["source", "edge_type"],
    right_on=["node_id", "type"],
    how="left",
)
outputs = inputs[
    ["stage", "type", "target", "2023_volume", "edge_type", "edge_destination"]
].merge(
    nodes_df[["node_id", "stage", "type"]],
    left_on="target",
    right_on="node_id",
    how="left",
    suffixes=("_source", "_target"),
)
outputs = outputs[
    (outputs["edge_destination"].isna())
    | (outputs["edge_destination"] == outputs["type_target"])
]
all_flows = (
    outputs[
        [
            "stage_source",
            "type_source",
            "stage_target",
            "type_target",
            "2023_volume",
            "edge_type",
        ]
    ]
    .groupby(
        ["stage_source", "type_source", "stage_target", "type_target", "edge_type"]
    )
    .sum()
    .reset_index()
)

stages = list(stages_dict.keys())
next_stage_map = {stages[i]: stages[i + 1] for i in range(len(stages) - 1)}

# 2. Keep only rows where stage_target matches next_stage_map[stage_source]
real_flows = pd.concat(
    (
        all_flows[
            all_flows["stage_source"].map(next_stage_map) == all_flows["stage_target"]
        ],
        all_flows[
            (all_flows["stage_source"] == "mining")
            & (all_flows["stage_target"] == "hydroxide")
        ],
        all_flows[
            (all_flows["stage_source"] == "carbonate")
            & (all_flows["stage_target"] == "cathode")
        ],
    )
)

unit_conversion = pull_from_gcs_excel(
    bucket, MAPPINGS_DIR + "Li_unit_conversion.xlsx", sheet_name="Sheet1"
)

converted_edges = real_flows.merge(
    unit_conversion[["type", "edge_conversion"]],
    left_on="edge_type",
    right_on="type",
    how="left",
)
converted_edges["2023_volume"] = (
    converted_edges["2023_volume"] * converted_edges["edge_conversion"]
)

converted_nodes = nodes_df.merge(
    unit_conversion[["type", "node_conversion"]],
    left_on="type",
    right_on="type",
    how="left",
)
converted_nodes[year] = converted_nodes[year] * converted_nodes["node_conversion"]

[32m2025-05-22 11:19:50.122[0m | [1mINFO    [0m | [36mutils.gcloud_utilities[0m:[36mpull_from_gcs_excel[0m:[36m46[0m - [1mPulling data from raw/mappings/Li_unit_conversion.xlsx in bucket lithium-datasets[0m


In [5]:
# Balance with total outputs
totals = converted_nodes.groupby("type").sum()[year].reset_index()
sources = (
    converted_edges.groupby(["stage_source", "type_source"])
    .sum()[year + "_volume"]
    .reset_index()
)
targets = (
    converted_edges.groupby(["stage_target", "type_target"])
    .sum()[year + "_volume"]
    .reset_index()
)

losses = totals.merge(sources, left_on="type", right_on="type_source")
losses[year + "_volume"] = losses[year] - losses[year + "_volume"]
losses["stage_target"] = losses["stage_source"]
losses["type_target"] = [
    i.capitalize() + " Losses & Stock" for i in losses["stage_source"]
]
losses["edge_type"] = losses["type_source"]

from_stocks = totals.merge(targets, left_on="type", right_on="type_target")
from_stocks[year + "_volume"] = from_stocks[year] - from_stocks[year + "_volume"]
from_stocks["stage_source"] = from_stocks["stage_target"]
from_stocks["type_source"] = [
    i.capitalize() + " Stock" for i in from_stocks["stage_target"]
]
from_stocks["edge_type"] = from_stocks["type_target"]

extra_losses = losses.merge(from_stocks, on="type")
extra_losses[year + "_volume_x"] = np.where(
    extra_losses[year + "_volume_y"] < 0,
    extra_losses[year + "_volume_x"] - extra_losses[year + "_volume_y"],
    extra_losses[year + "_volume_x"],
)

losses = losses.merge(extra_losses[["type", year + "_volume_x"]], on="type")
losses[year + "_volume"] = losses[year + "_volume_x"]
losses = losses.drop(columns=[year + "_volume_x"])

from_stocks[year + "_volume"] = np.where(
    from_stocks[year + "_volume"] < 0, 0, from_stocks[year + "_volume"]
)

converted_edges = pd.concat((converted_edges, losses, from_stocks))

In [6]:
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt

mapping = {"carbonate": "processing", "hydroxide": "processing"}

outputs["stage_source"] = (
    outputs["stage_source"].map(mapping).fillna(outputs["stage_source"])
)

converted_edges["stage_source"] = (
    converted_edges["stage_source"].map(mapping).fillna(converted_edges["stage_source"])
)

unique_stages = list(outputs["stage_source"].unique()) + [
    "end_use"
]  # converted_edges["stage_source"].unique()

# e.g. "tab10" has 10 distinct colors. You can choose other colormaps:
# "tab20", "Set1", "hsv", "Paired", etc.
cmap = plt.cm.get_cmap("tab10", len(unique_stages) + 1)

# stage_base_colors: stage_source → RGBA tuple
stage_base_colors = {}
for i, stage in enumerate(unique_stages):
    # cmap(i) returns RGBA in [0, 1]
    stage_base_colors[stage] = cmap(i if i < len(unique_stages) - 1 else i + 1)


# ----------------------------------------------------------------------------
# 3) HELPER FUNCTION TO LIGHTEN/DARKEN A COLOR
# ----------------------------------------------------------------------------
def adjust_color(rgb, amount=1.2):
    """
    Lightens or darkens an RGB color by multiplying (if >1 => lighten)
    or reducing (if <1 => darken) the R, G, B values.

    rgb is a tuple (r, g, b) in [0,1].
    amount is the factor by which to adjust (e.g. 1.2 ~ 20% lighter).
    """
    r, g, b = rgb
    r_new = min(1, max(0, r * amount))
    g_new = min(1, max(0, g * amount))
    b_new = min(1, max(0, b * amount))
    return (r_new, g_new, b_new)


# ----------------------------------------------------------------------------
# 4) BUILD A DICT (stage_source, edge_type) -> HEX COLOR CODE
# ----------------------------------------------------------------------------
color_map = {}

for stage in unique_stages:
    # Base RGBA for this stage
    base_rgba = stage_base_colors[stage]
    # We'll use only RGB part, ignoring alpha
    base_rgb = base_rgba[:3]

    # Find all edge_types within this stage_source
    edges_for_stage = converted_edges.loc[
        converted_edges["stage_source"] == stage, "edge_type"
    ].unique()
    n_edges = len(edges_for_stage)

    # We'll create n distinct brightness adjustments between 0.7 and 1.4
    # so each edge_type in the same stage_source gets a unique shade.
    amounts = np.linspace(0.7, 1.4, n_edges)

    for edge, amt in zip(edges_for_stage, amounts):
        adjusted_rgb = adjust_color(base_rgb, amount=amt)
        # Convert to hex (e.g. "#1f77b4")
        hex_color = mcolors.to_hex(adjusted_rgb)
        color_map[(stage, edge)] = hex_color

# ----------------------------------------------------------------------------
# 5) CREATE A NEW COLUMN IN THE DATAFRAME WITH THESE HEX CODES
# ----------------------------------------------------------------------------
converted_edges["color_hex"] = [
    color_map[(row["stage_source"], row["edge_type"])]
    for _, row in converted_edges.iterrows()
]

  cmap = plt.cm.get_cmap("tab10", len(unique_stages) + 1)


In [None]:
# converted_edges['2023_volume'] = converted_edges['2023_volume']/1E3 # Convert to kt

converted_edges["SankeyMatic"] = converted_edges.apply(
    lambda row: f"{row['type_source']} [{row['2023_volume']:.1f}] {row['type_target']} {row['color_hex']}",
    axis=1,
).str.replace(", ", "/")

output_path = "/figures/main_results/"
filename = "P3c_sankeyMatic"

#   "/Users/lukecullen/Library/CloudStorage/OneDrive-UniversityofCambridge/Post-doc/P3c/outputs/.txt",
converted_edges["SankeyMatic"].to_csv(
    project_root + output_path + filename + ".txt",
    index=False,
    header=False,
)

In [12]:
eu = converted_edges[converted_edges["stage_target"] == "end_use"]
eu.groupby("type_target").sum()["2023_volume"] / eu["2023_volume"].sum()

type_target
ESS         0.141287
EV          0.785234
Portable    0.073478
Name: 2023_volume, dtype: float64

In [13]:
eu["2023_volume"].sum()

np.float64(97279.83649707901)

In [14]:
eu = converted_edges[(converted_edges["stage_target"] == "end_use")]
eu.groupby("type_source").sum()["2023_volume"] / eu["2023_volume"].sum()

type_source
Cylindrical                      0.151191
Cylindrical, Pouch               0.066582
Cylindrical, Pouch, Prismatic    0.000245
Cylindrical, Prismatic           0.011065
End_use Stock                    0.000000
Pouch                            0.165830
Pouch, Prismatic                 0.007762
Prismatic                        0.597325
Name: 2023_volume, dtype: float64

In [15]:
rms = converted_edges[(converted_edges["stage_source"] == "mining")]
rms.groupby("type_source").sum()["2023_volume"] / rms["2023_volume"].sum()

type_source
Brine        0.426810
Mica         0.098294
Pegmatite    0.002908
Spodumene    0.471988
Name: 2023_volume, dtype: float64

In [16]:
rms["2023_volume"].sum()

np.float64(161527.93850317987)

In [17]:
caths = converted_edges[(converted_edges["stage_source"] == "cathode")]

In [18]:
eu["2023_volume"].sum() / rms["2023_volume"].sum()

np.float64(0.6022477436320649)

In [19]:
caths = converted_edges[(converted_edges["stage_source"] == "cathode")]

In [20]:
caths.groupby("type_source").sum()["2023_volume"] / caths["2023_volume"].sum()

type_source
4V Ni or Mn based    0.041727
5V Mn based          0.000000
Cathode Stock        0.129612
LCO                  0.039082
LFP                  0.330350
NCA                  0.064549
NCM high nickel      0.195814
NCM low nickel       0.006606
NCM mid nickel       0.192259
Name: 2023_volume, dtype: float64

In [21]:
0.192259 + 0.006606 + 0.195814

0.394679

In [22]:
# Petravatzi
eu["2023_volume"].sum() / (12334 + 4628 + 934 + 1543 + 4855 + 617 + 2241 + 309)

np.float64(3.542472469942064)

In [23]:
7147.967663 / (4855 + 617)

1.3062806401681286

In [24]:
# EV
76387.471415 / (12334 + 4628)

4.503447200507016

In [25]:
# ESS
13744.397419 / (934 + 1543)

5.5488080012111425