In [32]:
from jupyter_dash import JupyterDash
import dash
import dash_bootstrap_components as dbc
import dash_core_components as dcc
import dash_vtk
import dash_html_components as html
from dash.dependencies import Input, Output, State
from dash import callback_context
import plotly.express as px
import numpy as np
import geojson
import pandas as pd
import numpy as np
import pickle
import xgboost as xgb
import heapq
from datetime import datetime
import warnings
warnings.filterwarnings("ignore")

In [2]:
with open('map/mo.geojson', encoding='utf-8') as f:
    gj = geojson.load(f)

In [3]:
# Предобработанные данные регрессионной модели
df_regression = pd.read_csv('data/regression_data.csv')

In [4]:
# Предобработанные данные мультиклассовой классификации
df_classification = pd.read_csv('data/classification_data.csv')

In [5]:
# Маппинг районов из geojson и датасета для обучения
areas_df = pd.read_csv('data/areas_df.csv')

In [6]:
# Границы для параметров
min_year = int(np.min(df_regression.build_year))
max_year = int(np.max(df_regression.build_year))
mean_year = int(np.mean(df_regression.build_year))

min_fullsq = int(np.min(df_regression.full_sq))
max_fullsq = int(np.max(df_regression.full_sq))
mean_fullsq = int(np.mean(df_regression.full_sq))

min_life_sq = int(np.min(df_regression.life_sq))
max_life_sq = int(np.max(df_regression.life_sq))
mean_life_sq = int(np.mean(df_regression.life_sq))

min_kitch_sq = int(np.min(df_regression.kitch_sq))
max_kitch_sq = int(np.max(df_regression.kitch_sq))
mean_kitch_sq = int(np.mean(df_regression.kitch_sq))

min_floor = int(np.min(df_regression.floor))
max_floor = int(np.max(df_regression.floor))
mean_floor = int(np.mean(df_regression.floor))

min_max_floor = int(np.min(df_regression.max_floor))
max_max_floor = int(np.max(df_regression.max_floor))
mean_max_floor = int(np.mean(df_regression.max_floor))

min_price = int(np.min(df_classification.price_doc))
max_price = int(np.max(df_classification.price_doc))
mean_price = int(np.mean(df_classification.price_doc))

In [33]:
# Загрузка моделей 
model_reg = pickle.load(open('data/models/model.pkl', "rb"))
model_cl = pickle.load(open('data/models/model_cl.pkl', "rb"))

In [8]:
district_lookup = {feature['properties']['NAME']: feature 
                   for feature in gj['features']}

In [9]:
df = [feature['properties']['NAME'] for feature in gj['features']]
df = pd.DataFrame(df, columns=['NAME'])

In [10]:
def get_highlights(selections, geojson=gj, district_lookup=district_lookup):
    geojson_highlights = dict()
    for k in geojson.keys():
        if k != 'features':
            geojson_highlights[k] = geojson[k]
        else:
            geojson_highlights[k] = [district_lookup[selection] for selection in selections]        
    return geojson_highlights

In [11]:
def get_figure(selections):
    fig = px.choropleth_mapbox(df, geojson=gj,                               
                               locations="NAME", 
                               featureidkey="properties.NAME",
                               opacity=0.5, height=600)


    if len(selections) > 0:
        highlights = get_highlights(selections)

        fig.add_trace(
            px.choropleth_mapbox(df, geojson=highlights, 
                                 locations="NAME", 
                                 featureidkey="properties.NAME",                                 
                                 opacity=1).data[0]
        )

    fig.update_layout(mapbox_style="carto-positron", 
                      mapbox_zoom=9,
                      mapbox_center={"lat": 55.753995, "lon": 37.614069},
                      margin={"r":0,"t":0,"l":0,"b":0},
                      uirevision='constant')
    
    return fig

In [12]:
areas_color_df = df.copy()
areas_color_df['color'] = 0

