# Presto Jupyter Notebook Extensions

Version: 2024-04-17

This code is imported as a Jupyter notebook extension in any notebooks you create with Presto code in it. Place the following line of code in any notebook that you want to use these commands with:
<pre>
&#37;run presto.ipynb
</pre>

Note that this is a very simplistic example of using Magic commands for Presto. This is not an official release of watsonx.data Presto extensions, and is only used to illustrate the Presto SQL command syntax.

In [None]:
#
# Set up Jupyter MAGIC commands "sql" for watsonx.data
# %sql will return results from a Presto select statement or execute a Presto command
#
# IBM 2024: George Baklarz
# Version 2024-04-17
#

from IPython.display import HTML as pHTML, Image as pImage, display as pdisplay, Javascript as Javascript
from IPython.core.magic import (Magics, magics_class, line_magic,cell_magic, line_cell_magic, needs_local_scope)
from ipydatagrid import DataGrid 

import pandas
import json
import warnings
import sys
import time
import prestodb
from prestodb import transaction
import pandas as pd
import sqlalchemy
from sqlalchemy import create_engine

warnings.filterwarnings("ignore")

_settings = {
         "userid"     : None,
         "password"   : None, 
         "hostname"   : None,
         "port"       : "8443",
         "catalog"    : None,
         "schema"     : None,
         "connected"  : False,
         "connection" : None,
         "cursor"     : None,
         "certfile"   : None,
         "display"    : "grid",
         "height"     : 100,
         "quiet"      : False
}

flag_echo = False

SQLOK    = True
SQLERROR = False

#
# Display help (link to documentation)
#

def help():

    global _settings
    
    print("Presto Magic Command Format")
    print("Command Format:")
    print("    %sql command (single line)")  
    print("    %%sql ")
    print("       ... command over multiple lines with SQL separated with a semi-colons ; ...")
    print(" ")
    print("Display options")
    print("    --grid          - scrollable grid")
    print("     [50-500]       - grid size")    
    print("    --pandas        - pandas dataframe")
    print("    --text          - text display")
    print("    --raw           - array of rows/columns")
    print("    --timer         - return time to execute the SQL with no output")
    print("    --quiet         - no messages are produced on success")
    print("    --prototype     - turn on/off the prototype mode to display Python code")
    print("    --help          - help text")
    print(" ")
    print("Connection parameters:")
    print("%%sql connect ")
    print("      userid=value password=value")
    print("      hostname=host port=port")
    print("      catalog=catalog schema=schema")
    print("      certfile=filename")
    print(" ")
    print("Note: Set certfile=None or don't include it when connecting to watsonx.data SaaS servers.")
    print(" ")
    print("Change catalog: %sql use catalog_name.schema_name")
    print("Change schema : %sql use schema_name")
    print(" ")
    print("Force SQL Statement Type")
    print("    [dml]      - The SQL should return an answer set")
    print("    [ddl]      - This is an SQL command with no results")
    print(" ")
    print("Watsonx SaaS userid   : ibmlhapikey")
    print("Watsonx Saas password : Get the generated API key from your system") 
    
    return

#
# Update connection information
#

def setConnection(inSQL):
	
    global _settings
    
    _settings = {
         "userid"     : None,
         "password"   : None, 
         "hostname"   : None,
         "port"       : "8443",
         "catalog"    : None,
         "schema"     : None,
         "connected"  : False,
         "connection" : None,
         "cursor"     : None,
         "certfile"   : None,
         "display"    : "grid",
         "height"     : 100,
         "quiet"      : False
    }
    
    cParms = inSQL.split()
    cnt = 0
    
    for cnt in range(0,len(cParms)):
    
        settings = cParms[cnt].split("=")
        if (len(settings) == 1):
            continue
    
        parm = settings[0].upper()
        parmvalue = settings[1]
      
        if parm == 'USERID':
            _settings["userid"] = parmvalue
        elif parm == 'PASSWORD':
            _settings["password"] = parmvalue
        elif parm == 'HOSTNAME':
            _settings["hostname"] = parmvalue
        elif parm == 'PORT':
            _settings["port"] = parmvalue
        elif parm == 'CATALOG':
            _settings["catalog"] = parmvalue
        elif parm == 'SCHEMA':
            _settings["schema"] = parmvalue
        elif parm == 'CERTFILE':
            _settings["certfile"] = parmvalue
        else:
            print(f"Unknown option: {parm}")
            return False

    return True

