Skip to content

Commit

Permalink
#8 update rick.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhangyixue1537 committed Jun 10, 2024
1 parent a7bdfa7 commit c2cecc5
Showing 1 changed file with 135 additions and 85 deletions.
220 changes: 135 additions & 85 deletions rick.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
"""
Version 0.8.0
Version 0.9.0
"""
from psycopg2 import connect
import psycopg2.sql as pg
Expand All @@ -19,45 +19,50 @@
import pandas as pd
import copy
import datetime
import importlib

class font:
"""
Class defining the global font variables for all functions.
"""

leg_font = font_manager.FontProperties(family='DejaVu Sans',size=9)
normal = 'DejaVu Sans'
semibold = 'DejaVu Sans SemiBold'
leg_font = font_manager.FontProperties(family='Libre Franklin',size=9)
normal = 'Libre Franklin'
semibold = 'Libre Franklin SemiBold'


class colour():
"""
Class defining the global colour variables for all functions.
"""

purple = '#660159'
grey = '#7f7e7e'
orange = '#d95f02'
green = '#0D9F73'
green = '#0D9F73'
blue = '#253494'
light_grey = '#777777'
cmap = 'YlOrRd'
teal = '#23a87f'
blue_grey = '#1b5872'

# Purple shades
purple_0 = '#440436'
purple_1 = '#550347'
purple_2 = '#660159'
purple_3 = '#9c7b94'
purple_4 = '#c0abbb'

colours_map = {
1: purple,
2: grey,
3: orange,
4: blue,
5: green,
6: light_grey
1: purple_1,
2: purple_2,
3: purple_3,
4: light_grey
}

def get_colour_from_index(self, index):
return self.colours_map[index]


class geo:
"""
Class for additional gis layers needed for the cloropleth map.
Expand All @@ -84,7 +89,6 @@ def ttc(con):
'''
ttc = gpd.GeoDataFrame.from_postgis(query, con, geom_col='geom')
# ttc = ttc.to_crs({'init' :'epsg:3857'})
ttc = ttc.to_crs(epsg=3857)

# Below can be replaced by an apply lambda
Expand Down Expand Up @@ -116,7 +120,7 @@ def island(con):
SELECT
geom
FROM tts.zones_tts06
FROM gis.zones_tts06
WHERE gta06 = 81
'''
Expand All @@ -134,27 +138,26 @@ def island(con):

return island


class charts:
"""
Class defining all the charting functions.
"""

global func
def func():

"""Function to set global settings for the charts class.
"""

sns.set(font_scale=1.5)
mpl.rc('font',family='DejaVu Sans')

mpl.rc('font',family='Libre Franklin')
def chloro_map(con, df, lower, upper, title, **kwargs):
"""Creates a chloropleth map
Parameters
Parameters
-----------
con : SQL connection object
Connection object needed to connect to the RDS
Expand Down Expand Up @@ -196,9 +199,9 @@ def chloro_map(con, df, lower, upper, title, **kwargs):
df.columns = ['geom', 'values']
light = '#d9d9d9'

fig, ax = plt.subplots()
fig, ax = plt.subplots(dpi=450.0, figsize=(12,12))
fig.set_size_inches(6.69,3.345)

ax.set_yticklabels([])
ax.set_xticklabels([])
ax.set_axis_off()
Expand Down Expand Up @@ -254,7 +257,54 @@ def chloro_map(con, df, lower, upper, title, **kwargs):


return fig, ax


def histogram_chart(data, ylab, xlab, nbin, **kwargs):
"""Creates a histogram chart with specified nbin (nbin data)
Parameters
-----------
data : array like or scalar
Data for the line chart.
ylab : str
Label for the y axis.
xlab : str
Label for the x axis.
xmax : int, optional, default is the max x value
The max value of the x axis
xmin : int, optional, default is the min x value
The minimum value of the x axis
nbin : int, optional, default is none
The number of bins for the two dimensions
Returns
--------
fig
Matplotlib fig object
ax
Matplotlib ax object
"""