In [13]:
def get_figure2(datafig):
    fig = px.choropleth_mapbox(areas_color_df, geojson=gj,                               
                               locations="NAME", 
                               featureidkey="properties.NAME",
                               opacity=0.5, height=600)

    
    areas = datafig['areas']
    probs = datafig['probs']
    
    areas_color_df['color'] = 0
    
    if len(areas) > 0:
        highlights = get_highlights(areas)
        
        for ar, pr in zip(areas, probs):
            index = areas_color_df[areas_color_df.NAME == ar].index
            areas_color_df.loc[index, "color"] = pr
        
        fig.add_trace(
            px.choropleth_mapbox(areas_color_df, geojson=highlights,
                                 color = 'color',
                                 range_color=(np.min(areas_color_df.color), np.max(areas_color_df.color)),
                                 locations="NAME", 
                                 featureidkey="properties.NAME",                                 
                                 opacity=1).data[0]
        )

    fig.update_layout(mapbox_style="carto-positron", 
                      mapbox_zoom=9,
                      mapbox_center={"lat": 55.753995, "lon": 37.614069},
                      margin={"r":0,"t":0,"l":0,"b":0},
                      uirevision='constant')
    
    
    return fig

In [26]:
def Header(name, app):
    title = html.H1(name)
    logo = html.Img(
        src=app.get_asset_url("logo_sber.jpeg"), style={"float": "right", "height": 60}
    )
    link = html.A(logo, href="https://www.sberbank.ru/")

    return dbc.Row([dbc.Col(title, md=8), dbc.Col(link, md=4)], align="center")


In [27]:
def money_format(money):
    return '₽{:,}'.format(money)

In [28]:
# Начальная инициализация для отрисовки районов
selections = list(district_lookup.keys())[:5]
fig = get_figure(selections)
areas_test = list(district_lookup.keys())[:5]
probs_test = [0.3,0.4,0.01,0.005,0.9]
fig2 = get_figure2({'areas' : areas_test, 'probs' : probs_test})

In [29]:
vtk_view = dash_vtk.View(id="vtk-view")

In [30]:
# Структуры объектов при подаче в модель для предсказания
dict_for_reg = {'build_year': 2003,
                   'sub_area': 50,
                    'full_sq': 50,
                    'kitch_sq': np.nan,
                    'life_sq': np.nan,
                    'num_room': 5,
                    'floor': 12,
                    'max_floor': 22,
                    'month' : 5,
                    'dow' : 3,
                    'year' : 2015,
                    'rel_floor' : 0,
                    'rel_kitch_sq' : 0,                 
                    'state': 0,
                    'kindergarten_closely': 0,
                    'railroad_closely': 0,
                    'public_transport_station_closely': 0,
                    'metro_closely': 0,
                    'school_closely': 0,
                    'park_closely': 0,
                    'water_closely': 0}

dict_for_cl = {'build_year': 2003,                   
                    'full_sq': 50,
                    'kitch_sq': np.nan,
                    'life_sq': np.nan,
                    'num_room': 5,
                    'floor': 12,
                    'max_floor': 22,
                    'rel_floor' : 0,
                    'rel_kitch_sq' : 0,                 
                    'state': 0,
                    'kindergarten_closely': 0,
                    'railroad_closely': 0,
                    'public_transport_station_closely': 0,
                    'metro_closely': 0,
                    'school_closely': 0,
                    'park_closely': 0,
                    'water_closely': 0,
                    'price_doc' : 6000000}

In [35]:
selections = set(['Киевский'])
datafig = {'areas' : ['Киевский'], 'probs' : [1]}

