In [3]:
# Alcohol Consumption Network Graph (Dynamic)

import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from itertools import combinations
import pandas as pd

# Replace the filename with the path to your local CSV file
df = pd.read_csv('/home/mrh/alcohol.csv')

# Make sure the file has these columns: 'country', 'year', 'consumption'
print(df.head())

# Load your alcohol consumption data here
# Ensure the DataFrame has columns: 'country', 'year', 'consumption'
# Example: df = pd.read_csv('your_data.csv')

# Uncomment and adjust as needed:
# df = pd.read_csv('alcohol_consumption.csv')

# Pivot the data to wide format: rows = year, columns = country
pivot = df.pivot(index='year', columns='country', values='consumption')

# Compute correlation matrix between countries
corr = pivot.corr()

# Build the graph
G = nx.Graph()
means = pivot.mean()
for country, avg in means.items():
    G.add_node(country, mean_consumption=avg)

# Add edges based on correlation
threshold = 0.3  # Adjust threshold as needed
for a, b in combinations(corr.columns, 2):
    weight = corr.loc[a, b]
    if abs(weight) >= threshold:
        G.add_edge(a, b, weight=weight)

# Draw the graph (static)
pos = nx.spring_layout(G, seed=42)
node_sizes = [500 + (G.nodes[n]['mean_consumption'] / means.max()) * 1500 for n in G.nodes()]
node_colors = [G.nodes[n]['mean_consumption'] for n in G.nodes()]
edge_widths = [abs(G[u][v]['weight']) * 5 for u, v in G.edges()]

plt.figure(figsize=(10, 8))
nx.draw_networkx_nodes(G, pos, node_size=node_sizes, node_color=node_colors, cmap=plt.cm.viridis, alpha=0.8)
nx.draw_networkx_edges(G, pos, width=edge_widths, edge_color='gray', alpha=0.7)
nx.draw_networkx_labels(G, pos, font_size=10, font_weight='bold')
edge_labels = { (u, v): f"{d['weight']:.2f}" for u, v, d in G.edges(data=True) }
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)
plt.title('Alcohol Consumption Correlation Network')
plt.axis('off')
plt.show()

# Interactive Plot with Plotly
edge_x, edge_y = [], []
for u, v in G.edges():
    x0, y0 = pos[u]
    x1, y1 = pos[v]
    edge_x += [x0, x1, None]
    edge_y += [y0, y1, None]

edge_trace = go.Scatter(
    x=edge_x,
    y=edge_y,
    mode='lines',
    line=dict(width=0.5, color='#888'),
    hoverinfo='none'
)

node_x, node_y, hover = [], [], []
for n in G.nodes():
    x, y = pos[n]
    node_x.append(x)
    node_y.append(y)
    hover.append(f"{n}<br>Avg consumption: {G.nodes[n]['mean_consumption']:.1f}")

node_trace = go.Scatter(
    x=node_x,
    y=node_y,
    mode='markers+text',
    text=list(G.nodes()),
    textposition='bottom center',
    hovertext=hover,
    hoverinfo='text',
    marker=dict(
        size=node_sizes,
        color=node_colors,
        colorscale='Viridis',
        showscale=True,
        colorbar=dict(title='Avg consumption')
    )
)

fig = go.Figure(data=[edge_trace, node_trace])
fig.update_layout(
    title='Interactive Alcohol Consumption Network',
    showlegend=False,
    xaxis=dict(showgrid=False, zeroline=False, visible=False),
    yaxis=dict(showgrid=False, zeroline=False, visible=False),
    margin=dict(l=20, r=20, t=50, b=20)
)
fig.show()

  Country                                        Series Name  2000 [YR2000]  \
0    Iran  Total alcohol consumption per capita (liters o...          0.021   
1    Iran  Total alcohol consumption per capita, female (...          0.005   
2    Iran  Total alcohol consumption per capita, male (li...          0.037   
3   India  Total alcohol consumption per capita (liters o...          1.970   
4   India  Total alcohol consumption per capita, female (...          0.600   

   2001 [YR2001]  2002 [YR2002]  2003 [YR2003]  2004 [YR2004]  2005 [YR2005]  \
0          0.021          0.030          0.044          0.063          0.085   
1          0.005          0.007          0.010          0.014          0.019   
2          0.037          0.053          0.077          0.110          0.150   
3          1.970          2.110          2.210          2.290          2.390   
4          0.600          0.650          0.680          0.710          0.730   

   2006 [YR2006]  2007 [YR2007]  ...  2011 [

KeyError: 'year'