In [None]:
#!/usr/bin/env python
 
'''
UTAS_DB_tools.py
A. J. McCulloch, February 2020
'''

####################################################################################################
# Import modules
####################################################################################################

import pandas as pd # Required for dataframe manipulation
import cx_Oracle # Required for connection to database
import numpy as np # Required from numerical operations
import matplotlib.pyplot as plt # Required for plotting
import matplotlib.colors as colors # Required for setting custom colourmap

from getpass import getpass # Required for password input

####################################################################################################
# Define classes
####################################################################################################
class DW_connect:
    
    connected = False # Initialise the connection attribute flag
    
    # Establish connection to the data warehouse
    def __init__(self, username):
        if self.connected == False: # Don't connect if already connected
            self.username = username # Username to connect to data warehouse
            self.password = getpass() # Get password associated with user above

            """
            Make the required Data Source Name (DSN) Transparent Network Substrate (TNS)
            Info comes from an email from Nathalie (FW: DB Client Installs for Data Warehouse Access)
            Originating from Andrew with a document Oracle DB Drivers and Install
            """
            self.dsn_tns = cx_Oracle.makedsn('exa1-scan.its.utas.edu.au', '1521', service_name=r'edwprod_maa')
            # Connect to the database
            self.conn = cx_Oracle.connect(user = self.username, password = self.password, dsn = self.dsn_tns)

            # Verify connection has been made and update connection flag
            try:
                print("Connected to database, Oracle version {}".format(self.conn.version))
                self.connected = True
            except NameError:
                print("Could not connect to database")
        else:
            print("Connection already established")
    
    # Disconnect from the data warehouse
    def disconnect(self):
        self.conn.close()
        print("Connection to database ended")
        self.connected = False

# Class required to store data warehouse tables
class tablist(object):
    pass
####################################################################################################
# Define functions
####################################################################################################
# A function to run SQL query over a particular connection
def runSQL(query, conn):
    # Use pandas to execute SQL
    return pd.read_sql_query(query, con = conn)
     
# Return currently accessible tables
def get_tables(schema='owner'):
    # SQL to retun all tables accessible under schema
    q = """
        SELECT
            table_name, {0}
        FROM
            all_tables
        ORDER BY
            {0}, table_name
        """.format(schema)
    # Run the SQL and return the result
    return runSQL(q, EDW.conn)

# Function to return all available tables
def init_tables():
    tbls = get_tables() # Get available tables
    sch = tbls.OWNER.unique() # Get unique schema
    # Store each schema as an attribute for table object with associated table
    for s in sch:
        setattr(tables, s, tbls[tbls['OWNER']==s].reset_index(drop=True)) # Indexdrop required!
    print('Available schema are '+', '.join('{}'.format(s) for s in sch))

# Function to return an entire table
def get_all(table, row):
    t_name = tables.STUDENTS.TABLE_NAME[row]
    q = 'SELECT * FROM {}.{}'.format(tables.STUDENTS.OWNER[row], t_name)
    print('Retrieving table {}'.format(t_name))
    return runSQL(q, EDW.conn)

# Connecting to the data warehouse
## Example code
### Initialise connection

In [None]:
####################################################################################################
####################################################################################################
# Code starts here
####################################################################################################
####################################################################################################

# Connect to the data warehouse
EDW = DW_connect('ajm32')

### Find availables tables

In [None]:
tables = tablist() # Create tablist object
init_tables()

### Return a table

In [None]:
dfraw = get_all(tables.STUDENTS, 0)
dfraw.head(5)

### Disconnect from the database server

In [None]:
# Disconnect from the data warehouse
EDW.disconnect()

# Manipulating the data

## Functions

In [None]:
def make_cmap(colors, position=None, bit=False, showcmap = False):
    '''
    make_cmap takes a list of tuples which contain RGB values. The RGB
    values may either be in 8-bit [0 to 255] (in which bit must be set to
    True when called) or arithmetic [0 to 1] (default). make_cmap returns
    a cmap with equally spaced colors.
    Arrange your tuples so that the first color is the lowest value for the
    colorbar and the last is the highest.
    position contains values from 0 to 1 to dictate the location of each color.
    '''
    import matplotlib as mpl
    import numpy as np
    bit_rgb = np.linspace(0,1,256)
    if position == None:
        position = np.linspace(0,1,len(colors))
    else:
        if len(position) != len(colors):
            sys.exit("position length must be the same as colors")
        elif position[0] != 0 or position[-1] != 1:
            sys.exit("position must start with 0 and end with 1")
    if bit:
        for i in range(len(colors)):
            colors[i] = (bit_rgb[colors[i][0]],
                         bit_rgb[colors[i][1]],
                         bit_rgb[colors[i][2]])
    cdict = {'red':[], 'green':[], 'blue':[]}
    for pos, color in zip(position, colors):
        cdict['red'].append((pos, color[0], color[0]))
        cdict['green'].append((pos, color[1], color[1]))
        cdict['blue'].append((pos, color[2], color[2]))

    cmap = mpl.colors.LinearSegmentedColormap('my_colormap',cdict,256)
    
    if showcmap == True:
        gradient = np.linspace(0, 1, 256)
        gradient = np.vstack((gradient, gradient))

        f, ax = plt.subplots(figsize=(9,0.5))
        f.subplots_adjust(top=0.95, bottom=0.01, left=0.2, right=0.99)
        ax.imshow(gradient, aspect='auto', cmap=my_cmap)
        ax.set_axis_off()
    
    return cmap