controls = [
    
    
    html.Div(id='slider-container', children=
        [
            dbc.Label("Цена недвижимости"),
            dcc.Slider(
                min=min_price, max=max_price, value=mean_price, step=100000, id="price", marks=None,
                tooltip={"always_visible": True})

        ],
            style= {'display': 'none'}), 
    html.Div(id='price-container', children=
        [
            dbc.Label("Районы Москвы"),
            html.Div(dcc.Dropdown(list(district_lookup.keys()), ['Киевский'], multi=True,  id='districts')),
        ],
        style= {'display': 'block'}
    ),
    dbc.Row(
        [
            dbc.Label("Год постройки здания"),
            dcc.Slider(
                min=min_year, max=max_year, value=mean_year, step=1, id="builtyear", marks=None, tooltip={"always_visible": True},
            ),
        ]
    ),
    dbc.Row(
        [
            dbc.Label("Площадь недвижимости"),
            dcc.Slider(
                min=min_fullsq, max=max_fullsq, value=mean_fullsq, step=1, id="full_sq", marks=None, tooltip={"always_visible": True},
            ),
        ]
    ),
    dbc.Row(
        [
            
            dbc.Col([dbc.Label("Жилая площадь"),
                     dcc.Slider(
                min=min_life_sq, max=max_life_sq, value=mean_life_sq, step=1, id="life_sq", marks=None, tooltip={"always_visible": True},
            )]),
            dbc.Col([dbc.Label("Площадь кухни"),
                     dcc.Slider(
                min=min_kitch_sq, max=max_kitch_sq, value=mean_kitch_sq, step=1, id="kch_sq", marks=None, tooltip={"always_visible": True},
            )])
        ]
    ),
     dbc.Row(
        [dbc.Col([
            dbc.Label("Кол-во комнат"),
            html.Div(dcc.Dropdown(np.arange(1,9), 2,  id='num_rooms'))]),
         dbc.Col([dbc.Label("Этаж"),
                  dcc.Input(id="floor", type="number", value=mean_floor, min=min_floor, max=max_floor)
                ]),
         dbc.Col([dbc.Label("Этажей в доме"),
                  dcc.Input(id="max_floor", type="number", value=mean_max_floor, min=min_max_floor, max=max_max_floor)
                ])
        ]
     ),
    dbc.Row(
        [
            dbc.Label("Состояние квартиры"),
            dcc.Slider(
                min=1, max=5, value=3, step=1, id="state", tooltip={"always_visible": True},
            ),
        ]
    ),
    dbc.Row(
        [
            dbc.Label("Близко находится:"),
            dbc.Checklist(
                options=[
                    {"label": x.split(";")[1], "value": x.split(";")[0]}
                    for x in ['kindergarten_closely;Детский сад',
                              'railroad_closely;Железнодорожная станция',
                              'public_transport_station_closely;Остановка общественного транспорта',
                              'metro_closely;Станция метро',
                              'school_closely;Школа',
                              'park_closely;Парк',
                              'water_closely;Водохранилища / Реки']
                ],
                value=["kindergarten_closely", "railroad_closely"],
                id="params-enabled",
                inline=True,
            ),
        ]
    ),
    dbc.Row(
        [
            html.Button('Предсказать', id='predict', n_clicks=0),
            html.Div(id='predict-show')
        ]
    ),
    
]


radioactions = [
    
    dbc.RadioItems(id='radioaction',
                   options=[
                            {'label': 'Предсказать цену на недвижимость', 'value': 'Цена'},
                            {'label': 'Предсказать местоположения недвижимости', 'value': 'Район'}      
    ],
    value='Цена', 
)
]

maps = [
    
    html.Div(id='price-predict', children=
        [
            dbc.Label("Предсказание цен на недвижимость"),
            dcc.Graph(id='choropleth',figure=fig)

        ]),
    
    html.Div(id='areas-predict', children=
        [
            dbc.Label("Предсказание районов"),
            dcc.Graph(id='areasmap',figure=fig2)

        ])
      
]




app = JupyterDash(__name__,external_stylesheets=[dbc.themes.BOOTSTRAP])
server = app.server 

