In [None]:
DELTA_1 = 0.3
DELTA_2 = 0.1
ALPHA = 0.0035
NUM_ENT_M2 = 0
LOGGING_FOLDER = "/content/drive/MyDrive/Smruti-GEC-for-Gujarati/"
# LOGGING_FOLDER = "/content/"
M1_COLLECTION = "gold_sentences"
M2_COLLECTION = "history_const"

# Installations

In [None]:
! pip install --quiet --upgrade langchain langchain-community langchain-openai pymilvus sentence-transformers openai colorama termcolor pytz

# Imports

In [None]:
from langchain_openai import ChatOpenAI
from langchain.prompts import PromptTemplate
from google.colab import userdata
import numpy as np
import json
import pytz
from langchain.chains import SequentialChain, LLMChain
from colorama import init, Fore, Style
from openai import OpenAI
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
from sentence_transformers import SentenceTransformer
from langchain.globals import set_debug, set_verbose

# Constant variables and Configurations

In [None]:
connections.connect(alias="default", uri=ZILLIZ_HOST, token=ZILLIZ_TOKEN)
set_debug(False)
set_verbose(True)
init(autoreset=True)

# LLMs

In [None]:
llms = dict({})
llms["gpt-4o"] = ChatOpenAI(
    openai_api_key= userdata.get('OPENAI_API_KEY'),
    model_name="gpt-4o",
    temperature=0.1
)
llms["gpt-3.5-turbo"] = ChatOpenAI(
    openai_api_key= userdata.get('OPENAI_API_KEY'),
    model_name="gpt-3.5-turbo",
    temperature=0.1
)
llms["gpt-4o-mini"] = ChatOpenAI(
    openai_api_key= userdata.get('OPENAI_API_KEY'),
    model_name="gpt-4o-mini",
    temperature=0.1
)

# Embedding models

In [None]:
 em_model = SentenceTransformer("jinaai/jina-embeddings-v3", trust_remote_code=True)

# Prompt templates

In [None]:
templates = dict({})
templates["t1"] = '''
Task: Correct spelling and grammatical errors in the given Gujarati sentence.

Instructions:
    Only fix errors—do not modify correct sentences or make unnecessary changes.
    Be confident in corrections. If unsure, leave the sentence unchanged.
    Output only the corrected sentence, no explanations or extra text.
    Use reference data (if provided) to guide corrections while adhering to standard Gujarati rules.

Input Sentence:
{sentence_to_correct}

Reference Data (if available):

    {data_from_history}

    {data_from_gold_corpus}
'''

templates["t2"] = '''
Task: Correct spelling and grammatical errors in the given Gujarati sentence.

Instructions:
1. Make changes only to fix spelling or grammatical errors.
2. Do not make any changes unless you are confident about the correction.
3. If input doesn't contain any Gujarati text(It's fine to have other language text), return: INVALID_INPUT.
4. Output only the corrected sentence or the error message—no explanations or additional text.

The following is provided only to help understand the structure of the Gujarati language:

- {data_from_history}

- {data_from_gold_corpus}

Input Sentence:
{sentence_to_correct}

'''

templates["t_zero_shot"] = '''
# Task: Correct grammatical errors in the given Gujarati sentence by following standard Gujarati rules.

# Instructions:
* Only fix errors, do not modify correct sentences or make unnecessary changes.
* Be confident in corrections. If unsure, leave the sentence unchanged.
* Output only the corrected sentence, no explanations or extra text.

# Input Sentence:
{sentence_to_correct}
'''

templates["vanilla_m1"] = '''
# Task: Correct grammatical errors in the given Gujarati sentence by following standard Gujarati rules.

# Instructions:
* Only fix errors, do not modify correct sentences or make unnecessary changes.
* Be confident in corrections. If unsure, leave the sentence unchanged.
* Output only the corrected sentence, no explanations or extra text.
* Grammaticaly correct sentences are given only to understand the grammar and sentence structure.

# Grammatically correct Gujarati sentences:
{data_from_gold_corpus}

# Input Sentence:
{sentence_to_correct}
'''

templates["ltm_1_L=4"] = '''
Task: Correct only syntactic errors in the sentence, focusing strictly on word order and grammatical structure based on standard Gujarati syntax. Do **not** rephrase, paraphrase, or enhance the sentence in any way.

Instructions:
  1. Make corrections **only if there are clear syntactic errors** (e.g., misordered subject, object, verb, adjectives, or postpositions).
  2. If the sentence is already syntactically correct, leave it **completely unchanged**.
  3. If the input is not in Gujarati, return only: INVALID_INPUT.
  4. Output only the corrected sentence or 'INVALID_INPUT'. Do not include explanations, notes, or extra text.
  5. Use reference data (if provided) **only as a guide**, and follow standard Gujarati syntax—not stylistic preferences.

Input Sentence:
{sentence_to_correct}

Reference Data (if available):

{data_from_history}

{data_from_gold_corpus}
'''

templates["ltm_2_L=4"] = '''
Task: Correct only morphological errors in the sentence, ensuring proper use of gender (લિંગ), tense (કાલ), number (વચન), and person (પુરુષ). Do **not** modify other sentence aspects.

Instructions:
  1. If the input is 'INVALID_INPUT', return 'INVALID_INPUT' and do nothing else.
  2. Correct morphology **only when there's a clear mistake** in word inflection or agreement.
  3. Do **not** fix or adjust syntax, punctuation, or make improvements unless strictly morphological.
  4. If you are unsure about a correction or if the input is already correct, leave it **unchanged**.
  5. Output only the corrected sentence or 'INVALID_INPUT'. Do not include any explanations or modifications outside the scope.
  6. Use reference data (if available) to guide decisions, but follow strict Gujarati morphological rules.

Input Sentence:
{sentence_to_correct_1}

Reference Data (if available):

{data_from_history}

{data_from_gold_corpus}
'''

templates["ltm_3_L=4"] = '''
Task: Correct only spelling errors in the sentence, specifically focusing on issues involving hrasva and dirgha (short and long vowels), anusvara (nasalization), and sandhi (euphonic combination). Avoid making any stylistic or grammatical changes.

Instructions:
  1. If the input is 'INVALID_INPUT', return 'INVALID_INPUT' and do nothing else.
  2. Fix spelling **only if clearly incorrect**. Do **not** improve or adjust valid spellings or make stylistic edits.
  3. Leave already correct or ambiguous spellings **unchanged**.
  4. Do **not** fix syntax, punctuation, or morphology.
  5. Output only the corrected sentence or 'INVALID_INPUT'. Do not include explanations, formatting, or extra text.
  6. Use reference data (if available) only as support—not as the sole basis for spelling decisions.

Input Sentence:
{sentence_to_correct_2}

Reference Data (if available):

{data_from_history}

{data_from_gold_corpus}
'''

