In [3]:
from sklearn.neighbors import KernelDensity
from shapely.geometry import Point, Polygon
from transformers import DistilBertTokenizerFast
import matplotlib.pyplot as plt
import geopandas
# from sklearn.cluster import DBSCAN
import pandas as pd
import torch
import numpy as np
from scipy.spatial import ConvexHull
from api.model import GeoModel

In [4]:
import folium
from folium.plugins import HeatMap

In [5]:
BASE_MODEL = 'api/model/checkpoints_2021-12-17_model-distilbert-base-uncased_loss-haversine-combined-data_model.pt'
TOKEN_MODEL = 'distilbert-base-uncased'

In [6]:
geobert = GeoModel(
    model_string=BASE_MODEL,
    tokenizer_string=TOKEN_MODEL,
    tokenizer_class=DistilBertTokenizerFast,
    max_seq_length=200)

In [60]:
text = '''
Germany has summoned the Russian ambassador and expelled two Russian diplomats after a court ruled on Wednesday that Moscow had ordered the 2019 murder of a Georgian citizen in Berlin, German Foreign Minister Annalena Baerbock said.

The court found Russian citizen Vadim Krasikov guilty of the August 2019 murder of Tornike Khangoshvili, an ethnic Chechen of Georgian nationality, in a central Berlin park, and sentenced him to life imprisonment.
The court said his crime had been ordered by the Russian government.

"This murder, ordered by a state as the court found today, constitutes a severe breach of German law and the sovereignty of the Federal Republic of Germany. Therefore we have just summoned the Russian ambassador for a talk," Baerbock said in a statement.
"The Russian ambassador was notified that two members of the diplomatic personnel will be declared persona non grata," she added.
Later that day, Russian Foreign Ministry spokeswoman Maria Zakharova addressed the expulsion and summons of the ambassador.

"Berlin's unfriendly actions will not remain without an adequate response. A statement on this matter will be made in the near future," she wrote on her Telegram channel.
The judge had earlier emphasized the connection of the Russian state with Khangoshvili's murder.
"In June 2019 at the latest, state organs of the central government of the Russian Federation took the decision to liquidate Tornike Khangoshvili in Berlin," the judge said.
'''
len(text.split(" "))


223

In [61]:
geobert.forward(text=text)
point_prediction = geobert.predict_point()
# point_prediction = np.rad2deg(point_prediction)
point_prediction

array([53.14446 , 25.925411], dtype=float32)

In [62]:
last_hidden_state = geobert.output['hidden_states'][-1]
hidden_state = last_hidden_state  # (bs, seq_len, dim)
pooled_output = hidden_state[:, 0]  # (bs, dim)
rand_locs = []
pooled_output_dim = pooled_output.shape[1]
with torch.no_grad():
    geobert.model.eval()
    for i in range(pooled_output_dim):
        ones = torch.ones_like(pooled_output)
        ones[0, i] = 0
        masked_output = ones * pooled_output
        pre_clf_output = geobert.model.pre_classifier(masked_output)  # (bs, dim)
        relu_output = torch.nn.ReLU()(pre_clf_output)  # (bs, dim)
        rand_loc = geobert.model.classifier(relu_output).cpu().detach().numpy().squeeze()
        rand_locs.append(rand_loc)

out_df = pd.DataFrame(rand_locs, columns=['lat', 'lon'])
# out_df = np.rad2deg(out_df)

In [63]:
masked_area_hull = ConvexHull(out_df.values)
masked_area_vertecies = masked_area_hull.points[masked_area_hull.vertices]
masked_area_points = geopandas.GeoSeries([Point(point[::-1]) for point in masked_area_vertecies])
masked_area_polygon = geopandas.GeoSeries(Polygon(masked_area_points))

In [64]:
maploc = folium.Map(
    location=point_prediction,
    zoom_start=8,
    tiles="Stamen Toner",
)
maploc.add_child(
    folium.Marker(point_prediction.tolist())
)
maploc.add_child(folium.GeoJson(
    data=masked_area_polygon.to_json(),
    style_function=lambda x: {'fillColor': 'blue', 'stroke': False, 'fillOpacity': 0.5})
)
maploc.add_child(HeatMap(out_df))
maploc

In [65]:
kde = KernelDensity(
    kernel='gaussian',
    metric='haversine',
    bandwidth=0.006)

In [66]:
kde.fit(np.radians(out_df))

KernelDensity(bandwidth=0.006, metric='haversine')

In [67]:
kde_sample = np.rad2deg(kde.sample(1000))
kde_area_hull = ConvexHull(kde_sample)
kde_area_vertecies = kde_area_hull.points[kde_area_hull.vertices]
kde_area_points = geopandas.GeoSeries([Point(point[::-1]) for point in kde_area_vertecies])
kde_area_polygon = geopandas.GeoSeries(Polygon(kde_area_points))
kde_area_polygon.area

0    4.388364
dtype: float64

In [68]:
maploc = folium.Map(
    location=kde_sample.mean(axis=0),
    zoom_start=7,
    tiles="Stamen Toner",
)
maploc.add_child(
    folium.Marker(kde_sample.mean(axis=0).tolist())
)
maploc.add_child(folium.GeoJson(
    data=kde_area_polygon.to_json(),
    style_function=lambda x: {'fillColor': 'blue', 'stroke': False, 'fillOpacity': 0.5})
)
maploc.add_child(HeatMap(kde_sample))
maploc