In [1]:
import pandas as pd
import numpy as np
import requests
import os
import zipfile
import io
POSTGRES_PASSWORD = os.getenv('POSTGRES_PASSWORD')
import psycopg2
from sqlalchemy import create_engine

import dash
from dash import dcc
from dash import html
from dash import dash_table
from dash.dependencies import Input, Output, State

import plotly.figure_factory as ff
import plotly.express as px
import openai
import asyncio
from openai import AsyncOpenAI
import plotly.graph_objs as go
import dash_bootstrap_components as dbc
import pickle


key = os.getenv("OPENAI_API_KEY")
openai.key = key

In [2]:
engine = create_engine('postgresql+psycopg2://{user}:{password}@{host}:{port}/{db}'.format(
    user = 'postgres',
    password = POSTGRES_PASSWORD,
    host = 'postgres',
    port = 5432,
    db = 'world'
))

In [3]:
country_query = '''
SELECT country_name_hdi
FROM hdi
ORDER BY country_name_hdi'''

countries = pd.read_sql_query(country_query, engine)

countries_list = [{'label': country, 'value': country} for country in countries['country_name_hdi']]


In [4]:
external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']
external_stylesheets = [dbc.themes.BOOTSTRAP]

In [5]:
# Initialize the AsyncOpenAI client
client = AsyncOpenAI(
    api_key=key, 
)

In [6]:
query = '''SELECT rank, country_name_hdi, hdi_value
FROM hdi
ORDER BY rank'''
df = pd.read_sql(query, engine)
df

Unnamed: 0,rank,country_name_hdi,hdi_value
0,1,Switzerland,0.962
1,2,Norway,0.961
2,3,Iceland,0.959
3,5,Australia,0.951
4,6,Denmark,0.948
...,...,...,...
139,184,Burkina Faso,0.449
140,185,Mozambique,0.446
141,186,Mali,0.428
142,187,Burundi,0.426


In [7]:
import pickle
with open('model.pkl', 'rb') as f:
    model = pickle.load(f)

In [8]:
app = dash.Dash(__name__, external_stylesheets=external_stylesheets)


app.layout = html.Div(
    [
        ### Stuff on top
        html.H1("Understand the Global Economy"),
        html.H2("Data collected from the World Bank and Human Development"),
        html.H3("DS 6600: Data Engineering 1, UVA Data Science, Semester Project"),

        #### Side Bar
        html.Div([
            dcc.Markdown("Select a Country"),
            dcc.Dropdown(id='countries', options=countries_list, value='United States')
            
            
        ], style={'width': '24%', 'float':'left'}),
        
        ### main bar
        html.Div([
            dcc.Tabs([
                dcc.Tab(label='Trade Data', children=[
                    dcc.Tabs(id='trade-tabs', children=[
                        dcc.Tab(label='Exports', children=[
                            html.Div([dcc.Graph(id='exports-content')])  # Container for Exports data
                        ]),
                        dcc.Tab(label='Imports', children=[
                            html.Div([dcc.Graph(id='imports-content')])  # Container for Imports data
                        ]),
                        dcc.Tab(label='Tariffs', children=[
                            html.Div([dcc.Graph(id='tariff-content')])  # Container for Tariffs data
                        ])
                        
                    ]),
                    # Additional components for the combined tab, if any
                ]),
                dcc.Tab(label='HDI', children=[
                   # Container for the HDI prediction inputs
    html.Div([
        dcc.Input(id='life_exp', type='number', placeholder='Life Expectancy'),
        dcc.Input(id='exp_years_of_school', type='number', placeholder='Expected Years of School'),
        dcc.Input(id='mean_years_of_school', type='number', placeholder='Mean Years of School'),
        dcc.Input(id='gni', type='number', placeholder='GNI'),
        html.Button('Predict HDI', id='predict_button'),
        html.Div(id='prediction_output')
    ], style={'margin-bottom': '20px'}), 

    # Container for the DataTable
    html.Div([
        dash_table.DataTable(
            id='country_hdi_table',
            columns=[
            {'name': 'Rank', 'id': 'rank'},
            {'name': 'Country', 'id': 'country_name_hdi'},
            {'name': 'HDI', 'id': 'hdi_value'}
        ],
        data=df.to_dict('records'),  # Load data statically here
        sort_action='native',
        style_table={'overflowX': 'auto'},
    )
    ])
]),
                dcc.Tab(label="Chat GPT", children=[
                   # Chatbot UI components
        html.Div([
            dcc.Textarea(id='chat-input', placeholder='Select a country to get started...'),
            html.Button('Send', id='send-button'),
            html.Div(id='chat-output')
        ], style={'width': '100%', 'display': 'block', 'clear': 'both'})  # Adjust style as needed 
                ]),
                dcc.Tab(label='Location', children=[
                    dcc.Graph(id='country-map', style={'height': '1000px'})
                ]),
            ])
        ], style={'width': '74%', 'float': 'right', 'backgroundColor': '#FAEBD7'}),

        
    ]
)