templates["ltm_4_L=4"] = '''
Task: Correct only punctuation errors in the sentence by appropriately adding, removing, or fixing punctuation marks such as periods (.), commas (,), question marks (?), exclamation marks (!), hyphens (–), colons (:), semicolons (;), ellipsis (…), quotation marks (" " or ' '), and apostrophes (').

Instructions:
  1. If the input is 'INVALID_INPUT', return 'INVALID_INPUT' and do nothing else.
  2. Only correct punctuation **if clearly incorrect**. Do **not** change word order, spelling, or grammar.
  3. Do not introduce stylistic or expressive punctuation unless required by grammar.
  4. If unsure or the punctuation is already correct, leave the sentence **as is**.
  5. Output only the corrected sentence or 'INVALID_INPUT'. Do not explain your choices or add formatting.
  6. Reference data may help, but final decisions must follow **standard Gujarati punctuation rules**.

Input Sentence:
{sentence_to_correct_3}

Reference Data (if available):

{data_from_history}

{data_from_gold_corpus}
'''

templates["ltm_1_L=2"] = '''
Task: Correct the grammatical errors in the given sentence according to standard Gujarati grammar rules.

Instructions:
  1. Only fix grammatical errors—do not modify correct sentences or make unnecessary changes.
  2. Be confident in corrections. If unsure, leave the sentence unchanged.
  3. If the input doesn't contain Gujarati text, then and only then return: INVALID_INPUT.
  4. Output only the corrected sentence or INVALID_INPUT—no explanations or extra text.
  5. Use reference data (if provided) to guide corrections while adhering to standard Gujarati rules.

  Input Sentence:
  {sentence_to_correct}

  Reference Data (if available):

  {data_from_history}

  {data_from_gold_corpus}
'''

templates["ltm_2_L=2"] = '''
Task: Correct the spelling errors in the given sentence by properly applying hrasva and dirgha (short and long vowels), anusvara (nasal sound), and sandhi (word joining rules) as per standard Gujarati orthography.

Instructions:
  1. If the input is 'INVALID_INPUT', just give 'INVALID_INPUT' as output.
  2. Only fix spelling errors—do not modify correct sentences or make unnecessary changes.
  3. Be confident in corrections. If unsure or syntax is already correct, leave the sentence unchanged.
  4. Output only the corrected sentence or INVALID_INPUT—no explanations or extra text.
  5. Use reference data (if provided) to guide corrections while adhering to standard Gujarati rules.

  Input Sentence:
  {sentence_to_correct_1}

  Reference Data (if available):

  {data_from_history}

  {data_from_gold_corpus}
'''

templates["dac_1_L=4"] = '''
Task: Identify and correct only the syntactic errors in the sentence, such as incorrect word order, misplaced subject/object/verb, or improper use of postpositions — but only if correction is absolutely necessary.

Instructions:
  - Do not modify morphology, spelling, or punctuation.
  - If the sentence is already syntactically correct, leave it unchanged.
  - Return INVALID_INPUT **only if the input does not contain any Gujarati script (i.e., characters in the Unicode Gujarati block: U+0A80 to U+0AFF)**.
  - Output only the corrected sentence or the word INVALID_INPUT.

Input Sentence:
{input_sentence}
'''

templates["dac_2_L=4"] = '''
Task: Identify and fix only the morphological errors related to gender (લિંગ), number (વચન), tense (કાલ), and person (પુરુષ) according to standard Gujarati morphology.

Instructions:
  - Do not modify syntax, spelling, or punctuation.
  - Leave the sentence unchanged if there are no morphological errors.
  - Return INVALID_INPUT **only if the input does not contain any Gujarati script (Unicode range U+0A80 to U+0AFF)**.
  - Output only the corrected sentence or the word INVALID_INPUT.

Input Sentence:
{input_sentence}
'''

templates["dac_3_L=4"] = '''
Task: Correct only spelling errors in the sentence — such as incorrect use of hrasva-dirgha (short/long vowels), anusvara, chandrabindu, or sandhi — following Gujarati orthographic rules.

Instructions:
  - Do not modify syntax, morphology, or punctuation.
  - Leave the sentence unchanged if there are no spelling errors.
  - Return INVALID_INPUT **only if the input contains no Gujarati script (Unicode U+0A80–U+0AFF)**.
  - Output only the corrected sentence or the word INVALID_INPUT.

Input Sentence:
{input_sentence}
'''

templates["dac_4_L=4"] = '''
Task: Identify and fix punctuation errors such as missing or incorrect periods, commas, question marks, exclamations, semicolons, quotation marks, etc., according to Gujarati punctuation norms.

Instructions:
  - Do not fix syntax, morphology, or spelling.
  - Leave the sentence unchanged if punctuation is already correct.
  - Return INVALID_INPUT **only if the input does not contain any Gujarati characters (Unicode range U+0A80 to U+0AFF)**.
  - Output only the corrected sentence or the word INVALID_INPUT.

Input Sentence:
{input_sentence}
'''

templates["dac_5_L=4"] = '''
You are given the corrected versions of a sentence from four linguistic sub-tasks: syntax, morphology, spelling, and punctuation.

Your task is to generate a single final corrected sentence by combining these four outputs.

Instructions:
- If any one of the inputs is exactly 'INVALID_INPUT', your output must be 'INVALID_INPUT'. Do not generate or modify any sentence.
- Otherwise, combine the four corrected sentences into one final sentence that preserves the intended meaning.
- In case of conflicting corrections, prioritize the corresponding sub-task as follows:
  * Syntax corrections take priority for sentence structure.
  * Morphology corrections take priority for word forms (e.g., tense, gender, number).
  * Spelling corrections take priority for fixing misspelled words.
  * Punctuation corrections take priority for punctuation marks.
- Make sure the final output reflects the cumulative effect of all valid corrections.

Constraints:
- Your output must be a **single sentence**.
- Output only the final corrected sentence or the word **INVALID_INPUT** — nothing else.

Inputs:
Syntax Output: {syntax_output}
Morphology Output: {morph_output}
Spelling Output: {spelling_output}
Punctuation Output: {punct_output}

'''

