# Presto Jupyter Notebook Extensions

Version: 2023-10-01

This code is imported as a Jupyter notebook extension in any notebooks you create with PResto code in it. Place the following line of code in any notebook that you want to use these commands with:
<pre>
&#37;run presto.ipynb
</pre>

Note that this is a very simplistic example of using Magic commands for Presto. This is not an official release of watsonx.data Presto extensions, and is only used to illustrate the Presto SQL command syntax.

In [None]:
#
# Set up Jupyter MAGIC commands "sql" for watsonx.data
# %sql will return results from a Presto select statement or execute a Presto command
#
# IBM 2023: George Baklarz
# Version 2023-10-01
#


from IPython.display import HTML as pHTML, Image as pImage, display as pdisplay, Javascript as Javascript
from IPython.core.magic import (Magics, magics_class, line_magic,cell_magic, line_cell_magic, needs_local_scope)
from ipydatagrid import DataGrid 

import pandas
import json
import warnings
import sys
import prestodb
from prestodb import transaction
import pandas as pd
import sqlalchemy
from sqlalchemy import create_engine

warnings.filterwarnings("ignore")

_settings = {
         "userid"     : None,
         "password"   : None, 
         "hostname"   : None,
         "port"       : "8443",
         "catalog"    : None,
         "schema"     : None,
         "connected"  : False,
         "connection" : None,
         "cursor"     : None,
         "certfile"   : None,
         "display"    : "grid",
         "height"     : 100,
}

SQLOK    = True
SQLERROR = False

#
# Display help (link to documentation)
#

def help():

    global _settings
    
    print("Presto Magic Command Format")
    print("Command Format:")
    print("    %sql command (single line)")  
    print("    %%sql ")
    print("       ... command over multiple lines ...")
    print(" ")
    print("Display options")
    print("    --grid     - scrollable grid")
    print("     [50-500]  - grid size")    
    print("    --pandas   - pandas dataframe")
    print("    --text     - text display")
    print("    --raw      - array of rows/columns")
    print(" ")
    print("Connection parameters:")
    print("%%sql connect ")
    print("      userid=value password=value")
    print("      certfile=filename")
    print("      hostname=host port=port")
    print("      catalog=catalog schema=schema")
    print(" ")
    print("Change catalog: %sql use catalog.schema")
    print("Change schema : %sql use schema")
    print(" ")
    print("Force SQL Statement Type")
    print("    [dml]      - The SQL should return an answer set")
    print("    [ddl]      - This is a SQL command with no results")
    
    return

#
# Update connection information
#

def setConnection(inSQL):
	
    global _settings
    
    _settings = {
         "userid"     : None,
         "password"   : None, 
         "hostname"   : None,
         "port"       : "8443",
         "catalog"    : None,
         "schema"     : None,
         "connected"  : False,
         "connection" : None,
         "cursor"     : None,
         "certfile"   : None,
         "display"    : "grid",
         "height"     : 100
    }
    
    cParms = inSQL.split()
    cnt = 0
    
    for cnt in range(0,len(cParms)):
    
        settings = cParms[cnt].split("=")
        if (len(settings) == 1):
            continue
    
        parm = settings[0].upper()
        parmvalue = settings[1]
      
        if parm == 'USERID':
            _settings["userid"] = parmvalue
        elif parm == 'PASSWORD':
            _settings["password"] = parmvalue
        elif parm == 'HOSTNAME':
            _settings["hostname"] = parmvalue
        elif parm == 'PORT':
            _settings["port"] = parmvalue
        elif parm == 'CATALOG':
            _settings["catalog"] = parmvalue
        elif parm == 'SCHEMA':
            _settings["schema"] = parmvalue
        elif parm == 'CERTFILE':
            _settings["certfile"] = parmvalue
        else:
            print(f"Unknown option: {parm}")
            return False

    return True

#
# Connect to the database
#

