In [None]:
import os
from IPython.display import Image, display, clear_output
from IPython.core.display import HTML
from astroquery.cadc import Cadc
from astropy.table import Table as ATable
from astropy.io.votable.tree import VOTableFile, Resource, Table, Field
from astropy.io.votable import from_table, writeto
import pandas as pd
import glue_jupyter as gj
import ipywidgets as widgets
from ipywidgets import interact
from cadcutils import net
from cadctap import CadcTapClient
from six import BytesIO
#import sh
import base64

out = widgets.Output()
# hardcode 11 colours
color_list = ['#F780BF', '#C7E6A1', '#A1E6E2', '#FAFBA1',
              '#FF9C33', '#CBC6C0', '#D8A56E', '#F7A480',
              '#D798EE', '#80A0C8', '#FACCED']
observation_id = widgets.IntText(value=1013372,
                                 description='Observation ID',
                                 continuous_update=False,
                                 style={'description_width': 'initial'})
ui_database = widgets.HBox([observation_id])
G_Type = widgets.Dropdown(options=['Scatter', 'Histogram', 'Table'],
                          description='View Data',
                          style={'description_width': 'initial'})
ui_graph = widgets.HBox([G_Type])


def background_color(row):
    color = ''
    for item in range(0, len(out_data.subsets)):
        # Obs. ID is the 4th column
        if row.values[3] in out_data.subsets[item]['Obs. ID']:
            color = color_list[item]
    return ['background-color: %s ' % color] * len(row.values)


def subset_background_color(row, color):
    return ['background-color: %s ' % color] * len(row.values)
    

def select_method(obtain_images):
    if (obtain_images == 'Download Images'):
        txt = ""
        for key in url_dictionary:
            txt = txt + str(key) +'\n'
        b64 = base64.b64encode(txt.encode())
        payload = b64.decode()
        html = '<a download="testfile" href="data:text;base64,{payload}" target="_blank">Download Subset</a>'
        html = html.format(payload=payload)
        display(HTML(html))
    else:
        if len(url_dictionary) > 10:
            print( "Can not display more than 10 images")
        else:
            for value in url_dictionary.values():
                print("Observation Id: {}".format(value[1]))
                link = '<a href="{}">{}</a>'.format(str(value[0]), str(value[0]))
                display(HTML(link))
                display(Image(url=str(value[0]), width=100, height=100, unconfined=True))


def on_button_clicked(b):
    # linking function with output
    global results_df, url_dictionary
    with out:
        clear_output()
        if(click_flag == False):
            display(results_df.style.apply(background_color, axis=1))
        else:
            #b.description to access the button
            label = b.description.split(' ')
            i = int(label[1])-1
            subset_data = {'Preview':out_data.subsets[i]['Preview'],
                           'publisherID':out_data.subsets[i]['Publisher ID'],
                           'Collection':out_data.subsets[i]['Collection'],
                           'Obs. ID':out_data.subsets[i]['Obs. ID'],
                           'Product ID':out_data.subsets[i]['Product ID'],
                           'Instrument':out_data.subsets[i]['Instrument'],
                           'Int. Time':out_data.subsets[i]['Int. Time'],
                           'Overall Quality':out_data.subsets[i]['Overall Quality'],
                           'Really Bad Tracking':out_data.subsets[i]['Really Bad Tracking'],
                           'Bad Tracking':out_data.subsets[i]['Bad Tracking'],
                           'Bad Weather':out_data.subsets[i]['Bad Weather'],
                           'Background Problem':out_data.subsets[i]['Background Problem'],
                           'Dead CCDs':out_data.subsets[i]['Dead CCDs']}
            subset_df = pd.DataFrame(subset_data)
            cadc = Cadc()
            subset_table = ATable.from_pandas(subset_df)
            url_dictionary ={}
            for idx in range(len(subset_df)):
                url = cadc.get_data_urls(subset_table[idx:idx+1], include_auxiliaries=True)
                fz_url = next((u for u in url if '.fz' in u), None)
                jpg_url = next((u for u in url if '1024.jpg' in u), None)
                if jpg_url:
                    url_dictionary[fz_url] = [jpg_url, subset_df['Obs. ID'][idx]]
            if(subset_num_records[i] > 0):
                print("Number of record: {}".format(subset_num_records[i]))  #just in case the tooltip does not work
                display(subset_df.style.apply(subset_background_color, color=color_list[i], axis=1))
                obtain_images = widgets.Dropdown(options=['Download Images', 'Display Images'],
                                            description='Obtain Images',
                                            style={'description_width': 'initial'})
                selection = widgets.HBox([obtain_images])
                select_output = widgets.interactive_output(select_method, {'obtain_images': obtain_images})
                display(selection, select_output)
            else:
                print("\x1b[31m Subset does not contain any record. \x1b[0m")