#
# Connect to the database
#

def connect(echo):
	
    global _settings  
    
    for setting in _settings:
        if (setting in ["cursor","connection","height","certfile"]): continue
        parm = _settings[setting]
        if (parm in [None,""]):
            print(f"Connection setting {setting} is empty")
            return False

    connected  = _settings["connected"]
    connection = _settings["connection"]
    cursor     = _settings["cursor"]

    if (connected == True):
        if (cursor != None):
            cursor.close()
        if (connection != None):
            connection.close()
        connected = False
    
    userid     = _settings["userid"]
    password   = _settings["password"]  
    hostname   = _settings["hostname"]
    port       = _settings["port"]
    catalog    = _settings["catalog"]
    schema     = _settings["schema"]
    certfile   = _settings["certfile"]  

    connection = None
    cursor     = None
    connected  = False

    _settings["connection"] = False
    _settings["cursor"] = None
    _settings["connected"] = False

    if (echo == True):
        print(f"# Connection Parameters")
        print(f"userid     = '{userid}'")
        print(f"password   = '{password}'") 
        print(f"hostname   = '{hostname}'")
        print(f"port       = '{port}'")
        print(f"catalog    = '{catalog}'")
        print(f"schema     = '{schema}'")
        if (certfile == None):
            print(f'certfile   = None')
        else:
            print(f'certfile   = "{_settings["certfile"]}"') 

        print()
        print(f"# Connect Statement")
        
        print(f'try:')    
        print(f'    connection = prestodb.dbapi.connect(')
        print(f'            host=hostname,')
        print(f'            port=port,')
        print(f'            user=userid,')
        print(f'            catalog=catalog,')
        print(f'            schema=schema,')
        print(f'            http_scheme=\'https\',')
        print(f'            auth=prestodb.auth.BasicAuthentication(userid, password)')
        print(f'    )')
        print(f'    if (certfile != None):')
        print(f'        connection._http_session.verify = certfile')
        print(f'    cursor = connection.cursor()')
        print(f'    print("Connection successful")')
        print(f'except Exception as e:')
        print(f'    print("Unable to connect to the database.")')
        print(f'    print(repr(e))')
        print()
        
    
    try:    
        connection = prestodb.dbapi.connect(
                    host=hostname,
                    port=port,
                    user=userid,
                    catalog=catalog,
                    schema=schema,
                    http_scheme='https',
                    auth=prestodb.auth.BasicAuthentication(userid, password)
        )
        if (certfile != None):
            connection._http_session.verify = certfile
        cursor = connection.cursor()
    except Exception as e:
        print("Unable to connect to the database.")
        printSQLerror(repr(e))
        return False        

    try:
        df = pd.read_sql("select 1",connection)
    except Exception as e:
        print("Unable to connect to the database.")
        # printSQLerror(repr(e))
        return False
   
    _settings["connected"]  = True
    _settings["cursor"]     = cursor
    _settings["connection"] = connection
    return True
	
# Print out an error message