def connect():
	
    global _settings  
    
    for setting in _settings:
        if (setting in ["cursor","connection","height"]): continue
        parm = _settings[setting]
        if (parm in [None,""]):
            print(f"Connection setting {setting} is empty")
            return False
    
    userid     = _settings["userid"]
    password   = _settings["password"]  
    hostname   = _settings["hostname"]
    port       = _settings["port"]
    catalog    = _settings["catalog"]
    schema     = _settings["schema"]
    certfile   = _settings["certfile"]

    _settings["connected"]  = False
    _settings["cursor"]     = None
    _settings["connection"] = None    

    connection = None
    cursor     = None
    
    try:    
        connection = prestodb.dbapi.connect(
                    host=hostname,
                    port=port,
                    user=userid,
                    catalog=catalog,
                    schema=schema,
                    http_scheme='https',
                    auth=prestodb.auth.BasicAuthentication(userid, password)
        )
        connection._http_session.verify = certfile
        cursor = connection.cursor()
        try:
            df = pd.read_sql("select 1",connection)
        except Exception as err:
            formatSQLError(repr(err))
            return False
        
    except Exception as e:
        formatSQLError(repr(e))
        return False
    
    _settings["connected"]  = True
    _settings["cursor"]     = cursor
    _settings["connection"] = connection
    return True
	
# Print out an error message

def formatSQLError(error):

    offset = 0
    errormsg = str(error).replace("\\'","'").replace('\\n',' ')
    leftbrk = errormsg.find("<")
    while leftbrk >= 0:
        rightbrk = errormsg.find(">",leftbrk+1)
        if (rightbrk == -1):
            break
        errormsg = errormsg[:leftbrk]+errormsg[rightbrk+1:]
        leftbrk = errormsg.find("<")
        
    message = errormsg
    start_error = errormsg.find('message="')
    if (start_error >= 0):
        end_error   = errormsg.find('"',start_error+9)
        if (end_error >= 0):
            message = "SQL Error: " + errormsg[start_error+9:end_error]
    else:
        start_error = errormsg.find('Reason:')
        if (start_error >= 0):
            end_error = errormsg.find('{',start_error)
            if (end_error >= 0):
                message = "SQL Error: " + errormsg[start_error+8:end_error]
            else:
                message = "SQL Error: " + errormsg[start_error+8:]
    
    # if (message != ""):
    #     html = '<p><p style="font-family: monospace; border:2px; border-style:solid; border-color:#FF0000; background-color:#ffe6e6; padding: 1em;">'
    #     pdisplay(pHTML(html + message + "</p>"))  
    print(message)
    return 
	
def success(message):
	
    if (message not in (None,"")):
        print(message)
        # If you want the message in Green
        # html = '<p><pre style="font-family: monospace; border:2px; border-style:solid; border-color:#008000; background-color:#e6ffe6; padding: 1em;">'
        # pdisplay(pHTML(html + message + "</pre></p>"))
    return   

def execSQL(sql,display_type,unknown=False):

    from pandas import DataFrame
    from ipydatagrid import DataGrid
    import pandas as pd

    global _settings
    if (_settings["connected"] == False):
        if (connect() == False):
            return None

    connection = _settings["connection"]
    cursor     = _settings["cursor"]

    if (unknown == True):
        try:
            cursor.execute(sql)
            if (cursor.description == None):
                print("Command Completed.")
                # connection.commit()
                return None
        except Exception as err:
            formatSQLError(repr(err))
            return None                   
   
    try:
        if (display_type != "raw"):
            df = pd.read_sql(sql,connection)
            if (len(df) == 0):
                success("No rows found.")
                return None
            if (display_type == "grid"):
                height = _settings["height"]
                pdisplay(DataGrid(df,auto_fit_columns=True,layout={"height": f"{height}px"}))
                return None
            elif (display_type == "pandas"):
                return df
            elif (display_type == "text"):
                output = df.to_string(index=False)
                output = output.replace("\\n","\n")
                print(output)
                return None
            else:
                return df
        else:
            cursor = _settings["cursor"]
            cursor.execute(sql)
            rows = cursor.fetchall()
            return rows

    except Exception as err:
        message = formatSQLError(repr(err))
        return None
    
    return None

def execDDL(sql):

    from pandas import DataFrame
    from ipydatagrid import DataGrid
    import pandas as pd
    
    global _settings
    if (_settings["connected"] == False):
        connected, message = connect()
        if (connected == False):
            return

    connection = _settings["connection"]
    cursor     = _settings["cursor"]

    try:
        cursor.execute(sql)
        print("Command completed.")
        return
    except Exception as err:
        formatSQLError(repr(err))
        return

