In [None]:
# Install the packages
! pip3 install --upgrade google-cloud-aiplatform \
                         google-cloud-bigquery\
                         google-cloud-storage\
                         bigframes\
                         pandas-gbq\
                         db-dtypes

In [None]:
from google.colab import auth
auth.authenticate_user()

In [None]:
# BigQuery関連の設定
PROJECT_ID = ""  # @param {type:"string"} Google CloudプロジェクトID
REGION = "us-central1"  # @param {type:"string"} 使用するリージョン
BQ_DATASET_ID = ""  # @param {type:"string"} BigQueryのデータセットID
BQ_TABLE_ID = ""  # @param {type:"string"} BigQueryのテーブルID

# Feature Storeの設定
FEATURE_ONLINE_STORE_ID = ""  # @param {type:"string"} Feature StoreのオンラインストアID
FEATURE_VIEW_ID = ""  # @param {type:"string"} Feature StoreのビューID

# スケジュール設定
# スケジュールはCRON設定に基づいて作成されます。
# CRONが空の場合、即時スケジュールジョブが開始されます。
CRON_SCHEDULE = "TZ=Asia/Tokyo 0 9 * * *"  # @param {type:"string"} スケジュール設定（東京時間で毎日午前9時）

# ベクトル検索の設定
DIMENSIONS = 1408  # @param {type:"number"} ベクトルの次元数
EMBEDDING_COLUMN = "embedding"  # @param {type:"string"} 埋め込みを保持する列名

# オプショナル設定
LEAF_NODE_EMBEDDING_COUNT = 10000  # @param {type:"number"} リーフノードの埋め込み数（オプショナル）
FILTER_COLUMNS = ["title"]  # @param {type:"string"} フィルタリングに使用する列（オプショナル）

# Feature Store のデータソースとなる BigQuery テーブルを定義
BQ_TABLE_ID_FQN = f"{BQ_DATASET_ID}.{BQ_TABLE_ID}"
DATA_SOURCE = f"bq://{PROJECT_ID}.{BQ_TABLE_ID_FQN}"

# オンラインストアのエンドポイント
API_ENDPOINT = f"{REGION}-aiplatform.googleapis.com"

# オンラインストアのパブリックエンドポイント
PUBLIC_ENDPOINT="xxxxx-sample.vdb.vertexai.goog" # @param {type:"string"}

# 画像公開用 GCS のホストパス
GCS_HOST_PATH = "" # @param {type:"string"}

In [None]:
# Google Cloud AI Platform関連のインポート
from google.cloud import aiplatform
from google.cloud.aiplatform_v1beta1 import (
    FeatureOnlineStoreAdminServiceClient,
    FeatureOnlineStoreServiceClient
)
from google.cloud.aiplatform_v1beta1.types import (
    NearestNeighborQuery,
    feature_online_store as feature_online_store_pb2,
    feature_online_store_admin_service as feature_online_store_admin_service_pb2,
    feature_online_store_service as feature_online_store_service_pb2,
    feature_view as feature_view_pb2
)
from google.protobuf import struct_pb2

# Google Cloud BigQuery関連のインポート
from google.cloud import bigquery

# Google Cloud Storage関連のインポート
from google.cloud import storage

# その他のインポート
import bigframes.pandas as bpd
import random
import base64
import time
import typing



aiplatform.init(project=PROJECT_ID, location=REGION)

admin_client = FeatureOnlineStoreAdminServiceClient(
    client_options={"api_endpoint": API_ENDPOINT}
)

data_client = FeatureOnlineStoreServiceClient(
    client_options={"api_endpoint": PUBLIC_ENDPOINT}
)


# Set BigQuery DataFrames options
bpd.options.bigquery.project = PROJECT_ID
bpd.options.bigquery.location = "us"

def list_gcs_files(bucket_name, prefix):
    """Return a list of file names in the specified GCS bucket and prefix without the prefix."""
    storage_client = storage.Client()
    bucket = storage_client.bucket(bucket_name)
    blobs = bucket.list_blobs(prefix=prefix)
    # Remove the prefix from each file name
    return [blob.name.replace(prefix, '') for blob in blobs]

