In [None]:
import pandas as pd

df1=pd.read_csv('datasets\\ev_cat_01-24.csv')
df1

In [None]:
import plotly.express as px

vehicle_columns = [col for col in df1.columns if col != 'Date']
                   
total_vehicles = df1[vehicle_columns].sum()

pie_data = pd.DataFrame({
    'Vehicle Type': total_vehicles.index,
    'Total Count': total_vehicles.values
})
pie_data = pie_data[pie_data['Total Count'] > 0]
fig = px.pie(pie_data,
             values='Total Count',
             names='Vehicle Type',
             title='Distribution of Vehicle Types (Total Sum)',
             hole=0.3)
fig.show()


Hence the market is dominated by personal two wheelers, commercial three wheelers and LMVs, making these our ideal market. However three wheelers are much cheeper in general as they are mainly autos and totos, so we have to keep that in mind.

In [None]:
df1['Date'] = pd.to_datetime(df1['Date'], errors='coerce', format='%d/%m/%y')

# Drop rows where 'Date' conversion failed (NaT values)
df1.dropna(subset=['Date'], inplace=True)

# Convert 'TWO WHEELER(NT)' column to numeric, coercing errors to NaN
df1['TWO WHEELER(NT)'] = pd.to_numeric(df1['TWO WHEELER(NT)'], errors='coerce')

# Drop rows where 'TWO WHEELER(NT)' conversion failed (NaN values)
df1.dropna(subset=['TWO WHEELER(NT)'], inplace=True)

df_sorted = df1.sort_values(by='Date')

fig = px.line(df_sorted,
              x='Date',
              y='TWO WHEELER(NT)',
              title='TWO WHEELER (NT) Sales Over Time',
              labels={'Date': 'Date', 'TWO WHEELER(NT)': 'Sales Count'})

fig.update_xaxes(
    rangeselector=dict(
        buttons=list([
            dict(count=1, label="1m", step="month", stepmode="backward"),
            dict(count=6, label="6m", step="month", stepmode="backward"),
            dict(count=1, label="YTD", step="year", stepmode="todate"),
            dict(count=1, label="1y", step="year", stepmode="backward"),
            dict(step="all")
        ])
    ),
    rangeslider=dict(
        visible=True
    ),
    type="date"
)
fig.show()

In [None]:
fig = px.line(df_sorted,
              x='Date',
              y='THREE WHEELER(T)',
              title='THREE WHEELER (T) Sales Over Time',
              labels={'Date': 'Date', 'THREE WHEELER(T)': 'Sales Count'})

fig.update_xaxes(
    rangeselector=dict(
        buttons=list([
            dict(count=1, label="1m", step="month", stepmode="backward"),
            dict(count=6, label="6m", step="month", stepmode="backward"),
            dict(count=1, label="YTD", step="year", stepmode="todate"),
            dict(count=1, label="1y", step="year", stepmode="backward"),
            dict(step="all")
        ])
    ),
    rangeslider=dict(
        visible=True
    ),
    type="date"
)
fig.show()

In [None]:
fig = px.line(df_sorted,
              x='Date',
              y='LIGHT MOTOR VEHICLE',
              title='LIGHT MOTOR VEHICLE Sales Over Time',
              labels={'Date': 'Date', 'LIGHT MOTOR VEHICLE': 'Sales Count'})

fig.update_xaxes(
    rangeselector=dict(
        buttons=list([
            dict(count=1, label="1m", step="month", stepmode="backward"),
            dict(count=6, label="6m", step="month", stepmode="backward"),
            dict(count=1, label="YTD", step="year", stepmode="todate"),
            dict(count=1, label="1y", step="year", stepmode="backward"),
            dict(step="all")
        ])
    ),
    rangeslider=dict(
        visible=True
    ),
    type="date"
)
fig.show()

Now it is seen that while Three Wheeler (T) and LMV markets are booming, there has been a decline in Two Wheeler (NT) market. Noting this, we proceed with the analysis of electric vehicle markets.

In [None]:
df2=pd.read_csv('datasets\\ev_sales_by_makers_and_cat_15-24.csv')
df2

In [None]:
df2['Cat'].unique()

In [None]:
if 'Maker' in df2.columns:
    df2_no_maker = df2.drop(columns=['Maker'])
    print("\nDataFrame after dropping 'Maker' column:")