templates["cot"] = '''
# Task: Correct grammatical errors in the given Gujarati sentence by following standard Gujarati rules.

# Instructions:

* Only fix errors, do not modify correct sentences or make unnecessary changes.
* Be confident in corrections. If unsure, leave the sentence unchanged.
* Output only the corrected sentence, no explanations or extra text.
* Given example is for analysis, don't just mimic it.

# Example:

incorrect sentence: પહેલો વરસ્યો વરસાદ કે રાફડામાંથી પાંખાવાળો મકોડા આકાશે ઊડ્યા આખો દિવસ ઊડ્યા, એકાદ રાત પણ ઉડ્યા; બીજે દિવસે તેનો પાંખો જ્યાં ત્યાં રખડતી આવી જોવામાં?

Let's think step-by-step.

1. 'વરસાદ' is object and should be preceeded by 'વરસ્યો'(a verb).
2. 'પાંખાવાળો' should be replaced by 'પાંખવાળા' as મકોડા is plural of 'મકોડો'.
3. There should be a semi-colon(;) after 'આકાશે ઊડ્યા', because the first clause ends here and both the clauses are connected without a connector.
4. There will be a dirgha 'ઊ' in ઉડ્યા.
5. 'પાંખો' is plural and feminine, hence 'તેનો' will be replaced by 'તેની'.
6. 'આવી'(verb) should be preceeded by 'જોવામાં', which is a participle.
7. The overall sentence is Affirmative, so the question mark(?) will be removed and a period(.) should be added.

corrected sentence: પહેલો વરસાદ વરસ્યો કે રાફડામાંથી પાંખવાળા મકોડા આકાશે ઊડ્યા; આખો દિવસ ઊડ્યા, એકાદ રાત પણ ઊડ્યા; બીજા દિવસે તેની પાંખો જ્યાં ત્યાં રખડતી જોવામાં આવી.

# Input Sentence:
{sentence_to_correct}
'''

templates["cot_m1"] = '''
# Task: Correct grammatical errors in the given Gujarati sentence by following standard Gujarati rules.

# Instructions:

* Only fix errors, do not modify correct sentences or make unnecessary changes.
* Be confident in corrections. If unsure, leave the sentence unchanged.
* Output only the corrected sentence, no explanations or extra text.
* Grammaticaly correct sentences are given only to understand the grammar and sentence structure.


# Example:

incorrect sentence: પહેલો વરસ્યો વરસાદ કે રાફડામાંથી પાંખાવાળો મકોડા આકાશે ઊડ્યા આખો દિવસ ઊડ્યા, એકાદ રાત પણ ઉડ્યા; બીજે દિવસે તેનો પાંખો જ્યાં ત્યાં રખડતી આવી જોવામાં?

Let's think step-by-step.

1. 'વરસાદ' is object and should be preceeded by 'વરસ્યો'(a verb).
2. 'પાંખાવાળો' should be replaced by 'પાંખવાળા' as મકોડા is plural of 'મકોડો'.
3. There should be a semi-colon(;) after 'આકાશે ઊડ્યા', because the first clause ends here and both the clauses are connected without a connector.
4. There will be a dirgha 'ઊ' in ઉડ્યા.
5. 'પાંખો' is plural and feminine, hence 'તેનો' will be replaced by 'તેની'.
6. 'આવી'(verb) should be preceeded by 'જોવામાં', which is a participle.
7. The overall sentence is Affirmative, so the question mark(?) will be removed and a period(.) should be added.

corrected sentence: પહેલો વરસાદ વરસ્યો કે રાફડામાંથી પાંખવાળા મકોડા આકાશે ઊડ્યા; આખો દિવસ ઊડ્યા, એકાદ રાત પણ ઊડ્યા; બીજા દિવસે તેની પાંખો જ્યાં ત્યાં રખડતી જોવામાં આવી.

# Grammatically correct Gujarati sentences:
{data_from_gold_corpus}

#Input sentence:
{sentence_to_correct}
'''

# # Task: Correct grammatical errors in the given Gujarati sentence by following standard Gujarati rules.

# # Instructions:

# * Only fix errors, do not modify correct sentences or make unnecessary changes.
# * Be confident in corrections. If unsure, leave the error unchanged.
# * Output only the corrected sentence, no explanations or extra text.
# * Correction examples are your previous corrections, might not be accurate as correct sentences.

# # Example:

# incorrect sentence: પહેલો વરસ્યો વરસાદ કે રાફડામાંથી પાંખાવાળો મકોડા આકાશે ઊડ્યા આખો દિવસ ઊડ્યા, એકાદ રાત પણ ઉડ્યા; બીજે દિવસે તેનો પાંખો જ્યાં ત્યાં રખડતી આવી જોવામાં?

# Let's think step-by-step.

# 1. Correction examples can be used for understanding the task properly.
# 2. 'વરસાદ' is object and should be preceeded by 'વરસ્યો'(a verb).
# 3. 'પાંખાવાળો' should be replaced by 'પાંખવાળા' as મકોડા is plural of 'મકોડો'.
# 4. There should be a semi-colon(;) after 'આકાશે ઊડ્યા', because the first clause ends here and both the clauses are connected without a connector.
# 5. There will be a dirgha 'ઊ' in ઉડ્યા.
# 7. 'પાંખો' is plural and feminine, hence 'તેનો' will be replaced by 'તેની'.
# 8. 'આવી'(verb) should be preceeded by 'જોવામાં', which is a participle.
# 9. The overall sentence is Affirmative, so the question mark(?) will be removed and a period(.) should be added.
# 10. Now, grammatically correct sentences can be used as a reference for correcting the errors, not covered in the above steps.