@magics_class
class presto(Magics):
   
    @needs_local_scope    
    @line_cell_magic
    def sql(self, line, cell=None, local_ns=None):
            
        global _settings
  
        DML = ["VALUES","SHOW","SELECT","WITH","DESCRIBE","EXPLAIN"]
        DDL = ["PREPARE","DISPLAY","CONNECT","USE","CATALOG","DROP","CREATE","ALTER","INSERT","DELETE","UPDATE","HELP"]
        UNK = ["EXECUTE","CALL"]
        sql = line.replace("\n"," ").strip()
            
        if (cell not in [None,""]):
            sql = sql + " " + cell.replace("\n"," ").strip()
    
        if (sql in [None,""," "]):
            return

        if ("--text" in sql):
            display_type = "text"
            sql = sql.replace("--text","")
        elif ("--pandas" in sql):
            display_type = "pandas"
            sql = sql.replace("--pandas","")
        elif ("--grid" in sql):
            display_type = "grid"
            sql = sql.replace("--grid","")
        elif ("--raw" in sql):
            display_type = "raw"
            sql = sql.replace("--raw","")            
        else:
            display_type = _settings["display"]

        sqlline = sql.split(";")

        if (len(sqlline) > 1):
            show = False
        else:
            show = True           
        
        for sql in sqlline:

            sql = sql.strip()
            if (sql == ""): 
                continue

            if ("[dml]" in sql):
                sql_type = "dml"
                sql = sql.replace("[dml]","")    
            elif ("[ddl]" in sql):
                sql_type = "ddl"
                sql = sql.replace("[ddl]","")    
            else:
                sql_type = None

            tokens = sql.split()

            if (len(tokens) == 0):
                continue
            sqlType = tokens[0].upper()   
     
            if (sqlType == "HELP"):
                help()
                
            elif (sqlType == "DISPLAY"):
                if (len(tokens) != 2):
                    print("DISPLAY syntax: DISPLAY GRID [pixels | GRID | PANDAS | TEXT | RAW] 100 < pixels < 500 or None ")
                    return
                else:
                    pixelText = tokens[1].strip().upper()
                    if (pixelText == "GRID"):
                        _settings["display"] = "grid"
                    elif (pixelText == "PANDAS"):
                        _settings["display"] = "pandas"
                    elif (pixelText == "TEXT"):
                        _settings["display"] = "text"
                    elif (pixelText == "RAW"):
                        _settings["display"] = "raw"                    
                    else:
                        try:
                            pixels = int(tokens[1])
                        except:
                            print(f"Non-numeric value for Display height: {tokens[1]}")
                            continue
                        if (pixels < 50 or pixels > 500):
                            print("Pixel size for results sets is outside the range of 100 - 500")
                        else:
                            _settings["height"] = pixels
            
            elif (sqlType == "CONNECT"):
                if (setConnection(sql) == True):
                    connected = connect()
                    if (connected == True):
                        print("Connection successful")
            
            elif (sqlType == "USE"):
                if (len(tokens) != 2):
                    print("USE syntax: USE [schema] | USE [catalog.schema]")
                else:
                    catalogschema = tokens[1].strip()
                    if ("." in catalogschema):
                        args = catalogschema.split(".")
                        newCatalog = args[0].strip()
                        newSchema  = args[1].strip()
                        _settings["catalog"] = newCatalog
                        _settings["schema"]  = newSchema
                        connected = connect()
                        if (connected == True):
                            print(f"Using catalog {newCatalog} with schema {newSchema}")                    
                    else:
                        newSchema = catalogschema                      
                        _settings["schema"] = newSchema
                        connected = connect()
                        if (connected == True):
                            print(f"Using schema {newSchema}")
    
            elif (sqlType in DML or sql_type == "dml"):
                results = execSQL(sql,display_type,False)
                if (results == None):
                    continue
                if (show == True):
                    return results
                        
            elif (sqlType in UNK):
                results = execSQL(sql,display_type,True)
                if (results == None):
                    continue
                if (show == True):
                    return results
                        
            elif (sqlType in DDL or sql_type == "ddl"):
                execDDL(sql)

            else:
                results = execSQL(sql,display_type,True)
                if (results == None):
                    continue
                if (show == True):
                    return results
			
# Register the Magic extension in Jupyter    
ip = get_ipython()          
ip.register_magics(presto)
success("Presto Extensions Loaded.")


##### Credits: IBM 2023, George Baklarz [baklarz@ca.ibm.com]