func()
xmax = kwargs.get('xmax', None)
xmin = kwargs.get('xmin', 0)

if (xmax is None):
xmax = int(max(data))

if (nbin is None):
nbin = 10

plt.style.use('seaborn-whitegrid')
fig, ax = plt.subplots(1,1, dpi=450.0)
fig.set_size_inches(6.1, 4.2)
ax.hist(data, bins=nbin, alpha=1.0, color=colour.purple)
ax.set_xlabel(xlab)
ax.set_ylabel(ylab)
ax.set_xlim(xmin, xmax)
ax.get_legend().remove()

return fig, ax

def line_chart(data, ylab, xlab, **kwargs):
"""Creates a line chart. x axis must be modified manually
Expand Down Expand Up @@ -355,6 +405,7 @@ def tow_chart(data, ylab, **kwargs):
Dictionary of the text annotation properties
"""

func()
ymax = kwargs.get('ymax', None)
ymin = kwargs.get('ymin', 0)
Expand All @@ -379,14 +430,15 @@ def tow_chart(data, ylab, **kwargs):
else:
upper = int(3*yinc+ymin)

fig, ax =plt.subplots()
fig, ax = plt.subplots(dpi=450.0)
ax.plot(data, linewidth = 2.5, color = colour.purple)

plt.grid()
ax.set_xlim(0, 168)
ax.set_facecolor('xkcd:white')

plt.xlabel('Time of week', fontname = font.normal, fontsize=9, horizontalalignment='left', x=0, labelpad=3, fontweight = 'bold')
ax.set_ylim([ymin,upper])
ax.set_ylim([ymin, upper])

ax.grid(color='k', linestyle='-', linewidth=0.2)
plt.ylabel(ylab, fontname = font.normal, fontsize=9, horizontalalignment='right', y=1, labelpad=7, fontweight = 'bold')
Expand All @@ -396,7 +448,6 @@ def tow_chart(data, ylab, **kwargs):
ax.yaxis.set_major_formatter(mpl.ticker.StrMethodFormatter('{x:,.0f}'))
plt.yticks(range(ymin,upper+int(0.1*yinc), yinc), fontsize =9, fontname = font.normal)

ax.set_xticks(range(0,180,12))
ax.set_xticklabels(['0','12','0','12',
'0','12','0','12',
'0','12','0','12','0','12'], fontname = font.normal, fontsize = 7, color = colour.light_grey)
Expand Down Expand Up @@ -499,7 +550,7 @@ def stacked_chart(data_in, xlab, lab1, lab2, **kwargs):
data[['values1', 'values2']] = data[['values1', 'values2']].astype(int)
for i in data['values2']:
if i < 0.1*upper:
ax.annotate(str(format(round(i,precision), ',')), xy=(i+0.015*upper, j-0.05), ha = 'left', color = 'k', fontname = font.normal, fontsize=10)
ax.annotate(str(format(round(i,precision), ',')), xy=(i-0.015*upper, j-0.05), ha = 'right', color = 'w', fontname = font.normal, fontsize=10)
else:
ax.annotate(str(format(round(i,precision), ',')), xy=(i-0.015*upper, j-0.05), ha = 'right', color = 'w', fontname = font.normal, fontsize=10)
j=j+1
Expand All @@ -521,34 +572,34 @@ def stacked_chart(data_in, xlab, lab1, lab2, **kwargs):
data_yoy['percent'] = (data['values2']-data['values1'])*100/data['values1']
j=0.15
for index, row in data_yoy.iterrows():
ax.annotate(('+' if row['percent'] > 0 else '')+str(format(int(round(row['percent'],0)), ','))+'%',
xy=(max(row[['values1', 'values2']]) + (0.12 if row['values2'] < 0.1*upper else 0.03)*upper, j), color = 'k', fontname = font.normal, fontsize=10)
ax.annotate('+'+str(format(int(round(row['percent'],0)), ','))+'%',
xy=(max(row[['values1', 'values2']]) + 0.03*upper, j), color = 'k', fontname = font.normal, fontsize=10)
j=j+1


return fig, ax