def printSQLerror(error):

    offset = 0
    errormsg = str(error).replace("\\'","'").replace('\\n',' ')
    leftbrk = errormsg.find("<")
    while leftbrk >= 0:
        rightbrk = errormsg.find(">",leftbrk+1)
        if (rightbrk == -1):
            break
        errormsg = errormsg[:leftbrk]+errormsg[rightbrk+1:]
        leftbrk = errormsg.find("<")
        
    message = errormsg
    start_error = errormsg.find('message="')
    if (start_error >= 0):
        end_error   = errormsg.find('"',start_error+9)
        if (end_error >= 0):
            message = "SQL Error: " + errormsg[start_error+9:end_error]
    else:
        start_error = errormsg.find('Reason:')
        if (start_error >= 0):
            end_error = errormsg.find('{',start_error)
            if (end_error >= 0):
                message = "SQL Error: " + errormsg[start_error+8:end_error]
            else:
                message = "SQL Error: " + errormsg[start_error+8:]
    
    # if (message != ""):
    #     html = '<p><p style="font-family: monospace; border:2px; border-style:solid; border-color:#FF0000; background-color:#ffe6e6; padding: 1em;">'
    #     pdisplay(pHTML(html + message + "</p>")) 

    print(message)

    return
	
def success(message):
	
    if (message not in (None,"")):
        print(message)
        # If you want the message in Green
        # html = '<p><pre style="font-family: monospace; border:2px; border-style:solid; border-color:#008000; background-color:#e6ffe6; padding: 1em;">'
        # pdisplay(pHTML(html + message + "</pre></p>"))
    return   

#
# execSQL(sql, display_type, unknown)
# SQL          - The SQL to execute
# display_type - Grid, Pandas, Text, or Raw format. Timer only returns the execution time.
# unknown      - False=Returns results set True=We don't know.
# Execute SQL that has an answer set. If we don't know, try to see if it returns an answer set.
#
def execSQL(sql,display_type,unknown=False,echo=False):

    from pandas import DataFrame
    from ipydatagrid import DataGrid
    import pandas as pd

    global _settings
    if (_settings["connected"] == False):
        if (connect(echo) == False):
            return None

    connection = _settings["connection"]
    cursor     = _settings["cursor"]

    if (echo == True):
        print(f'# SQL')
        print(f"sql = '''")
        print(f"{sql}")
        print(f"'''")
        print()  

    if (display_type == "timer"): 

        if (echo == True):
            print(f"# Timer Execution")   
            print(f'import time')     
            print(f'start_time = time.time()')
            print(f'cursor.execute(sql)')
            print(f'end_time = time.time()')
            print(f'elapsed = end_time - start_time')
            print() 
      
        start_time = time.time()
        cursor.execute(sql)
        end_time = time.time()
        return end_time - start_time            

    if (unknown == True):

        if (echo == True):

            print(f"# Executing Unknown SQL (DDL or DML)") 
            print(f'try:')
            print(f'    cursor.execute(sql)')
            print(f'    if (cursor.description == None):')
            print(f'        print("Command Completed.")')
            print(f'except Exception as e:')
            print(f'    print(repr(e))')
            print()
        
        try:
            cursor.execute(sql)
            if (cursor.description == None):
                if (_settings["quiet"] == False):
                    print("Command Completed.")
                # connection.commit()
                return None
        except Exception as e:
            printSQLerror(repr(e))
            return None                   
   
    try:
        if (display_type != "raw"):

            if (echo == True):
                print(f"# Executing SQL Statement (Returning Dataframe)") 
                print(f'# Variable df contains the answer set in a Pandas dataframe')
                print(f'try:')
                print(f'    df = pd.read_sql(sql,connection)')
                print(f'    if (len(df) == 0):')
                print(f'        print("No rows found.")')
                print(f'except Exception as e:')
                print(f'    print(repr(e))')
                print()
            
            df = pd.read_sql(sql,connection)
            if (len(df) == 0):
                print("No rows found.")
                return None
            if (display_type == "grid"):
                height = _settings["height"]
                pdisplay(DataGrid(df,auto_fit_columns=True,layout={"height": f"{height}px"}))
                return None
            elif (display_type == "pandas"):
                return df
            elif (display_type == "text"):
                output = df.to_string(index=False)
                output = output.replace("\\n","\n")
                print(output)
                return None
            else:
                return df
        else:

            if (echo == True):
                print(f"# Executing SQL Statement (Returning Array)") 
                print(f"# Variable rows contains an array of values from the answer set")
                print(f'try:')
                print(f'    cursor.execute(sql)')
                print(f'    rows = cursor.fetchall()')
                print(f'except Exception as e:')
                print(f'    print(repr(e))')
                print()
            
            cursor = _settings["cursor"]
            cursor.execute(sql)
            rows = cursor.fetchall()
            return rows

    except Exception as e:
        printSQLerror(repr(e))
        return None   
    
    return None