query = "SELECT country_name_hdi, latitude, longitude FROM hdi"
df = pd.read_sql(query, engine)


country_coordinates = {row['country_name_hdi']: {'lat': row['latitude'], 'lon': row['longitude']} for index, row in df.iterrows()}

@app.callback(
    Output('country-map', 'figure'),
    [Input('countries', 'value')]
)
def update_map(selected_country):
    lat = country_coordinates[selected_country]['lat']
    lon = country_coordinates[selected_country]['lon']
    return go.Figure(
        data=[go.Scattergeo(
            lat=[lat],
            lon=[lon],
            mode='markers'
        )],
        layout=go.Layout(
            geo=dict(
                scope='world',
                showland=True,
                landcolor="rgb(217, 217, 217)",
                subunitcolor="rgb(217, 217, 217)",
                countrycolor="rgb(217, 217, 217)",
                showcountries=True,
                center={'lat': lat, 'lon': lon},
                projection={'type': 'mercator'}
            )
        )
    )

#################################

# Callback to update the chat-input placeholder based on selected country
@app.callback(
    Output('chat-input', 'placeholder'),
    [Input('countries', 'value')]
)
def update_chat_placeholder(selected_country):
    if selected_country:
        return f"Provide information on {selected_country} regarding their imports, exports, tariffs, and GDP over the last 10 years"
    else:
        return "Select a country to get started..."
#################

import threading

def run_async(coroutine):
    def start_loop(loop):
        asyncio.set_event_loop(loop)
        loop.run_forever()

    new_loop = asyncio.new_event_loop()
    t = threading.Thread(target=start_loop, args=(new_loop,))
    t.start()
    return asyncio.run_coroutine_threadsafe(coroutine, new_loop)

async def collect_stream_data(user_query):
    response_content = ""
    stream = await client.chat.completions.create(
        model="gpt-3.5-turbo",
        messages=[{"role": "user", "content": user_query}],
        stream=True,
    )

    async for chunk in stream:
        if chunk.choices[0].delta.content is not None:
            response_content += chunk.choices[0].delta.content

    return response_content

@app.callback(
    Output('chat-output', 'children'),
    [Input('send-button', 'n_clicks')],
    [State('chat-input', 'value')]
)
def update_output(n_clicks, user_query):
    if n_clicks is None:
        return "Awaiting input..."
    elif user_query is None or user_query.strip() == "":
        return "Please enter a query."
    else:
        future = run_async(collect_stream_data(user_query))
        response_content = future.result()

        return html.Div([
            html.P(f"User: {user_query}"),
            html.P(f"OpenAI: {response_content}")
        ])

############################################
@app.callback(
    Output('prediction_output', 'children'),
    [Input('predict_button', 'n_clicks')],
    [State('life_exp', 'value'), 
     State('exp_years_of_school', 'value'), 
     State('mean_years_of_school', 'value'), 
     State('gni', 'value')]
)
def predict_hdi(n_clicks, life_exp, exp_years_of_school, mean_years_of_school, gni):
    if n_clicks is not None and all(v is not None for v in [life_exp, exp_years_of_school, mean_years_of_school, gni]):
        input_data = pd.DataFrame([[life_exp, exp_years_of_school, mean_years_of_school, gni]],
                                  columns=['life_exp', 'exp_years_of_school', 'mean_years_of_school', 'gni'])
        prediction = model.predict(input_data)[0]
        return f'Predicted HDI: {prediction}'
    else:
        return 'Enter all values and click Predict'