else:
    df2_no_maker = df2.copy() # If 'Maker' doesn't exist, just copy to proceed
    print("\n'Maker' column not found, proceeding with original DataFrame.")

df2_no_maker

In [None]:
numeric_cols = df2_no_maker.select_dtypes(include=['number']).columns.tolist()

# The 'Cat' column should not be included in the sum, it's the grouping key
if 'Cat' in numeric_cols:
    numeric_cols.remove('Cat')

# Perform the grouping and summation
df2_grouped = df2_no_maker.groupby('Cat')[numeric_cols].sum()

print("\nDataFrame after grouping by 'Cat' and summing numeric columns:")
df2_grouped

# Optional: Reset index if you want 'Cat' as a regular column
# df2_grouped = df2_grouped.reset_index()
# print("\nDataFrame after grouping and resetting index:")
# print(df2_grouped)

In [None]:
df2_grouped['Sum'] = df2_grouped.sum(axis=1)
df2_grouped

In [None]:
fig = px.pie(df2_grouped,
             values='Sum',
             names=df2_grouped.index, # Use the DataFrame's index (which is 'Cat') for names
             title='Distribution of Total Sales by Vehicle Category (Cat)',
             hole=0.3) # Creates a donut chart

# Display the plot in the Jupyter Notebook
fig.show()

Comparing this with the earlier Pie Chart, we see that the proportions are almost equal. This means that EVs constitute a similar share compared to their fuelled counterparts in each of the 2W, 3W and LMV Market segments. In the MMV market, the share of EVs to fulled vehicles is significantly higher (0.0182% as compared to 0.0111%), hence it is not a suitable market for startups. Now let us see the demands by the yearly graphs.

In [None]:


# Assuming df2_grouped is already prepared as per previous steps.
# It should look something like this (example structure):
#               2015   2016   2017  ...   2024      Sum
# Cat
# 2W             100    150    200  ...    800     3000
# 3W              50     70    100  ...    400     1500
# LMV            200    250    300  ...   1200     5000
# MMV             10     15     20  ...     80      300

print("df2_grouped before melting:")
print(df2_grouped)

# Identify the year columns. Assuming they are numeric years from 2015 to 2024.
# We'll dynamically find them, excluding 'Sum' column.
year_columns = [col for col in df2_grouped.columns if str(col).isdigit() and 2015 <= int(col) <= 2024]
# Ensure year_columns are strings if they were read as ints
year_columns = [str(col) for col in year_columns]


# Reset the index so 'Cat' becomes a regular column
df_plot = df2_grouped.reset_index()

# Melt the DataFrame to long format for Plotly
# 'id_vars' are the columns to keep as identifiers (our 'Cat' column)
# 'value_vars' are the columns to unpivot (our year columns)
# 'var_name' will be the new column name for the unpivoted column names (e.g., '2015', '2016')
# 'value_name' will be the new column name for the values from the unpivoted columns
df_melted = pd.melt(df_plot,
                    id_vars=['Cat'],
                    value_vars=year_columns,
                    var_name='Year',
                    value_name='Sales')

# Convert 'Year' column to integer for proper numerical sorting on the x-axis
df_melted['Year'] = pd.to_numeric(df_melted['Year'])

print("\nMelted DataFrame for plotting:")
print(df_melted.head())

# Create the line plot
fig = px.line(df_melted,
              x='Year',
              y='Sales',
              color='Cat', # This creates a separate line for each 'Cat'
              title='Sales Trends for Each Vehicle Category Over Years',
              labels={'Sales': 'Total Sales', 'Year': 'Year'},
              hover_data={'Cat': True, 'Sales': ':,.0f'}) # Show formatted sales on hover

# Further customize the plot (optional)
fig.update_traces(mode='lines+markers') # Show points on lines
fig.update_layout(hovermode="x unified") # Unified hover for multiple lines

# Display the plot in the Jupyter Notebook
fig.show()

# Optional: Save the plot to an HTML file
# fig.write_html('category_sales_trends_line_plot.html')
# print("Plot saved as 'category_sales_trends_line_plot.html'")

In [None]:

# Assuming df_melted is already prepared from the previous steps.
# If you are running this cell independently and df_melted is not defined,
# you would need to run the previous steps:
# 1. Load df2: df2 = pd.read_csv('your_path_to_ev_sales_by_makers_and_cat_15-24.csv')
# 2. Drop 'Maker' and group by 'Cat':
#    df2_no_maker = df2.drop(columns=['Maker'])
#    numeric_cols = df2_no_maker.select_dtypes(include=['number']).columns.tolist()
#    df2_grouped = df2_no_maker.groupby('Cat')[numeric_cols].sum()
# 3. Add 'Sum' column (optional for this plot, but good for completeness):
#    df2_grouped['Sum'] = df2_grouped.sum(axis=1)
# 4. Melt df2_grouped to df_melted:
#    year_columns = [col for col in df2_grouped.columns if str(col).isdigit() and 2015 <= int(col) <= 2024]
#    year_columns = [str(col) for col in year_columns]
#    df_plot = df2_grouped.reset_index()
#    df_melted = pd.melt(df_plot,
#                        id_vars=['Cat'],
#                        value_vars=year_columns,
#                        var_name='Year',
#                        value_name='Sales')
#    df_melted['Year'] = pd.to_numeric(df_melted['Year'])

# List of categories to plot
categories_to_plot = ['2W', '3W', 'LMV']

for category in categories_to_plot:
    # Filter data for the current category
    df_filtered = df_melted[df_melted['Cat'] == category]

    if not df_filtered.empty:
        # Create the line plot for the current category
        fig = px.line(df_filtered,
                      x='Year',
                      y='Sales',
                      title=f'Sales Trend for {category} Over Years',
                      labels={'Sales': 'Total Sales', 'Year': 'Year'},
                      markers=True) # Add markers to show individual year points

        # Customize plot for better readability
        fig.update_layout(hovermode="x unified") # Unified hover for better comparison if desired

        # Display the plot in the Jupyter Notebook
        fig.show()

        # Optional: Save each plot to an HTML file
        # fig.write_html(f'{category}_sales_trend_line_plot.html')
        # print(f"Plot saved for {category} as '{category}_sales_trend_line_plot.html'")
    else:
        print(f"No data found for category: {category}. Skipping plot.")

The graphs show that the increase of demand for Electric 2W is saturating, the growth for 3W is linear and that for LMVs is exponential. Thus LMVs have the highest market rise and greatest scope for entry from startup companies.

In [None]:
df3=pd.read_csv('datasets\\OperationalPC.csv')
df3.head()

In [None]:
import plotly.express as px
import plotly.graph_objects as go
import json
import requests

def create_india_map_plotly(df3):
    """
    Create choropleth map using Plotly Express
    """
    # Clean state names to match standard names
    df_clean = df3.copy()
    df_clean['State'] = df_clean['State'].str.strip()
    
    # Create state mapping for common variations
    state_mapping = {
        'Andaman & Nicobar': 'Andaman and Nicobar Islands',
        'Dadra & Nagar Haveli': 'Dadra and Nagar Haveli',
        'Jammu & Kashmir': 'Jammu and Kashmir',
        'NCT of Delhi': 'Delhi',
        'Telangana': 'Telangana'  # Sometimes needed for consistency
    }
    
    # Apply mapping
    df_clean['State_Standard'] = df_clean['State'].map(state_mapping).fillna(df_clean['State'])
    
    # Download India geojson (you can also use a local file)
    url = "https://raw.githubusercontent.com/geohacker/india/master/state/india_state.geojson"
    
    try:
        # Try to get geojson data
        response = requests.get(url)
        india_geojson = response.json()
        
        # Create the choropleth map
        fig = px.choropleth(
            df_clean,
            geojson=india_geojson,
            locations='State_Standard',
            color='No. of Operational PCS',
            hover_name='State',
            hover_data={'No. of Operational PCS': True, 'State_Standard': False},
            featureidkey="properties.NAME_1",
            color_continuous_scale='Viridis',
            title='Number of Operational PCS by State in India'
        )
        
        # Update layout for better visualization
        fig.update_geos(
            fitbounds="locations",
            visible=False
        )
        
        fig.update_layout(
            title_x=0.5,
            font=dict(size=12),
            coloraxis_colorbar=dict(
                title="No. of Operational PCS"
            )
        )
        
        fig.show()
        
    except Exception as e:
        print(f"Error with online geojson: {e}")
        print("Trying alternative method...")
        create_india_map_alternative(df_clean)