def execDDL(sql,echo=False):

    from pandas import DataFrame
    from ipydatagrid import DataGrid
    import pandas as pd
    
    global _settings
    if (_settings["connected"] == False):
        connected = connect(echo)
        if (connected == False):
            return 

    connection = _settings["connection"]
    cursor     = _settings["cursor"]

    if (echo == True):
        print(f'# Parameters')
        print(f"sql = '''")
        print(f"{sql}")
        print(f"'''")

        print()
        print(f"# DDL Statement")        
        print(f'try:')
        print(f'    cursor.execute(sql)')
        print(f'except Exception as err:')
        print(f'    printSQLerror(repr(err))')
        print()

    try:
        cursor.execute(sql)
        if (_settings["quiet"] == False): 
            print("Command completed.")
        return 
    except Exception as err:
        printSQLerror(repr(err))
        return 

def getSQL(line,cell):

    sqlines = []
    inSQL = "" 
    delimiter = ";"

    if (line not in [None,""]):
        inSQL = inSQL + line + " "
        
    if (cell not in [None,""]):
        inSQL = inSQL + cell

    if (inSQL.strip() in [None,""]):
        return sqllines
	 
    pos = 0
    arg = ""
    results = []
    quoteCH = ""

    # Chunk the SQL into sections by Delimiter
			
    while pos < len(inSQL):
        ch = inSQL[pos]
        pos += 1
        if (ch in ('"',"'")):                   # Is this a quote characters?
            arg = arg + ch                      # Keep appending the characters to the current arg
            if (ch == quoteCH):                 # Is this quote character we are in
                quoteCH = ""
            elif (quoteCH == ""):               # Create the quote
                quoteCH = ch
            else:
                None
        elif (quoteCH != ""):                   # Still in a quote
            arg = arg + ch
        elif (ch == delimiter):                 # Is there a delimiter?
            if (arg.strip != ""):
                results.append(arg)
            arg = ""
        else:
            arg = arg + ch
			
    if (arg != ""):
        results.append(arg)
		
    return(results)

def stripComments(sql):

    import re

    sql = re.sub(r"(.*)--.*\n",r"\1 ",sql)
    sql = sql.strip()
    return(sql)