def graphs(G_Type):
    global out_data, results_df, click_flag, subset_num_records
    if (G_Type == 'Scatter'):
        scatter_viewer = app.scatter2d(x='Really Bad Tracking',
                                       y='Overall Quality',
                                       data=out_data,
                                       show=True)
    elif (G_Type == 'Histogram'):
        histogram_viewer = app.histogram1d(x='Overall Quality',
                                           data=out_data,
                                           show=True)
    else:
        click_flag = False
        if (len(out_data.subsets)):
            subset_option = []
            subset_num_records = []
            items_auto = []
            for i in range(len(out_data.subsets)):
                subset_option.append(out_data.subsets[i].label)
                n_records = len(out_data.subsets[i]['Product ID'])
                subset_num_records.append(n_records)
            for num in range(0, len(subset_option)):
                n_records = str(subset_num_records[num])
                t_tip = "number of records: {}".format(n_records)
                button = widgets.Button(description=subset_option[num],
                                        layout=widgets.Layout(flex='1 1 auto',
                                                              width='auto'),
                                        style=widgets.ButtonStyle(
                                            button_color=color_list[num]),
                                        tooltip=t_tip)
                button.on_click(on_button_clicked)
                items_auto.append(button)
            box_layout = widgets.Layout(display='flex',
                                        flex_flow='row',
                                        align_items='stretch',
                                        width='100%')
            box_auto = widgets.Box(children=items_auto,
                                   layout=box_layout)
            display(widgets.VBox([box_auto,out]))
            button.click()
            click_flag = True
        else:
            display(results_df)


def retrieve_data(observation_id):
    global app, out_data, results_df
    obs_query = """SELECT Observation.observationURI,
    Plane.publisherID,
    Observation.collection,
    Observation.observationID,
    Plane.productID,
    Observation.instrument_name,
    Plane.time_exposure
    FROM caom2.Plane AS Plane
    JOIN caom2.Observation AS Observation
    ON Plane.obsID = Observation.obsID
    WHERE (Observation.observationID = '{obs_id}'
    AND collection = '{collection}'
    AND instrument_name = '{instrument}')"""
    obs_query_param = {'obs_id': observation_id,
                       'collection': 'CFHT',
                       'instrument': 'MegaPrime'}
    output_file = 'tmp/test_vooutput.xml'
    cadc = Cadc()
    output = cadc.exec_sync(obs_query.format(**obs_query_param))
    votable = from_table(output)
    writeto(votable, "tmp/output.xml")
    anonSubject = net.Subject()
    #client = CadcTapClient(anonSubject, resource_id='ivo://cadc.nrc.ca/youcat')
    certSubject = net.Subject(certificate=os.path.join(os.environ['HOME'], ".ssl/cadcproxy.pem"))
    client = CadcTapClient(certSubject)
    quality_query = """SELECT Tmp.observationURI AS Preview,
    Tmp.publisherID AS "Publisher ID",
    Tmp.collection AS Collection,
    Tmp.observationID AS "Obs. ID",
    Tmp.productID AS "Product ID",
    Tmp.instrument_name AS Instrument,
    Tmp.time_exposure AS "Int. Time",
    Quality.overallQuality AS "Overall Quality",
    Quality.reallyBadTracking AS "Really Bad Tracking",
    Quality.badTracking AS "Bad Tracking",
    Quality.badWeather AS "Bad Weather",
    Quality.backgroundProblem AS "Background Problem",
    Quality.deadCCDs AS "Dead CCDs"
    FROM ml.MegaprimeQuality AS Quality
    JOIN tap_upload.tmptable AS Tmp
    ON Quality.observationID=Tmp.observationID
    WHERE publisherID LIKE '%p'"""
    def_table = os.path.join('tmp', 'output.xml')
    client.query(quality_query,
                 response_format='csv',
                 output_file='tmp/output_file.csv',
                 tmptable='tmptable:' + def_table, timeout=30)

    """# trim csv file
    first = "Preview,Publisher ID,Collection,Obs. ID,Product ID,Instrument,Int. Time,"
    second = "Overall Quality,Really Bad Tracking,Bad Tracking,"
    line = first + second +"Bad Weather,Background Problem,Dead CCDs"
    sh.sed("-i", "1s/.*/" + line + "/", "tmp/output_file.csv")
    sh.sed("-i", "2d", "tmp/output_file.csv")
    sh.sed("-i", "$ d","tmp/output_file.csv")
    # Get the total number of points
    results_df = pd.read_csv('tmp/output_file.csv')
    cols = results_df.columns
    count_row = len(results_df.index)
    results_df.to_csv('tmp/output_file.csv', index=False)
    print('Total row count: {}'.format(str(count_row)))"""
    results_df = pd.read_csv('tmp/output_file.csv')
    results_df.columns = [col.replace('"', '') for col in results_df.columns]
    results_df.dropna(subset=['Obs. ID'], inplace=True)
    results_df.to_csv('tmp/output_file.csv', index=False)
    count_row = len(results_df.index)
    print('Total row count: {}'.format(str(count_row)))    
    
    # display data using glue
    app = gj.jglue()
    out_data = app.load_data('tmp/output_file.csv')
    print("displaying data.")
    graph_output = widgets.interactive_output(graphs, {'G_Type': G_Type})
    display(ui_graph, graph_output)


database_output = widgets.interactive_output(retrieve_data, {"observation_id": observation_id})
display(ui_database, database_output)