###########################

def normalize_column(column):
    return (column - column.min()) / (column.max() - column.min())

@app.callback(
    ([Output(component_id = 'exports-content', component_property = 'figure')]),
    [Input(component_id ='countries', component_property = 'value')]
)
def update_exports_tab(selected_country):
    if selected_country is not None:
        # Query to fetch the last 10 years of export data for the selected country
        query = f"""
        SELECT year, exports, gdp
        FROM year_country
        WHERE country_name_exp = '{selected_country}'
        ORDER BY year DESC
        LIMIT 10
        """
        df = pd.read_sql(query, engine)
        # Normalize the data
        df['exports_norm'] = normalize_column(df['exports'])
        df['gdp_norm'] = normalize_column(df['gdp'])

        
        fig = go.Figure()

        
        fig.add_trace(go.Scatter(
            x=df['year'],
            y=df['exports_norm'],
            mode='lines+markers',
            name='Exports'
        ))

        
        fig.add_trace(go.Scatter(
            x=df['year'],
            y=df['gdp_norm'],
            mode='lines+markers',
            name='GDP'
        ))

        
        fig.update_layout(
            title=f'Exports and GDP for {selected_country} (Last 10 Years)',
            xaxis_title='Year',
            yaxis_title='Value (USD | Normalized)',
            legend_title='Indicator',
            hovermode='closest'
        )

        return [fig]

    else:
        return "Please select a country."



##########################

@app.callback(
    ([Output(component_id = 'imports-content', component_property = 'figure')]),
    [Input(component_id ='countries', component_property = 'value')]
)
def update_imports_tab(selected_country):
    if selected_country is not None:
        
        query = f"""
        SELECT year, imports, gdp
        FROM year_country
        WHERE country_name_exp = '{selected_country}'
        ORDER BY year DESC
        LIMIT 10
        """
        df = pd.read_sql(query, engine)
        
        df['imports_norm'] = normalize_column(df['imports'])
        df['gdp_norm'] = normalize_column(df['gdp'])

       
        fig = go.Figure()

        
        fig.add_trace(go.Scatter(
            x=df['year'],
            y=df['imports_norm'],
            mode='lines+markers',
            name='Exports'
        ))

       
        fig.add_trace(go.Scatter(
            x=df['year'],
            y=df['gdp_norm'],
            mode='lines+markers',
            name='GDP'
        ))

        
        fig.update_layout(
            title=f'Imports and GDP for {selected_country} (Last 10 Years)',
            xaxis_title='Year',
            yaxis_title='Value (USD | Normalized)',
            legend_title='Indicator',
            hovermode='closest'
        )

        return [fig]

    else:
        return "Please select a country."


##################


@app.callback(
    ([Output(component_id = 'tariff-content', component_property = 'figure')]),
    [Input(component_id ='countries', component_property = 'value')]
)
def update_tariffs_tab(selected_country):
    if selected_country is not None:
        
        query = f"""
        SELECT year, tariff_rate, gdp
        FROM year_country
        WHERE country_name_exp = '{selected_country}'
        ORDER BY year DESC
        LIMIT 10
        """
        df = pd.read_sql(query, engine)

        
        df['tariff_norm'] = normalize_column(df['tariff_rate'])
        df['gdp_norm'] = normalize_column(df['gdp'])

        
        fig = go.Figure()

        
        fig.add_trace(go.Scatter(
            x=df['year'],
            y=df['tariff_norm'],
            mode='lines+markers',
            name='Tariffs'
        ))

      
        fig.add_trace(go.Scatter(
            x=df['year'],
            y=df['gdp_norm'],
            mode='lines+markers',
            name='GDP'
        ))

        
        fig.update_layout(
            title=f'Imports and GDP for {selected_country} (Last 10 Years)',
            xaxis_title='Year',
            yaxis_title='Value (USD | Normalized)',
            legend_title='Indicator',
            hovermode='closest'
        )

        return [fig]

    else:
        return "Please select a country."


if __name__=="__main__":
    app.run_server(mode='external', host = "0.0.0.0", port = 8050, debug=True)