# Import Libraries

In [1]:
import utility_func as util
import importlib
importlib.reload(util)
import gradio as gr

In [2]:
import warnings
warnings.filterwarnings('ignore')

# Functions

In [3]:
## Adjust plotting functions to accept axes as a parameter
def plot_ppm_variation(df, element, area):
    lat_list = util.np.sort(df['latitude'].unique())[::-1]
    n_plots = len(lat_list)

    if area == 'Ramagiri':
        # Create a figure with subplots
        fig, axes = util.plt.subplots(n_plots, 
                                      1, 
                                      figsize=(10, 2 * n_plots), 
                                      constrained_layout=True)
    elif area == 'Kodangal':
        # Create a figure with subplots
        fig, axes = util.plt.subplots(n_plots, 
                                      1, 
                                      figsize=(12, 2 * n_plots), 
                                      constrained_layout=True)
        
    if n_plots == 1:
        axes = [axes]  # Ensure axes is iterable
    
    for ax, lat in zip(axes, lat_list):
        subset = df[df['latitude'] == lat].sort_values(by='longitude')
        ax.plot(subset['longitude'], subset[element], marker='o', linestyle='-', color='b')
        ax.set_title(f'Concentration at Latitude {lat}')
        ax.set_xlabel('Longitude')
        ax.set_ylabel(f'Concentration(ppm)')
        ax.set_ylim(0, max(df[element]))
        ax.grid(True)
    
    return fig

## PLot contour map
def plot_contour(df, element, name, title, area, ax=None):

    if ax is None:
        ax = plt.gca()
    if element not in df.columns:
        raise KeyError(f"The element '{element}' does not exist in the dataframe.")
    
    min_lat_limit = min(df['latitude']) - 0.005
    max_lat_limit = max(df['latitude']) + 0.005
    min_long_limit = min(df['longitude']) - 0.005 
    max_long_limit = max(df['longitude']) + 0.005

    grid_x, grid_y = util.np.mgrid[
        min_long_limit:max_long_limit:200j,
        min_lat_limit:max_lat_limit:200j
    ]

    grid_z = util.griddata(
        (df['longitude'], df['latitude']),
        df[element],
        (grid_x, grid_y),
        method='cubic'
    )

    cmap = util.LinearSegmentedColormap.from_list("green_to_red", ["green", "yellow", "red"])
    
    if area == 'Ramagiri':
        ax.figure.set_size_inches(10, 6)
    elif area == 'Kodangal':
        ax.figure.set_size_inches(15, 5)

    cp = ax.contourf(grid_x, grid_y, grid_z, levels=15, cmap=cmap, alpha=0.7)
    util.plt.colorbar(cp, ax=ax, label=f'{name} concentration')
    
    cs = ax.contour(grid_x, grid_y, grid_z, levels=15, colors='k', linewidths=0.5)
    ax.clabel(cs, inline=True, fontsize=8, fmt='%1.0f')

    ax.set_title(title)
    ax.set_xlabel('Longitude')
    ax.set_ylabel('Latitude')
    ax.grid(True, color='gray', linestyle='--', linewidth=0.5)

## Function to get element symbol from the element name and area
def get_element_symbol(element_name, area):
    if area == 'Ramagiri':
        element = rg_elem_name_df.loc[rg_elem_name_df['Name'] == element_name, 'Symbol'].values[0]
    elif area == 'Kodangal':
        element = kg_elem_name_df.loc[kg_elem_name_df['Name'] == element_name, 'Symbol'].values[0]
    return element

## Function to update element dropdown options based on selected area
def update_elements(area):
    if area == 'Ramagiri':
        elements = list(rg_elem_name_df['Name'])
    elif area == 'Kodangal':
        elements = list(kg_elem_name_df['Name'])
   
    return gr.update(choices=elements)

## Function to be called by Gradio
def show_plots(element_name, area):
    df = DATASETS[area]

    element = get_element_symbol(element_name, area)

    img_path = (util.os.path.join(curr_dir, 'Images')).replace('\\', '/')

    combined_ppm_plot_path = plot_ppm_variation(df, element, area)

    contour_fig, contour_axes = util.plt.subplots()
    plot_contour(df, element, element_name, f'{area} - {element_name} Concentration contour map)', area, ax=contour_axes)
    contour_fig_path = img_path + '/contour_plot.png'
    contour_fig.savefig(contour_fig_path)
    util.plt.close(contour_fig)
    
    return [contour_fig_path, combined_ppm_plot_path]

## Plotting ppm concentration plot on interface
def ppm_plot_interface(element_name, area):
    df = DATASETS[area]

    element = get_element_symbol(element_name, area)

    img_path = (util.os.path.join(curr_dir, 'Images')).replace('\\', '/')

    combined_ppm_plot_path = plot_ppm_variation(df, element, area)

    return combined_ppm_plot_path