templates["cot_m1&m2"] = '''
# Task: Correct grammatical errors in the given Gujarati sentence by following standard Gujarati rules.

# Instructions:

* Only fix errors, do not modify correct sentences or make unnecessary changes.
* Be confident in corrections. If unsure, leave the sentence unchanged.
* Output only the corrected sentence, no explanations or extra text.
* Grammaticaly correct sentences are given only to understand the grammar and sentence structure; might not be accurate as Correct sentences

# Example:

incorrect sentence: પહેલો વરસ્યો વરસાદ કે રાફડામાંથી પાંખાવાળો મકોડા આકાશે ઊડ્યા આખો દિવસ ઊડ્યા, એકાદ રાત પણ ઉડ્યા; બીજે દિવસે તેનો પાંખો જ્યાં ત્યાં રખડતી આવી જોવામાં?

Let's think step-by-step.

1. 'વરસાદ' is object and should be preceeded by 'વરસ્યો'(a verb).
2. 'પાંખાવાળો' should be replaced by 'પાંખવાળા' as મકોડા is plural of 'મકોડો'.
3. There should be a semi-colon(;) after 'આકાશે ઊડ્યા', because the first clause ends here and both the clauses are connected without a connector.
4. There will be a dirgha 'ઊ' in ઉડ્યા.
5. 'પાંખો' is plural and feminine, hence 'તેનો' will be replaced by 'તેની'.
6. 'આવી'(verb) should be preceeded by 'જોવામાં', which is a participle.
7. The overall sentence is Affirmative, so the question mark(?) will be removed and a period(.) should be added.

corrected sentence: પહેલો વરસાદ વરસ્યો કે રાફડામાંથી પાંખવાળા મકોડા આકાશે ઊડ્યા; આખો દિવસ ઊડ્યા, એકાદ રાત પણ ઊડ્યા; બીજા દિવસે તેની પાંખો જ્યાં ત્યાં રખડતી જોવામાં આવી.

# Some examples for analysis:
{data_from_history}

# Grammatically correct Gujarati sentences:
{data_from_gold_corpus}

# Input sentence:
{sentence_to_correct}
'''


templates["cot_m2"] = '''
Task: Correct the spelling and grammatical errors in the given Gujarati sentence.

Instructions:
1. Only fix errors—do not modify correct sentences or make unnecessary changes.
2. Be confident in corrections. If unsure, leave the sentence unchanged.
3. If the input doesn't contain Gujarati text then and only then return: INVALID_INPUT.
4. Output only the corrected sentence -no explanations or extra text.

Example: પહેલો વરસ્યો વરસાદ કે રાફડામાંથી પાંખાવાળો મકોડા આકાશે ઊડ્યા આખો દિવસ ઊડ્યા, એકાદ રાત પણ ઉડ્યા; બીજે દિવસે તેનો પાંખો જ્યાં ત્યાં રખડતી આવી જોવામાં?

Let's think step-by-step.

1. 'વરસાદ' is object and should be preceeded by 'વરસ્યો'(verb).
2. 'પાંખાવાળો' should be replaced by 'પાંખવાળા' as મકોડા is plural of 'મકોડો'.
3. There should be a semi-colon(;) after 'આકાશે ઊડ્યા', because the first clause ends and both the clauses are not connected with a connector.
4. There will be a dirgha 'ઊ' in ઉડ્યા.
5. 'પાંખો' is plural and feminine, hence 'તેનો' will be replaced by 'તેની'.
6. 'આવી'(verb) should be preceeded by 'જોવામાં' which is a verb used as adjective(called krudant in gujarati).
7. The overall sentence is Affirmative Sentence, so the question mark(?) will be removed and a period should be added.

corrected sentence: પહેલો વરસાદ વરસ્યો કે રાફડામાંથી પાંખવાળા મકોડા આકાશે ઊડ્યા; આખો દિવસ ઊડ્યા, એકાદ રાત પણ ઊડ્યા; બીજા દિવસે તેની પાંખો જ્યાં ત્યાં રખડતી જોવામાં આવી.

Some examples for analysis:
{data_from_history}

Input Sentence:
{sentence_to_correct}
'''

# Read from $M_1$ and $M_2$

In [None]:
collections = dict({})

In [None]:
def read_from_M1(query_sentence, k1=5):
    """
    Retrieve top-k1 similar sentences with cosine distance from the specified collection.

    Arguments:
        query_sentence: Sentence to be used as the search query.
        k1: Number of top similar sentences to retrieve.

    Returns:
        A list of dictionaries, each containing a matched sentence and its cosine distance.
    """
    try:
        if not utility.has_collection(M1_COLLECTION):
            raise ValueError(f"Collection '{M1_COLLECTION}' does not exist.")

        if not collections:
            collection = Collection(M1_COLLECTION)
            collection.load()
            collections[M1_COLLECTION] = collection
        else:
            collection = collections[M1_COLLECTION]

    except Exception as e:
        raise RuntimeError(f"Failed to access or load collection '{M1_COLLECTION}': {e}")

    query_embedding = em_model.encode([query_sentence]).tolist()
    search_params = {"metric_type": "COSINE", "params": {"nprobe": 50}}

    results = collection.search(
        data=query_embedding,
        anns_field="embedding",
        param=search_params,
        limit=k1,
        output_fields=["sentence"]
    )

    return [
        {"sentence": hit.entity.get("sentence"), "cosine_distance": 1 - hit.distance}
        for hit in results[0]
    ]

In [None]:
def read_from_M2(query_sentence, k2=5):
    """
    Retrieve top-k2 similar sentences with similarity scores from the specified collection.

    Arguments:
        query_sentence: Sentence to be used as the search query.
        k2: Number of top similar sentences to retrieve.

    Returns:
        A list of dictionaries, each containing the incorrect sentence, corrected sentence, and cosine distance.
    """
    try:
        if not utility.has_collection(M2_COLLECTION):
            raise ValueError(f"Collection '{M2_COLLECTION}' does not exist.")

        collection = Collection(M2_COLLECTION)
        collection.load()

    except Exception as e:
        raise RuntimeError(f"Failed to access or load collection '{M2_COLLECTION}': {e}")

    query_embedding = em_model.encode([query_sentence]).tolist()
    search_params = {"metric_type": "COSINE", "params": {"nprobe": 50}}

    results = collection.search(
        data=query_embedding,
        anns_field="embedding",
        param=search_params,
        limit=k2,
        output_fields=["incorrect_sentence", "corrected_sentence"]
    )

    # collection.flush()

    return [
        {
            "incorrect_sentence": hit.entity.get("incorrect_sentence"),
            "corrected_sentence": hit.entity.get("corrected_sentence"),
            "cosine_distance": 1 - hit.distance
        }
        for hit in results[0]
    ]

In [None]:
def avg_distance(results):
    """
    Compute the average cosine distance from a list of search results.
    """
    if not results:
        return 1
    return sum(x['cosine_distance'] for x in results) / len(results)

def min_distance(results):
    """
    Compute the minimum cosine distance from a list of search results.
    """
    if not results:
        return 1
    return min(x['cosine_distance'] for x in results)

# Write to $M_2$

In [None]:
def write_to_M2(incorrect_sentence, corrected_sentence):
    """
    Insert an incorrect sentence and its corrected version into the specified collection.
    """
    try:
        if not collections.get(M2_COLLECTION):
          col = Collection(M2_COLLECTION)
          col.load()

        embedding = np.array(em_model.encode([incorrect_sentence]), dtype=np.float32).tolist()
        col.insert([[corrected_sentence], [incorrect_sentence], embedding])
        col.flush()
        print("History updated successfully.")

    except Exception as e:
        print(f"Error updating history: {e}")