# Function for true/false testing on columns
def tf_test(frame, test_column, new_column, test, inverse = False, inverse_column = None):
    frame[new_column] = frame[test_column].apply(lambda x: True if x in test else False)
    if inverse == True:
        frame[inverse_column] = frame[new_column].apply(lambda x: True if x == False else False)  

# Function to make a column to flag domestic students
def make_domestic(frame, make_international = False, test_column = 'State'):
    # If original data is indexed at state level
    if test_column == 'State':
        domestic = ['ACT', 'NSW', 'NT', 'QLD', 'SA', 'TAS', 'VIC', 'WA']
        frame['Domestic'] = frame[test_column].apply(lambda x: True if x in domestic else False)
    # If original data is sorted at country level
    elif test_colum == 'Country':
        frame['Domestic'] = frame[test_column].apply(lambda x: True if x in 'Australia (includes External Territories)' else False)
        
    if make_international == True:
        frame['International'] = frame['Domestic'].apply(lambda x: True if x == False else False)
    
def barplot(df, numeric = True, slc = 1):
    
    f, ax = plt.subplots(figsize=(15,7)) # Make a subplot to place the axis 
    
    obj = df.index # Extract the objects
    y_vals = df.iloc[:,0] # Make a list of y values
    
    # If the objects are not numeric, don't try to plot them
    if numeric == False:
        x_pos = np.arange(len(obj))
        
    else:
        x_pos = obj
        
        if True:
            year = 2020
            cls = []
            for y in df.index:
                if y <= year:
                    cls.append(utas_cmap(0))
                elif y > year:
                    cls.append(utas_cmap(.6))
    
    plt.bar(x_pos, y_vals, color = cls)
    plt.xticks(x_pos[::slc], obj[::slc])
    ax.tick_params(axis='both', which='major', labelsize=14)
    plt.xlabel(obj.name, fontsize=20)
    plt.ylabel(y_vals.name, fontsize=20)  

In [None]:
"""
Create a custom colour map for the UTAS style
Colours come from template documentation
"""
colours = [(226,0,1), (192,0,0), (96,96,96), (155,155,155), (202,202,202)] # Create a list of RGB tuples
utas_cmap = make_cmap(colours, bit=True) # Call the function make_cmap which returns your colormap

### Example usage

In [None]:
course_codes = ['73M', '73N', '73U', 'K3T', 'P3K', 'S4A', 'S4X', 'S9A', 'Z1C', 'Z1J', 'Z2A', 'Z2J'] # Courses supplied for analysis
df = dfraw[dfraw['COURSE_CODE'].isin(course_codes)] # Select only the relevant courses

df = df[['ENROL_YEAR', 'STUDENT_ID', 'ORIG_COUNTRY_STATE_DESC']] # Select only the relevant columns
df.columns = ['Year', 'ID', 'State'] # Rename the columns
df = df.sort_values(by=['Year', 'State']) # Sort the data by year then state

make_domestic(df, True) # Add a dometic flag
df['Tasmanian'] = df['State'].apply(lambda x: True if x == 'TAS' else False) # Add a Tasmanian flag
df.loc[(df['Domestic'] == True) & (df['Tasmanian'] == False), 'Mainland'] = True # Add a mainland flag
df['Mainland'] = df['Mainland'].fillna(False) # Replace NaN with False

In [None]:
toplot = df.groupby('Year')[['Tasmanian', 'Mainland']].sum()
toplot.plot.bar(rot=0, colormap = utas_cmap, stacked = True)

In [None]:
df_country = dfraw[dfraw['COURSE_CODE'].isin(course_codes)]
df_country = df_country[['COURSE_LEVEL_BROAD_DESC', 'COUNTRY_OF_CITIZENSHIP_DESC']]
df2 = df.merge(df_country, left_index=True, right_index=True)
df2 = df2.rename(columns={'COURSE_LEVEL_BROAD_DESC':'Course level', 'COUNTRY_OF_CITIZENSHIP_DESC':'Country'})
df2['International'] = df2['International'].apply(lambda x: 'International' if x else 'Domestic')
#df2 = df2.set_index('International')

In [None]:
df2.pivot(columns='Course level', values = 'Year')

In [None]:
df_dom = df[df['Domestic']]
df_int = df[~df['Domestic']]