## Plotting ppm concentration plot on interface
def contour_plot_interface(element_name, area):
    df = DATASETS[area]

    element = get_element_symbol(element_name, area)
    if area == 'Ramagiri':
        abundance = rg_elem_name_df.loc[rg_elem_name_df['Symbol'] == element, 'Crustal Abundance'].values[0]
    elif area == 'Kodangal':
        abundance = kg_elem_name_df.loc[kg_elem_name_df['Symbol'] == element, 'Crustal Abundance'].values[0]

    img_path = (util.os.path.join(curr_dir, 'Images')).replace('\\', '/')

    bold_start = '\033[1m'
    bold_end = '\033[0m'
    
    contour_fig, contour_axes = util.plt.subplots()
    plot_contour(df, element, element_name, f"{area} - {element_name} Concentration contour map(" + r"$\bf{Crustal\ Abundance}$" + f": {abundance})", 
                 area, ax=contour_axes)

    return contour_fig

## Cluster map
def plot_cluster_map(area):
    
    colors = ['#b45f06', '#d30202', '#203254']
    
    if area == 'Ramagiri':
        clustered_data = ramagiri_df
        # colors = ['#b45f06', '#d30202', '#203254']
        max_long_limit = util.np.round(max(clustered_data['longitude']) + 0.05, 4)
        min_long_limit = util.np.round(min(clustered_data['longitude']) - 0.05, 4)
        max_lat_limit = util.np.round(max(clustered_data['latitude']) + 0.05, 4)
        min_lat_limit = util.np.round(min(clustered_data['latitude']) - 0.05, 4)
        title = 'Ramagiri Clusters'
        y = 1.08
        x = 0.45
        width = 700
        height = 700
    elif area == 'Kodangal':
        clustered_data = kodangal_df
        # colors = ['#891ed2', '#b45f06', '#d30202', '#203254']
        max_long_limit = util.np.round(max(clustered_data['longitude'] + 1), 4)
        min_long_limit = util.np.round(min(clustered_data['longitude'] - 1), 4)
        max_lat_limit = util.np.round(max(clustered_data['latitude'] + 1), 4)
        min_lat_limit = util.np.round(min(clustered_data['latitude'] - 1), 4)
        title = 'Kodangal Clusters'
        y = 1.05
        x = 0.65
        width=1000
        height=1000
    
    # colors = ['#b45f06', '#d30202', '#203254']
    # Create separate traces for each cluster
    traces = []
    for clst in sorted(clustered_data['cluster'].unique()):
        clustered_subset = clustered_data[clustered_data['cluster'] == clst]
        hover_text = [
            f'Latitude: {lat}<br>Longitude: {lon}<br>Sediment ID: {gid}<br>Cluster: {clst}' 
            for lat, lon, gid in zip(clustered_subset['latitude'], 
                                     clustered_subset['longitude'], 
                                     clustered_subset['gid'])
        ]
        
        trace = util.go.Scattermapbox(
            lat=clustered_subset['latitude'],
            lon=clustered_subset['longitude'],
            mode='markers',
            marker=dict(
                size=10,
                color=colors[(clst-1) % len(colors)],  # Assign color for the cluster
                opacity=0.8
            ),
            name=f'Cluster {clst}',  # Name will be used in the legend
            text=hover_text,
            hoverinfo='text'
        )
        traces.append(trace)

    # Create layout for the map
    layout = util.go.Layout(
        title=title,
        hovermode='closest',
        mapbox=dict(
            style='open-street-map',
            center=dict(
                lat=clustered_data['latitude'].mean(),
                lon=clustered_data['longitude'].mean()
            ),
            zoom=10,
            bounds=dict(
                south=min_lat_limit,
                north=max_lat_limit,
                west=min_long_limit,
                east=max_long_limit
            )
        ),
        legend=dict(
            orientation="h",
            yanchor="top",
            y=y,
            xanchor="left",
            x=x
        )
    )
    
    # Create figure object and add traces and layout
    fig = util.go.Figure(data=traces, layout=layout)

    fig.update_layout(
        autosize=False,
        width=width,
        height=height,
    )
    
    # Return the figure
    return fig

## Dictionary for target-feature pairs
def get_models_for_area(area):
    models_by_area = {
        'Ramagiri' : ['Random Forest', 'XGBoost', 'CNN'],
        'Kodangal' : ['Random Forest', 'XGBoost', 'CNN']
    }

    # return {'choices' : models_by_area.get(area, [])} 
    return gr.update(models_by_area.get(area, []))

