<a href="https://colab.research.google.com/github/samhita-alla/geolocator/blob/main/geolocator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Video Geolocator

Uses a [pre-trained GeoEstimation model](https://github.com/TIBHannover/GeoEstimation) to perform video inferencing. 

* Get the video path. If a YouTube video, download it.
* Retrieve video frames.
* Perform model inferencing on every video frame.
* Apply DBSCAN clustering on the predicted lats and longs.
* Get the most dense cluster.
* Compute mean of lats and longs of the data points belonging to the dense cluster.
* Predict location.

## Library dependencies

* Katna
* Youtube DL
* Scikit Learn
* PyTorch Lightning
* s2sphere
* Geopy


In [2]:
# install dependencies
!pip install -q katna youtube_dl pytorch-lightning s2sphere scikit-learn

In [3]:
from __future__ import unicode_literals
import youtube_dl
from pathlib import Path
from Katna.config import Video as VideoConfig
from Katna.writer import KeyFrameDiskWriter
import os
from typing import Dict, Any, Tuple

MAX_FILESIZE = 10000000

def sort_key(key):
  file_size = key["filesize"]
  if file_size:
    return int(key["filesize"])
  return 0

def validate_extension(selected_format: Dict[str, Any]) -> str:
  extension = selected_format.get("ext")
  if extension and extension not in map(lambda x: x.replace(".", ""), VideoConfig.video_extensions):
    raise f"{extension} isn't supported."
  return extension

def extract_youtube_video(url: str) -> Tuple[str, Dict[str, Any]]:
  ydl = youtube_dl.YoutubeDL({})

  # extra information about the video
  info_dict = ydl.extract_info(url, download=False)
  formats = info_dict.get("formats", [])

  # sort the formats in descending order w.r.t the file size
  sorted_formats = sorted(formats, key=sort_key, reverse=True)

  # remove "webm" formatted videos
  filtered_sorted_formats = list(filter(lambda x: x["ext"] != "webm", sorted_formats))

  # select the best format -- the nearest big number to MAX_FILESIZE
  selected_format = {}
  for format in filtered_sorted_formats:
    file_size = format["filesize"]
    if file_size and file_size < MAX_FILESIZE and format["vcodec"] != "none":
      selected_format = format
      break
  
  # verify if the extension is valid
  extension = validate_extension(selected_format)

  # extract YT video
  videos_path = "videos"
  ydl_opts = {"max_filesize": MAX_FILESIZE, "format": selected_format.get("format_id"), "outtmpl": f"{videos_path}/%(id)s.%(ext)s"}

  with youtube_dl.YoutubeDL(ydl_opts) as ydl:
    ydl.cache.remove()
    ydl.download([url])
    saved_location = f"videos/{info_dict['id']}.{extension}"

  return saved_location, info_dict

In [4]:
from Katna.video import Video

NUMBER_OF_FRAMES = 20

def capture_frames(video_file_path: str, info_dict: Dict[str, Any]) -> str:
  # create a directory to store video frames
  frames_directory = f"selected-frames/{info_dict['id']}"
  shutil.rmtree(frames_directory, ignore_errors=True)
  os.makedirs(frames_directory, exist_ok=True)
  diskwriter = KeyFrameDiskWriter(location=frames_directory)

  vd = Video()
  try:
    vd.extract_video_keyframes(
        no_of_frames=NUMBER_OF_FRAMES,
        file_path=video_file_path,
        writer=diskwriter
    )
  except Exception as e:
    raise f"Error capturing the frames: {e}"

  return frames_directory

In [5]:
import shutil
import glob
from IPython.display import Image, display

image_dir = None
image_parent_dir = "geolocator-images"

def display_video_frames(frames_directory: str):
  frames = glob.glob(f"{frames_directory}/*.jpeg")

  for frame in frames:
    display(Image(filename=frame, width = 200, height = 100))

In [6]:
!git clone https://github.com/yiyixuxu/GeoEstimation.git

Cloning into 'GeoEstimation'...
remote: Enumerating objects: 451, done.[K
remote: Counting objects: 100% (123/123), done.[K
remote: Compressing objects: 100% (72/72), done.[K
remote: Total 451 (delta 63), reused 98 (delta 46), pack-reused 328[K
Receiving objects: 100% (451/451), 1.94 MiB | 5.27 MiB/s, done.
Resolving deltas: 100% (236/236), done.


In [7]:
# download the model checkpoint & hyperparameters
!mkdir -p models/base_M
!wget https://github.com/TIBHannover/GeoEstimation/releases/download/pytorch/epoch.014-val_loss.18.4833.ckpt -O models/base_M/epoch=014-val_loss=18.4833.ckpt
!wget https://github.com/TIBHannover/GeoEstimation/releases/download/pytorch/hparams.yaml -O models/base_M/hparams.yaml

--2022-09-24 12:15:53--  https://github.com/TIBHannover/GeoEstimation/releases/download/pytorch/epoch.014-val_loss.18.4833.ckpt
Resolving github.com (github.com)... 140.82.113.3
Connecting to github.com (github.com)|140.82.113.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/142275851/fc162380-3e05-11eb-9190-3ec4e4ff49c1?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20220924%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20220924T121553Z&X-Amz-Expires=300&X-Amz-Signature=52a1015ac6178895715c836f02f1dff0c145a9c52e9a5903a32262e4db5da2c3&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=142275851&response-content-disposition=attachment%3B%20filename%3Depoch.014-val_loss.18.4833.ckpt&response-content-type=application%2Foctet-stream [following]
--2022-09-24 12:15:53--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/142275851/fc1

In [8]:
import yaml

with open("/content/models/base_M/hparams.yaml") as f:
  list_doc = yaml.safe_load(f)

list_doc["partitionings"]["files"] = list(map(lambda x: "/content/" + x, list_doc["partitionings"]["files"]))

with open("/content/models/base_M/hparams.yaml", "w") as f:
  yaml.dump(list_doc, f)

In [9]:
!mkdir -p resources/s2_cells
!wget https://raw.githubusercontent.com/TIBHannover/GeoEstimation/original_tf/geo-cells/cells_50_5000.csv -O resources/s2_cells/cells_50_5000.csv
!wget https://raw.githubusercontent.com/TIBHannover/GeoEstimation/original_tf/geo-cells/cells_50_2000.csv -O resources/s2_cells/cells_50_2000.csv
!wget https://raw.githubusercontent.com/TIBHannover/GeoEstimation/original_tf/geo-cells/cells_50_1000.csv -O resources/s2_cells/cells_50_1000.csv

--2022-09-24 12:17:05--  https://raw.githubusercontent.com/TIBHannover/GeoEstimation/original_tf/geo-cells/cells_50_5000.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.111.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 177214 (173K) [text/plain]
Saving to: ‘resources/s2_cells/cells_50_5000.csv’


2022-09-24 12:17:05 (10.4 MB/s) - ‘resources/s2_cells/cells_50_5000.csv’ saved [177214/177214]

--2022-09-24 12:17:05--  https://raw.githubusercontent.com/TIBHannover/GeoEstimation/original_tf/geo-cells/cells_50_2000.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.108.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 389388 (380K) [

In [18]:
from IPython.core.profiledir import LoggingConfigurable
import subprocess
import pandas as pd
from pathlib import Path
import logging
from sklearn.cluster import DBSCAN
from sklearn.preprocessing import StandardScaler
from typing import List
import matplotlib.pyplot as plt
import numpy as np
from geopy.geocoders import Nominatim


def dbscan_clustering(lats_and_longs: List[List[float]]) -> Tuple[np.ndarray, List[List[int]]]:
  # DBSCAN clustering algorithm to cluster lats and longs
  # Why DBSCAN? -- robust to outliers; needn't specify the number of clusters; density-based
  lats_and_longs_standardized = StandardScaler().fit_transform(lats_and_longs)
  db = DBSCAN(eps=0.5, min_samples=3)
  db_fit = db.fit(lats_and_longs_standardized)
  labels = db_fit.labels_
  logging.info(f"DBSCAN cluster labels: {labels}")

  return labels, lats_and_longs_standardized


def plot_clusters(labels, lats_and_longs_standardized):
  # not being called currently!
  # plot DBSCAN clusters
  plt.scatter(lats_and_longs_standardized[:, 0], lats_and_longs_standardized[:, 1], c=labels, cmap="Paired")


def data_engineering(image_dir: str) -> List[List[float]]:
  inference_file_path = os.path.join("models/base_M", f"inference_{Path(os.path.join('/content', image_dir)).stem}.csv")
  inference_df = pd.read_csv(inference_file_path) 
  logging.info(f"Inference DF: {inference_df.head()}")

  inference_df_lats_longs = inference_df[["pred_lat", "pred_lng"]]
  logging.info(f"Inference Lats & Longs only DF: {inference_df_lats_longs}")

  lats_and_longs = inference_df_lats_longs.values.tolist()
  return lats_and_longs


def get_location(latitude: float, longitude: float) -> str:
  geolocator = Nominatim(user_agent="geolocater")
  location = geolocator.geocode(f"{latitude},{longitude}")
  return location


def generate_prediction(image_dir: str) -> str:
  %cd GeoEstimation

  # generate predictions on all the video frames
  subprocess.run(["python", "-m", "classification.inference", "--image_dir", os.path.join("/content", image_dir), "--checkpoint", "/content/models/base_M/epoch=014-val_loss=18.4833.ckpt", "--hparams", "/content/models/base_M/hparams.yaml"], capture_output=True)

  # go back to the /content directory
  %cd ..

  # data engineering
  lats_and_longs = data_engineering(image_dir=image_dir)  

  labels, lats_and_longs_standardized = dbscan_clustering(lats_and_longs=lats_and_longs)

  # find the dense cluster
  dense_cluster_label = max(set(labels), key=list(labels).count)
  logging.info(f"Dense cluster label: {dense_cluster_label}")

  # get data labels belonging to the dense cluster
  indices = np.where(labels == dense_cluster_label)[0]
  dense_cluster_data = list(map(lats_and_longs.__getitem__, indices))
  logging.info(f"Dense cluster data: {dense_cluster_data}")

  # fetch lat and long mean
  lat_long_array = np.mean(np.array(dense_cluster_data, dtype=float), axis=0)
  latitude, longitude = lat_long_array[0], lat_long_array[1]
  logging.info(f"Latitude: {latitude}, Longitutde: {longitude}")

  # get location
  return get_location(latitude=latitude, longitude=longitude)

In [32]:
def img_processor(img_file: str) -> str:
  image_dir = os.path.join(image_parent_dir, os.path.basename(img_file).split(".")[0])

  # clear the image directory before filling it up
  shutil.rmtree(image_dir, ignore_errors=True)
  os.makedirs(image_dir)
  shutil.copy(img_file, image_dir)

  return generate_prediction(image_dir=image_dir)


def video_helper(video_file: str, info_dict: Dict[str, Any]) -> str:
  # capture frames
  frames_directory = capture_frames(video_file_path=video_file, info_dict=info_dict)
  display_video_frames(frames_directory=frames_directory)

  image_dir = frames_directory
  return generate_prediction(image_dir=image_dir)


def video_processor(video_file: str) -> str:
  info_dict = {"id": os.path.basename(video_file).split(".")[0]}
  return video_helper(video_file=video_file, info_dict=info_dict)


def url_processor(url: str) -> str:
  video_file, info_dict = extract_youtube_video(url=url)
  return video_helper(video_file=video_file, info_dict=info_dict)

In [None]:
# validation
# url_processor(url="https://www.youtube.com/watch?v=ADt1LnbL2HI")

In [23]:
# gradio
!pip install -q gradio

In [36]:
import gradio as gr


with gr.Blocks() as demo:
  gr.Markdown("# GeoLocator")
  gr.Markdown("An app that guesses the location of an image 🌌, a video 📹 or a YouTube link 🔗.")
  with gr.Tab("Image"):
    with gr.Row():
      img_input = gr.Image(type="filepath")
      img_text_output = gr.Textbox(label="Location")
    img_text_button = gr.Button("Go locate!")
  with gr.Tab("Video"):
    with gr.Row():
      video_input = gr.Video(type="filepath")
      video_text_output = gr.Textbox(label="Location")
    video_text_button = gr.Button("Go locate!")
  with gr.Tab("YouTube Link"):
    with gr.Row():
      url_input = gr.Textbox(label="YouTube video link")
      url_text_output = gr.Textbox(label="Location")
    url_text_button = gr.Button("Go locate!")

  img_text_button.click(img_processor, inputs=img_input, outputs=img_text_output)
  video_text_button.click(video_processor, inputs=video_input, outputs=video_text_output)
  url_text_button.click(url_processor, inputs=url_input, outputs=url_text_output)

  examples = gr.Examples(examples=["https://www.youtube.com/watch?v=wxeQkJTZrsw"], inputs=[url_input])

demo.launch()



Colab notebook detected. To show errors in colab notebook, set `debug=True` in `launch()`
Running on public URL: https://21030.gradio.app

This share link expires in 72 hours. For free permanent hosting, check out Spaces: https://huggingface.co/spaces


(<gradio.routes.App at 0x7efc66b52f50>,
 'http://127.0.0.1:7863/',
 'https://21030.gradio.app')