<a href="https://colab.research.google.com/github/samhita-alla/geolocator/blob/main/geolocator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Video/Image Geolocator

Uses a [pre-trained GeoEstimation model](https://github.com/TIBHannover/GeoEstimation) to perform video/image inferencing.

- Get the video/image path. If a YouTube video, download it.
- If a video, retrieve video frames.
- Perform model inferencing on video frames or an image.
- Apply DBSCAN clustering on the predicted lats and longs.
- Retrieve the dense cluster.
- Compute mean of lats and longs of the data points belonging to the dense cluster.
- Predict location and retrieve plotly graph.

## BentoML

- Create an ONNX version of the model.
- Generate Bento.
- Spin up the bento service.

## Library dependencies

- Katna
- Youtube DL
- Scikit Learn
- PyTorch Lightning
- s2sphere
- Geopy
- Gradio
- ONNX
- ONNX Runtime
- BentoML


In [1]:
# install dependencies
%pip install -q katna youtube_dl pytorch-lightning s2sphere scikit-learn gradio bentoml onnx onnxruntime

Note: you may need to restart the kernel to use updated packages.


In [2]:
sh = """
URL="https://github.com/samhita-alla/GeoEstimation.git"
FOLDER="GeoEstimation"
if [ ! -d "$FOLDER" ] ; then
    git clone $URL $FOLDER
else
    cd "$FOLDER"
    git pull $URL
fi
"""

with open("clone_script.sh", "w") as file:
  file.write(sh)

!bash clone_script.sh

Cloning into 'GeoEstimation'...
remote: Enumerating objects: 574, done.[K
remote: Counting objects: 100% (115/115), done.[K
remote: Compressing objects: 100% (74/74), done.[K
remote: Total 574 (delta 83), reused 69 (delta 41), pack-reused 459[K
Receiving objects: 100% (574/574), 1.90 MiB | 361.00 KiB/s, done.
Resolving deltas: 100% (333/333), done.


In [3]:
import sys

in_colab = "google.colab" in sys.modules

if in_colab:
  sh = """
  URL="https://github.com/samhita-alla/geolocator.git"
  FOLDER="."
  if [ ! -d "$FOLDER" ] ; then
      git clone $URL $FOLDER
  else
      cd "$FOLDER"
      git pull $URL
  fi
  """

  with open("clone_script.sh", "w") as file:
    file.write(sh)

  !bash clone_script.sh

In [4]:
import shutil

shutil.copy("service.py", "GeoEstimation")
shutil.copy("post_processing.py", "GeoEstimation")
shutil.copy("pre_processing.py", "GeoEstimation")
shutil.copy("bentofile.yaml", "GeoEstimation")


'GeoEstimation/bentofile.yaml'

In [5]:
%cd GeoEstimation

/Users/samhitaalla/Desktop/geolocator/GeoEstimation


In [6]:
import glob
import shutil

from IPython.display import Image, display


image_dir = None
image_parent_dir = "geolocator-images"


def display_video_frames(frames_directory: str):
    frames = glob.glob(f"{frames_directory}/*.jpeg")

    for frame in frames:
        display(Image(filename=frame, width=200, height=100))


In [7]:
# download the model checkpoint & hyperparameters
!mkdir -p models/base_M
!wget https://github.com/TIBHannover/GeoEstimation/releases/download/pytorch/epoch.014-val_loss.18.4833.ckpt -O models/base_M/epoch=014-val_loss=18.4833.ckpt
!wget https://github.com/TIBHannover/GeoEstimation/releases/download/pytorch/hparams.yaml -O models/base_M/hparams.yaml

--2022-10-05 11:43:52--  https://github.com/TIBHannover/GeoEstimation/releases/download/pytorch/epoch.014-val_loss.18.4833.ckpt
Resolving github.com (github.com)... 20.207.73.82
Connecting to github.com (github.com)|20.207.73.82|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/142275851/fc162380-3e05-11eb-9190-3ec4e4ff49c1?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20221005%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20221005T061353Z&X-Amz-Expires=300&X-Amz-Signature=61111447ab4fee1e95c16e3b9fe28985a7349ca54476d0c8fdc62e675ce23167&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=142275851&response-content-disposition=attachment%3B%20filename%3Depoch.014-val_loss.18.4833.ckpt&response-content-type=application%2Foctet-stream [following]
--2022-10-05 11:43:53--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/142275851/fc1

In [8]:
!mkdir -p resources/s2_cells
!wget https://raw.githubusercontent.com/TIBHannover/GeoEstimation/original_tf/geo-cells/cells_50_5000.csv -O resources/s2_cells/cells_50_5000.csv
!wget https://raw.githubusercontent.com/TIBHannover/GeoEstimation/original_tf/geo-cells/cells_50_2000.csv -O resources/s2_cells/cells_50_2000.csv
!wget https://raw.githubusercontent.com/TIBHannover/GeoEstimation/original_tf/geo-cells/cells_50_1000.csv -O resources/s2_cells/cells_50_1000.csv

--2022-10-05 11:44:12--  https://raw.githubusercontent.com/TIBHannover/GeoEstimation/original_tf/geo-cells/cells_50_5000.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 177214 (173K) [text/plain]
Saving to: ‘resources/s2_cells/cells_50_5000.csv’


2022-10-05 11:44:13 (3.42 MB/s) - ‘resources/s2_cells/cells_50_5000.csv’ saved [177214/177214]

--2022-10-05 11:44:13--  https://raw.githubusercontent.com/TIBHannover/GeoEstimation/original_tf/geo-cells/cells_50_2000.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.110.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 389388 (380K) [

In [28]:
import logging
import os
import subprocess
from pathlib import Path
from typing import List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly
import plotly.express as px
from generate_map import get_plotly_graph
from geopy.extra.rate_limiter import RateLimiter
from geopy.geocoders import Nominatim
from IPython.core.profiledir import LoggingConfigurable
from post_processing import generate_prediction_helper, generate_prediction_logit
from sklearn.cluster import DBSCAN
from sklearn.preprocessing import StandardScaler


def data_engineering(image_dir: str) -> pd.DataFrame:
    inference_file_path = os.path.join(
        "models/base_M",
        f"inference_{Path(os.path.join('/content', image_dir)).stem}.csv",
    )
    inference_df = pd.read_csv(inference_file_path)
    logging.info(f"Inference DF: {inference_df.head()}")

    return inference_df


def generate_prediction(image_dir: str, num_workers: int = 0) -> Tuple[str, plotly.graph_objects.Figure]:
    # generate predictions on all the video frames
    subprocess.run(
        [
            "python",
            "-m",
            "classification.inference",
            "--image_dir",
            image_dir,
            "--checkpoint",
            "models/base_M/epoch=014-val_loss=18.4833.ckpt",
            "--hparams",
            "models/base_M/hparams.yaml",
            "--num_workers",
            str(num_workers)
        ],
        capture_output=True,
    )

    # data engineering
    inference_df = data_engineering(image_dir=image_dir)

    # get location
    location, latitude, longitude = generate_prediction_logit(
        inference_df=inference_df
    )

    return location, get_plotly_graph(
        latitude=latitude, longitude=longitude, location=location
    )


ImportError: cannot import name 'generate_prediction_logit' from 'post_processing' (/Users/samhitaalla/Desktop/geolocator/post_processing.py)

In [10]:
from typing import Any, Dict

from pre_processing import capture_frames, extract_youtube_video


IMAGE_PARENT_DIR = "geolocator-images"


def create_image_dir(img_file: str) -> str:
    image_dir = os.path.join(IMAGE_PARENT_DIR, os.path.basename(img_file).split(".")[0])

    # clear the image directory before filling it up
    shutil.rmtree(image_dir, ignore_errors=True)
    os.makedirs(image_dir)
    shutil.copy(img_file, image_dir)

    return image_dir


def img_processor(img_file: str) -> Tuple[str, plotly.graph_objects.Figure]:
    image_dir = create_image_dir(img_file=img_file)
    return generate_prediction(image_dir=image_dir)


def video_helper(
    video_file: str, info_dict: Dict[str, Any]
) -> Tuple[str, plotly.graph_objects.Figure]:
    # capture frames
    frames_directory = capture_frames(video_file_path=video_file, info_dict=info_dict)
    display_video_frames(frames_directory=frames_directory)

    image_dir = frames_directory
    return generate_prediction(image_dir=image_dir)


def video_processor(video_file: str) -> Tuple[str, plotly.graph_objects.Figure]:
    info_dict = {"id": os.path.basename(video_file).split(".")[0]}
    return video_helper(video_file=video_file, info_dict=info_dict)

def url_processor(url: str) -> Tuple[str, plotly.graph_objects.Figure]:
    video_file, info_dict = extract_youtube_video(url=url)
    return video_helper(video_file=video_file, info_dict=info_dict)


In [25]:
# validation
# url_processor(url="https://www.youtube.com/watch?v=ADt1LnbL2HI")
!wget -nc https://thumbs.dreamstime.com/b/santorini-island-greece-santorini-island-greece-oia-town-traditional-white-houses-churches-blue-domes-over-caldera-146011399.jpg
img_processor("santorini-island-greece-santorini-island-greece-oia-town-traditional-white-houses-churches-blue-domes-over-caldera-146011399.jpg")

File ‘santorini-island-greece-santorini-island-greece-oia-town-traditional-white-houses-churches-blue-domes-over-caldera-146011399.jpg’ already there; not retrieving.



('Viewpoint, Karavokyridon Street, Ormos Ammoudiou, Community of Ia, Ia Municipal Unit, Municipality of Thira, Thira Regional Unit, South Aegean, Aegean, 847 02, Greece',
 Figure({
     'data': [{'hovertemplate': '<b>%{hovertext}</b><br><br>latitude=%{lat}<br>longitude=%{lon}<extra></extra>',
               'hovertext': array(['Viewpoint, Karavokyridon Street, Ormos Ammoudiou, Community of Ia, Ia Municipal Unit, Municipality of Thira, Thira Regional Unit, South Aegean, Aegean, 847 02, Greece'],
                                  dtype=object),
               'lat': array([36.46125]),
               'legendgroup': '',
               'lon': array([25.372522]),
               'marker': {'color': 'fuchsia'},
               'mode': 'markers',
               'name': '',
               'showlegend': False,
               'subplot': 'mapbox',
               'type': 'scattermapbox'}],
     'layout': {'height': 300,
                'legend': {'tracegroupgap': 0},
                'mapbox': {'acces

# BentoML w/ ONNX


In [19]:
from math import ceil

import torch
from classification.dataset import FiveCropImageDataset
from classification.train_base import MultiPartitioningClassifier
from tqdm.auto import tqdm


model = MultiPartitioningClassifier.load_from_checkpoint(
    checkpoint_path="models/base_M/epoch=014-val_loss=18.4833.ckpt",
    hparams_file="models/base_M/hparams.yaml",
    map_location=None,
)

!wget -nc https://thumbs.dreamstime.com/b/santorini-island-greece-santorini-island-greece-oia-town-traditional-white-houses-churches-blue-domes-over-caldera-146011399.jpg
image_dir = create_image_dir(img_file="santorini-island-greece-santorini-island-greece-oia-town-traditional-white-houses-churches-blue-domes-over-caldera-146011399.jpg")

dataloader = torch.utils.data.DataLoader(
    FiveCropImageDataset(meta_csv=None, image_dir=image_dir),
    batch_size=1,
    shuffle=False,
    num_workers=0,
)

images, meta_batch = next(iter(dataloader))
cur_batch_size = images.shape[0]
ncrops = images.shape[1]

# reshape crop dimension to batch
images = torch.reshape(images, (cur_batch_size * ncrops, *images.shape[2:]))

model.to_onnx(
  "geolocator.onnx",
  input_sample=images,
  export_params=True,
  opset_version=11,
  input_names=["input"],
  dynamic_axes={
      "input": {0: "batch_size"},
  }
)


IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html


The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead.


Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=ResNet50_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet50_Weights.DEFAULT` to get the most up-to-date weights.



File ‘santorini-island-greece-santorini-island-greece-oia-town-traditional-white-houses-churches-blue-domes-over-caldera-146011399.jpg’ already there; not retrieving.



In [20]:
import onnxruntime


def to_numpy(tensor):
    return (
        tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
    )


ort_session = onnxruntime.InferenceSession("geolocator.onnx")
# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(images)}
# ONNX Runtime will return a list of outputs
ort_outs = ort_session.run(None, ort_inputs)
ort_outs


[array([[ 0.90516365,  4.4899    , -2.804096  , ..., -5.473765  ,
          4.9408894 ,  1.3072193 ],
        [ 0.22281647,  1.1072205 , -3.549011  , ..., -4.007944  ,
         11.304101  , -0.20302217],
        [ 0.46286303,  4.413016  , -2.9810796 , ..., -5.889538  ,
          5.7547584 ,  0.85683393],
        [-0.06760192,  0.5161087 , -3.9409113 , ..., -3.7524354 ,
         10.674072  , -0.24188542],
        [ 1.5776274 ,  3.5012524 , -3.909003  , ..., -5.370172  ,
         10.660667  ,  0.135409  ]], dtype=float32),
 array([[ 1.153804  ,  4.026694  , -1.9779874 , ..., -1.13531   ,
          2.1075304 , -4.0744047 ],
        [-0.09019619,  1.675056  , -2.7516851 , ..., -0.61443245,
          0.59368336, -5.013054  ],
        [ 0.3058862 ,  3.6243377 , -2.1243274 , ..., -1.3385389 ,
          1.7169186 , -4.256903  ],
        [-0.14555266,  0.47077364, -3.1073918 , ..., -0.5177057 ,
          0.55621105, -4.7999754 ],
        [ 1.2242081 ,  4.621024  , -2.9370506 , ..., -1.1997478 ,

In [21]:
import bentoml
import onnx


bentoml.onnx.save_model("onnx_geolocator", onnx.load("geolocator.onnx"))




Model(tag="onnx_geolocator:ojvvijceo2ytrlg6", path="/Users/samhitaalla/bentoml/models/onnx_geolocator/ojvvijceo2ytrlg6/")

In [27]:
# run bentoml service
!bentoml serve service:svc --reload

2022-10-05T12:05:51+0530 [INFO] [cli] Prometheus metrics for HTTP BentoServer from "service:svc" can be accessed at http://localhost:3000/metrics.
2022-10-05T12:05:53+0530 [INFO] [cli] Starting development HTTP BentoServer from "service:svc" running on http://0.0.0.0:3000 (Press CTRL+C to quit)
2022-10-05 12:05:53 circus[86476] [INFO] Loading the plugin...
2022-10-05 12:05:53 circus[86476] [INFO] Endpoint: 'tcp://127.0.0.1:65115'
2022-10-05 12:05:53 circus[86476] [INFO] Pub/sub: 'tcp://127.0.0.1:65116'
2022-10-05T12:05:53+0530 [INFO] [observer] Watching directories: ['/Users/samhitaalla/Desktop/geolocator/GeoEstimation', '/Users/samhitaalla/bentoml/models']
[youtube] ADt1LnbL2HI: Downloading webpage
[youtube] ADt1LnbL2HI: Downloading MPD manifest
Removing cache dir /Users/samhitaalla/.cache/youtube-dl ...
[youtube] ADt1LnbL2HI: Downloading webpage
[youtube] ADt1LnbL2HI: Downloading player 374003a5
[youtube] ADt1LnbL2HI: Downloading MPD manifest
[download] videos/ADt1LnbL2HI.mp4 has alr

# Gradio

For the UI part, Gradio is being used.


In [None]:
import gradio as gr


with gr.Blocks() as demo:
    gr.Markdown("# GeoLocator")
    gr.Markdown(
        "An app that guesses the location of an image 🌌, a video 📹 or a YouTube link 🔗."
    )
    with gr.Tab("Image"):
        with gr.Row():
            img_input = gr.Image(type="filepath")
            with gr.Column():
                img_text_output = gr.Textbox(label="Location")
                img_plot = gr.Plot()
        img_text_button = gr.Button("Go locate!")
    with gr.Tab("Video"):
        with gr.Row():
            video_input = gr.Video(type="filepath")
            with gr.Column():
                video_text_output = gr.Textbox(label="Location")
                video_plot = gr.Plot()
        video_text_button = gr.Button("Go locate!")
    with gr.Tab("YouTube Link"):
        with gr.Row():
            url_input = gr.Textbox(label="YouTube video link")
            with gr.Column():
                url_text_output = gr.Textbox(label="Location")
                url_plot = gr.Plot()
        url_text_button = gr.Button("Go locate!")

    img_text_button.click(
        img_processor, inputs=img_input, outputs=[img_text_output, img_plot]
    )
    video_text_button.click(
        video_processor, inputs=video_input, outputs=[video_text_output, video_plot]
    )
    url_text_button.click(
        url_processor, inputs=url_input, outputs=[url_text_output, url_plot]
    )

    examples = gr.Examples(
        examples=["https://www.youtube.com/watch?v=wxeQkJTZrsw"], inputs=[url_input]
    )

demo.launch()