app.layout = dbc.Container(
    fluid=True,
    style={"height": "100vh"},
    children=[
        Header("Предсказание цен на недвижимость и местоположения районов", app),
        html.Hr(),
        dbc.Row(
            [                
                dbc.Col(
                    width=4,
                    children=dbc.Card(
                        [dbc.CardHeader("Actions"), dbc.CardBody(radioactions),
                         dbc.CardHeader("Controls"), dbc.CardBody(controls)
                         ]
                    ),
                ),
                dbc.Col(
                    width=8,
                    children=dbc.Card(
                        [
                            dbc.CardHeader("Карта Москвы и Московской области"),
                            dbc.CardBody(maps),
                            dcc.Markdown(id='predicted-areas', children=[f'''
                            ## Привет!
                            '''], 
                                         style={'marginLeft': 10, 'marginRight': 10, 'marginTop': 30, 'marginBottom': 10})
                        ],
                        style={"height": "77vh"},
                    ),
                ),
            ],
        ), dcc.Store(id='cache')
    ],
)


@app.callback(
    
     [Output(component_id='predicted-areas', component_property='children'),
      Output(component_id='cache', component_property='data')],
    [Input(component_id='radioaction', component_property='value'),
     Input(component_id='predict', component_property='n_clicks'),
     Input(component_id='price', component_property='value'),
     Input(component_id='districts', component_property='value'),
     Input(component_id='builtyear', component_property='value'),
     Input(component_id='full_sq', component_property='value'),
     Input(component_id='life_sq', component_property='value'),
     Input(component_id='kch_sq', component_property='value'),
     Input(component_id='num_rooms', component_property='value'),
     Input(component_id='floor', component_property='value'),
     Input(component_id='max_floor', component_property='value'),
     Input(component_id='state', component_property='value'),
     Input(component_id='params-enabled', component_property='value'),
    ]
)
def predict_price(action, n_clicks, price, district, builtyear, fullsq, lifesq, kchsq, numrooms, floor, maxfloor, state, params):
    all_params = ['kindergarten_closely',
                  'railroad_closely',
                  'public_transport_station_closely',
                  'metro_closely',
                  'school_closely',
                  'park_closely',
                  'water_closely']
    now_date = datetime.now()
    
    changed_id = [p['prop_id'] for p in dash.callback_context.triggered][0]

    if 'predict' in changed_id:

        if action == 'Цена':
            dict_for_reg['year'] = now_date.year
            dict_for_reg['month'] = now_date.month
            dict_for_reg['dow'] = now_date.weekday()
            dict_for_reg['build_year'] = builtyear
            dict_for_reg['full_sq'] = fullsq
            dict_for_reg['kitch_sq'] = kchsq
            dict_for_reg['life_sq'] = lifesq
            dict_for_reg['num_room'] = numrooms
            dict_for_reg['floor'] = floor
            dict_for_reg['max_floor'] = maxfloor
            dict_for_reg['state'] = state
            for p in all_params:
                if p in params:
                    dict_for_reg[p] = 1
                else: 
                    dict_for_reg[p] = 0

            data_for_predict = pd.DataFrame(dict_for_reg.values(), index=dict_for_reg.keys()).T

            data_for_predict['rel_floor'] = data_for_predict['floor'] / data_for_predict['max_floor']
            data_for_predict['rel_kitch_sq'] = data_for_predict['kitch_sq'] / data_for_predict['full_sq']

            if len(district) == 1:
                data_for_predict['sub_area'] = areas_df[areas_df.json_areas == district[0]]['label'].values[0]
                res = np.exp(model_reg.predict(xgb.DMatrix(data_for_predict, feature_names=data_for_predict.columns.values)))
            if len(district) > 1:
                mean_price = []
                for d in district:
                    data_for_predict['sub_area'] = areas_df[areas_df.json_areas == d]['label'].values[0]
                    mean_price.append(np.exp(model_reg.predict(xgb.DMatrix(data_for_predict,
                                                                           feature_names=data_for_predict.columns.values))))
                res = [int(np.mean(mean_price))]

                
            res = '''
                ## Прогноз цены: _{}_
                '''.format(money_format(int(res[0])))

            return res, dash.no_update

        if action == 'Район':
            dict_for_cl['price_doc'] = price
            dict_for_cl['build_year'] = builtyear
            dict_for_cl['full_sq'] = fullsq
            dict_for_cl['kitch_sq'] = kchsq
            dict_for_cl['life_sq'] = lifesq
            dict_for_cl['num_room'] = numrooms
            dict_for_cl['floor'] = floor
            dict_for_cl['max_floor'] = maxfloor
            dict_for_cl['state'] = state
            for p in all_params:
                if p in params:
                    dict_for_cl[p] = 1
                else: 
                    dict_for_cl[p] = 0

            data_for_predict = pd.DataFrame(dict_for_cl.values(), index=dict_for_cl.keys()).T

            data_for_predict['rel_floor'] = data_for_predict['floor'] / data_for_predict['max_floor']
            data_for_predict['rel_kitch_sq'] = data_for_predict['kitch_sq'] / data_for_predict['full_sq']

            f_el = model_cl.predict_proba(data_for_predict)[0]

            dict_probs = {idx : el for idx, el in enumerate(f_el)}
            top5_sub_areas = heapq.nlargest(5, dict_probs, key=dict_probs.get)
            top5_probs = [dict_probs[i] for i in top5_sub_areas]
            top5_sub_areas = [areas_df[areas_df.label == x]['json_areas'].values[0] for x in top5_sub_areas]
            
            res = '''
                ## Прогноз районов: 
                ### _{}_
                '''.format(', '.join(top5_sub_areas))

            return res, {'areas' : top5_sub_areas, 'probs' : top5_probs}
    return dash.no_update, dash.no_update

        
            

    
