In [None]:
%cd ../..
%load_ext autoreload

%autoreload 2

In [None]:
from emv.db.dao import DataAccessObject
from emv.db.queries import get_features_by_type_paginated, count_features_by_type
import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from ast import literal_eval
import numpy as np
from emv.api.models import Feature
from emv.api.models import Projection, MapProjectionFeatureCreate
from emv.db.queries import create_projection, create_map_projection_feature, create_feature
from emv.io.media import create_square_atlases
from umap import UMAP
import numba
import cv2
from PIL import Image

from emv.db.queries import get_all_media_by_library_id, get_library_id_from_name, get_library_from_name, check_media_exists
from emv.storage.storage import get_storage_client
from emv.features.image import embed_images

In [None]:
MAX_FEATURES = 100000 #count_features_by_type("transcript+ner", short_clips_only=True) + 1
data = get_features_by_type_paginated("transcript+ner", page_size=10000, short_clips_only=True)

while len(data) < MAX_FEATURES:
    last_seen_id = data[-1].get("feature_id", None)
    if last_seen_id is None:
        break
    data.extend(get_features_by_type_paginated("transcript+ner", page_size=10000, last_seen_feature_id=last_seen_id, short_clips_only=True))

df = pd.DataFrame(data)
print(f"Retrieved {len(df)} instances")

In [None]:
df["locations"] = df["data"].map(lambda x: [w[0] for w in x["entities"] if w[1] == "LOC"])

In [None]:
# Manual matching
with open("emv/features/cities.json", "r") as f:
    cities = json.load(f)
    
locations = pd.DataFrame([{"locations":k, "lon":float(v[0]), "lat":float(v[1])} for k,v in cities.items() if len(v) == 2])

In [None]:
found_locations = locations.locations.values
df = df[df.locations.map(lambda x: any([l in found_locations for l in x]))]
print(f"Filtered to {len(df)} instances")

In [None]:
df = df[["data", "feature_id", "media_id", "locations"]]
df["locations"] = df["locations"].map(lambda x: list(set([l for l in x if l in found_locations])))
df["geo_coords"] = df["locations"].map(lambda x: [cities[l] for l in x])
df = df.explode(["locations", "geo_coords"])

In [None]:
df

## Get thumbnails and create atlases

In [None]:
lib_id = get_library_id_from_name("rts")

max_medias = 500000
thumbnails = get_all_media_by_library_id(lib_id, media_type="image", sub_type="screenshot", page_size=1000)

while len(thumbnails) < max_medias:
    new_medias = get_all_media_by_library_id(lib_id, 
                                             media_type="image", 
                                             sub_type="screenshot", 
                                             page_size=1000, 
                                             last_seen_media_id=thumbnails[-1]["media_id"], 
                                             last_seen_date=thumbnails[-1]["created_at"])
    if len(new_medias) == 0:
        break
    thumbnails.extend(new_medias)

thumbnails = pd.DataFrame(thumbnails)
print(f"Found {len(thumbnails)} images")

In [None]:
thumbnails = thumbnails[["parent_id", "media_path"]].groupby("parent_id").agg(list).reset_index()
thumbnails["media_path"] = thumbnails["media_path"].map(lambda x: x[0])
thumbnails.head()

In [None]:
df = df.merge(thumbnails, left_on="media_id", right_on="parent_id", how="left")
df = df.drop(columns=["parent_id"])
df.dropna(subset=["media_path"], inplace=True)
df.shape

In [None]:
storage_client = get_storage_client()

def get_thumbnail(media_path):
    frame_bytes = storage_client.get_bytes("rts", media_path)
    if type(frame_bytes) == bytes:
        frame = cv2.imdecode(np.frombuffer(frame_bytes, np.uint8), -1)
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    else:
        frame = None
        
    return frame

df["thumbnail"] = df.media_path.map(lambda x: get_thumbnail(x))

In [None]:
df = df.dropna(subset=["thumbnail"])

In [None]:
# Check map
plt.scatter(df["geo_coords"].map(lambda x: float(x[1])), df["geo_coords"].map(lambda x: float(x[0])), s=1, marker="o")
plt.scatter(6.15, 46.2, s=100, marker="x", color="red", label="Geneva")
plt.scatter(8.5, 47.4, s=100, marker="x", color="green", label="Zurich")
plt.legend()
plt.show()

# Create Features, Atlases and Projection

## Features

In [None]:
df.head()

In [None]:
df["feature_id"] = df.apply(lambda x: create_feature(Feature(
                                                        feature_type='locations',
                                                        version="1",
                                                        model_name='transcript+ner+geolocation',
                                                        model_params={},
                                                        data={
                                                            "location": x["locations"],
                                                            "geo_coords": x["geo_coords"],
                                                            "media_path": x["media_path"]
                                                            },
                                                        media_id=x['media_id']
                                                    )), axis=1)

In [None]:
df["feature_id"] = df.feature_id.map(lambda x: x["feature_id"])

## Atlases

**Note**: the same clip can mention multiple locations. Since the mapping is based on the locations, the same clip can appear multiple times.
In the Atlases, we don't need to duplicate the thumbnails.

In [None]:
clips = df[["media_id", "media_path", "thumbnail"]].groupby("media_id").agg(list).reset_index()
clips["media_path"] = clips["media_path"].map(lambda x: x[0])
clips["thumbnail"] = clips["thumbnail"].map(lambda x: x[0])
clips.head(2)

In [None]:
total_tiles = len(clips) # either all features or a subset of features
atlas_width = 4096
max_tile_size = 512
max_tiles_per_atlas = (atlas_width // max_tile_size) ** 2
atlas_count = int(total_tiles / max_tiles_per_atlas) + 1

In [None]:
# Create the projection, replace the names with the desired ones
projection = Projection(
    projection_name="RTS locations",
    version="1",
    library_id=get_library_id_from_name("rts"),
    model_name="whisper+spacy",
    model_params={},
    data={},
    dimension=3,
    atlas_folder_path="",
    atlas_width=atlas_width,
    tile_size=max_tile_size,
    atlas_count=atlas_count,
    total_tiles=total_tiles,
    tiles_per_atlas=max_tiles_per_atlas,
)

projection_id = create_projection(projection)['projection_id']
print(f"Projection ID: {projection_id}")

In [None]:
images = clips.thumbnail.values.tolist()
images = [Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) for img in images]

In [None]:
square_atlases = create_square_atlases(atlas_name="atlas_rts_locations",
                                       projection_id=projection_id, 
                                       images=images, 
                                       width=atlas_width, 
                                       max_tile_size=max_tile_size, 
                                       no_border=True)

In [None]:
clips["atlas_order"] = clips.index // max_tiles_per_atlas
clips["index_in_atlas"] = clips.index % max_tiles_per_atlas

# Merge with locations df
df = df.merge(clips[["media_id", "atlas_order", "index_in_atlas"]], on="media_id", how="left")

## Projection

In [None]:
for i, row in df.iterrows():
    create_map_projection_feature(MapProjectionFeatureCreate(
        projection_id=projection_id,
        media_id=row.media_id,
        atlas_order=row.atlas_order,
        index_in_atlas=row.index_in_atlas,
        coordinates=[row.geo_coords[0], row.geo_coords[1], 0],
        feature_id=row.feature_id
    ))