def create_bar_chart(df3):
    """
    Create a horizontal bar chart as an alternative visualization
    """
    import matplotlib.pyplot as plt
    
    # Sort data for better visualization
    df_sorted = df3.sort_values('No. of Operational PCS', ascending=True)
    
    plt.figure(figsize=(12, 10))
    plt.barh(df_sorted['State'], df_sorted['No. of Operational PCS'])
    plt.xlabel('Number of Operational PCS')
    plt.title('Number of Operational PCS by State')
    plt.tight_layout()
    
    # Rotate labels if needed
    plt.xticks(rotation=45)
    plt.show()

if __name__ == "__main__":
    # Try the Plotly method first (recommended)
    print("Creating India map visualization...")
    create_india_map_plotly(df3)
    create_bar_chart(df3)  # Simple alternative

Thus, we see that Maharashtra, Delhi and the Southern State - Karnataka, Kerala and Tamil Nadu have the highest number of charging stations, and hence the largest target geographical zones for EV market.

In [None]:
df4=pd.read_csv('datasets\\EV_Dataset.csv')
print(df4.head())

In [None]:
import xarray as xr

# Create a mapping for month names to numbers
month_mapping = {
    'jan': '01', 'feb': '02', 'mar': '03', 'apr': '04',
    'may': '05', 'jun': '06', 'jul': '07', 'aug': '08',
    'sep': '09', 'oct': '10', 'nov': '11', 'dec': '12'
}

# Create mm-yy column in proper format
df4['mm-yy'] = df4['Year'].astype(int).astype(str) + '-' + df4['Month_Name'].str.lower().map(month_mapping) 

# Group by the three dimensions and sum EV_Sales_Quantity
grouped = df4.groupby(['mm-yy', 'State', 'Vehicle_Class'])['EV_Sales_Quantity'].sum().reset_index()

# Convert to xarray DataArray
data_array = grouped.set_index(['mm-yy', 'State', 'Vehicle_Class'])['EV_Sales_Quantity'].to_xarray()

print(f"Shape: {data_array.shape}")
print(f"Dimensions: {data_array.dims}")
print(f"Time periods: {sorted(data_array.coords['mm-yy'].values)}")

In [None]:
import pandas as pd
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns
import os
from matplotlib.backends.backend_pdf import PdfPages
import warnings
warnings.filterwarnings('ignore')

import re

def sanitize_filename(name):
    # Replace & with 'and', spaces with '_', then remove invalid characters
    name = name.replace('&', 'and').replace(' ', '_')
    name = re.sub(r'[\\/:*?"<>|]', '', name)
    return name

# Set style for better-looking plots
plt.style.use('default')
sns.set_palette("husl")

# Create directory structure for saving plots
def create_directories():
    """Create organized directory structure for plots"""
    directories = [
        'plots/vehicle_types_over_time',
        'plots/state_pie_charts',
        'plots/states_over_time',
        'plots/state_vehicle_combinations'
    ]
    
    for directory in directories:
        os.makedirs(directory, exist_ok=True)
    
    print("Created directory structure:")
    for directory in directories:
        print(f"  - {directory}")

# Create directories
create_directories()

# Convert back to DataFrame for easier plotting
df_plot = data_array.to_dataframe().reset_index()
df_plot['mm-yy'] = pd.to_datetime(df_plot['mm-yy'], format='%Y-%m')


print("Data prepared. Starting plot generation...")

# ============================================================================
# 1. LINE GRAPHS FOR EVERY VEHICLE TYPE OVER TIME (TOTALLED OVER ALL STATES)
# ============================================================================

print("1. Creating line graphs for each vehicle type over time...")

# Aggregate over all states for each vehicle type and time
vehicle_time_data = df_plot.groupby(['mm-yy', 'Vehicle_Class'])['EV_Sales_Quantity'].sum().reset_index()

# Get unique vehicle classes
vehicle_classes = sorted(vehicle_time_data['Vehicle_Class'].unique())