def list_files_from_metadata(metadata_paths):
    """Return a list of file names from the metadata txt files."""
    file_list = []
    storage_client = storage.Client()

    for metadata_path in metadata_paths:
        bucket_name, file_path = metadata_path.split("gs://")[1].split("/", 1)
        bucket = storage_client.bucket(bucket_name)
        blob = bucket.blob(file_path)
        content = blob.download_as_text()
        # Convert each line to its corresponding image path with .jpg extension
        file_list.extend([line + ".jpg" for line in content.splitlines()])

    return file_list

def upload_to_gcs(local_file, bucket_name, gcs_path):
    """Upload a local file to GCS."""
    storage_client = storage.Client()
    bucket = storage_client.bucket(bucket_name)
    blob = bucket.blob(gcs_path)
    blob.upload_from_filename(local_file)

In [None]:
import base64
import time
import typing

from google.cloud import aiplatform
from google.protobuf import struct_pb2


class EmbeddingResponse(typing.NamedTuple):
    text_embedding: typing.Sequence[float]
    image_embedding: typing.Sequence[float]


def load_image_bytes(image_uri: str) -> bytes:
    """Load image bytes from a remote or local URI."""
    image_bytes = None
    if image_uri.startswith("http://") or image_uri.startswith("https://"):
        response = requests.get(image_uri, stream=True)
        if response.status_code == 200:
            image_bytes = response.content
    else:
        image_bytes = open(image_uri, "rb").read()
    return image_bytes


class EmbeddingPredictionClient:
    """Wrapper around Prediction Service Client."""

    def __init__(
        self,
        project: str,
        location: str = "us-central1",
        api_regional_endpoint: str = "us-central1-aiplatform.googleapis.com",
    ):
        client_options = {"api_endpoint": api_regional_endpoint}
        # Initialize client that will be used to create and send requests.
        # This client only needs to be created once, and can be reused for multiple requests.
        self.client = aiplatform.gapic.PredictionServiceClient(
            client_options=client_options
        )
        self.location = location
        self.project = project

    def get_embedding(self, text: str = None, image_file: str = None):
        if not text and not image_file:
            raise ValueError("At least one of text or image_file must be specified.")

        # Load image file
        image_bytes = None
        if image_file:
            image_bytes = load_image_bytes(image_file)

        instance = struct_pb2.Struct()
        if text:
            instance.fields["text"].string_value = text

        if image_bytes:
            encoded_content = base64.b64encode(image_bytes).decode("utf-8")
            image_struct = instance.fields["image"].struct_value
            image_struct.fields["bytesBase64Encoded"].string_value = encoded_content

        instances = [instance]
        endpoint = (
            f"projects/{self.project}/locations/{self.location}"
            "/publishers/google/models/multimodalembedding@001"
        )
        response = self.client.predict(endpoint=endpoint, instances=instances)

        text_embedding = None
        if text:
            text_emb_value = response.predictions[0]["textEmbedding"]
            text_embedding = [v for v in text_emb_value]

        image_embedding = None
        if image_bytes:
            image_emb_value = response.predictions[0]["imageEmbedding"]
            image_embedding = [v for v in image_emb_value]

        return EmbeddingResponse(
            text_embedding=text_embedding, image_embedding=image_embedding
        )

In [None]:
import copy
from typing import List, Optional

import numpy as np
import requests
from tenacity import retry, stop_after_attempt

client = EmbeddingPredictionClient(project=PROJECT_ID)


# Use a retry handler in case of failure
@retry(reraise=True, stop=stop_after_attempt(3))
def encode_texts_to_embeddings_with_retry(text: List[str]) -> List[List[float]]:
    assert len(text) == 1

    try:
        return [client.get_embedding(text=text[0], image_file=None).text_embedding]
    except Exception:
        raise RuntimeError("Error getting embedding.")


def encode_texts_to_embeddings(text: List[str]) -> List[Optional[List[float]]]:
    try:
        return encode_texts_to_embeddings_with_retry(text=text)
    except Exception:
        return [None for _ in range(len(text))]


@retry(reraise=True, stop=stop_after_attempt(3))
def encode_images_to_embeddings_with_retry(image_uris: List[str]) -> List[List[float]]:
    assert len(image_uris) == 1

    try:
        return [
            client.get_embedding(text=None, image_file=image_uris[0]).image_embedding
        ]
    except Exception as ex:
        print(ex)
        raise RuntimeError("Error getting embedding.")


def encode_images_to_embeddings(image_uris: List[str]) -> List[Optional[List[float]]]:
    try:
        return encode_images_to_embeddings_with_retry(image_uris=image_uris)
    except Exception as ex:
        print(ex)
        return [None for _ in range(len(image_uris))]