## Dictionary for predictors based on target and area
def get_features_for_target(target_element, area):
    feature_list = []

    if area == 'Kodangal':
        symbols_df = util.pd.read_excel(data_path + '/Kodangal dataset/Kodangal_element_names.xlsx', sheet_name=target_element)
        feat = list(symbols_df['Elements'])
        filtered_df = kg_elem_name_df[kg_elem_name_df['Symbol'].isin(feat)]
        feat_names = list(filtered_df['Name'])
    elif area == 'Ramagiri':
        chem_symbols = util.pd.read_excel(data_path + '/Ramagiri dataset/Ramagiri_element_names.xlsx', sheet_name=target_element)
        feat = list(symbols_df['Elements'])
        filtered_df = rg_elem_name_df[rg_elem_name_df['Symbol'].isin(feat)]
        feat_names = list(filtered_df['Name'])
        
    return feat_names

## Prediction
def pred(target_element, model, area, *features):
    feat_len = len(list(filter(None, features)))
    feature_values = features[:feat_len]
    feature_values = list(feature_values)
    feature_values = [float(val) for val in feature_values]
    
    if area == 'Kodangal':
        symb = kg_elem_name_df[kg_elem_name_df['Name'] == target_element]['Symbol'].values[0]
        symb_df = util.pd.read_excel(data_path + '/Kodangal dataset/Kodangal_element_names.xlsx', sheet_name=target_element)
        symb_list = list(symb_df['Elements'])
        xgb_model_path = util.os.path.join(curr_dir, 'Models/Kodangal/XGB_RF_' + symb + '.pkl').replace('\\', '/')
        norm_rs_path = (util.os.path.join(curr_dir, 'Models/Kodangal/rs_norm_' + symb + '.pkl')).replace('\\', '/')
        rf_model_path = (util.os.path.join(curr_dir, 'Models/Kodangal/RFR_RF_' + symb + '.pkl')).replace('\\', '/')
        
        with open(norm_rs_path, 'rb') as file:
            norm_rs = util.pkl.load(file)
        with open(rf_model_path, 'rb') as file:
            rf_model = util.pkl.load(file)
        with open(xgb_model_path, 'rb') as file:
            xgb_model = util.pkl.load(file)
            
    elif area == 'Ramagiri':
        symb = rg_elem_name_df[rg_elem_name_df['Name'] == target_element]['Symbol'].values[0]
        symb_df = util.pd.read_excel(data_path + '/Ramagiri dataset/Ramagiri_element_names.xlsx', sheet_name=target_element)
        symb_list = list(symb_df['Elements'])
        xgb_model_path = util.os.path.join(curr_dir, 'Models/Ramagiri/XGB_RF_' + symb + '.pkl').replace('\\', '/')
        norm_rs_path = (util.os.path.join(curr_dir, 'Models/Ramagiri/rs_norm_' + symb + '.pkl')).replace('\\', '/')
        rf_model_path = (util.os.path.join(curr_dir, 'Models/Ramagiri/RFR_RF_' + symb + '.pkl')).replace('\\', '/')
        
        with open(norm_rs_path, 'rb') as file:
            norm_rs = util.pkl.load(file)
        with open(rf_model_path, 'rb') as file:
            rf_model = util.pkl.load(file)
        with open(xgb_model_path, 'rb') as file:
            xgb_model = util.pkl.load(file)

    data = {symb_list[i] : [feature_values[i]] for i in range(len(symb_list))}
    df = util.pd.DataFrame(data)
    
    if model == 'Random Forest':
        pred_val = util.pred_val(df, norm_rs, rf_model)
    elif model == 'XGBoost':
        pred_val = util.pred_val(df, norm_rs, xgb_model)
    # elif model == 'CNN':
    #     util.pred_val(df, norm_rs, rf_model)
    
    return f'Prediction result for {target_element} using {model} is {util.np.round(pred_val[0], 4)}.'

# Load Data

In [4]:
curr_dir = (util.os.getcwd()).replace('\\', '/')
data_path = (util.os.path.join(curr_dir, 'Data')).replace('\\', '/')

In [5]:
kodangal_df = util.pd.read_csv(data_path + '/Kodangal dataset/3 clusters/clst3_Kodangal_gcM.csv')
ramagiri_df = util.pd.read_csv(data_path + '/Ramagiri dataset/3 clusters/clst3_stream_sediments_57F11.csv')

In [6]:
kg_elem_name_df = util.pd.read_excel(data_path + '/Kodangal dataset/Kodangal_element_names.xlsx')
rg_elem_name_df = util.pd.read_csv(data_path + '/Ramagiri dataset/Ramagiri_element_names.csv')

# Gradio interface

In [7]:
rg_elems = util.elem_list(ramagiri_df)
kg_elems = util.elem_list(kodangal_df)

In [8]:
rg_elem_name_df.head()

