In [23]:
from tqdm import tqdm
import typing, json
from absl import app
from google.cloud import aiplatform
import base64
from google.cloud import storage
from google.protobuf import struct_pb2
from google.oauth2 import service_account

In [2]:
class EmbeddingResponse(typing.NamedTuple):
  text_embedding: typing.Sequence[float]
  image_embedding: typing.Sequence[float]

In [201]:
class EmbeddingPredictionClient:
    """Wrapper around Prediction Service Client."""
    def __init__(self, project : str, location : str = "asia-northeast3"):
        api_regional_endpoint = f"{location}-aiplatform.googleapis.com"
        client_options = {"api_endpoint": api_regional_endpoint}
        cred = service_account.Credentials.from_service_account_file(
            './ai-solution-genai-hack2skill-ee1fa6f57032.json'
        )
        
        # Initialize client that will be used to create and send requests.
        # This client only needs to be created once, and can be reused for multiple requests.
        self.client = aiplatform.gapic.PredictionServiceClient(credentials=cred, client_options=client_options)
        self.location = location
        self.project = project


    def get_embedding(self, text : str = None, image_bytes : bytes = None):
        if not text and not image_bytes:
            raise ValueError('At least one of text or image_bytes must be specified.')

        instance = struct_pb2.Struct()
        if text:
            instance.fields['text'].string_value = text

        if image_bytes:
            encoded_content = base64.b64encode(image_bytes).decode("utf-8")
            image_struct = instance.fields['image'].struct_value
            image_struct.fields['bytesBase64Encoded'].string_value = encoded_content

        instances = [instance]
        endpoint = (f"projects/{self.project}/locations/{self.location}"
        "/publishers/google/models/multimodalembedding@001")
        response = self.client.predict(endpoint=endpoint, instances=instances)

        text_embedding = None
        if text:    
            text_emb_value = response.predictions[0]['textEmbedding']
            text_embedding = [v for v in text_emb_value]

        image_embedding = None
        if image_bytes:    
            image_emb_value = response.predictions[0]['imageEmbedding']
            image_embedding = [v for v in image_emb_value]

        return EmbeddingResponse(
            text_embedding=text_embedding,
            image_embedding=image_embedding
        )


def limit_dict_values(data, limit=500):
  new_dict = {}
  for key, value in data.items():
    if isinstance(value, str):
      # 문자열인 경우 길이 제한
      new_dict[key] = value[:limit]
    else:
      # 문자열이 아닌 경우 그대로 유지
      new_dict[key] = value
  return new_dict


def dict2text(dict_data):
    text = ""
    for key, value in dict_data.items():
        text += f"{key} is {value}. "
    return text.rstrip(".")


def rename_keys(d):
  new_dict = {}
  for key, value in d.items():
    new_key = key.split("_")[0]
    new_dict[new_key] = value
  return new_dict

In [216]:
project = "ai-solution-genai-hack2skill"
location = "asia-northeast3"
bucket_name = "hack2skill_dataset"
prefix = "ikea/images"
client = EmbeddingPredictionClient(project = project, location = location)
cred = service_account.Credentials.from_service_account_file('./ai-solution-genai-hack2skill-ee1fa6f57032.json')
storage_client = storage.Client(credentials=cred)
bucket = storage_client.get_bucket(bucket_name) 
files = bucket.list_blobs(prefix=prefix) 

blob_name = "ikea/ikea_results.json"
blob = storage_client.get_bucket(bucket_name).blob(blob_name)
json_string = blob.download_as_string().decode('utf-8')
ikea_meta = json.loads(json_string)
ikea_meta = rename_keys(ikea_meta)

In [None]:
for idx, file in tqdm(enumerate(files)):
    
    if "image" in file.content_type:
        # item_key = file.name.split("/")[-1].split(".jpg")[0]
        item_key = file.name.split("/")[-1].split("_")[0]
        text_file_contents = ikea_meta[item_key]
        text_file_contents.pop("link")
        text_file_contents = limit_dict_values(text_file_contents)
        text_file_contents = dict2text(text_file_contents)
    
        with file.open('rb') as image_file:
            image_file_contents =image_file.read()
        response = client.get_embedding(
            text = text_file_contents,
            image_bytes=image_file_contents,
        )
        encoded_name = file.name.encode(encoding = 'UTF-8', errors = 'strict')
        # output_dict = dict(id=str(encoded_name), embedding=response)

        with open("ikea_indexData.json", "a") as f:
            f.write('{"id":"' + str(encoded_name) + '",')
            f.write('"embedding":[[' + ",".join(str(x) for x in response[0]) + "]"\
                    + ",[" + ",".join(str(x) for x in response[1]) + "]]" + "}")
            f.write("\n")

131it [01:24,  1.55it/s]