def stacked_chart_quad(data_in, xlab, lab1, lab2, lab3, lab4, **kwargs):
"""Creates a stacked bar chart comparing 4 sets of data
def multi_stacked_bar_chart(data, xlab, lab1, lab2, lab3, **kwargs):
"""Creates a stacked bar chart
Parameters
-----------
data : dataframe
Data for the stacked bar chart. The dataframe must have 5 columns, the first representing the y ticks, the second representing the baseline, and the third representing the next series of data.
data : array like or scalar
Data for the stacked bar chart.
xlab : str
Label for the x axis.
lab1 : str
Label in the legend for the baseline
lab2 : str
Label in the legend fot the next data series
lab3 : str
Label in the legend fot the next data series
xmax : int, optional, default is the max s value
The max value of the y axis
xmin : int, optional, default is 0
The minimum value of the x axis
precision : int, optional, default is -1
precision : int, optional, default is 1
Decimal places in the annotations
percent : boolean, optional, default is False
Whether the annotations should be formatted as percentages
xinc : int, optional
The increment of ticks on the x axis.
Expand All @@ -563,18 +614,18 @@ def stacked_chart_quad(data_in, xlab, lab1, lab2, lab3, lab4, **kwargs):
"""

func()
data = data_in.copy(deep=True)
data = data.copy(deep=True)

data.columns = ['name', 'values1', 'values2', 'values3', 'values4']
data.columns = ['name', 'values1', 'values2', 'values3']

xmin = kwargs.get('xmin', 0)
xmax = kwargs.get('xmax', None)
precision = kwargs.get('precision', -1)
precision = kwargs.get('precision', 1)
percent = kwargs.get('percent', False)

xmax_flag = True
if xmax == None:
xmax = int(max(data[['values1', 'values2', 'values3', 'values4']].max()))
xmax = int(max(data[['values1', 'values2', 'values3']].max()))
xmax_flag = False

delta = (xmax - xmin)/4
Expand All @@ -594,57 +645,56 @@ def stacked_chart_quad(data_in, xlab, lab1, lab2, lab3, lab4, **kwargs):
ind = np.arange(len(data))

fig, ax = plt.subplots()
fig.set_size_inches(6.1, len(data)*1.5)
fig.set_size_inches(6.1, len(data))
ax.grid(color='k', linestyle='-', linewidth=0.25)

p1 = ax.barh(ind+0.6, data['values1'], 0.2, align='center', color = colour.green)
p2 = ax.barh(ind+0.4, data['values2'], 0.2, align='center', color = colour.blue)
p3 = ax.barh(ind+0.2, data['values3'], 0.2, align='center', color = colour.grey)
p4 = ax.barh(ind, data['values4'], 0.2, align='center', color=colour.purple)
p1 = ax.barh(ind, data['values1'], 0.4, align='center', color = colour.grey)
p2 = ax.barh(ind, data['values2'], 0.4, align='center', color = colour.purple, left = data['values1'])
p3 = ax.barh(ind, data['values3'], 0.4, align='center', color = colour.teal, left = (data['values1'] + data['values2']))
ax.xaxis.set_major_formatter(mpl.ticker.StrMethodFormatter('{x:,.0f}'))

ax.xaxis.grid(True)
ax.yaxis.grid(False)
ax.set_yticks(ind+0.6/2)
ax.set_yticks(ind)
ax.set_xlim(0,upper)
ax.set_yticklabels(data['name'])
ax.set_xlabel(xlab, horizontalalignment='left', x=0, labelpad=10, fontname = font.normal, fontsize=10, fontweight = 'bold')

ax.set_facecolor('xkcd:white')


if precision < 1:
data[['values1', 'values2', 'values3', 'values4']] = data[['values1', 'values2', 'values3', 'values4']].astype(int)

j = 0.0
for k in range(4,0,-1):

for i in data[f'values{k}']:
if i < 0.1*upper:
ax.annotate(str(format(round(i,precision), ',')), xy=(i+0.015*upper, j-0.05), ha = 'left', color = 'k', fontname = font.normal, fontsize=10)
else:
ax.annotate(str(format(round(i,precision), ',')), xy=(i-0.015*upper, j-0.05), ha = 'right', color = 'w', fontname = font.normal, fontsize=10)
j=j+1
j = j-len(data[f'values{k}']) + 0.2


ax.legend((p1[0], p2[0], p3[0], p4[0]), (lab1, lab2, lab3, lab4), loc=4, frameon=False, prop=font.leg_font)
# if precision < 1: # removed this so it does not round or cast to int prematurely. Also, casting to int truncates the decimal WITHOUT rounding.
# data[['values1', 'values2','values3']] = data[['values1', 'values2','values3']].astype(int)
horiz_nudge = 0.2
for index, i in enumerate(data['values3']):
offset = data['values3'][index]+ data['values2'][index] + data['values1'][index]
# if value is less than 0.5%, do not show data label, if less than 4%, show data label above the bar, else show label on the bar
if i < 0.5:
continue
if i < 4:
ax.annotate(str(format(round(i,precision), ',')), xy=((offset+offset-i)/2+horiz_nudge, index+0.3), ha = 'center', color = 'k', fontname = font.normal, fontsize=10)
else:
ax.annotate(str(format(round(i,precision), ',')), xy=((offset+offset-i)/2+horiz_nudge, index-0.07), ha = 'center', color = 'w', fontname = font.normal, fontsize=10)
for index, i in enumerate(data['values2']):
offset = data['values2'][index] + data['values1'][index]
if i < 0.5:
continue
if i < 4:
ax.annotate(str(format(round(i,precision), ',')), xy=((offset+offset-i)/2+horiz_nudge, index+0.3), ha = 'center', color = 'k', fontname = font.normal, fontsize=10)
else:
ax.annotate(str(format(round(i,precision), ',')), xy=((offset+offset-i)/2+horiz_nudge, index-0.07), ha = 'center', color = 'w', fontname = font.normal, fontsize=10)
for index, i in enumerate(data['values1']):
offset = data['values1'][index]
if i < 0.5:
continue
if i < 4:
ax.annotate(str(format(round(i,precision), ',')), xy=((offset+offset-i)/2+horiz_nudge, index+0.3), ha = 'center', color = 'k', fontname = font.normal, fontsize=10)
else:
ax.annotate(str(format(round(i,precision), ',')), xy=((offset+offset-i)/2+horiz_nudge, index-0.07), ha = 'center', color = 'w', fontname = font.normal, fontsize=10)

ax.legend((p1[0], p2[0], p3[0]), (lab1, lab2, lab3), bbox_to_anchor=(1.05, 1.0), loc='upper left', frameon=False, prop=font.leg_font)
# ax.legend((p1[0], p2[0], p3[0]), (lab1, lab2, lab3), bbox_to_anchor=(0.5, 1.15), loc='upper center', ncol=3, frameon=False, prop=font.leg_font)
# plt.subplots_adjust(bottom=0.2) # Adjust layout to make room for the legend above the plot
plt.xticks(range(xmin,upper+int(0.1*xinc), xinc), fontname = font.normal, fontsize =10)
plt.yticks( fontname = font.normal, fontsize =10)

if percent == True:
j = 0.15
data_yoy = data
for k in range(3,0,-1):
data_yoy[f'percent{k}'] = (data['values4']-data[f'values{k}'])*100/data[f'values{k}']

for index, row in data_yoy.iterrows():
ax.annotate(('+' if row[f'percent{k}'] > 0 else '')+str(format(int(round(row[f'percent{k}'],0)), ','))+'%',
xy=(max(row[['values1', 'values2', 'values3', 'values4']]) + (0.12 if row['values4'] < 0.1*upper else 0.03)*upper, j), color = 'k', fontname = font.normal, fontsize=10)
j+=1
j = j-len(data_yoy) + 0.2


return fig, ax

def horizontal_grouped_bar_chart(data: pd.DataFrame, **kwargs: dict) -> (plt.figure, plt.axes):
Expand Down

0 comments on commit c2cecc5

Please sign in to comment.