Unnamed: 0,Symbol,Name,Crustal Abundance
0,sio2,SiO2,
1,al2o3,Al2O3,
2,fe2o3,Fe2O3,
3,tio2,TiO2,
4,cao,CaO,


In [9]:
kg_elem_name_df.head()

Unnamed: 0,Symbol,Name,Crustal Aubandance (PPM)
0,sio2,SiO2,
1,al2o3,Al2O3,
2,fe2o3,Fe2O3,
3,tio2,TiO2,
4,cao,CaO,


In [10]:
DATASETS = {
    'Ramagiri' : ramagiri_df,
    'Kodangal' : kodangal_df
}

In [12]:
# Create Gradio interface with dynamic element dropdown
with gr.Blocks() as interface:
    gr.Markdown("<h1 style='text-align: center;'><b>CMT GSC</b></h1>")
    gr.Markdown("<h2 style='text-align: center;'>Geochemical Data Prediction & Classification App</h2>")

    with gr.Tab('Visualization'):
        with gr.Tab('Element concentrations'):
            with gr.Row():
                area = gr.Dropdown(choices=["Ramagiri", "Kodangal"], label="Area")
                elements = gr.Dropdown(label="Element")
        
            area.change(fn=update_elements, inputs=[area], outputs=[elements])
        
            with gr.Row():
                ppm_button = gr.Button("Generate ppm concentration variation Plots")
                contour_button = gr.Button("Generate Contour Map")
        
            with gr.Row():
                contour_output = gr.Plot()
        
            with gr.Row():
                ppm_output = gr.Plot()
                
            ppm_button.click(fn=ppm_plot_interface, inputs=[elements, area], outputs=[ppm_output], scroll_to_output=True)
            contour_button.click(fn=contour_plot_interface, inputs=[elements, area], outputs=[contour_output], scroll_to_output=True)

        with gr.Tab('PCA & Maps'):
            with gr.Row():
                area = gr.Dropdown(choices=["Ramagiri", "Kodangal"], label="Area")

            with gr.Row():
                pca_button = gr.Button('Generate PCA loadings')
                clst_button = gr.Button('Generate PC1 vs PC2')
                map_button = gr.Button('Generate map')

            plot_output = gr.Plot(elem_id='center-plot')

            map_button.click(fn=plot_cluster_map, inputs=[area], outputs=[plot_output], scroll_to_output=True)
                
                

    with gr.Tab('Prediction'): 
        with gr.Tab('New point prediction'):
            with gr.Row():
                pred_area = gr.Dropdown(choices=["Ramagiri", "Kodangal"], label="Area")
                target_element = gr.Dropdown(label="Target Element")
                model = gr.Dropdown(choices=['Random Forest', 'XGBoost', 'CNN'], label="Select the model")
            
            pred_area.change(fn=update_elements, inputs=[pred_area], outputs=[target_element])
    
            with gr.Row():
                
                feature_inputs = [gr.Textbox(visible=False) for _ in range(40)]
                
                @target_element.change(inputs=[target_element, pred_area], outputs=feature_inputs)
                def set_feature_inputs(target_element, area):
                    
                    features = get_features_for_target(target_element, area)
                    text_box = []
                    for idx, feature in enumerate(features):
                        temp = gr.Textbox(visible=True, label=feature, interactive=True)
                        text_box.append(temp)
    
                    for i in range(40 - len(features)):
                        temp = gr.Textbox(visible=False)
                        text_box.append(temp)
                    
                    return text_box
    
            with gr.Row():
                predict_button = gr.Button("Predict")
            
            with gr.Row():
                prediction_output = gr.Textbox(label="Prediction Result")
    
            predict_button.click(fn=pred, inputs=[target_element, model, pred_area] + feature_inputs, outputs=prediction_output)
        # with gr.Row():

        with gr.Tab('Actual vs Predicted Contours'):
            with gr.Row():
                area = gr.Dropdown(choices=["Ramagiri", "Kodangal"], label="Area")
                target_element = gr.Dropdown(label="Target Element")
                model = gr.Dropdown(choices=['Random Forest', 'XGBoost', 'CNN'], label="Select the model")
        
            area.change(fn=update_elements, inputs=[area], outputs=[elements])

            with gr.Row():
                contour_button = gr.Button("Generate Contour Map")

            @gr.render(inputs=area)
            def area_plot(area):
                if area == 'Ramagiri':
                    with gr.Row():
                        act_plot = gr.Plot()
                        pred_plot = gr.Plot()
                elif area == 'Kodangal':
                    with gr.Row():
                        act_plot = gr.Plot()
                    with gr.Row():
                        pred_plot = gr.Plot()
                

    # with gr.Tab('Results'):
    #     with gr.Row():
            
            

interface.launch()

Running on local URL:  http://127.0.0.1:7861

To create a public link, set `share=True` in `launch()`.