# Create separate plots for each vehicle type
for i, vehicle in enumerate(vehicle_classes):
    vehicle_data = vehicle_time_data[vehicle_time_data['Vehicle_Class'] == vehicle]
    
    plt.figure(figsize=(12, 6))
    plt.plot(vehicle_data['mm-yy'], vehicle_data['EV_Sales_Quantity'], 
             marker='o', linewidth=2, markersize=4)
    plt.title(f'{vehicle} - Sales Over Time (All States)', fontsize=14, fontweight='bold')
    plt.xlabel('Time Period', fontsize=12)
    plt.ylabel('Total Sales', fontsize=12)
    xtick_locs = vehicle_data['mm-yy'].iloc[::12]
    plt.xticks(ticks=xtick_locs, labels=xtick_locs.dt.strftime('%Y-%m'), rotation=45)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    # Save as PNG
    safe_filename = sanitize_filename(vehicle)
    filename = f"plots/vehicle_types_over_time/{i+1:02d}_{safe_filename}.png"
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"   Saved: {filename}")

# ============================================================================
# 2. PIE CHARTS FOR EVERY STATE FOR VEHICLE TYPE (TOTALLED OVER ENTIRE TIME)
# ============================================================================

print("2. Creating pie charts for each state by vehicle type...")

# Aggregate over all time periods for each state and vehicle type
state_vehicle_data = df_plot.groupby(['State', 'Vehicle_Class'])['EV_Sales_Quantity'].sum().reset_index()

# Get unique states
states = sorted(state_vehicle_data['State'].unique())

# Create separate pie charts for each state
for i, state in enumerate(states):
    state_data = state_vehicle_data[state_vehicle_data['State'] == state]
    # Filter out zero values for cleaner pie charts
    state_data = state_data[state_data['EV_Sales_Quantity'] > 0]
    
    if len(state_data) > 0:  # Only create chart if there's data
        plt.figure(figsize=(10, 8))
        colors = sns.color_palette("husl", len(state_data))
        
        wedges, texts, autotexts = plt.pie(state_data['EV_Sales_Quantity'], 
                                          labels=state_data['Vehicle_Class'],
                                          autopct='%1.1f%%',
                                          colors=colors,
                                          startangle=90)
        
        plt.title(f'{state} - Vehicle Type Distribution (Total Sales)', 
                 fontsize=14, fontweight='bold')
        
        # Improve text readability
        for autotext in autotexts:
            autotext.set_color('white')
            autotext.set_fontweight('bold')
        
        plt.tight_layout()
        
        # Save as PNG
        safe_filename = sanitize_filename(state)
        filename = f"plots/state_pie_charts/{i+1:02d}_{safe_filename}.png"
        plt.savefig(filename, dpi=300, bbox_inches='tight')
        plt.close()
        
        print(f"   Saved: {filename}")

# ============================================================================
# 3. LINE GRAPHS FOR EVERY STATE OVER TIME (TOTALLED OVER ALL VEHICLE TYPES)
# ============================================================================

print("3. Creating line graphs for each state over time...")

# Aggregate over all vehicle types for each state and time
state_time_data = df_plot.groupby(['mm-yy', 'State'])['EV_Sales_Quantity'].sum().reset_index()

# Create separate plots for each state
for i, state in enumerate(states):
    state_data = state_time_data[state_time_data['State'] == state]
    
    plt.figure(figsize=(12, 6))
    plt.plot(state_data['mm-yy'], state_data['EV_Sales_Quantity'], 
             marker='o', linewidth=2, markersize=4, color='steelblue')
    plt.title(f'{state} - Total EV Sales Over Time', fontsize=14, fontweight='bold')
    plt.xlabel('Time Period', fontsize=12)
    plt.ylabel('Total Sales', fontsize=12)
    xtick_locs = vehicle_data['mm-yy'].iloc[::12]
    plt.xticks(ticks=xtick_locs, labels=xtick_locs.dt.strftime('%Y-%m'), rotation=45)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    # Save as PNG
    safe_filename = sanitize_filename(state)
    filename = f"plots/states_over_time/{i+1:02d}_{safe_filename}.png"
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"   Saved: {filename}")

# ============================================================================
# 4. LINE GRAPHS FOR EVERY STATE FOR EVERY VEHICLE TYPE OVER TIME
# ============================================================================

print("4. Creating line graphs for each state-vehicle combination over time...")

