# Implementing semantic video search using open-source Large Vision Models on Amazon SageMaker and OpenSearch Serverless

---
This notebook supplements _[Implementing semantic video search using Large Vision Models on Amazon SageMaker and OpenSearch Serverless](https://link-to-blog)_ blog post and contains code samples for deploying and operating the proposed semantic video search architecture with Large Vision Models on Amazon SageMaker.

---

## Contents

1. [Architecture overview](#sec-1)
2. [Setup the environment](#sec-2)
3. [Setup and create Amazon OpenSearch Serverless (AOSS) collection](#sec-3)
4. [Select the Large Vision Model (LVM) and configure indexing parameters](#sec-4)
5. [Create an AOSS vector index](#sec-5)
6. [Download a few sample videos for testing](#sec-6)
7. [Locally embed video and perform semantic text search for testing](#sec-7)
8. [Deploy video indexing and video search endpoints with SageMaker](#sec-8)
9. [Send videos to SageMaker asyncronous endpoint for indexing](#sec-9)
10. [Search AOSS index with text/image queries and visualize results](#sec-10)
11. [Clean up resource](#sec-11)
  

---

## 1. Architecture overview

In this notebook, we demostrate how to deploy and operate the semantic video search architecture (shown below). While the solution is suitable for any open-source Language Vision Model (LVM), the sample code in this demo can leverage an version of **CLIP** or **SigLIP** model families from Hugging Face Model Hub. We leverage **SageMaker Asyncronous Inference** endpoints to embed video frames and ingest them into an **Amazon OpenSearch Serverless** vector index. Using SageMaker asyncronous inference endpoints allows us to handle typically large video payloads and scale resources down to zero, when there are no new videos to index. The vector search collections in Amazon OpenSearch Serverless (AOSS) provides a similarity search capability that is scalable and high performing.

<a id='sec-1'></a>

![Architecture](doc/architecture.png)

---

## 2. Setup the environment

<a id='sec-2'></a>

### 2.1 Install required packages

In [None]:
!pip install -q -r requirements.txt

### 2.2 Install Git LFS and pull sample video files (if needed)

This repo includes a few sample videos stored remotely as [Git LFS](https://docs.github.com/en/repositories/working-with-files/managing-large-files/collaboration-with-git-large-file-storage) objects. If you did not have Git LFS client installed when pulling this repo, the sample video files might not have been pulled from Git LFS. If this is the case, please install Git LFS and pull video samples by executing the cells below:

In [None]:
%%sh

# optional: update system packages in Amazon SageMaker Studio Ubuntu environment
sudo bash -c 'export DEBIAN_FRONTEND=noninteractive && apt-get -qq update -y && apt-get -qq upgrade -y'

# install system packages
sudo bash -c 'export DEBIAN_FRONTEND=noninteractive && apt-get -qq install -y git git-lfs'

In [None]:
%%sh

# pull sample video files from Git LFS
git lfs pull

### 2.3 Setup clients and global variables

In [None]:
# To keep the cell outputs clean, let's suppress non-critical logging messages from SageMaker libs
import logging
logging.getLogger('sagemaker.config').setLevel(logging.CRITICAL)
logging.getLogger('sagemaker').setLevel(logging.CRITICAL)

In [None]:
import os
import json
import time

import boto3
import sagemaker

# Create boto3 session, set AWS region and get account ID
boto_session = boto3.Session()
region = boto_session.region_name
account_id = boto3.client('sts').get_caller_identity().get('Account')

# Create SageMaker session, setup S3 bucket, and get IAM role
bucket_prefix = 'video-search-demo'
sm_session = sagemaker.session.Session(boto_session=boto_session, default_bucket_prefix=bucket_prefix)
bucket_name = sm_session.default_bucket()
iam_role = sagemaker.get_execution_role(sagemaker_session=sm_session)

# Define SageMaker instance types for video indexing and search endpoints
deploy_instance_type_index = 'ml.g5.2xlarge'
deploy_instance_type_search = 'ml.c5.2xlarge'

# Define Amazon OpenSearch Serverless client and collection/index names
aoss_client = boto3.client('opensearchserverless')
aoss_collection_name = 'semantic-video-collection'
aoss_index_name = 'frame-index'

# Some temporary local paths
local_video_dir = 'sample_videos'
local_sagemaker_artifact_dir = 'sagemaker_artifact'
local_sagemaker_artifact_tarball = 'model.tar.gz'

###
print("AWS Account ID:", account_id)
print("AWS Region:", region)
print("IAM Role:", iam_role)
print("Bucket name:", bucket_name)
print("Bucket prefix:", bucket_prefix)

### 2.4 Attach required permission to IAM role / user

The AWS identity you assume from your notebook environment (which is the [*Studio/notebook Execution Role*](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html) from SageMaker, or could be a role or IAM User for self-managed notebooks), must have sufficient [AWS IAM permissions](https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html) to read/write from/to your S3 bucket and operate the OpenSearch Serverless service.

To grant these permissions to your identity, you can:

- Open the [AWS IAM Console](https://us-east-1.console.aws.amazon.com/iam/home?#)
- Find your [Role](https://us-east-1.console.aws.amazon.com/iamv2/home?#/roles) (if using SageMaker or otherwise assuming an IAM Role), or else [User](https://us-east-1.console.aws.amazon.com/iamv2/home?#/users)
- Select *Add Permissions > Create Inline Policy* to attach the required permissions, open the *JSON* editor and paste in the below example policy:

> ⚠️ **Important:** Replace `<bucket_name>` with the name of your Amazon S3 bucket (see printouts of the previous cell).

```json
{
    "Version": "2012-10-17",
    "Statement": [
        {
            "Action": [
                "s3:GetObject",
                "s3:PutObject",
                "s3:AbortMultipartUpload",
                "s3:ListBucket"
            ],
            "Effect": "Allow",
            "Resource": "arn:aws:s3:::<bucket_name>/*"
        },
        {
          "Action": [
                "aoss:CreateCollection",
                "aoss:ListCollections",
                "aoss:DeleteCollection",
                "aoss:CreateAccessPolicy",
                "aoss:DeleteAccessPolicy",
                "aoss:CreateSecurityPolicy",
                "aoss:DeleteSecurityPolicy",
                "aoss:APIAccessAll"
          ],
          "Effect": "Allow",
          "Resource": "*"
        }
    ]
    
}
```

> ⚠️ **Note:** With Amazon SageMaker, your notebook execution role will typically be *separate* from the user or role that you log in to the AWS Console with. If you'd like to explore the OpenSearch Serverless collections/indices from AWS Console, you'll need to grant revelant permissions to your Console user/role too.

---

## 3. Setup and create Amazon OpenSearch Serverless (AOSS) collection

<a id='sec-3'></a>

You need to create policies for the OpenSearch resource. For the overall guidance on OpenSearch security and data access control refer to the corresponding section in the OpenSearch documentation. You need to create a security policy that would enforce encryption, a network policy and access policy. You can use existing methods from our code below for programmatic approach in OpenSearch Serverless or use the AWS console for that.

First, we will use boto3-client for AOSS to create encryption, network, and data access policies and associate them with the AOSS collection for vector search, that we will also create later.

### 3.1 Create security components of AOSS

In [None]:
# Create an encryption policy that matches our AOSS collections name
aoss_encryption_policy = aoss_client.create_security_policy(
    name=aoss_collection_name + '-ep',
    type='encryption',
    policy=json.dumps(
        {
            'Rules': [
                {
                    'Resource': ['collection/' + aoss_collection_name],
                    'ResourceType': 'collection'
                }
            ],
            'AWSOwnedKey': True
        }
    )
)

print('Encryption policy created:')
print(aoss_encryption_policy)

In [None]:
# Create a network policy that matches our AOSS collections name
aoss_network_policy = aoss_client.create_security_policy(
    name=aoss_collection_name + '-np',
    type='network',
    policy=json.dumps(
        [
            {
                'Rules': [
                    {
                        'Resource': ['collection/' + aoss_collection_name],
                        'ResourceType': 'collection'
                    }
                ],
                'AllowFromPublic': True
            }
        ]
    )
)

print('Network policy created:')
print(aoss_network_policy)

> ⚠️ **Note:** _in order to keep setup overhead at mininum, this proof-of-concept implementation **allows public internet access** to the OpenSearch Serverless collection resource. However, for production environments we strongly suggest to leverage private connection between your VPC and Amazon OpenSearch Serverless resources via an VPC endpoint, as described [here](https://docs.aws.amazon.com/opensearch-service/latest/developerguide/serverless-vpc.html)._

In [None]:
# Create a data access policy that matches our AOSS collections name and grants our IAM role access.
aoss_data_policy = aoss_client.create_access_policy(
    name=aoss_collection_name + '-dp',
    type='data',
    policy=json.dumps(
        [
            {
                'Rules': [
                    {
                        'Resource': ['collection/' + aoss_collection_name],
                        'Permission': [
                            'aoss:CreateCollectionItems',
                            'aoss:DeleteCollectionItems',
                            'aoss:UpdateCollectionItems',
                            'aoss:DescribeCollectionItems'],
                        'ResourceType': 'collection'
                    },
                    {
                        'Resource': ['index/' + aoss_collection_name + '/*'],
                        'Permission': [
                            'aoss:CreateIndex',
                            'aoss:DeleteIndex',
                            'aoss:UpdateIndex',
                            'aoss:DescribeIndex',
                            'aoss:ReadDocument',
                            'aoss:WriteDocument'],
                        'ResourceType': 'index'
                    }],
                'Principal': [iam_role],
                'Description': 'My custom easy data policy'
            }
        ]
    )
)

print('Data policy created:')
print(aoss_data_policy)

### 3.2 Create AOSS collection for vector seach

With all AOSS security components properly set up, we can now create an AOSS collection for vector search.

In [None]:
# Request to create collection
aoss_collection = aoss_client.create_collection(name=aoss_collection_name, type='VECTORSEARCH')
aoss_collection_id = aoss_collection['createCollectionDetail']['id']

# Wait until collection becomes active
while True:
    aoss_reply = aoss_client.list_collections(collectionFilters={'name': aoss_collection_name})
    aoss_status = aoss_reply['collectionSummaries'][0]['status']
    if aoss_status in ('ACTIVE', 'FAILED'):
        print('!')
        break
    print('-', end='')
    time.sleep(5)

print("An AOSS collection created with collection ID:", aoss_collection_id)

---

## 4. Select the Large Vision Model (LVM) and configure indexing parameters

<a id='sec-4'></a>

This semantic video search architecture is flexible enough to, in principle, leverage any type of open-source LVMs that can generate text and image embeddings (see, for example, this [list of zero-shot image classification models](https://huggingface.co/models?pipeline_tag=zero-shot-image-classification)). However, the deployable sample code in this demo has been tailored to use either of these two LVM model families available on Hugging Face (HF) Model Hub: **CLIP** (see all models [here](https://huggingface.co/models?library=open_clip)), or **SigLIP** (see all models [here](https://huggingface.co/models?other=siglip)).

### 4.1 Inspect the config file
Let's choose a concrete model version from **CLIP** or **SigLIP** model families that we want to deploy by specifying its name from HF Model Hub in the `config.yaml` file. Besides specifying the name of a vision model, there are a few other parameters that we'd need to define before deploying the SageMaker endpoints (marked with `<TO_BE_PROVIDED>` placeholders). Let's have a look at the config template:

In [None]:
from omegaconf import OmegaConf

config = OmegaConf.load('src/config_template.yaml')
print(OmegaConf.to_yaml(config))

### 4.2 Specify the LVM model name and its embedding dimensions

There are plenty of models to choose from (see, e.g., this list for [CLIP models](https://huggingface.co/models?library=open_clip), or this list for [SigLIP models](https://huggingface.co/models?other=siglip)) and you should typically test a few different models and see how they perform on your particular use-case in terms of search quality and latency requirements. Here are just a few suggestions:

- `laion/CLIP-ViT-B-32-laion2B-s34B-b79K` (embedding dimensions: `512`) - one of the *base* OpenCLIP models
- `laion/CLIP-ViT-H-14-laion2B-s32B-b79K` (embedding dimensions: `1024`) - one of the *best* OpenCLIP models
- `google/siglip-base-patch16-224` (embedding dimensions: `768`) - one of the *base* SigLIP models
- `google/siglip-so400m-patch14-384` (embedding dimensions: `1152`) - one of the *best multi-lingual* SigLIP models

As a default for this demo, we'll go with the `google/siglip-base-patch16-224` model in order to prioritize indexing and lookup speed over quality. However, SigLIP models tend to outperfrom CLIP models of similar sizes and pre-training procedures (see, e.g., _[Sigmoid Loss for Language Image Pre-Training](https://arxiv.org/pdf/2303.15343.pdf)_ by *Zhai et al.* for more details), so it is fair compromise. You are, of course, welcome to explore other LVMs. 

> ⚠️ **Important:** When experimenting with other LVMs, **don't forget to edit the model's embedding dimension** in the `config.yaml` accordingly.

In [None]:
# Specify our vision model and it output dimensions:
config.model_name = 'google/siglip-base-patch16-224'
config.model_embedding_dim = 768

# Specify AOSS resources to be used by SageMaker endpoints:
config.aws_region = region
config.opensearch.collection_id = aoss_collection_id
config.opensearch.index_name = aoss_index_name

Finally, let's save the final config back as `config.yaml`:

In [None]:
OmegaConf.save(config, 'src/config.yaml')

print("Saved the following config:", "\n -------")
print(OmegaConf.to_yaml(config))

---

## 5. Create an AOSS vector index

<a id='sec-5'></a>

Now that we know embedding dimensions of the vectors coming out of selected LVM model, we can create the AOSS index:

In [None]:
from requests_aws4auth import AWS4Auth
from opensearchpy import OpenSearch, RequestsHttpConnection

# Use default credential configuration for authentication
credentials = boto_session.get_credentials()
awsauth = AWS4Auth(
    credentials.access_key,
    credentials.secret_key,
    region,
    'aoss',
    session_token=credentials.token)

# Construct AOSS endpoint host
host = f"{aoss_collection_id}.{region}.aoss.amazonaws.com"

# Build the OpenSearch client
os_client = OpenSearch(
    hosts=[{'host': host, 'port': 443}],
    http_auth=awsauth,
    use_ssl=True,
    verify_certs=True,
    connection_class=RequestsHttpConnection,
    timeout=300
)

In our code we provide a method to create a vector database index in OpenSearch Serverless. For our case we chose three parameters for each frame to send to OpenSearch: `timestamp` (time indication in seconds of the picked frame), `video_id` (video name) and `frame_vector` (vector embeddings generated by the chosen model). Feel free to add any further parameters depending on your business use case. 

In OpenSearch we use [Approximate k-NN (ANN)](https://opensearch.org/docs/latest/search-plugins/knn/approximate-knn/) search as it can restructure indexes more efficiently compared to Exact k-NN and reduce the dimensionality of searchable vectors. OpenSearch [index space setting](https://opensearch.org/docs/latest/search-plugins/knn/approximate-knn/#spaces) needs to be set to `cosinesimil` when creating an index to enforce search for the nearest vectors.

In [None]:
# Define the configuration for the AOSS vector index
index_body = {
    "mappings": {
        "properties": {
            "timestamp": {"type": "text"},
            "video_id": {"type": "text"},
            "frame_vector": {
                "type": "knn_vector",
                "dimension": config.model_embedding_dim,
                "method": {
                    "engine": "nmslib",
                    "space_type": "cosinesimil",
                    "name": "hnsw",
                    "parameters": {"ef_construction": config.model_embedding_dim, "m": 16}
                }
            }
        }
    },
    "settings": {
        "index": {
            "number_of_shards": 2,
            "knn.algo_param": {"ef_search": config.model_embedding_dim},
            "knn": True
        }
    }
}

# Create AOSS index
response = os_client.indices.create(aoss_index_name, body=index_body)

print('A new index created:')
print(response)

---

## 6. Download a few sample videos for testing

<a id='sec-6'></a>

In [None]:
# Define our video library for testing
video_samples_youtube = {
    'fashionshow': 'https://www.youtube.com/watch?v=0Py6W56LMK4',
    'formulaone': 'https://www.youtube.com/watch?v=6UsInj7lNOk',
    'airplanes': 'https://www.youtube.com/watch?v=mCTb7fcPhqU',
    'trucking': 'https://www.youtube.com/watch?v=itY07uKxiYQ'
}

In [None]:
from pytube import YouTube

def download_youtube_video(vid_dir, vid_name, vid_url):
    os.makedirs(vid_dir, exist_ok=True)
    yt = YouTube(vid_url)
    yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first().download(
        output_path=vid_dir, filename=vid_name)
    return os.path.join(vid_dir, vid_name)

In [None]:
video_samples_local = {}

# Download, if videos are yet available locally
for video_name, video_url in video_samples_youtube.items():
    video_file = video_name + '.mp4'
    video_path = os.path.join(local_video_dir, video_file)
    
    if not os.path.isfile(video_path):
        video_path = download_youtube_video(
            vid_dir=local_video_dir,
            vid_name=video_file,
            vid_url=video_url
        )
    
    video_samples_local[video_name] = video_path
    print("Video available:", video_path)

In [None]:
from IPython.display import Video

# Let's take the the last video an review it:
Video(video_path, width=900)

---

## 7. Locally embed video and perform semantic text search for testing

<a id='sec-7'></a>

In order to outline the main concepts behind our LVM-based approach for semantic video search, let's first locally embed a video and illustrate a search procedure with a couple of text prompts. 

> ⚠️ **Note:** In order to keep code duplication for this local demo as small as possible, we will be directly calling a few functions from the `inference.py` script and a few other helper libraries from our deployable package, that we will later actually host on a SageMaker endpoint. This is also a *good practice* to test parts of any custom inference scripts *locally* before deploying to SageMaker endpoints.

### 7.1 Import inference module

In [None]:
# As `inference` module initialize parameters from the accompanying config-file, let's set its local path for testing here
os.environ["CONFIG_FILE"] = 'src/config.yaml'

# Import our `inference.py` script and a few other helper libraries as Python modules
from src import inference, processing_funcs, video_loader

# Check that our config-file was read successully in inference module
print(inference.CFG)

### 7.2 Define LVM models from config

Any custom inference script for SageMaker hosting must implement a few handler functions. One of these is the `model_fn` that defines how and where to load the ML model by the model server. Concretely, once we deploy our inference scripts as SageMaker endpoint, the SageMaker PyTorch model server will load our model by invoking a `model_fn` function (see, e.g., [this guide](https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#load-a-model) for more details).

In [None]:
# Call `model_fn` to load the text and image preprocessing and encoder functions (as specified in config-file).
model = inference.model_fn(model_dir=None)

print("Model and runtime specification loaded:")
print(model)

### 7.2 Embed video locally

Now we can embed the video locally by calling the same `embed_and_index_video` function from `inference` module (that would be invoked on SageMaker endpoint to process an incoming video). Note that we set `do_index=False` to disable indexing with AOSS and instead keep all video embedding locally for our tests here. See `inference.py` for other implementation details.

In [None]:
# Pick a video from our video library (defined earlier) to vectorize locally
vid_name = 'trucking'

# Embed video and keep frame vectors locally
with open(video_samples_local[vid_name], 'rb') as vid_bytes:
    vid_embs, vid_times, vid_inds = inference.embed_and_index_video(model, vid_bytes.read(), vid_name, do_index=False)

print("Frame embedding shape:", vid_embs.shape)
print("Frame timestamps shape:", vid_times.shape)
print("Frame indices shape:", vid_inds.shape)

### 7.3 Embed text prompts locally

In [None]:
# Define prompt keys and embed text prompts
prompt_keys = ['dog', 'truck', 'cables', 'sketch on whiteboard', 'red truck']
prompt_template = 'a photo of a %s'
prompts = [prompt_template % k for k in prompt_keys]

text_embs = processing_funcs.get_text_embeddings(
    prompts,
    text_processor=model.processor_text,
    text_model=model.model_text,
    device=inference.CFG.device,
    return_tensors=True
)

print("Prompt keys:", prompt_keys)

### 7.3 Calculate and plot search signal

In [None]:
# Calculate search signal scores
scores = (vid_embs @ text_embs.T).cpu().numpy()

print("Search signal shape:", scores.shape)

#### Plot search signal scores for each prompt key

Let's visualize the raw search signal scores for each prompt keys. We'll use `plotly` library for interactive plotting so that you can explore signal scores closely. Note that the scores generated locally and shown on the chart differ from OpenSearch as OpenSearch provides its own search scores and not the k-NN similarity scores. 

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import plotly.express as px

fig = px.line(pd.DataFrame(scores, columns=prompt_keys), title=vid_name)
fig.show()

#### Plot search hits and print found frame samples

Let's illustrate a realistic search scenario by visualizing a few search hits for our text prompts. 

> ⚠️ **Note:** When deployed to a SageMaker endpoint, we will use a similar approach for semantic video search using AOSS vector index. However, there will be a few differences:
> - AOSS search results will **not** return *all* video frames, but only the ones with highest scores, according to the defined approx. nearest neighbor (ANN) search parameters;
> - we will **not** use `frame_distance` parameter as a scene deduplication step, but will be using frame clustering for this (see `src/search_utils.py` for more details).

In [None]:
# Define a few helper functions for plotting
import scipy
import numpy as np

def filter_hits(scores, top_k=None, **kwargs):
    peaks, properties = scipy.signal.find_peaks(scores, **kwargs)
    hit_scores = properties['peak_heights']
    _ids = hit_scores.argsort()[::-1]
    hits = peaks[_ids]
    hit_scores = hit_scores[_ids]
    if top_k:
        hits = hits[:top_k]
        hit_scores = hit_scores[:top_k]
    return hits, hit_scores

def plot_hits(scores, hit_indices, timestamps, threshold=None, title=''):
    _steps = len(scores)
    _x_coord = np.arange(_steps)
    hit_mask = np.zeros(_steps, dtype=bool)
    hit_mask[hit_indices] = True
    hit_times = timestamps[hit_indices]
    fig, ax = plt.subplots(figsize=(12, 4), dpi=200)
    ax.plot(scores)

    if threshold:
        plt.plot(_x_coord, 0 * _x_coord + threshold)

    def forward(x):
        return np.exp(20 * x)

    def inverse(x):
        return np.log(x + 1e-3) / 20
    
    plt.fill_between(_x_coord, hit_mask, alpha=0.2, color='red')
    plt.ylim(ymin=0.0, ymax=scores.max() * 1.1)
    plt.xticks(_x_coord[::500], timestamps[:_steps][::500].astype('int'))
    ax.set_yscale('function', functions=(forward, inverse))
    plt.xlabel('timestamp')
    plt.ylabel('score')
    plt.title(title)
    plt.show()

def plot_images(images):
    n = len(images)
    f = plt.figure(figsize=(5*len(images), 5))
    for i, img in enumerate(images):
        f.add_subplot(1, n, i + 1)
        plt.axis('off')
        plt.imshow(img)

In [None]:
# Define a few search parameters to filter results
min_score = 0.09       # - show/filter hits above this score
frame_distance = 20     # - show/filter hits that are sufficiently spaced apart
top_k = 5              # - show/filter only top K hits

for sc, key in zip(scores.T, prompt_keys):
    hit_indices, hit_scores = filter_hits(sc, height=min_score, top_k=top_k, distance=frame_distance)
    hit_indices_expanded = [inference.CFG.video_decoder.sampling_rate * ix for ix in hit_indices]
    frames, stamps = video_loader.get_frames_with_indices_jumping(video_samples_local[vid_name], hit_indices_expanded)
    plot_hits(sc, hit_indices, vid_times, threshold=min_score, title=f"The {len(hit_indices)} hit(s) for '{key}'")
    plot_images(frames)
    

---

## 8. Deploy video indexing and video search endpoints with SageMaker

<a id='sec-8'></a>

### 8.1 Package and upload SageMaker deployable asset to S3

In [None]:
%%bash -s $local_sagemaker_artifact_dir $local_sagemaker_artifact_tarball

echo "Staging all SageMaker deployable assets in '$1' dir:"
rm -rf $1
mkdir -vp $1
mkdir -vp $1/code
cp -v ./src/* $1/code

echo
echo "Packaging SageMaker deployable assets to '$2' tarball:"
rm $2
tar cvzf $2 -C $1/ . 

In [None]:
# Upload the deployable tarball package to S3
model_artifact = sm_session.upload_data(
    'model.tar.gz',
    bucket=bucket_name,
    key_prefix=f'{bucket_prefix}/model'
)

print("Model artifact:", model_artifact)

### 8.2 Define SageMaker model package

Let's collect all the pieces together and define a SageMaker model configuration that we will later deploy as an endpoint. Since for both video indexing and video search we require to have an identical LVM for consistent embeddings, we will have a single model package for both indexing and search workflows. This helps us to keep our codebase for both workflows [DRY](https://en.wikipedia.org/wiki/Don%27t_repeat_yourself) and automatically reflect any code changes (e.g. update of LVM) in both indexing and search arms of our architecture.

In [None]:
from sagemaker.pytorch import PyTorchModel

env_dict = {
    'TS_MAX_REQUEST_SIZE': '200000000',     # increase the TorchServe default max request size to 200MB
    'TS_MAX_RESPONSE_SIZE': '100000000', 
    'TS_DEFAULT_RESPONSE_TIMEOUT': '1000',
    'SAGEMAKER_MODEL_SERVER_WORKERS': '1',
    'CONFIG_FILE': 'config.yaml'
}

sm_model_name = f"lvm-{time.strftime('%Y-%m-%d-%H-%M-%S', time.gmtime())}"

sm_model = PyTorchModel(
    name=sm_model_name,
    role=iam_role,
    model_data=model_artifact,
    framework_version='2.0',
    py_version='py310',
    entry_point='inference.py',
    env=env_dict,
    sagemaker_session=sm_session,
)

### 8.3 Deploy SageMaker asyncronous inference endpoint for video indexing

Let's deploy the indexing workflow as SageMaker Asynchronous Inference endpoint, which allow to enqueue incoming requests and are ideal for workloads where both the request payload sizes and inference processing times are large (which is typical for large video payloads). Moreover, unlike SageMaker Real-Time Inference endpoints, with SageMaker Asynchronous Inference endpoints you can scale down the number of instances backing your asyncrounous endpoint down to zero. This allows us to deprovision the compute resources of the asyncronous endpoint when there are no traffic (i.e. no new videos to index) and pay only when payloads (i.e. new videos) arrive.

> ⚠️ **Note:** In order to keep this notebook compact, we do not set scaling policies for SageMaker Asyncronous Inference endpoints in this demo, but you can [read more about autoscaling policies for asyncronous endpoints here](https://docs.aws.amazon.com/sagemaker/latest/dg/async-inference-autoscale.html).

In [None]:
from sagemaker.async_inference import AsyncInferenceConfig

sm_endpoint_name_index = sm_model_name + '-endpoint-index'

async_config = AsyncInferenceConfig(
    output_path=f"s3://{bucket_name}/{bucket_prefix}/output",
    max_concurrent_invocations_per_instance=2,
)

sm_model.deploy(
    initial_instance_count=1,
    instance_type=deploy_instance_type_index,
    endpoint_name=sm_endpoint_name_index,
    async_inference_config=async_config
)

print("")
print("Deployed endpoint name for video indexing:", sm_endpoint_name_index)

### 8.4 Deploy SageMaker real-time inference endpoint for video search

As we typically expect (relatively) frequent search queries and have low latency requirements, we use SageMaker Real-Time Inference endpoints to handle the user queries and return semantic search results. We'll now deploy the same model package (as used for video indexing) to the real-time endpoint to facilitate quick video search.

In [None]:
sm_endpoint_name_search = sm_model_name + '-endpoint-search'

sm_model.deploy(
    initial_instance_count=1,
    instance_type=deploy_instance_type_search,
    endpoint_name=sm_endpoint_name_search
)

print("")
print("Deployed endpoint name for video search:", sm_endpoint_name_search)

---

## 9. Send videos to SageMaker asyncronous endpoint for indexing

<a id='sec-9'></a>

In [None]:
import urllib
from botocore.exceptions import ClientError

def upload_video(video_path, bucket, prefix):
    return sm_session.upload_data(
        video_path,
        bucket=bucket,
        key_prefix=prefix,
        extra_args={"ContentType": "video/mp4"}
    )

def get_async_process_output(output_location):
    output_url = urllib.parse.urlparse(output_location)
    bucket = output_url.netloc
    key = output_url.path[1:]
    while True:
        try:
            return sm_session.read_s3_file(bucket=bucket, key_prefix=key)
        except ClientError as e:
            if e.response["Error"]["Code"] == "NoSuchKey":
                print("-", end='')
                time.sleep(2)
                continue
            raise


def index_video(predictor_async, video_path, video_name):
    # uploading
    print('Uploading:', video_path)
    video_uri = upload_video(video_path, bucket_name, f'{bucket_prefix}/input')
    print('Video uploaded:', video_uri)

    # async process
    async_response = predictor_async.predict_async(
        input_path=video_uri,
        initial_args={'CustomAttributes': video_name}
    )

    output_location = async_response.output_path
    print('Indexing...')

    # waiting
    output = get_async_process_output(output_location)
    print('')
    print('Done!')
    for k, v in json.loads(output).items():
        print(f'{k}: {v}')

All is now set up to send our sample video files from S3 to the SageMaker Asyncronous Endpoint for processing with LVM and indexing to an AOSS vector index:

In [None]:
from sagemaker import Predictor
from sagemaker.predictor_async import AsyncPredictor

# Create an index endpoint client
sm_endpoint_predictor_index = AsyncPredictor(
    Predictor(
        endpoint_name=sm_endpoint_name_index,
        sagemaker_session=sm_session
    )
)

# Send a few sample videos for indexing
for video_name, video_path in video_samples_local.items():
    print("Processing video:", video_name)
    index_video(sm_endpoint_predictor_index, video_path, video_name)

---

## 10. Search AOSS index with text/image queries and visualize results

<a id='sec-10'></a>

When using OpenSearch API to look up vectors, you need to set parameters `k` for k-Nearest Neighbors and `search_size`. The `search_size` parameter defines the number of results which will be returned in the response, and the `k` parameter is the number of neighbors the search of each graph will return and impacts the performance of search (e.g. precision and recall). For example, decreasing `k` sacrifices recall, but increases search processing speeds significantly. `k` supports a maximum 10,000 nearest neighbors. For best quality results use `k` which is equal or greater to size when performing search.

> ⚠️ **Note:** To test search results _with_ or _without_ temporal smoothing, create different indexes in the same collection and restart video processing after switching smoothing ON or OFF.

In [None]:
import base64


def search_index(predictor, image=None, text=None, search_size=10, k=10, time_offset=1, video_name=None):
    """ Search for relevant video segments based on an input image or text query.
    
        Args:
            predictor (obj): An SageMaker SDK object that can make predictions based on the input data.
            image (str, optional): Path to the input image file. Defaults to None.
            text (str, optional): Input text query. Defaults to None.
            search_size (int, optional): Number of video segments to search. Defaults to 10.
            k (int, optional): Number of nearest neighbors to consider. Defaults to 10.
            time_offset (int, optional): Time offset for video segment search. Defaults to 1.
            video_name (str, optional): Name of the video file. Defaults to None.
    
        Returns:
            dict: A dictionary containing the search results.
    """
    if image:
        with open(image, 'rb') as f:
            image = base64.b64encode(f.read()).decode('utf-8')

    data = {
        "query": {
            "image": image,
            "text": text
        },
        "search_args": {
            "size": search_size,
            "k": k,
            "time_offset": time_offset,
            "name": video_name
        }
    }
    
    prediction = predictor.predict(data=data)
    return prediction

In [None]:
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer

endpoint_predictor_search = Predictor(
    endpoint_name=sm_endpoint_name_search, 
    sagemaker_session=sm_session, 
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer()
)

Now that we have a few videos indexed in our AOSS database, let's do our first natural language query:

In [None]:
search_index(
    predictor=endpoint_predictor_search,
    image=None,
    text="formula one",
    search_size=10, 
    k=10, 
    time_offset=1, 
    video_name=None
)

The output of the search function is a dictionary, where for each of the videos with matches we have a set of **temporal clusters**. Each temporal cluster represents a video segment, where the search query semantically matches the content of the video. The clusters are shorter if the match is based off just one or a few video frames, and longer if, for example, a larger video scene displays exactly what we are looking for (by matching multiple frames across the entire temporal cluster).

We need to implement a couple of functions to visualize search results.

In [None]:
from PIL import Image
from PIL import ImageDraw
import cv2
import glob


def image_grid(imgs, rows, cols):
    if not len(imgs) == rows*cols:
        raise ValueError("number of images must match the grid size (rows * cols)")
    
    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid


def get_indices(t_start, t_end, num_indices):
    indices = 30 * np.linspace(t_start, t_end, num=num_indices)
    indices = indices.astype('int')
    return list(indices)


def get_frames_from_timestamps(video_path, t_start, t_end, num_frames):
    output_frames = []
    output_timestamps = []
    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS)

    indices = np.linspace(t_start, t_end, num_frames)
    indices = (indices * fps).astype('int')

    while cap.isOpened():
        for i in indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, i)
            success, frame = cap.read()
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            output_frames.append(Image.fromarray(frame, mode='RGB'))
            output_timestamps.append(cap.get(cv2.CAP_PROP_POS_MSEC) / 1000)
            
        cap.release()
        return output_frames, output_timestamps
    else:
        raise Exception("Failed to open video '%s'!.." % video_path)



def visualize_results(results, video_dir=local_video_dir):
    all_videos = sorted(glob.glob('sample_videos/*.mp4'))

    results_local = {video_samples_local[k]: v for (k, v) in results.items() if k in video_samples_local}
    
    all_results = []
    for k, v_list in results_local.items():
        for v in v_list:
            all_results.append({
                'score': v['score'],
                'path': k, 
                't_start': v['t_start'],
                't_end': v['t_end']
            })

    all_results = sorted(all_results, key=lambda x: x['score'], reverse=True)
    # visualization
    # return all_results
    for x in all_results:
        title = f"{x['path'].split('/')[-1].replace('.mp4', '')}: {x['t_start']-0.5:.0f}s - {x['t_end']+0.5:.0f}s"
        print(title)
        frames, _ = get_frames_from_timestamps(x['path'], x['t_start']-0.5, x['t_end']+0.5, 5)

        central_frame = frames[len(frames) // 2]
        draw = ImageDraw.Draw(central_frame)
        width, height = central_frame.size
        draw.rectangle([(0, 0), (width, height)], outline='red', width=30)
        
        image = image_grid(frames, 1, 5)
        plt.figure(dpi=180)
        plt.imshow(image)
        plt.xticks([])
        plt.yticks([])
        plt.box(False)
        plt.show()

### A. Text Search Across Videos

Semantic search across all videos in AOSS vector index (no keyword filters or search terms):

In [None]:
results = search_index(
    predictor=endpoint_predictor_search,
    image=None,
    text="a photo of a dog",
    search_size=10, 
    k=10, 
    time_offset=2, 
    video_name=None
)


visualize_results(results)

### B. Text Search Within a Video

Hybrid search example (semantic and keyword search) to query scenes from a specific video:

In [None]:
results = search_index(
    predictor=endpoint_predictor_search,
    image=None,
    text="A photo of a sketch on whiteboard",
    search_size=10, 
    k=10, 
    time_offset=1, 
    video_name='trucking'
)
visualize_results(results)  

### C. Reverse Image Search

Reverse image search to query video segments that are simiar to the provided image sample:

<img src="sample_images/fashionshow.jpg" width="700px">

In [None]:
results = search_index(
    predictor=endpoint_predictor_search,
    image="sample_images/fashionshow.jpg",
    text=None,
    search_size=10, 
    k=10, 
    time_offset=3, 
    video_name=None
)
visualize_results(results)  

---

## 11. Clean up resources

<a id='sec-11'></a>

In [None]:
# Clean up AOSS resources
aoss_client.delete_collection(id=aoss_collection_id)
aoss_client.delete_access_policy(type="data", name=aoss_data_policy['accessPolicyDetail']['name'])
aoss_client.delete_security_policy(type="network", name=aoss_network_policy['securityPolicyDetail']['name'])
aoss_client.delete_security_policy(type="encryption", name=aoss_encryption_policy['securityPolicyDetail']['name'])

In [None]:
# Clean up SageMaker resources
sm_client = boto3.client('sagemaker')
sm_client.delete_model(ModelName=sm_model_name)
sm_client.delete_endpoint(EndpointName=sm_endpoint_name_index)
sm_client.delete_endpoint_config(EndpointConfigName=sm_endpoint_name_index)
sm_client.delete_endpoint(EndpointName=sm_endpoint_name_search)
sm_client.delete_endpoint_config(EndpointConfigName=sm_endpoint_name_search)

In [None]:
# Clean up S3 resources
s3_resource = boto3.resource('s3')
s3_bucket = s3_resource.Bucket(bucket_name)
s3_bucket.objects.filter(Prefix=bucket_prefix).delete()