### Load in the Data Source

In [None]:
SELECT *
FROM public.grocery_sales;

### Importing other libraries

In [None]:
import pandas as pd
import os
import matplotlib.pyplot as plt

# Start here...

def extract(df1, df2_path):
    df2 = pd.read_parquet(df2_path)
    full_df = pd.merge(df1, df2, how = "left", on = "index")
    return full_df

merged_df = extract(grocery_sales, "extra_data.parquet")

In [None]:
def transform(merged_df):
    merged_df.fillna(
      {
          'CPI': merged_df['CPI'].mean(),
          'Weekly_Sales': merged_df['Weekly_Sales'].mean(),
          'Unemployment': merged_df['Unemployment'].mean(),
      }, inplace = True
    )
    clean_df = merged_df[merged_df['Weekly_Sales'] > 10000]
    clean_df['Month'] = clean_df['Date'].dt.month
    return clean_df.filter(items = ['Store_ID', "Month", "Dept", "IsHoliday", 
                                   "Weekly_Sales", "CPI", "Unemployment"])
clean_data = transform(merged_df)
print(clean_data.head(10))

In [None]:
def avg_monthly_sales(clean_data):
    return (clean_data.groupby("Month")
    .agg(Avg_Sales = ("Weekly_Sales", "mean"))
    .reset_index().round(2))
agg_data = avg_monthly_sales(clean_data)

In [None]:
def load(agg_data, clean_data, agg_save_path, clean_save_path):
    agg_data.to_csv(agg_save_path, index = False)
    clean_data.to_csv(clean_save_path, index = False)
load(agg_data, clean_data, "agg_data.csv", "clean_data.csv")

In [None]:
def validation(data_path):
    if os.path.exists(data_path):
        return("Agg data file is in the home directory")
    else:
        return("Agg data does not exist")
validation("agg_data.csv")
validation("clean_data.csv")

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(agg_data['Month'], agg_data['Avg_Sales'], marker='o', linestyle='-', color='#87CEEB')
plt.axvline(x=9, color='red', linestyle='--', linewidth=1)
plt.text(9.2, 38000, 'New Policy', rotation=0, color='black', fontsize=12)
# Add titles and labels
plt.title('Average Sales by Month')
plt.xlabel('Month')
plt.ylabel('Average Sales')

# Rotate x-axis labels for better readability
plt.xticks(rotation=45)

# Display the chart
plt.grid(False)
plt.show()

### Substantial growth in average sales after the introduction of a new policy and the advent of holidays

In [None]:
dept_data = (clean_data.
             groupby("Dept").
             agg(Avg_Sales = ("Weekly_Sales", "mean")).
             reset_index().
             round(2))
top_10 = dept_data.sort_values(ascending = False, by = "Avg_Sales").head(10)
bottom_10 = dept_data.sort_values(ascending = True, by = "Avg_Sales").head(10)

### Plotting top ten departments

In [None]:
# Plotting the bar plot
plt.figure(figsize=(8, 6))
top_10.plot(kind='bar', color='#87CEEB')



# Add titles and labels
plt.title('Top 10 Performing Departments by Total Sales')
plt.xlabel('Department')
plt.ylabel('Total Sales')

# Rotate x-axis labels for better readability
plt.xticks(rotation=45)

# Display the plot
plt.show()

![alt text](../../../images/ETL_graph2.png)

### Plotting Bottom 10 departments

In [None]:
# Plotting the bar plot
plt.figure(figsize=(8, 6))
top_10.plot(kind='bar', color='#87CEEB')



# Add titles and labels
plt.title('Top 10 Performing Departments by Total Sales')
plt.xlabel('Department')
plt.ylabel('Total Sales')

# Rotate x-axis labels for better readability
plt.xticks(rotation=45)

# Display the plot
plt.show()

![alt text](../../../images/ETL_graph3.png)