<a href="https://colab.research.google.com/github/WinsonTruong/geolocator/blob/main/geolocator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Video/Image Geolocator

Uses a [pre-trained GeoEstimation model](https://github.com/TIBHannover/GeoEstimation) to perform video/image inferencing.

- Get the video/image path. If a YouTube video, download it.
- If a video, retrieve video frames.
- Perform model inferencing on video frames or an image.
- Apply DBSCAN clustering on the predicted lats and longs.
- Retrieve the dense cluster.
- Compute mean of lats and longs of the data points belonging to the dense cluster.
- Predict location and retrieve plotly graph.

## BentoML

- Create an ONNX version of the model.
- Generate Bento.
- Spin up the bento service.

## Library dependencies

- Katna
- Youtube DL
- Scikit Learn
- PyTorch Lightning
- s2sphere
- Geopy
- Gradio
- ONNX
- ONNX Runtime
- BentoML


In [None]:
# install dependencies
%pip install -q katna youtube_dl pytorch-lightning s2sphere scikit-learn gradio bentoml onnx onnxruntime

In [None]:
sh = """
URL="https://github.com/samhita-alla/GeoEstimation.git"
FOLDER="GeoEstimation"
if [ ! -d "$FOLDER" ] ; then
    git clone $URL $FOLDER
else
    cd "$FOLDER"
    git pull $URL
fi
"""

with open("clone_script.sh", "w") as file:
  file.write(sh)

!bash clone_script.sh

In [None]:
import sys

in_colab = "google.colab" in sys.modules

if in_colab:
  sh = """
  URL="https://github.com/samhita-alla/geolocator.git"
  FOLDER="."
  if [ ! -d "$FOLDER" ] ; then
      git clone $URL $FOLDER
  else
      cd "$FOLDER"
      git pull $URL
  fi
  """

  with open("clone_script.sh", "w") as file:
    file.write(sh)

  !bash clone_script.sh

In [None]:
import shutil

shutil.copy("services/bentoml/service.py", "GeoEstimation")
shutil.copy("app/post_processing.py", "GeoEstimation")
shutil.copy("app/pre_processing.py", "GeoEstimation")
shutil.copy("services/bentoml/bentofile.yaml", "GeoEstimation")


In [None]:
%cd GeoEstimation

In [None]:
import glob
import shutil

from IPython.display import Image, display


image_dir = None
image_parent_dir = "geolocator-images"


def display_video_frames(frames_directory: str):
    frames = glob.glob(f"{frames_directory}/*.jpeg")

    for frame in frames:
        display(Image(filename=frame, width=200, height=100))


In [None]:
# download the model checkpoint & hyperparameters
!mkdir -p models/base_M
!wget https://github.com/TIBHannover/GeoEstimation/releases/download/pytorch/epoch.014-val_loss.18.4833.ckpt -O models/base_M/epoch=014-val_loss=18.4833.ckpt
!wget https://github.com/TIBHannover/GeoEstimation/releases/download/pytorch/hparams.yaml -O models/base_M/hparams.yaml

In [None]:
!mkdir -p resources/s2_cells
!wget -nc https://raw.githubusercontent.com/TIBHannover/GeoEstimation/original_tf/geo-cells/cells_50_5000.csv -O resources/s2_cells/cells_50_5000.csv
!wget -nc https://raw.githubusercontent.com/TIBHannover/GeoEstimation/original_tf/geo-cells/cells_50_2000.csv -O resources/s2_cells/cells_50_2000.csv
!wget -nc https://raw.githubusercontent.com/TIBHannover/GeoEstimation/original_tf/geo-cells/cells_50_1000.csv -O resources/s2_cells/cells_50_1000.csv

In [None]:
import logging
import os
import subprocess
from pathlib import Path
from typing import List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly
import plotly.express as px
from generate_map import get_plotly_graph
from geopy.extra.rate_limiter import RateLimiter
from geopy.geocoders import Nominatim
from IPython.core.profiledir import LoggingConfigurable
from post_processing import generate_prediction_logit
from sklearn.cluster import DBSCAN
from sklearn.preprocessing import StandardScaler


def data_engineering(image_dir: str) -> pd.DataFrame:
    inference_file_path = os.path.join(
        "models/base_M",
        f"inference_{Path(os.path.join('/content', image_dir)).stem}.csv",
    )
    inference_df = pd.read_csv(inference_file_path)
    logging.info(f"Inference DF: {inference_df.head()}")

    return inference_df


def generate_prediction(image_dir: str, num_workers: int = 0) -> Tuple[str, plotly.graph_objects.Figure]:
    # generate predictions on all the video frames
    subprocess.run(
        [
            "python",
            "-m",
            "classification.inference",
            "--image_dir",
            image_dir,
            "--checkpoint",
            "models/base_M/epoch=014-val_loss=18.4833.ckpt",
            "--hparams",
            "models/base_M/hparams.yaml",
            "--num_workers",
            str(num_workers)
        ],
        capture_output=True,
    )

    # data engineering
    inference_df = data_engineering(image_dir=image_dir)

    # get location
    location, latitude, longitude = generate_prediction_logit(
        inference_df=inference_df
    )

    return location, get_plotly_graph(
        latitude=latitude, longitude=longitude, location=location
    )


In [None]:
from typing import Any, Dict

from pre_processing import capture_frames, extract_youtube_video


IMAGE_PARENT_DIR = "geolocator-images"


def create_image_dir(img_file: str) -> str:
    image_dir = os.path.join(IMAGE_PARENT_DIR, os.path.basename(img_file).split(".")[0])

    # clear the image directory before filling it up
    shutil.rmtree(image_dir, ignore_errors=True)
    os.makedirs(image_dir)
    shutil.copy(img_file, image_dir)

    return image_dir


def img_processor(img_file: str) -> Tuple[str, plotly.graph_objects.Figure]:
    image_dir = create_image_dir(img_file=img_file)
    return generate_prediction(image_dir=image_dir)


def video_helper(
    video_file: str, info_dict: Dict[str, Any]
) -> Tuple[str, plotly.graph_objects.Figure]:
    # capture frames
    frames_directory = capture_frames(video_file_path=video_file, info_dict=info_dict)
    display_video_frames(frames_directory=frames_directory)

    image_dir = frames_directory
    return generate_prediction(image_dir=image_dir)


def video_processor(video_file: str) -> Tuple[str, plotly.graph_objects.Figure]:
    info_dict = {"id": os.path.basename(video_file).split(".")[0]}
    return video_helper(video_file=video_file, info_dict=info_dict)

def url_processor(url: str) -> Tuple[str, plotly.graph_objects.Figure]:
    video_file, info_dict = extract_youtube_video(url=url)
    return video_helper(video_file=video_file, info_dict=info_dict)


In [None]:
# validation
url_processor(url="https://www.youtube.com/watch?v=ADt1LnbL2HI")
# !wget -nc https://thumbs.dreamstime.com/b/santorini-island-greece-santorini-island-greece-oia-town-traditional-white-houses-churches-blue-domes-over-caldera-146011399.jpg
# img_processor("santorini-island-greece-santorini-island-greece-oia-town-traditional-white-houses-churches-blue-domes-over-caldera-146011399.jpg")

# BentoML w/ ONNX


In [None]:
from math import ceil

import torch
from classification.dataset import FiveCropImageDataset
from classification.train_base import MultiPartitioningClassifier
from tqdm.auto import tqdm


model = MultiPartitioningClassifier.load_from_checkpoint(
    checkpoint_path="models/base_M/epoch=014-val_loss=18.4833.ckpt",
    hparams_file="models/base_M/hparams.yaml",
    map_location=None,
)

!wget -nc https://thumbs.dreamstime.com/b/santorini-island-greece-santorini-island-greece-oia-town-traditional-white-houses-churches-blue-domes-over-caldera-146011399.jpg
image_dir = create_image_dir(img_file="santorini-island-greece-santorini-island-greece-oia-town-traditional-white-houses-churches-blue-domes-over-caldera-146011399.jpg")

dataloader = torch.utils.data.DataLoader(
    FiveCropImageDataset(meta_csv=None, image_dir=image_dir),
    batch_size=1,
    shuffle=False,
    num_workers=0,
)

images, meta_batch = next(iter(dataloader))
cur_batch_size = images.shape[0]
ncrops = images.shape[1]

# reshape crop dimension to batch
images = torch.reshape(images, (cur_batch_size * ncrops, *images.shape[2:]))

model.to_onnx(
  "geolocator.onnx",
  input_sample=images,
  export_params=True,
  opset_version=11,
  input_names=["input"],
  dynamic_axes={
      "input": {0: "batch_size"},
  }
)

In [None]:
import onnxruntime


def to_numpy(tensor):
    return (
        tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
    )


ort_session = onnxruntime.InferenceSession("geolocator.onnx")
# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(images)}
# ONNX Runtime will return a list of outputs
ort_outs = ort_session.run(None, ort_inputs)
ort_outs


In [None]:
import bentoml
import onnx


bentoml.onnx.save_model("onnx_geolocator", onnx.load("geolocator.onnx"))


In [None]:
%cd GeoEstimation

In [None]:
# run bentoml service
!bentoml serve service:svc --reload --port 3000

# Gradio

For the UI part, Gradio is being used.


In [None]:
import gradio as gr


with gr.Blocks() as demo:
    gr.Markdown("# GeoLocator")
    gr.Markdown(
        "An app that guesses the location of an image 🌌, a video 📹 or a YouTube link 🔗."
    )
    with gr.Tab("Image"):
        with gr.Row():
            img_input = gr.Image(type="filepath")
            with gr.Column():
                img_text_output = gr.Textbox(label="Location")
                img_plot = gr.Plot()
        img_text_button = gr.Button("Go locate!")
    with gr.Tab("Video"):
        with gr.Row():
            video_input = gr.Video(type="filepath")
            with gr.Column():
                video_text_output = gr.Textbox(label="Location")
                video_plot = gr.Plot()
        video_text_button = gr.Button("Go locate!")
    with gr.Tab("YouTube Link"):
        with gr.Row():
            url_input = gr.Textbox(label="YouTube video link")
            with gr.Column():
                url_text_output = gr.Textbox(label="Location")
                url_plot = gr.Plot()
        url_text_button = gr.Button("Go locate!")

    img_text_button.click(
        img_processor, inputs=img_input, outputs=[img_text_output, img_plot]
    )
    video_text_button.click(
        video_processor, inputs=video_input, outputs=[video_text_output, video_plot]
    )
    url_text_button.click(
        url_processor, inputs=url_input, outputs=[url_text_output, url_plot]
    )

    examples = gr.Examples(
        examples=["https://www.youtube.com/watch?v=wxeQkJTZrsw"], inputs=[url_input]
    )

demo.launch()