@app.callback(
   [Output(component_id='slider-container', component_property='style'),
    Output(component_id='price-container', component_property='style'),
    Output(component_id='price-predict', component_property='style'),
    Output(component_id='areas-predict', component_property='style')],
   [Input(component_id='radioaction', component_property='value')])
def show_hide_element(visibility_state):
    if visibility_state == 'Район':
        return {'display': 'block'}, {'display': 'none'}, {'display': 'none'}, {'display': 'block'}
    if visibility_state == 'Цена':
        return {'display': 'none'}, {'display': 'block'}, {'display': 'block'}, {'display': 'none'}

    

@app.callback(
   [Output('choropleth', 'figure'),
    Output('districts', 'value')],
    [Input('choropleth', 'clickData'),
     Input('districts', 'value'),
     Input('radioaction', 'value'),
     Input(component_id='predict', component_property='n_clicks')])
def update_figure(clickData, value, action, n_clicks):

    changed_id = [p['prop_id'] for p in dash.callback_context.triggered][0]


    
    if len(list(set(value) ^ set(selections))) != 0:


        if len(value) < len(selections):
            res = selections - set(value)
            res = list(res)
            for el in res:
                selections.remove(el)

        else:
            res = set(value) - selections
            res = list(res)[0]
            selections.add(res)


        return get_figure(selections), list(selections)

    
    
    if clickData is not None and 'predict' not in changed_id:
        location = clickData['points'][0]['location']
    

        if location not in selections:

            selections.add(location)
        else:

            selections.remove(location)

            
    return get_figure(selections), list(selections)


@app.callback(
   Output('areasmap', 'figure'),
    [Input('radioaction', 'value'),
     Input(component_id='predict', component_property='n_clicks'),
     Input('cache', 'data')
     ])
def update_figure_for_areas(action, n_clicks, cache):
    
    if cache is None:
        return get_figure2(datafig)
            

    changed_id = [p['prop_id'] for p in dash.callback_context.triggered][0]   
    if action == 'Район' and cache is not None:
        
        if 'predict' in changed_id:
            return get_figure2(cache)
    return get_figure2(cache)




app.run_server(debug=True)

Dash app running on http://127.0.0.1:8050/
