In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go

In [2]:
df = pd.read_csv("ReFED_US_Food_Surplus_Summary.csv",skiprows=3 )
# pd.read_csv('filename.csv', skiprows=3)
# Filter the required columns
df = df[['sector', 'tons_uneaten', 'tons_eaten']]

# Print the original data
print("Original data:")
print(df)

# Aggregate the 'tons_uneaten' and 'tons_eaten' values for each unique 'sector'
df = df.groupby('sector').agg({'tons_uneaten': 'sum', 'tons_eaten': 'sum'}).reset_index()

# Print the aggregated data
print("\nAggregated data:")
print(df)

# Create an empty DataFrame to store the Sankey data
sankey_data = pd.DataFrame(columns=['source', 'target', 'value', 'type'])

# Define a dictionary to map the sector to the corresponding target names for food consumed
food_consumed_target_map = {
    'Farm': ['Manufacturing_Consumed', 'Retail_Consumed', 'Food_service_Consumed'],
    'Manufacturing': ['Residential_Consumed'],
    'Retail': ['Residential_Consumed'],
    'Foodservice': ['Residential_Consumed'],
}

# Loop through each row in the aggregated DataFrame and populate the sankey_data DataFrame
for index, row in df.iterrows():
    sector = row['sector']
    tons_eaten = row['tons_eaten']
    tons_uneaten = row['tons_uneaten']

    if sector in food_consumed_target_map:
        targets = food_consumed_target_map[sector]
        for target in targets:
            sankey_data = sankey_data.append({
                'source': sector,
                'target': target,
                'value': tons_eaten / len(targets),
                'type': 'Food Consumed',
            }, ignore_index=True)

    surplus_target = f"{sector}_Surplus"
    sankey_data = sankey_data.append({
        'source': sector,
        'target': surplus_target,
        'value': tons_uneaten,
        'type': 'Food Surplus',
    }, ignore_index=True)

# Save the sankey_data DataFrame as a CSV file
sankey_data.to_csv('sankey_data.csv', index=False)


Original data:
           sector  tons_uneaten     tons_eaten
0            Farm  1.725822e+05       8.061332
1            Farm  1.508448e+07  250428.269786
2     Foodservice  7.123902e+02       9.886955
3     Foodservice  3.388924e+02       4.703341
4     Foodservice  6.057685e+02       8.407198
...           ...           ...            ...
3475       Retail  3.523333e+05   84443.490174
3476       Retail  5.254138e+04   12410.649567
3477       Retail  1.382478e+05   32655.127204
3478       Retail  9.541376e+05  225374.113290
3479       Retail  2.339707e+05   77410.512162

[3480 rows x 3 columns]

Aggregated data:
          sector  tons_uneaten    tons_eaten
0           Farm  1.969084e+08  3.018754e+06
1    Foodservice  1.436554e+08  2.980561e+08
2  Manufacturing  1.420142e+08  5.081355e+06
3    Residential  4.779282e+08  1.354578e+09
4         Retail  4.315339e+07  1.048139e+07


  sankey_data = sankey_data.append({
  sankey_data = sankey_data.append({
  sankey_data = sankey_data.append({
  sankey_data = sankey_data.append({
  sankey_data = sankey_data.append({
  sankey_data = sankey_data.append({
  sankey_data = sankey_data.append({
  sankey_data = sankey_data.append({
  sankey_data = sankey_data.append({
  sankey_data = sankey_data.append({
  sankey_data = sankey_data.append({


In [10]:

# Read the Sankey data from the CSV file
sankey_data = pd.read_csv('sankey_data.csv')

# Create the node labels
labels = pd.concat([sankey_data['source'], sankey_data['target']]).unique().tolist()

# Define a dictionary to map the sectors and their colors
sector_color_map = {
    'Farm': 'lightgreen',
    'Manufacturing': 'blue',
    'Retail': 'maroon',
    'Foodservice': 'purple',
    'Residential': 'yellow',
}

# Function to get sector color based on the sector name
def get_sector_color(sector_name):
    for key, color in sector_color_map.items():
        if key.lower() in sector_name.lower():
            return color
    return 'gray'  # Default color if no match is found

# Create the node colors
colors = [get_sector_color(label) for label in labels]

# Create the Sankey diagram
fig = go.Figure(data=[go.Sankey(
    node=dict(
        pad=15,
        thickness=20,
        line=dict(color='black', width=0.5),
        label=labels,
        color=colors,
    ),
    link=dict(
        source=sankey_data['source'].apply(lambda x: labels.index(x)),
        target=sankey_data['target'].apply(lambda x: labels.index(x)),
        value=sankey_data['value'],
        color=sankey_data['type'].apply(lambda x: 'rgba(144,238, 144, 0.5)' if x == 'Food Surplus' else 'rgba(235, 168, 128, 0.5)'),
    ),
)])

# Set the title and layout for the Sankey diagram
fig.update_layout(
    title_text="Food Flow Sankey Diagram",
    font_size=20,
    autosize=False,
    width=1200,
    height=800,
)

# Show the Sankey diagram
fig.show()