In [None]:
def format_retrieved_data(data):
    """
    Format a list of retrieved sentence data into a string.
    """
    if not data:
        return ""
    x = data[0]
    if "sentence" in x:
        return '\n'.join([d["sentence"] for d in data])
    elif "incorrect_sentence" in x and "corrected_sentence" in x:
        return '\n'.join(["incorrect: " + d["incorrect_sentence"] + " corrected: " + d["corrected_sentence"] for d in data])
    else:
        raise ValueError("Unrecognized data format")

# Human Feedback

In [None]:
HUMAN_CURATED_COL = "human_curated_dataset"

In [None]:
def store_feedback(correct_sentence, incorrect_sentence):
    """
    Store user feedback in the 'human_curated_dataset' collection.
    Embeds the incorrect sentence and stores it along with the corrected version.
    """
    try:
        col = Collection(HUMAN_CURATED_COL)
        col.load()

        embedding = np.array(em_model.encode([incorrect_sentence]), dtype=np.float32).tolist()
        col.insert([embedding, [correct_sentence], [incorrect_sentence]])

        print("Feedback stored successfully.")

    except Exception as e:
        print(f"Error storing feedback: {e}")

In [None]:
def take_human_feedback(output, user_input):
    """
    Prompt the user to verify the correctness of a system-generated output.

    If the user confirms the output is correct, the feedback is stored.
    If incorrect, the user can optionally provide a corrected version.
    Returns True if feedback is stored, False otherwise.
    """
    if output and "INVALID_INPUT" in str(output):
        return False

    feedback = input("Is the output correct (spelling and grammar)? (y/n): ").strip().lower()

    if feedback == "y":
        store_feedback(str(output), str(user_input))
        print("Feedback recorded as correct.")
        return True

    elif feedback == "n":
        res = input("Would you like to provide the correct output? (y/n): ").strip().lower()

        if res == "y":
            correct_sentence = input("Enter the correct sentence: ").strip()
            store_feedback(correct_sentence, str(user_input))
            print("Corrected sentence recorded.")
            return True

        print("No correction provided.")
        return False

    else:
        print("Invalid input. Please enter 'y' or 'n'.")
        return take_human_feedback(output, user_input)

# Prompt pipeline