@magics_class
class presto(Magics):
   
    @needs_local_scope    
    @line_cell_magic
    def sql(self, line, cell=None, local_ns=None):
            
        global _settings, flag_echo
  
        DML = ["VALUES","SHOW","SELECT","WITH","DESCRIBE","EXPLAIN"]
        DDL = ["PREPARE","DISPLAY","CONNECT","USE","CATALOG","DROP","CREATE","ALTER","INSERT","DELETE","UPDATE","HELP"]
        UNK = ["EXECUTE","CALL"]

        if ("--help" in line):
            help()      
            return

        if ("--quiet" in line):
            _settings["quiet"] = True
            line = line.replace("--quiet","")

        if ("--prototype" in line):
            if (flag_echo == False):
                flag_echo = True
                print(f"# All Python code will be displayed during execution")
                print(f'# Imports required to connect to watsonx.data')
                print(f'import prestodb')
                print(f'from prestodb import transaction')
                print(f'import pandas as pd')
                print(f'import sqlalchemy')
                print(f'from sqlalchemy import create_engine') 
                print(f'import warnings')
                print(f'warnings.filterwarnings("ignore")')
                print()
            else:
                flag_echo = False
                print("Echo OFF. SQL Statements will not be displayed.")
            return

        if ("--text" in line):
            display_type = "text"
            line = line.replace("--text","")
        elif ("--pandas" in line):
            display_type = "pandas"
            line = line.replace("--pandas","")
        elif ("--grid" in line):
            display_type = "grid"
            line = line.replace("--grid","")
        elif ("--raw" in line):
            display_type = "raw"
            line = line.replace("--raw","")    
        elif ("--timer" in line):   
            display_type = "timer"
            line = line.replace("--timer","")
        else:
            display_type = _settings["display"]

        sqllines = getSQL(line,cell)

        if (len(sqllines) == 0):
            return
        elif (len(sqllines) > 1):
            show = False
        else:
            show = True           

        for sql in sqllines:

            # sql = stripComments(sql)

            if (sql == ""): 
                continue

            if ("[dml]" in sql):
                sql_type = "dml"
                sql = sql.replace("[dml]","")    
            elif ("[ddl]" in sql):
                sql_type = "ddl"
                sql = sql.replace("[ddl]","")    
            else:
                sql_type = None

            tokens = sql.split()

            if (len(tokens) == 0):
                continue
            sqlType = tokens[0].upper()   
                
            if (sqlType == "DISPLAY"):
                if (len(tokens) != 2):
                    print("DISPLAY syntax: DISPLAY GRID [pixels | GRID | PANDAS | TEXT | RAW] 100 < pixels < 500 or None ")
                    return
                else:
                    pixelText = tokens[1].strip().upper()
                    if (pixelText == "GRID"):
                        _settings["display"] = "grid"
                    elif (pixelText == "PANDAS"):
                        _settings["display"] = "pandas"
                    elif (pixelText == "TEXT"):
                        _settings["display"] = "text"
                    elif (pixelText == "RAW"):
                        _settings["display"] = "raw"                    
                    else:
                        try:
                            pixels = int(tokens[1])
                        except:
                            print(f"Non-numeric value for Display height: {tokens[1]}")
                            continue
                        if (pixels < 50 or pixels > 500):
                            print("Pixel size for results sets is outside the range of 100 - 500")
                        else:
                            _settings["height"] = pixels
            
            elif (sqlType == "CONNECT"):
                if (setConnection(sql) == True):
                    connected = connect(flag_echo)
                    if (connected == True):
                        print("Connection successful.")
                    else:
                        return
                else:
                    return
            
            elif (sqlType == "USE"):
                if (len(tokens) != 2):
                    print("USE syntax: USE [schema] | USE [catalog.schema]")
                else:
                    catalogschema = tokens[1].strip()
                    if ("." in catalogschema):
                        args = catalogschema.split(".")
                        newCatalog = args[0].strip()
                        newSchema  = args[1].strip()
                        _settings["catalog"] = newCatalog
                        _settings["schema"]  = newSchema
                        connected = connect(flag_echo)
                        if (connected == True):
                            print(f"Using catalog {newCatalog} with schema {newSchema}")   
                        else:
                            return
                    else:
                        newSchema = catalogschema                      
                        _settings["schema"] = newSchema
                        connected = connect(flag_echo)
                        if (connected == True):
                            print(f"Using schema {newSchema}")
                        else:
                            return
    
            elif (sqlType in DML or sql_type == "dml"):
                results = execSQL(sql,display_type,unknown=False,echo=flag_echo)
                if (show == True):
                    return results
                        
            elif (sqlType in UNK):
                results = execSQL(sql,display_type,unknown=True,echo=flag_echo)
                if (show == True):
                    return results
                        
            elif (sqlType in DDL or sql_type == "ddl"):
                execDDL(sql,flag_echo)
                if (show == True): 
                    return

            else:
                results = execSQL(sql,display_type,unknown=True,echo=flag_echo)
                if (show == True):
                    return results
			
# Register the Magic extension in Jupyter    
ip = get_ipython()          
ip.register_magics(presto)
success("Presto Extensions Loaded.")

##### Credits: IBM 2024, George Baklarz [baklarz@ca.ibm.com]