## Prerequisites

1. Stacks must have been successfully deployed
2. The Translation Memory Table must have been created as described in the Readme.md

## Setup
Install required python libraries for the workshop.


In [None]:
# Install all the required prerequiste libraries - approx 3 min to complete
%pip install -r requirements.txt
%pip install -r bedrock_requirements.txt

## Generate Embeddings for Sample Data

### Load sample data

In [None]:
import boto3
import json

bedrock = boto3.client(service_name="bedrock")
bedrock_runtime = boto3.client(service_name="bedrock-runtime")

In [None]:
import pandas as pd

# Load the data of csv
df = pd.read_csv('data/wmt19_fr-de.csv')
print("Total number of records : {}".format(len(df.index)))

display(df.head(2))

### Generate Embeddings

In [None]:
def generate_embeddings(query):
    
    payLoad = json.dumps({'inputText': query })
    
    response = bedrock_runtime.invoke_model(
        body=payLoad, 
        modelId='amazon.titan-embed-text-v2:0',
        accept="application/json", 
        contentType="application/json" )
    response_body = json.loads(response.get("body").read())
    return(response_body.get("embedding"))
    
source_embeddings = generate_embeddings(df.iloc[1].get('source'))

print ("Number of dimensions : {}".format(len(source_embeddings)))

In [None]:
# Generate embeddings for all the products descriptions - approx 3 min to complete
# If there are any failures, please rerun the cell again.

from pandarallel import pandarallel

pandarallel.initialize(progress_bar=True, nb_workers=8)

df_20 = df.head(20)
df_20['target_embeddings'] = df_20['target'].apply(generate_embeddings)
df_20['source_embeddings'] = df_20['source'].apply(generate_embeddings)

### Load Embeddings into Translation Memory Table

In [None]:
import boto3 
import json 

secret_arn= "arn:aws:secretsmanager:us-east-2:986528949439:secret:AuroraCredentials20EDD625-lm2rx5dUXdeO-OOCrM5"
cluster_arn = "arn:aws:rds:us-east-2:986528949439:cluster:databasestack-translationmemoryauroraclusterfd1dc6-axcfoistdywf"
database_name = 'MTEngineTranslationMemoryDb'

def call_rds_data_api(source_lang, target_lang, source_text, target_text, source_text_embedding, target_text_embedding):
    rds_data = boto3.client('rds-data')

    sql = """
          INSERT INTO translation_memory(source_text, target_text, source_lang, target_lang, source_text_embedding, target_text_embedding)
          VALUES( :source_text, :target_text, :source_lang, :target_lang, CAST(:source_text_embedding AS VECTOR), CAST(:target_text_embedding AS VECTOR))
          """

    #param1 = {'name':'unique_id', 'value':{'longValue': unique_id}}
    param2 = {'name':'source_text', 'value':{'stringValue': source_text}}
    param3 = {'name':'target_text', 'value':{'stringValue': target_text}}
    param4 = {'name':'source_lang', 'value':{'stringValue': source_lang}}
    param5 = {'name':'target_lang', 'value':{'stringValue': target_lang}}
    param6 = {'name':'source_text_embedding', 'value':{'stringValue': source_text_embedding}}
    param7 = {'name':'target_text_embedding', 'value':{'stringValue': target_text_embedding}}
    param_set = [param2, param3, param4, param5, param6, param7]
 
    response = rds_data.execute_statement(
        resourceArn = cluster_arn, 
        secretArn = secret_arn, 
        database = database_name, 
        sql = sql,
        parameters = param_set)

for  index, record in df_20.iterrows():
    call_rds_data_api("fr", "de", record['source'], record['target'], str(record['source_embeddings']), str(record['target_embeddings']))

## Test translation memory table vector search


In [None]:
import numpy
from IPython.display import display, Markdown, Latex, HTML


def similarity_search(search_text):
    
    embedding = numpy.array(generate_embeddings(search_text))
    rds_data = boto3.client('rds-data')
    embedding_str = str(embedding.tolist())
    sql_text = f"SELECT unique_id, source_text, target_text FROM translation_memory ORDER BY source_text_embedding <=> CAST('{embedding_str}' AS VECTOR) limit 3;"
    
    #print(sql_text)

    param1 = {'name':'embedding', 'value':{'stringValue': str(embedding.tolist())}}
    
    #print(param1)
    response = rds_data.execute_statement(
        resourceArn = cluster_arn, 
        secretArn = secret_arn, 
        database = database_name, 
        sql = sql_text
    )

    print(response)
    #img_td = ""
    #for x in r:
    #    url = x[1].split("|")[0]
    #    img_td = img_td + """<tr><td><img src={} width="1000"></td>""".format(url)
    #    img_td = img_td + """<td style="text-align: left; vertical-align: top;"> <h3>{}</h3> <p>{}</p></td></tr>""".format(str(x[2]),str(x[4]))
       
    #display(HTML("""<table>{}</table>""".format(img_td)))
    #dbconn.close()

similarity_search("Reprise de la session")


{'ResponseMetadata': {'RequestId': '02293b1a-f53c-464a-80c2-2d86049a3359', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': '02293b1a-f53c-464a-80c2-2d86049a3359', 'date': 'Thu, 03 Apr 2025 19:30:05 GMT', 'content-type': 'application/json', 'content-length': '854', 'connection': 'keep-alive'}, 'RetryAttempts': 0}, 'records': [[{'longValue': 1}, {'stringValue': 'Reprise de la session'}, {'stringValue': 'Wiederaufnahme der Sitzungsperiode'}], [{'longValue': 2}, {'stringValue': 'Je déclare reprise la session du Parlement européen qui avait été interrompue le vendredi 17 décembre dernier et je vous renouvelle tous mes vux en espérant que vous avez passé de bonnes vacances.'}, {'stringValue': 'Ich erkläre die am Freitag, dem 17. Dezember unterbrochene Sitzungsperiode des Europäischen Parlaments für wiederaufgenommen, wünsche Ihnen nochmals alles Gute zum Jahreswechsel und hoffe, daß Sie schöne Ferien hatten.'}], [{'longValue': 8}, {'stringValue': "Madame la Présidente, c'est une m