In [None]:
%cd ../..
%load_ext autoreload

%autoreload 2

In [70]:
import os
import random
import textwrap as tw
from tqdm.notebook import tqdm
import numpy as np
import pandas as pd
pd.options.mode.chained_assignment = None
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
from PIL import Image
from ast import literal_eval
from scipy.spatial.distance import pdist, squareform
import numba
from emv.db.dao import DataAccessObject
from sqlalchemy.sql import text

from umap import UMAP
from emv.embeddings.dr_eval import plot_embeddings_with_images, normalize_embedding

from emv.db.queries import get_all_media_by_library_id, get_library_id_from_name, get_library_from_name
from emv.storage.storage import get_storage_client
from emv.features.image import embed_images

# Load clips from DB

In [None]:
lib_id = get_library_id_from_name("rts")

max_medias = 10000
medias = get_all_media_by_library_id(lib_id, media_type="image", page_size=100)

while len(medias) < max_medias:
    new_medias = get_all_media_by_library_id(lib_id, media_type="image", page_size=100, last_seen_media_id=medias[-1]["media_id"], last_seen_date=medias[-1]["created_at"])
    if len(new_medias) == 0:
        break
    medias.extend(new_medias)

medias = pd.DataFrame(medias)
print(f"Found {len(medias)} images")

In [120]:
storage_client = get_storage_client()

def get_thumbnail(media_path):
    frame_bytes = storage_client.get_bytes("rts", media_path)
    if type(frame_bytes) == bytes:
        frame = cv2.imdecode(np.frombuffer(frame_bytes, np.uint8), -1)
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        
    return frame

medias["thumbnail"] = medias.media_path.map(lambda x: get_thumbnail(x))

In [121]:
def batch(iterable, n=1):
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]
        
features = []
for b in batch(medias.thumbnail.values.tolist(), 500):
    features.append(embed_images(b))

features = np.concatenate(features)

In [None]:
reducer = UMAP(n_components=2, n_neighbors=100, min_dist=0.1, metric="euclidean", random_state=42)
embedding = reducer.fit_transform(features)

In [123]:
embedding = normalize_embedding(embedding)

In [125]:
thumbnails = medias.thumbnail.values
thumbnails = [cv2.resize(t, (64, 64)) for t in thumbnails]

In [None]:
plot_embeddings_with_images(embedding, thumbnails, zoom = 0.07, figsize=16)

# Create new Projection based on the visual features

In [143]:
from emv.api.models import Feature
from emv.api.models import Projection, MapProjectionFeatureCreate
from emv.db.queries import create_projection, create_map_projection_feature, create_feature
from emv.io.media import create_square_atlases

## Save features to DB

In [None]:
medias["embedding"] = features.tolist()

medias["feature_id"] = medias.apply(lambda row: create_feature(Feature(feature_type='rts_visual',
                                                                       version="1",
                                                                       model_name='resnet50',
                                                                       model_params={},
                                                                       data={},
                                                                       media_id=row['media_id'], 
                                                                       embedding_size=2048,
                                                                       embedding_2048=row['embedding']
                                                                    )
                                                               ), axis = 1)
medias["feature_id"] = medias["feature_id"].map(lambda x: x.feature_id)

## Create Projection

In [None]:
total_tiles = len(medias) # either all features or a subset of features
atlas_width = 4096
max_tile_size = 512
max_tiles_per_atlas = (atlas_width // max_tile_size) ** 2
atlas_count = int(total_tiles / max_tiles_per_atlas) + 1

# Create the projection, replace the names with the desired ones
projection = Projection(
    projection_name="RTS Visual 3D",
    version="0.0.1",
    library_id=get_library_id_from_name("rts"),
    model_name="resnet50",
    model_params={},
    data={},
    dimension=3,
    atlas_folder_path="",
    atlas_width=atlas_width,
    tile_size=max_tile_size,
    atlas_count=atlas_count,
    total_tiles=total_tiles,
    tiles_per_atlas=max_tiles_per_atlas,
)

projection_id = create_projection(projection)['projection_id']
print(f"Projection ID: {projection_id}")

## Create Atlases

In [153]:
images = medias.thumbnail.values.tolist()
images = [Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) for img in images]

In [156]:
square_atlases = create_square_atlases(atlas_name="atlas_rts_visual",
                                       projection_id=projection_id, 
                                       images=images, 
                                       width=atlas_width, 
                                       max_tile_size=max_tile_size, 
                                       no_border=True)

## Compute embeddings in 3D and save to DB

In [None]:
reducer = UMAP(n_components=3, n_neighbors=100, min_dist=0.1, metric="euclidean", random_state=42)
embedding_3d = reducer.fit_transform(medias_with_features.embedding_2048.tolist())

In [None]:
fig = plt.figure(figsize=(6, 6))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(embedding_3d[:, 0], embedding_3d[:, 1], embedding_3d[:, 2], s=0.1)
plt.show()

In [188]:
# Create an entry in the map_projection_feature table for each feature, links features, media and coordinates
for i, row in medias.iterrows():
    create_map_projection_feature(MapProjectionFeatureCreate(
        projection_id=projection_id,
        media_id=row.media_id,
        atlas_order=i // max_tiles_per_atlas,
        index_in_atlas=i % max_tiles_per_atlas,
        coordinates=[embedding_3d[i, 0], embedding_3d[i, 1], embedding_3d[i, 2]],
        feature_id=row.feature_id
    ))