In [None]:
def correct_the_sentence(sentence_to_correct, config, data_from_gold_corpus="", data_from_history="", verbose=True):

    llm = llms[config["llmName"]]

    input_data = {
        "sentence_to_correct": sentence_to_correct,
        "data_from_gold_corpus": data_from_gold_corpus,
        "data_from_history": data_from_history
    }

    input_variables = [x for x in input_data.keys()]

    # vanilla / zshot
    if config["name"] == "zeroshot":
      prompt0 = PromptTemplate.from_template(templates['t_zero_shot'])
      chain0 = LLMChain(llm=llm, prompt=prompt0, output_key="corrected_sentence")
      chain_zs = SequentialChain(
                chains=[chain0],
                input_variables=["sentence_to_correct"],
                output_variables=["corrected_sentence"],
                verbose=verbose
      )
      response = chain_zs.invoke({"sentence_to_correct":sentence_to_correct})["corrected_sentence"]
      return response


    #Least-to-Most
    elif config["name"] == "ltm":
      if config["L"] == 4:
            prompt1 = PromptTemplate(
            template=templates["ltm_1_L=4"],
            input_variables=input_variables
            )
            prompt2 = PromptTemplate(
                template=templates["ltm_2_L=4"],
                input_variables=input_variables
            )
            prompt3 = PromptTemplate(
                template=templates["ltm_3_L=4"],
                input_variables=input_variables
            )
            prompt4 = PromptTemplate(
                template=templates["ltm_4_L=4"],
                input_variables=input_variables
            )
            chain1 = LLMChain(llm=llm, prompt=prompt1, output_key="sentence_to_correct_1")
            chain2 = LLMChain(llm=llm, prompt=prompt2, output_key="sentence_to_correct_2")
            chain3 = LLMChain(llm=llm, prompt=prompt3, output_key="sentence_to_correct_3")
            chain4 = LLMChain(llm=llm, prompt=prompt4, output_key="corrected_sentence")
            chain_ltm_L4 = SequentialChain(
                chains=[chain1, chain2, chain3, chain4],
                input_variables=input_variables,
                output_variables=["corrected_sentence"],
                verbose=verbose
            )
            response = chain_ltm_L4.invoke(input_data)["corrected_sentence"]


      elif config["L"] == 2:
            prompt5 = PromptTemplate(
            template=templates["ltm_1_L=2"],
            input_variables=input_variables
            )
            prompt6 = PromptTemplate(
            template=templates["ltm_2_L=2"],
            input_variables=input_variables
            )
            chain5 = LLMChain(llm=llm, prompt=prompt5, output_key="sentence_to_correct_1")
            chain6 = LLMChain(llm=llm, prompt=prompt6, output_key="corrected_sentence")
            chain_ltm_L2 = SequentialChain(
            chains=[chain5, chain6],
            input_variables=input_variables,
            output_variables=["corrected_sentence"],
            verbose=verbose
            )
            response = chain_ltm_L2.invoke(input_data)["corrected_sentence"]

    # vanilla with m1
    elif config["name"] == "vanill_m1":
            promptx = PromptTemplate(
            template=templates["vanilla_m1"],
            input_variables=["sentence_to_correct", "data_from_gold_corpus"]
            )
            chainx = LLMChain(llm=llm, prompt=promptx, output_key="corrected_sentence", verbose=verbose)
            chain_v_m1 = SequentialChain(
                chains=[chainx],
                input_variables=["sentence_to_correct", "data_from_gold_corpus"],
                output_variables=["corrected_sentence"],
                verbose=verbose
            )
            response = chain_v_m1.invoke(input_data, verbose=True)["corrected_sentence"]
            return esponse

    # Divide-and-Conquer
    elif config["name"] == "dac":
        if config["L"] == 2:
          prompt12 = PromptTemplate(
          template=templates["dac_1_L=2"],
          input_variables=["input_sentence"]
          )
          prompt13 = PromptTemplate(
              template=templates["dac_2_L=2"],
              input_variables=["grammar_output"]
          )
          chain12 = LLMChain(llm=llm, prompt=prompt12, output_key="grammar_output")
          chain13 = LLMChain(llm=llm, prompt=prompt13, output_key="corrected_sentence")
          chain_dac_L2 = SequentialChain(
              chains=[chain12, chain13],
              input_variables=["input_sentence"],
              output_variables=["corrected_sentence"],
              verbose=verbose
          )
          response = chain_dac_L2.invoke(input_data)["corrected_sentence"]

        elif config["L"] == 4:
          prompt7 = PromptTemplate(
          template=templates["dac_1_L=4"],
          input_variables=["input_sentence"]
          )
          prompt8 = PromptTemplate(
              template=templates["dac_2_L=4"],
              input_variables=["input_sentence"]
          )
          prompt9 = PromptTemplate(
              template=templates["dac_3_L=4"],
              input_variables=["input_sentence"]
          )
          prompt10 = PromptTemplate(
              template=templates["dac_4_L=4"],
              input_variables=["input_sentence"]
          )
          prompt11 = PromptTemplate(
              template=templates["dac_5_L=4"],
              input_variables=["syntax_output", "morph_output", "spelling_output", "punct_output"]
          )
          chain7 = LLMChain(llm=llm, prompt=prompt7, output_key="syntax_output")
          chain8 = LLMChain(llm=llm, prompt=prompt8, output_key="morph_output")
          chain9 = LLMChain(llm=llm, prompt=prompt9, output_key="spelling_output")
          chain10 = LLMChain(llm=llm, prompt=prompt10, output_key="punct_output")
          syntax_output = chain7.run(input_sentence=sentence_to_correct)
          morph_output = chain8.run(input_sentence=sentence_to_correct)
          spelling_output = chain9.run(input_sentence=sentence_to_correct)
          punct_output = chain10.run(input_sentence=sentence_to_correct)

          chain_dac_L4 = SequentialChain(
              chains=[chain7, chain8, chain9, chain10],
              input_variables=["syntax_output", "morph_output", "spelling_output", "punct_output"],
              output_variables=["corrected_sentence"],
              verbose=verbose
          )
          print("warning:not tested this config")
          response=chain_dac_L4.invoke({
              "syntax_output": syntax_output,
              "morph_output": morph_output,
              "spelling_output": spelling_output,
              "punct_output": punct_output
          })["corrected_sentence"]


    # Chain-of-Thought
    elif config["name"] == "cot":
          prompt14 = PromptTemplate(
          template=templates["cot"],
          input_variables=["sentence_to_correct"]
          )
          chain14 = LLMChain(llm=llm, prompt=prompt14, output_key="corrected_sentence")
          chain_cot = SequentialChain(
              chains=[chain14],
              input_variables=["sentence_to_correct"],
              output_variables=["corrected_sentence"],
              verbose=verbose
          )
          response = chain_cot.invoke({"sentence_to_correct":sentence_to_correct})["corrected_sentence"]


    elif config["name"] == "cot_with_m1":
          prompt142 = PromptTemplate(
          template=templates["cot_m1"],
          input_variables=["sentence_to_correct", "data_from_gold_corpus"]
          )
          chain142 = LLMChain(llm=llm, prompt=prompt142, output_key="corrected_sentence", verbose=verbose)

          chain_cot_m1 = SequentialChain(
              chains=[chain142],
              input_variables=["sentence_to_correct", "data_from_gold_corpus"],
              output_variables=["corrected_sentence"],
              verbose=verbose
          )
          response = chain_cot_m1.invoke(input_data, verbose=True)["corrected_sentence"]


    elif config["name"] == "cot_with_m1&m2":
          prompt143 = PromptTemplate(
          template=templates["cot_m1&m2"],
          input_variables=input_variables
          )
          chain143 = LLMChain(llm=llm, prompt=prompt143, output_key="corrected_sentence", verbose=verbose)
          chain_cot_m1_m2 = SequentialChain(
                  chains=[chain143],
                  input_variables=["sentence_to_correct", "data_from_history", "data_from_gold_corpus"],
                  output_variables=["corrected_sentence"],
                  verbose=verbose
              )
          response = chain_cot_m1_m2.invoke(input_data)["corrected_sentence"]


    elif config["name"] == "cot_with_m2":
          prompt144 = PromptTemplate(
          template=templates["cot_m2"],
          input_variables=["sentence_to_correct", "data_from_history"]
          )
          chain144 = LLMChain(llm=llm, prompt=prompt144, output_key="corrected_sentence", verbose=True)
          chain_cot_m2 = SequentialChain(
              chains=[chain144],
              input_variables=["sentence_to_correct", "data_from_history"],
              output_variables=["corrected_sentence"],
              verbose=verbose
          )
          response = chain_cot_m2.invoke(input_data)["corrected_sentence"]

    # if "INVALID_INPUT" in response :
    #     print(f"{Fore.RED}Output: INVALID_INPUT")
    #     print(f"{Fore.YELLOW}Please enter a valid sentence.")
    #     return "INVALID_INPUT"
    if not response:
      print("Invalid config")
    return response

# Logging

In [None]:
import os
import json
from datetime import datetime