In [None]:
import time
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Generator, List

from tqdm.auto import tqdm


def generate_batches(
    inputs: List[str], batch_size: int
) -> Generator[List[str], None, None]:
    """
    Generator function that takes a list of strings and a batch size, and yields batches of the specified size.
    """

    for i in range(0, len(inputs), batch_size):
        yield inputs[i : i + batch_size]


API_IMAGES_PER_SECOND = 2


def encode_to_embeddings_chunked(
    process_function: Callable[[List[str]], List[Optional[List[float]]]],
    items: List[str],
    batch_size: int = 1,
) -> List[Optional[List[float]]]:
    """
    Function that encodes a list of strings into embeddings using a process function.
    It takes a list of strings and returns a list of optional lists of floats.
    The data is processed in chunks to prevent out-of-memory errors.
    """

    embeddings_list: List[Optional[List[float]]] = []

    # Prepare the batches using a generator
    batches = generate_batches(items, batch_size)

    seconds_per_job = batch_size / API_IMAGES_PER_SECOND

    with ThreadPoolExecutor() as executor:
        futures = []
        for batch in tqdm(batches, total=len(items) // batch_size, position=0):
            futures.append(executor.submit(process_function, batch))
            time.sleep(seconds_per_job)

        for future in futures:
            embeddings_list.extend(future.result())
    return embeddings_list

In [None]:
import math
from io import BytesIO

import matplotlib.pyplot as plt
from PIL import Image

# 意味検索の実施
# 文章やキーワードをセット
text_query = "please teach me okinawa soul food"

# Calculate text embedding of query
text_embedding = encode_texts_to_embeddings(text=[text_query])[0]

result = data_client.search_nearest_entities(
    request=feature_online_store_service_pb2.SearchNearestEntitiesRequest(
        feature_view=f"projects/{PROJECT_ID}/locations/{REGION}/featureOnlineStores/{FEATURE_ONLINE_STORE_ID}/featureViews/{FEATURE_VIEW_ID}",
        query=NearestNeighborQuery(
            embedding=NearestNeighborQuery.Embedding(value=text_embedding),
            neighbor_count=20,
            # string_filters=[country_filter],
        ),
        return_full_entity=True,  # returning entities with metadata
    )
)

selected_paths = [neighbor.entity_id for neighbor in result.nearest_neighbors.neighbors]
distances = [neighbor.distance for neighbor in result.nearest_neighbors.neighbors]


# Set the maximum number of images to display
MAX_IMAGES = 20

# 重複を排除しながら selected_paths と distances をペアにする
unique_pairs = {}
for path, distance in zip(selected_paths, distances):
    if path not in unique_pairs:
        unique_pairs[path] = distance

# 重複が排除されたペアから、距離に基づいて並べ替え
sorted_data = sorted(
    unique_pairs.items(), key=lambda x: x[1], reverse=False
)[:MAX_IMAGES]

# Calculate the number of rows and columns needed to display the images
num_cols = 4
num_rows = math.ceil(len(sorted_data) / num_cols)


# Create a grid of subplots to display the images
fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, figsize=(10, 12))

# Loop through the top max_images images and display them in the subplots
for i, (image_path, distance) in enumerate(sorted_data):
    # Calculate the row and column index for the current image
    row_idx = i // num_cols
    col_idx = i % num_cols

    # Check if image_path is a remote URL
    if image_path.startswith("http://") or image_path.startswith("https://"):
        response = requests.get(image_path)
        image = Image.open(BytesIO(response.content))
    else:
        image_path = GCS_HOST_PATH + image_path.replace("/content/", "")
        response = requests.get(image_path)
        image = Image.open(BytesIO(response.content))

    # Display the image in the current subplot
    axs[row_idx, col_idx].imshow(image, cmap="gray")

    # Set the title of the subplot to the image index and score
    # axs[row_idx, col_idx].set_title(f"Rank {i+1}, Distance = {distance:.2f}, image_path = {image_path.split('/')[-2]}")
    axs[row_idx, col_idx].set_title(f"{image_path.split('/')[-2]}")

    # Remove ticks from the subplot
    axs[row_idx, col_idx].set_xticks([])
    axs[row_idx, col_idx].set_yticks([])

# Adjust the spacing between subplots and display the plot
plt.subplots_adjust(hspace=0.3, wspace=0.1)
plt.show()
