In [1]:
import pandas as pd
import intertrans.protos_pb2 as ptpb
from intertrans.utils import submit_request_cak
from intertrans.data import response_to_pandas, get_ca_metric

In [2]:
df = pd.read_json('../data/datasets/humanevalx_dataset_subset.jsonl', lines=True)

In [3]:
df_java_python = df[(df['source_lang'] == 'Java') & (df['target_lang'] == 'Python')].reset_index(drop=True)

In [None]:
df_java_python.head()

In [5]:
batch_request = ptpb.BatchTranslationRequest()

for index, row in df_java_python.iterrows():
    request = ptpb.TranslationRequest()

    request.id = str(index)
    request.seed_language = row['source_lang']
    request.target_language = row['target_lang']
    request.seed_code = row['input_code']
    request.model_name = "ise-uiuc/Magicoder-S-DS-6.7B"
    request.prompt_template_name = "prompt_humanevalx"
    request.regex_template_name = "temperature"

    request.used_languages.append("Go")
    request.used_languages.append("Java")
    request.used_languages.append("Python")
    request.used_languages.append("C++")
    request.used_languages.append("JavaScript")
    request.used_languages.append("Rust")


    #We attach the test cases to the request
    unittest = ptpb.UnitTestCase()
    unittest.language = row['target_lang']
    unittest.test_case = row['test_code']
    request.test_suite.unit_test_suite.append(unittest)

    # Add signature
    signature = ptpb.TargetSignature()
    signature.language = row['target_lang']
    signature.signature = row['target_signature']
    request.target_signatures.append(signature)

    batch_request.translation_requests.append(request)

In [6]:
response = submit_request_cak(batch_request, "localhost:50051")

In [7]:
df_response = response_to_pandas(response)

In [15]:
total_requests = df_response.groupby('request_id')['status'].any().sum().item()
total_translations_found = df_response[(df_response['status'] == 'TRANSLATION_FOUND')]
total_found_at_least_one_translation = total_translations_found.groupby('request_id')['status'].any().sum().item()

In [16]:
ca_metric = total_found_at_least_one_translation / total_requests * 100

In [None]:
print(f"CA@10: {ca_metric}%")

In [None]:
print(f"CA@10: {get_ca_metric(df_response, 10)}%")