# This will create a lot of plots, so we'll create subplots for each state
for i, state in enumerate(states):
    state_data = df_plot[df_plot['State'] == state]
    state_vehicles = sorted(state_data['Vehicle_Class'].unique())
    
    # Create subplots for each vehicle type in this state
    n_vehicles = len(state_vehicles)
    cols = min(3, n_vehicles)  # Max 3 columns
    rows = (n_vehicles + cols - 1) // cols  # Calculate needed rows
    
    fig, axes = plt.subplots(rows, cols, figsize=(15, 5*rows))
    fig.suptitle(f'{state} - All Vehicle Types Over Time', fontsize=16, fontweight='bold')
    
    # Handle single subplot case
    if n_vehicles == 1:
        axes = [axes]
    elif rows == 1:
        axes = axes if cols > 1 else [axes]
    else:
        axes = axes.flatten()
    
    for j, vehicle in enumerate(state_vehicles):
        vehicle_data = state_data[state_data['Vehicle_Class'] == vehicle]
        
        ax = axes[j]
        ax.plot(vehicle_data['mm-yy'], vehicle_data['EV_Sales_Quantity'],
                marker='o', linewidth=2, markersize=3)
        ax.set_title(vehicle, fontsize=10, fontweight='bold')
        xtick_locs = vehicle_data['mm-yy'].iloc[::12]
        plt.xticks(ticks=xtick_locs, labels=xtick_locs.dt.strftime('%Y-%m'), rotation=45)
        ax.tick_params(axis='y', labelsize=8)
        ax.grid(True, alpha=0.3)
    
    # Hide empty subplots
    for j in range(n_vehicles, len(axes)):
        axes[j].set_visible(False)
    
    plt.tight_layout()
    
    # Save as PNG
    safe_filename = sanitize_filename(state)
    filename = f"plots/state_vehicle_combinations/{i+1:02d}_{safe_filename}_all_vehicles.png"
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"   Saved: {filename}")

# ============================================================================
# 5. CREATE SUMMARY DATAFRAME: STATES × VEHICLE TYPES (TOTAL OVER TIME)
# ============================================================================

print("5. Creating summary dataframe and heatmap...")

# Aggregate over all time periods
summary_data = df_plot.groupby(['State', 'Vehicle_Class'])['EV_Sales_Quantity'].sum().reset_index()

# Pivot to create states as rows and vehicle types as columns
summary_df = summary_data.pivot(index='State', 
                               columns='Vehicle_Class', 
                               values='EV_Sales_Quantity').fillna(0)

print("Summary DataFrame - States × Vehicle Types (Total Sales Over Entire Time Period):")
print(summary_df.head())

# Save the summary dataframe
summary_df.to_csv('ev_sales_summary_states_x_vehicles.csv')
print("   Saved: ev_sales_summary_states_x_vehicles.csv")

# Create a heatmap of the summary data
plt.figure(figsize=(16, 10))
sns.heatmap(summary_df.T, annot=False, cmap='YlOrRd', cbar_kws={'label': 'Total Sales'})
plt.title("EV Sales Heatmap: Vehicle Types × States (Total Over Time)", 
          fontsize=16, fontweight='bold')
plt.xlabel('State', fontsize=12)
plt.ylabel('Vehicle Class', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()

# Save heatmap
plt.savefig("plots/summary_heatmap.png", dpi=300, bbox_inches='tight')
plt.close()
print("   Saved: plots/summary_heatmap.png")

# Display some statistics
print(f"\nSUMMARY STATISTICS:")
print(f"DataFrame Shape: {summary_df.shape}")
print(f"Total Sales Across All States and Vehicle Types: {summary_df.sum().sum():,.0f}")
print(f"\nTop 5 States by Total Sales:")
top_states = summary_df.sum(axis=1).sort_values(ascending=False).head()
for state, sales in top_states.items():
    print(f"  {state}: {sales:,.0f}")

print(f"\nTop 5 Vehicle Types by Total Sales:")
top_vehicles = summary_df.sum(axis=0).sort_values(ascending=False).head()
for vehicle, sales in top_vehicles.items():
    print(f"  {vehicle}: {sales:,.0f}")

# Create summary statistics plots
print("6. Creating summary statistics plots...")

# Top states bar chart
plt.figure(figsize=(10, 6))
bars = plt.bar(range(len(top_states)), top_states.values, color='steelblue')
plt.title("Top 5 States by Total EV Sales", fontsize=14, fontweight='bold')
plt.xlabel('State', fontsize=12)
plt.ylabel('Total Sales', fontsize=12)
plt.xticks(range(len(top_states)), top_states.index, rotation=45, ha='right')

# Add value labels on bars
for bar, value in zip(bars, top_states.values):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(top_states.values)*0.01,
             f'{value:,.0f}', ha='center', va='bottom', fontsize=10)

