<a href="https://colab.research.google.com/github/samhita-alla/geolocator/blob/main/geolocator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Video/Image Geolocator

Uses a [pre-trained GeoEstimation model](https://github.com/TIBHannover/GeoEstimation) to perform video/image inferencing. 

* Get the video/image path. If a YouTube video, download it.
* If a video, retrieve video frames.
* Perform model inferencing on video frames or an image.
* Apply DBSCAN clustering on the predicted lats and longs.
* Retrieve the dense cluster.
* Compute mean of lats and longs of the data points belonging to the dense cluster.
* Predict location and retrieve plotly graph.

## BentoML

* Create an ONNX version of the model.
* Generate Bento.
* Spin up the bento service.

## Library dependencies

* Katna
* Youtube DL
* Scikit Learn
* PyTorch Lightning
* s2sphere
* Geopy
* Gradio
* ONNX
* ONNX Runtime
* BentoML


In [1]:
# install dependencies
%pip install -q katna youtube_dl pytorch-lightning s2sphere scikit-learn gradio bentoml onnx onnxruntime

Note: you may need to restart the kernel to use updated packages.


In [4]:
!git clone https://github.com/samhita-alla/GeoEstimation.git

Cloning into 'GeoEstimation'...
remote: Enumerating objects: 556, done.[K
remote: Counting objects: 100% (97/97), done.[K
remote: Compressing objects: 100% (64/64), done.[K
remote: Total 556 (delta 69), reused 56 (delta 33), pack-reused 459[K
Receiving objects: 100% (556/556), 1.89 MiB | 2.27 MiB/s, done.
Resolving deltas: 100% (319/319), done.


In [5]:
import shutil

shutil.copy("service.py", "GeoEstimation")
shutil.copy("post_processing.py", "GeoEstimation")
shutil.copy("pre_processing.py", "GeoEstimation")
shutil.copy("bentofile.yaml", "GeoEstimation")

'GeoEstimation/pre_processing.py'

In [6]:
%cd GeoEstimation

/Users/samhitaalla/Desktop/geolocator/GeoEstimation


In [7]:
from __future__ import unicode_literals
import youtube_dl
from pathlib import Path
from Katna.config import Video as VideoConfig
from Katna.writer import KeyFrameDiskWriter
import os
from typing import Dict, Any, Tuple

MAX_FILESIZE = 10000000

def sort_key(key):
  file_size = key["filesize"]
  if file_size:
    return int(key["filesize"])
  return 0

def validate_extension(selected_format: Dict[str, Any]) -> str:
  extension = selected_format.get("ext")
  if extension and extension not in map(lambda x: x.replace(".", ""), VideoConfig.video_extensions):
    raise ValueError(f"{extension} isn't supported.")
  return extension

def extract_youtube_video(url: str) -> Tuple[str, Dict[str, Any]]:
  ydl = youtube_dl.YoutubeDL({})

  # extra information about the video
  info_dict = ydl.extract_info(url, download=False)
  formats = info_dict.get("formats", [])

  # sort the formats in descending order w.r.t the file size
  sorted_formats = sorted(formats, key=sort_key, reverse=True)

  # remove "webm" formatted videos
  filtered_sorted_formats = list(filter(lambda x: x["ext"] != "webm", sorted_formats))

  # select the best format -- the nearest big number to MAX_FILESIZE
  selected_format = {}
  for format in filtered_sorted_formats:
    file_size = format["filesize"]
    if file_size and file_size < MAX_FILESIZE and format["vcodec"] != "none":
      selected_format = format
      break
  
  # verify if the extension is valid
  extension = validate_extension(selected_format)

  # extract YT video
  videos_path = "videos"
  ydl_opts = {"max_filesize": MAX_FILESIZE, "format": selected_format.get("format_id"), "outtmpl": f"{videos_path}/%(id)s.%(ext)s"}

  with youtube_dl.YoutubeDL(ydl_opts) as ydl:
    ydl.cache.remove()
    ydl.download([url])
    saved_location = f"videos/{info_dict['id']}.{extension}"

  return saved_location, info_dict

In [8]:
from Katna.video import Video

NUMBER_OF_FRAMES = 20

def capture_frames(video_file_path: str, info_dict: Dict[str, Any]) -> str:
  # create a directory to store video frames
  frames_directory = f"selected-frames/{info_dict['id']}"
  shutil.rmtree(frames_directory, ignore_errors=True)
  os.makedirs(frames_directory, exist_ok=True)
  diskwriter = KeyFrameDiskWriter(location=frames_directory)

  vd = Video()
  try:
    vd.extract_video_keyframes(
        no_of_frames=NUMBER_OF_FRAMES,
        file_path=video_file_path,
        writer=diskwriter
    )
  except Exception as e:
    raise ValueError(f"Error capturing the frames: {e}")

  return frames_directory

In [9]:
import shutil
import glob
from IPython.display import Image, display

image_dir = None
image_parent_dir = "geolocator-images"

def display_video_frames(frames_directory: str):
  frames = glob.glob(f"{frames_directory}/*.jpeg")

  for frame in frames:
    display(Image(filename=frame, width = 200, height = 100))

In [10]:
# download the model checkpoint & hyperparameters
!mkdir -p models/base_M
!wget https://github.com/TIBHannover/GeoEstimation/releases/download/pytorch/epoch.014-val_loss.18.4833.ckpt -O models/base_M/epoch=014-val_loss=18.4833.ckpt
!wget https://github.com/TIBHannover/GeoEstimation/releases/download/pytorch/hparams.yaml -O models/base_M/hparams.yaml

--2022-10-02 18:04:14--  https://github.com/TIBHannover/GeoEstimation/releases/download/pytorch/epoch.014-val_loss.18.4833.ckpt
Resolving github.com (github.com)... 20.207.73.82
Connecting to github.com (github.com)|20.207.73.82|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/142275851/fc162380-3e05-11eb-9190-3ec4e4ff49c1?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20221002%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20221002T123415Z&X-Amz-Expires=300&X-Amz-Signature=a1667f6d00537d9c9903b5b38769555a2f81899584e4fd85bad9429340dd6ba0&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=142275851&response-content-disposition=attachment%3B%20filename%3Depoch.014-val_loss.18.4833.ckpt&response-content-type=application%2Foctet-stream [following]
--2022-10-02 18:04:15--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/142275851/fc1

In [11]:
!mkdir -p resources/s2_cells
!wget https://raw.githubusercontent.com/TIBHannover/GeoEstimation/original_tf/geo-cells/cells_50_5000.csv -O resources/s2_cells/cells_50_5000.csv
!wget https://raw.githubusercontent.com/TIBHannover/GeoEstimation/original_tf/geo-cells/cells_50_2000.csv -O resources/s2_cells/cells_50_2000.csv
!wget https://raw.githubusercontent.com/TIBHannover/GeoEstimation/original_tf/geo-cells/cells_50_1000.csv -O resources/s2_cells/cells_50_1000.csv

--2022-10-02 18:06:49--  https://raw.githubusercontent.com/TIBHannover/GeoEstimation/original_tf/geo-cells/cells_50_5000.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.110.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 177214 (173K) [text/plain]
Saving to: ‘resources/s2_cells/cells_50_5000.csv’


2022-10-02 18:06:49 (2.37 MB/s) - ‘resources/s2_cells/cells_50_5000.csv’ saved [177214/177214]

--2022-10-02 18:06:50--  https://raw.githubusercontent.com/TIBHannover/GeoEstimation/original_tf/geo-cells/cells_50_2000.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 389388 (380K) [

In [12]:
from IPython.core.profiledir import LoggingConfigurable
import subprocess
import pandas as pd
from pathlib import Path
import logging
from sklearn.cluster import DBSCAN
from sklearn.preprocessing import StandardScaler
from typing import List
import matplotlib.pyplot as plt
import numpy as np
from geopy.geocoders import Nominatim
import plotly
import plotly.express as px
import pandas as pd
from geopy.extra.rate_limiter import RateLimiter


def dbscan_clustering(lats_and_longs: List[List[float]]) -> Tuple[np.ndarray, List[List[int]]]:
  # DBSCAN clustering algorithm to cluster lats and longs
  # Why DBSCAN? -- robust to outliers; needn't specify the number of clusters; density-based
  lats_and_longs_standardized = StandardScaler().fit_transform(lats_and_longs)
  db = DBSCAN(eps=0.5, min_samples=3)
  db_fit = db.fit(lats_and_longs_standardized)
  labels = db_fit.labels_
  logging.info(f"DBSCAN cluster labels: {labels}")

  return labels, lats_and_longs_standardized


def plot_clusters(labels, lats_and_longs_standardized):
  # not being called currently!
  # plot DBSCAN clusters
  plt.scatter(lats_and_longs_standardized[:, 0], lats_and_longs_standardized[:, 1], c=labels, cmap="Paired")


def data_engineering(image_dir: str) -> List[List[float]]:
  inference_file_path = os.path.join("models/base_M", f"inference_{Path(os.path.join('/content', image_dir)).stem}.csv")
  inference_df = pd.read_csv(inference_file_path) 
  logging.info(f"Inference DF: {inference_df.head()}")

  inference_df_lats_longs = inference_df[["pred_lat", "pred_lng"]]
  logging.info(f"Inference Lats & Longs only DF: {inference_df_lats_longs}")

  lats_and_longs = inference_df_lats_longs.values.tolist()
  return lats_and_longs


def get_location(latitude: float, longitude: float) -> str:
  geolocator = Nominatim(user_agent="geolocater")
  geocode = RateLimiter(geolocator.geocode, min_delay_seconds=1)
  location = geocode(f"{latitude},{longitude}", language="en")
  return location.address


def get_plotly_graph(latitude: float, longitude: float, location: str) -> plotly.graph_objects.Figure:
  lat_long_data = [[latitude, longitude, location]]
  map_df = pd.DataFrame(lat_long_data, columns=["latitude", "longitude", "location"])

  px.set_mapbox_access_token("pk.eyJ1Ijoic2FtaGl0YS1hbGxhIiwiYSI6ImNsOGgwZ3lyajB0NWczb3F4cHU4dHhocmcifQ.gl4lARnWScZcHJHtXClrLg")
  fig = px.scatter_mapbox(map_df, lat="latitude", lon="longitude", hover_name="location",
                    color_discrete_sequence=["fuchsia"], zoom=5, height=300)
  fig.update_layout(mapbox_style="dark")
  fig.update_layout(margin={"r":0,"t":0, "l":0, "b":0})
  return fig


def generate_prediction(image_dir: str) -> Tuple[str, plotly.graph_objects.Figure]:
  # generate predictions on all the video frames
  subprocess.run(["python", "-m", "classification.inference", "--image_dir", os.path.join("/content", image_dir), "--checkpoint", "/content/models/base_M/epoch=014-val_loss=18.4833.ckpt", "--hparams", "/content/models/base_M/hparams.yaml"], capture_output=True)

  # data engineering
  lats_and_longs = data_engineering(image_dir=image_dir)

  labels, lats_and_longs_standardized = dbscan_clustering(lats_and_longs=lats_and_longs)

  # find the dense cluster
  dense_cluster_label = max(set(labels), key=list(labels).count)
  logging.info(f"Dense cluster label: {dense_cluster_label}")

  # get data labels belonging to the dense cluster
  indices = np.where(labels == dense_cluster_label)[0]
  dense_cluster_data = list(map(lats_and_longs.__getitem__, indices))
  logging.info(f"Dense cluster data: {dense_cluster_data}")

  # fetch lat and long mean
  lat_long_array = np.mean(np.array(dense_cluster_data, dtype=float), axis=0)
  latitude, longitude = lat_long_array[0], lat_long_array[1]
  logging.info(f"Latitude: {latitude}, Longitutde: {longitude}")

  # get location
  location = get_location(latitude=latitude, longitude=longitude)

  return location, get_plotly_graph(latitude=latitude, longitude=longitude, location=location)

In [13]:
def create_image_dir(img_file: str) -> str:
  image_dir = os.path.join(image_parent_dir, os.path.basename(img_file).split(".")[0])

  # clear the image directory before filling it up
  shutil.rmtree(image_dir, ignore_errors=True)
  os.makedirs(image_dir)
  shutil.copy(img_file, image_dir)

  return image_dir


def img_processor(img_file: str) -> Tuple[str, plotly.graph_objects.Figure]:
  image_dir = create_image_dir(img_file=img_file)
  return generate_prediction(image_dir=image_dir)


def video_helper(video_file: str, info_dict: Dict[str, Any]) -> Tuple[str, plotly.graph_objects.Figure]:
  # capture frames
  frames_directory = capture_frames(video_file_path=video_file, info_dict=info_dict)
  display_video_frames(frames_directory=frames_directory)

  image_dir = frames_directory
  return generate_prediction(image_dir=image_dir)


def video_processor(video_file: str) -> Tuple[str, plotly.graph_objects.Figure]:
  info_dict = {"id": os.path.basename(video_file).split(".")[0]}
  return video_helper(video_file=video_file, info_dict=info_dict)


def url_processor(url: str) -> Tuple[str, plotly.graph_objects.Figure]:
  video_file, info_dict = extract_youtube_video(url=url)
  return video_helper(video_file=video_file, info_dict=info_dict)

In [14]:
# validation
# url_processor(url="https://www.youtube.com/watch?v=ADt1LnbL2HI")
# img_processor("/content/santorini-island-greece-santorini-island-greece-oia-town-traditional-white-houses-churches-blue-domes-over-caldera-146011399.jpg")

# BentoML w/ ONNX

In [16]:
from classification.train_base import MultiPartitioningClassifier
from classification.dataset import FiveCropImageDataset
from math import ceil
import torch
from tqdm.auto import tqdm

model = MultiPartitioningClassifier.load_from_checkpoint(
    checkpoint_path="models/base_M/epoch=014-val_loss=18.4833.ckpt",
    hparams_file="models/base_M/hparams.yaml",
    map_location=None,
)

!wget -nc https://thumbs.dreamstime.com/b/santorini-island-greece-santorini-island-greece-oia-town-traditional-white-houses-churches-blue-domes-over-caldera-146011399.jpg
image_dir = create_image_dir(img_file="santorini-island-greece-santorini-island-greece-oia-town-traditional-white-houses-churches-blue-domes-over-caldera-146011399.jpg")

dataloader = torch.utils.data.DataLoader(
    FiveCropImageDataset(meta_csv=None, image_dir=image_dir),
    batch_size=1,
    shuffle=False,
    num_workers=0,
)

images, meta_batch = next(iter(dataloader))
cur_batch_size = images.shape[0]
ncrops = images.shape[1]

# reshape crop dimension to batch
images = torch.reshape(images, (cur_batch_size * ncrops, *images.shape[2:]))

# model.forward = model.inference
model.to_onnx(
  "geolocator.onnx",
  input_sample=images,
  export_params=True,
  opset_version=11,
  input_names=["input"],
  dynamic_axes={
      "input": {0: "batch_size"},
  }
)



--2022-10-02 18:07:24--  https://thumbs.dreamstime.com/b/santorini-island-greece-santorini-island-greece-oia-town-traditional-white-houses-churches-blue-domes-over-caldera-146011399.jpg
Resolving thumbs.dreamstime.com (thumbs.dreamstime.com)... 192.229.144.114
Connecting to thumbs.dreamstime.com (thumbs.dreamstime.com)|192.229.144.114|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 103072 (101K) [image/jpeg]
Saving to: ‘santorini-island-greece-santorini-island-greece-oia-town-traditional-white-houses-churches-blue-domes-over-caldera-146011399.jpg’


2022-10-02 18:07:24 (4.68 MB/s) - ‘santorini-island-greece-santorini-island-greece-oia-town-traditional-white-houses-churches-blue-domes-over-caldera-146011399.jpg’ saved [103072/103072]



In [17]:
import onnxruntime

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

ort_session = onnxruntime.InferenceSession("geolocator.onnx")
# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(images)}
# ONNX Runtime will return a list of outputs
ort_outs = ort_session.run(None, ort_inputs)
print(ort_outs)

[array([[ 0.90516365,  4.4899    , -2.804096  , ..., -5.473765  ,
         4.9408894 ,  1.3072193 ],
       [ 0.22281647,  1.1072205 , -3.549011  , ..., -4.007944  ,
        11.304101  , -0.20302217],
       [ 0.46286303,  4.413016  , -2.9810796 , ..., -5.889538  ,
         5.7547584 ,  0.85683393],
       [-0.06760192,  0.5161087 , -3.9409113 , ..., -3.7524354 ,
        10.674072  , -0.24188542],
       [ 1.5776274 ,  3.5012524 , -3.909003  , ..., -5.370172  ,
        10.660667  ,  0.135409  ]], dtype=float32), array([[ 1.153804  ,  4.026694  , -1.9779874 , ..., -1.13531   ,
         2.1075304 , -4.0744047 ],
       [-0.09019619,  1.675056  , -2.7516851 , ..., -0.61443245,
         0.59368336, -5.013054  ],
       [ 0.3058862 ,  3.6243377 , -2.1243274 , ..., -1.3385389 ,
         1.7169186 , -4.256903  ],
       [-0.14555266,  0.47077364, -3.1073918 , ..., -0.5177057 ,
         0.55621105, -4.7999754 ],
       [ 1.2242081 ,  4.621024  , -2.9370506 , ..., -1.1997478 ,
         1.083598

In [18]:
import onnx
import bentoml

bentoml.onnx.save_model("onnx_geolocator", onnx.load("geolocator.onnx"))



Model(tag="onnx_geolocator:76lfzwccj2chdlg6", path="/Users/samhitaalla/bentoml/models/onnx_geolocator/76lfzwccj2chdlg6/")

In [None]:
# run bentoml service
# !bentoml serve service:svc --reload

# Gradio

For the UI part, Gradio is being used.

In [None]:
import gradio as gr


with gr.Blocks() as demo:
  gr.Markdown("# GeoLocator")
  gr.Markdown("An app that guesses the location of an image 🌌, a video 📹 or a YouTube link 🔗.")
  with gr.Tab("Image"):
    with gr.Row():
      img_input = gr.Image(type="filepath")
      with gr.Column():
        img_text_output = gr.Textbox(label="Location")
        img_plot = gr.Plot()
    img_text_button = gr.Button("Go locate!")
  with gr.Tab("Video"):
    with gr.Row():
      video_input = gr.Video(type="filepath")
      with gr.Column():
        video_text_output = gr.Textbox(label="Location")
        video_plot = gr.Plot()
    video_text_button = gr.Button("Go locate!")
  with gr.Tab("YouTube Link"):
    with gr.Row():
      url_input = gr.Textbox(label="YouTube video link")
      with gr.Column():
        url_text_output = gr.Textbox(label="Location")
        url_plot = gr.Plot()
    url_text_button = gr.Button("Go locate!")

  img_text_button.click(img_processor, inputs=img_input, outputs=[img_text_output, img_plot])
  video_text_button.click(video_processor, inputs=video_input, outputs=[video_text_output, video_plot])
  url_text_button.click(url_processor, inputs=url_input, outputs=[url_text_output, url_plot])

  examples = gr.Examples(examples=["https://www.youtube.com/watch?v=wxeQkJTZrsw"], inputs=[url_input])

demo.launch()