def init_log_file(config, folder_path):
    os.makedirs(folder_path, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"{timestamp}.json"
    filepath = os.path.join(folder_path, filename)

    with open(filepath, "w", encoding="utf-8") as f:
        json.dump({
            "metadata": {
                "timestamp": datetime.now().isoformat(),
                "config": config
            },
            "log": []
        }, f, ensure_ascii=False, indent=2)

    return filepath

def append_correction_to_log(filepath, input_text, output_text, index, x, stored_in_M2):
    log_entry = {
        "index": index,
        "input": input_text,
        "output": output_text,
        "x": x,
        "stored_in_M2": stored_in_M2
    }

    with open(filepath, "r+", encoding="utf-8") as f:
        data = json.load(f)
        data["log"].append(log_entry)
        f.seek(0)
        json.dump(data, f, ensure_ascii=False, indent=2)
        f.truncate()

# δ$_1$ tuning

In [None]:
import math

def fun(x, name, alpha, DELTA_1):
    if name == "exp":
        return DELTA_1 + alpha * (math.exp(alpha * x) - 1)
    elif name == "linear":
        return DELTA_1 + alpha * x
    elif name == "const":
        return DELTA_1
    else:
        print("invalid fun name.")
        return None

In [None]:
def get_collection_and_count(collection_name):
    """
    Load the collection and return it along with its current entity count.
    """
    try:
        collection = Collection(collection_name)
        collection.load()
        collection.flush()
        entity_count = collection.num_entities
        return entity_count

    except Exception as e:
        print(f"Error loading collection '{collection_name}': {e}")
        return 0

# Input

In [None]:
def load_input_sentences(filepath):
  with open(filepath, 'r', encoding='utf-8') as f:
    data = json.load(f)
  incorrect_sentences = [entry['input'] for entry in data if 'input' in entry]
  if not incorrect_sentences:
    print("No incorrect sentences found in the file.")
    return None
  return incorrect_sentences

In [None]:
def take_input(list_of_sentences, config, verbose=True, update_M2=True, take_feedback=False, log=True, start_index=0):

  if log: log_file_path = init_log_file(config, LOGGING_FOLDER)

  corrected_sentences = []
  incorrect_sentences = []
  list_of_sentences = list_of_sentences[start_index-1:]
  xs, inds, delta1s = [], [], []
  delta1, delta2 = DELTA_1, DELTA_2

  global NUM_ENT_M2

  try:
    if not utility.has_collection(M1_COLLECTION):
        raise ValueError(f"Collection '{M1_COLLECTION}' does not exist.")

    x = get_collection_and_count(M2_COLLECTION)
  except Exception as e:
    print(f"[Milvus Error] {e}")
    return

  index = start_index

  for sentence in list_of_sentences:
    dm1, dm2 = "", ""

    if "m1" in config["name"]:
      try:
        dm1 = read_from_M1(sentence, k1=config["k1"])
        data_from_m1 = format_retrieved_data(dm1)
      except Exception as e:
        print(f"[Milvus Error in M1] {e}")
        continue
    else:
      data_from_m1 = ""

    if "m2" in config["name"]:
      try:
        dm2 = read_from_M2(sentence, k2=config["k2"])
        data_from_m2 = format_retrieved_data(dm2)
      except Exception as e:
        print(f"[Milvus Error in M2] {e}")
        continue
    else:
      data_from_m2 = ""

    delta1 = fun(x, "const", ALPHA, DELTA_1)
    xs.append(x)
    delta1s.append(delta1)
    inds.append(index)

    print("index:"+str(index),"x:"+str(x),"delta1:"+str(delta1), "avg_dist:"+str(avg_distance(dm1)), "min_dist:"+str(min_distance(dm2)))

    # ---------------For cost cutting--------------
    # if not ((avg_distance(dm1) <= delta1 or False) and min_distance(dm2) >= DELTA_2):
      # corrected = "-"
    # else:
    corrected = correct_the_sentence(config=config, sentence_to_correct=sentence, data_from_gold_corpus=data_from_m1, data_from_history=data_from_m2, verbose=verbose)
    if verbose: print("corrected: "+corrected)
    corrected_sentences.append(corrected)
    incorrect_sentences.append(sentence)

    fb = False
    if take_feedback:
      fb = take_human_feedback()
    if fb:
      store_feedback(corrected, sentence)

    updated_M2 = False
    if update_M2:
      # print("index:"+str(index),"x:"+str(x),"delta1:"+str(delta1), "avg_dist:"+str(avg_distance(dm1)), "min_dist:"+str(min_distance(dm2)))
      if (avg_distance(dm1) <= delta1 or fb) and min_distance(dm2) >= DELTA_2:
        try:
          write_to_M2(sentence, corrected)
          x += 1
          NUM_ENT_M2 = x
          updated_M2 = True
        except Exception as e:
          print(f"[Milvus Error in write_to_M2] {e}")

    index += 1
    if log:
      append_correction_to_log(log_file_path, sentence, corrected, index, x, updated_M2)

  nm = config["name"]

  with open(f"{nm}_delta1.json", "w", encoding="utf-8") as f:
    json.dump([{
        "x": x,
        "delta1": delta1,
        "i":i
    } for x,delta1,i in zip(xs, delta1s, inds)], fp=f)


  with open(f"/content/drive/MyDrive/Smruti-GEC-for-Gujarati/results/synthetic/{nm}_results.json", "w", encoding="utf-8") as f:
    json.dump([{
        "incorrect_sentence": incorrect_sentence,
        "corrected_sentence": corrected_sentence
    } for corrected_sentence,incorrect_sentence in zip(corrected_sentences,incorrect_sentences)], fp=f)

# _

In [None]:
# input_sentences =["દીકરીઓને સનસ આવી ગયેલા છે , ને આ શિકારીનું ટોળું પણ ગંધ લીધા વગર નહિ ગયું હોય ."]*10
# load_input_sentences("/content/drive/MyDrive/Gujarati_Spelling_and_Grammar_Autocorrect/evaluation_set.json")
input_sentences = load_input_sentences("/content/drive/MyDrive/Smruti-GEC-for-Gujarati/data/synthetic_eval_set.json")
config = {
  "llmName":"gpt-4o-mini",
  "name":"cot_with_m1&m2",
  "k1":5,
  "k2":2,
}
start_index = 1
take_input(input_sentences, config, update_M2=False, start_index=start_index, verbose=True)

In [None]:
import json

def convert_format1_with_references(format1_path, references_path, output_path):
    with open(format1_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    with open(references_path, 'r', encoding='utf-8') as f:
        references = json.load(f)

    log_entries = data["log"]
    if len(log_entries) != len(references):
        raise ValueError("Mismatch between log and reference lengths.")

    result = []
    for i, (entry, ref) in enumerate(zip(log_entries, references)):
        result.append({
            "input": entry["input"],
            "prediction": entry["output"],
            "reference": ref["reference"],
            "index": i + 1
        })

    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(result, f, ensure_ascii=False, indent=2)

convert_format1_with_references("/content/drive/MyDrive/Smruti-GEC-for-Gujarati/20250623_074913.json", "/content/drive/MyDrive/Smruti-GEC-for-Gujarati/data/synthetic_eval_set.json", "/content/drive/MyDrive/Smruti-GEC-for-Gujarati/results/synthetic/cot_with_m1&m2_k1=5_k2=2.json")

# Batch API

In [None]:
!%pip install openai --upgrade

In [None]:
import json
from openai import OpenAI
import pandas as pd
from IPython.display import Image, display

## prompt builder

In [None]:
def build_prompt(sentence_to_correct, config, verbose=True):

    if not utility.has_collection(M1_COLLECTION):
      raise ValueError(f"Collection '{M1_COLLECTION}' does not exist.")

    if not collections:
        collection = Collection(M1_COLLECTION)
        collection.load()

    # m2_total_recs = Collection(M2_COLLECTION).num_entities
    # delta1 = fun(m2_total_recs, "h1")

    dm1, dm2 = "", ""
    if "m1" in config["name"]:
      dm1 = read_from_M1(sentence_to_correct, k1=config["k1"])
      data_from_m1 = format_retrieved_data(dm1)
    else:
      data_from_m1 = ""

    if "m2" in config["name"]:
      dm2 = read_from_M2(sentence_to_correct, k2=config["k2"])
      data_from_m2 = format_retrieved_data(dm2)
    else:
      data_from_m2 = ""

    input_data = {
        "sentence_to_correct": sentence_to_correct,
        "data_from_gold_corpus": data_from_m1,
        "data_from_history": data_from_m2
    }

    input_variables = [x for x in input_data.keys()]

    # Zero-shot
    if config["name"] == "zeroshot":
        prompt0 = PromptTemplate.from_template(templates['t_zero_shot'])
        prompt_str = prompt0.format(sentence_to_correct=sentence_to_correct)

    # Chain-of-Thought
    elif config["name"] == "cot":
        prompt14 = PromptTemplate(template=templates["cot"], input_variables=["sentence_to_correct"])
        prompt_str = prompt14.format(sentence_to_correct=sentence_to_correct)

    # Chain-of-Thought with m1
    elif config["name"] == "cot_with_m1":
        prompt142 = PromptTemplate(
            template=templates["cot_m1"],
            input_variables=["sentence_to_correct", "data_from_gold_corpus"]
        )
        prompt_str = prompt142.format(
            sentence_to_correct=sentence_to_correct,
            data_from_gold_corpus=data_from_m1
        )

    # Chain-of-Thought with m2
    elif config["name"] == "cot_with_m2":
        prompt144 = PromptTemplate(
            template=templates["cot_m2"],
            input_variables=["sentence_to_correct", "data_from_history"]
        )
        prompt_str = prompt144.format(
            sentence_to_correct=sentence_to_correct,
            data_from_history=data_from_m2
        )

    elif config["name"] == "vanilla_m1":
            prompt142 = PromptTemplate(
            template=templates["vanilla_m1"],
            input_variables=["sentence_to_correct", "data_from_gold_corpus"]
            )
            prompt_str = prompt142.format(
            sentence_to_correct=sentence_to_correct,
            data_from_gold_corpus=data_from_m1
            )

    # Chain-of-Thought with m1 & m2
    elif config["name"] == "cot_with_m1&m2":
        prompt143 = PromptTemplate(
            template=templates["cot_m1&m2"],
            input_variables=input_variables
        )
        prompt_str = prompt143.format(**input_data)

    else:
        print("Invalid config")
        return None

    return prompt_str

In [None]:
def generate_batch_file(input_json_path, config, output_file_path, verbose=True):
    with open(input_json_path, "r", encoding="utf-8") as infile:
        data = json.load(infile)

    with open(output_file_path, "w", encoding="utf-8") as outfile:
        for idx, item in enumerate(data):
            sentence = item["input"]
            custom_id = item["index"]

            prompt_str = build_prompt(
                sentence_to_correct=sentence,
                config=config,
                verbose=verbose
            )

            if not prompt_str:
                print(f"Warning: Empty prompt for item #{idx+1}. Skipping.")
                continue

            if verbose:
                print(f"Prompt for index {custom_id}: ===============================================================")
                print(prompt_str)

            record = {
                "custom_id": str(custom_id),
                "method": "POST",
                "url": "/v1/chat/completions",
                "body": {
                    "model": config.get("llmName", "gpt-4o-mini"),
                    "messages": [
                        {"role": "system", "content": "You are a grammar correction model."},
                        {"role": "user", "content": prompt_str}
                    ],
                    "temperature": 0
                }
            }

            outfile.write(json.dumps(record, ensure_ascii=False) + "\n")

    print(f"Successfully written {len(data)} prompts to {output_file_path}")

## batch file generator

In [None]:
config = {
    "name": "cot_with_m1&m2",
    "llmName": "gpt-4o-mini",
    "k1":5,
    "k2":2
}

generate_batch_file(
    input_json_path="/content/drive/MyDrive/Smruti-GEC-for-Gujarati/data/synthetic_eval_set.json",
    config=config,
    output_file_path="batch_16-cot-with-m1&m2-k1=5-k2=2-const-gpt-4o-mini.jsonl",
    verbose=True
)

## _

In [None]:
client = OpenAI(api_key = "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") #userdata.get('OPENAI_API_KEY'))

In [None]:
file_name = "batch_16-cot-with-m1&m2-k1=5-k2=2-const-gpt-4o-mini.jsonl"
batch_file = client.files.create(
  file=open(file_name, "rb"),
  purpose="batch"
)

In [None]:
batch_file = client.files.create(
  file=open(file_name, "rb"),
  purpose="batch"
)
print(batch_file)

In [None]:
batch_job = client.batches.create(
  input_file_id=batch_file.id,
  endpoint="/v1/chat/completions",
  completion_window="24h"
)

In [None]:
# job_id = ""
batch_job = client.batches.retrieve(job_id)
batch_job = client.batches.retrieve(batch_job.id)
print(batch_job)

In [None]:
result_file_id = batch_job.output_file_id
result = client.files.content(result_file_id).content

In [None]:
result_file_name = "/content/batch_results.jsonl"
with open(result_file_name, 'w', encoding='utf-8') as file:
    file.write(result.decode('utf-8'))

In [None]:
def extract_batch_outputs(batch_result_file, evaluation_file, output_file):
    with open(evaluation_file, 'r', encoding='utf-8') as f:
        references = json.load(f)

    records = []
    with open(batch_result_file, 'r', encoding='utf-8') as infile:
        for idx, line in enumerate(infile):
            data = json.loads(line)
            content = data['response']['body']['choices'][0]['message']['content']
            record = {
                "input": references[idx]["input"],
                "prediction": content,
                "reference": references[idx]["reference"],
                "index": idx+1
            }
            records.append(record)

    with open(output_file, 'w', encoding='utf-8') as outfile:
        json.dump(records, outfile, ensure_ascii=False, indent=2)

In [None]:
RES_FOLDER = "/content/drive/MyDrive/Smruti-GEC-for-Gujarati/results/synthetic/"
output_file = RES_FOLDER + "vanilla_with_m1_k1=3.json"
batch_result_file = result_file_name
evaluation_file = "/content/drive/MyDrive/Smruti-GEC-for-Gujarati/data/synthetic_eval_set.json"
extract_batch_outputs(batch_result_file, evaluation_file, output_file)

In [None]:
batches = client.batches.list()
[batch.status for batch in batches.data]