plt.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.savefig("plots/top_states_bar_chart.png", dpi=300, bbox_inches='tight')
plt.close()
print("   Saved: plots/top_states_bar_chart.png")

# Top vehicle types bar chart
plt.figure(figsize=(12, 6))
bars = plt.bar(range(len(top_vehicles)), top_vehicles.values, color='darkorange')
plt.title("Top 5 Vehicle Types by Total EV Sales", fontsize=14, fontweight='bold')
plt.xlabel('Vehicle Type', fontsize=12)
plt.ylabel('Total Sales', fontsize=12)
plt.xticks(range(len(top_vehicles)), top_vehicles.index, rotation=45, ha='right')

# Add value labels on bars
for bar, value in zip(bars, top_vehicles.values):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(top_vehicles.values)*0.01,
             f'{value:,.0f}', ha='center', va='bottom', fontsize=10)

plt.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.savefig("plots/top_vehicles_bar_chart.png", dpi=300, bbox_inches='tight')
plt.close()
print("   Saved: plots/top_vehicles_bar_chart.png")

print(f"\n✅ ALL PLOTS SAVED SUCCESSFULLY!")
print(f"📁 Check the 'plots/' directory for all generated visualizations")

# Count total plots
total_plots = (len(vehicle_classes) + 
              len([s for s in states if len(state_vehicle_data[state_vehicle_data['State'] == s][state_vehicle_data['EV_Sales_Quantity'] > 0]) > 0]) + 
              len(states) + 
              len(states) + 4)

print(f"📊 Total plots created: {total_plots}")

# Print directory structure
print(f"\n📂 DIRECTORY STRUCTURE:")
print(f"plots/")
print(f"├── vehicle_types_over_time/     ({len(vehicle_classes)} plots)")
print(f"├── state_pie_charts/            (pie charts for states with data)")
print(f"├── states_over_time/            ({len(states)} plots)")
print(f"├── state_vehicle_combinations/  ({len(states)} plots)")
print(f"├── summary_heatmap.png")
print(f"├── top_states_bar_chart.png")
print(f"└── top_vehicles_bar_chart.png")

# Display final summary dataframe
print(f"\n📋 FINAL SUMMARY DATAFRAME:")
print(summary_df)

In [None]:
# Display some statistics
print(f"\nSUMMARY STATISTICS:")
print(f"DataFrame Shape: {summary_df.shape}")
print(f"Total Sales Across All States and Vehicle Types: {summary_df.sum().sum():,.0f}")
print(f"\nTop 5 States by Total Sales:")
top_states = summary_df.sum(axis=1).sort_values(ascending=False)[:10]
for state, sales in top_states.items():
    print(f"  {state}: {sales:,.0f}")

print(f"\nTop 5 Vehicle Types by Total Sales:")
top_vehicles = summary_df.sum(axis=0).sort_values(ascending=False).head()
for vehicle, sales in top_vehicles.items():
    print(f"  {vehicle}: {sales:,.0f}")

# Create summary statistics plots
print("6. Creating summary statistics plots...")

# Top states bar chart
plt.figure(figsize=(10, 6))
bars = plt.bar(range(len(top_states)), top_states.values, color='steelblue')
plt.title("Top 5 States by Total EV Sales", fontsize=14, fontweight='bold')
plt.xlabel('State', fontsize=12)
plt.ylabel('Total Sales', fontsize=12)
plt.xticks(range(len(top_states)), top_states.index, rotation=45, ha='right')

# Add value labels on bars
for bar, value in zip(bars, top_states.values):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(top_states.values)*0.01,
             f'{value:,.0f}', ha='center', va='bottom', fontsize=10)

plt.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.show()