In [1]:
import pandas as pd
import folium
from folium.plugins import HeatMap
import plotly.express as px
import plotly.io as pio

def generate_heatmap(group_number):
    # Define the columns to plot based on the group number
    if group_number == 1:
        columns_to_plot = [
            'Median rent weekly',
            'Median mortgage repayment monthly',
            'House Median Sales Price',
            'House No of Transactions',
            'Attached dwellings Median Sales Price',
            'Attached dwellings No of Transactions'
        ]
        group_name = "Group1: Response"
    elif group_number == 2:
        columns_to_plot = [
            'Building Approvals - Year ended 30 June Total private sector dwelling units (no.)',
            'Building Approvals - Year ended 30 June Total value of private sector dwelling units ($m)',
            'Building Approvals - Year ended 30 June Private sector houses (no.)',
            'Building Approvals - Year ended 30 June Value of private sector houses ($m)',
            'Value of New Houses ($000)',
            'Number of New Houses'
        ]
        group_name = "Group2: Housing/rental market: Supply Side"
    elif group_number == 3:
        columns_to_plot = [
            'Household Stress - Census Households with rent payments greater than or equal to 30% of household income (%)',
            'Household Stress - Census Households with mortgage repayments greater than or equal to 30% of household income (%)',
            'Housing Suitability - Occupied private dwellings - Census Dwellings with extra bedrooms needed (no.)',
            'Estimates of Personal Income - Year ended 30 June Mean employee income ($)',
            'Estimates of Personal Income - Year ended 30 June Median employee income ($)',
            'Gross Capital Gains reported by taxpayers - Year ended 30 June Gross Capital Gains reported by taxpayers - Mean ($)',
            'Gross Capital Gains reported by taxpayers - Year ended 30 June Gross Capital Gains reported by taxpayers  - Median ($)',
            'Estimates of Personal Income - Year ended 30 June Mean investment income ($)',
            'Estimates of Personal Income - Year ended 30 June Median investment income ($)',
            'Estimates of Personal Income - Year ended 30 June Total income (excl. Government pensions and allowances) - Gini coefficient'
        ]
        group_name = "Group3: Housing/rental market: Demand Side"
    else:
        raise ValueError("Invalid group number. Please choose 1, 2, or 3.")

    # Load df_sudo data
    df_sudo_major = pd.read_csv("csv/df_sudo.csv")

    # Read latitude and longitude data
    df_lat_long = pd.read_csv("csv/sydney_melbourne_extracted_sa4_info.csv")
    SA4_COORDINATES = dict(zip(df_lat_long['sa4_code'], zip(df_lat_long['latitude'], df_lat_long['longitude'])))

    # Merge df_sudo_major with df_lat_long on 'SA4 Code'
    df_sudo_major = pd.merge(df_sudo_major, df_lat_long[['sa4_code', 'sa4_name']], left_on='SA4 Code', right_on='sa4_code', how='left')

    # Map SA4 Code to coordinates
    df_sudo_major['Latitude'] = df_sudo_major['SA4 Code'].map(lambda code: SA4_COORDINATES.get(code, [None, None])[0])
    df_sudo_major['Longitude'] = df_sudo_major['SA4 Code'].map(lambda code: SA4_COORDINATES.get(code, [None, None])[1])

    # Filter out rows with missing coordinates
    df_sudo_major = df_sudo_major.dropna(subset=['Latitude', 'Longitude'] + [f'{columns_to_plot[i]}' for i in range(len(columns_to_plot))])

    # Create a base map centered on Australia
    australia_center = [-25.2744, 133.7751]  # Latitude and longitude of the center of Australia
    m = folium.Map(location=australia_center, zoom_start=4, tiles='OpenStreetMap')

    # Create a heatmap
    heatmap_data = df_sudo_major[['Latitude', 'Longitude', columns_to_plot[0]]].values.tolist()

    # Define a custom gradient if needed
    gradient = {0.2: 'blue', 0.4: 'lime', 0.6: 'yellow', 0.8: 'orange', 1.0: 'red'}

    # Add heatmap to the map
    HeatMap(
        heatmap_data,
        name=group_name,
        min_opacity=0.2,
        max_zoom=18,
        radius=25,
        blur=15,
        gradient=gradient
    ).add_to(m)

    # Create line graphs for each SA4 region and add them as popups
    for sa4_code, sa4_name in df_sudo_major[['SA4 Code', 'sa4_name']].drop_duplicates().values:
        df_sa4 = df_sudo_major[df_sudo_major['SA4 Code'] == sa4_code]
        
        fig = px.line(
            df_sa4,
            x="Year",
            y=[f'{columns_to_plot[i]}' for i in range(len(columns_to_plot))],
            labels=[f'{columns_to_plot[i]}' for i in range(len(columns_to_plot))],
            title=f'Multiple Metrics for {sa4_name}'
        )

        # Customize the line and marker properties for each trace
        for i, _ in enumerate(fig.data):
            fig.data[i].update(
                mode='lines+markers',
                marker=dict(size=8),
                line=dict(width=2)
            )

        # Customize the layout
        fig.update_layout(
            width=600,
            height=400,
            paper_bgcolor='white',
            plot_bgcolor='white',
            xaxis=dict(
                title="Year",
                gridcolor='lightgray',
                gridwidth=1,
                linecolor='gray',
                linewidth=1,
                tickfont=dict(size=12)
            ),
            yaxis=dict(
                title="Value",
                gridcolor='lightgray',
                gridwidth=1,
                linecolor='gray',
                linewidth=1,
                tickfont=dict(size=12)
            ),
            legend=dict(
                title="Data",
                orientation="h",
                yanchor="bottom",
                y=0.5,
                xanchor="right",
                x=1.6,
                font=dict(size=8),
                bgcolor='rgba(255, 255, 255, 0)'  # Set the background color to transparent
            ),
            title=dict(
                text=f'Multiple Metrics for {sa4_name}',
                font=dict(size=16),
                x=0.5
            )
        )

        html_str = pio.to_html(fig, full_html=False, include_plotlyjs='cdn')
        html_str = f"<div style='overflow:auto;max-width:600px;max-height:500px;'>{html_str}</div>"
        iframe = folium.IFrame(html=html_str, width=600, height=500)
        popup = folium.Popup(iframe, max_width=600)
        
        lat, lon = SA4_COORDINATES[sa4_code]
        marker = folium.Marker([lat, lon], popup=popup)
        marker.add_to(m)

    # Add the layer control to the map
    folium.LayerControl(position='topright', collapsed=False).add_to(m)

    # Save the map to an HTML file
    map_filename = f"output/interactive_heatmap_group{group_number}.html"
    m.save(map_filename)

    print(f"Heatmap has been saved to '{map_filename}'")

# Example usage
generate_heatmap(1)  # Generate heatmap for Group 1
generate_heatmap(2)  # Generate heatmap for Group 2
generate_heatmap(3)  # Generate heatmap for Group 3

Heatmap has been saved to 'output/interactive_heatmap_group1.html'
Heatmap has been saved to 'output/interactive_heatmap_group2.html'
Heatmap has been saved to 'output/interactive_heatmap_group3.html'
