In [None]:
from huggingface_hub import login
login(token='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')


In [None]:
!huggingface-cli whoami              

In [None]:
import re
import torch
import pandas as pd
from transformers import pipeline
from tqdm.notebook import tqdm


In [None]:
df = pd.read_csv("results/1.csv")

In [None]:
for i in range(2):
    print(f"sentence: {df['Predicted Sentence'][i]}")
    print(f"corrected: {df['True Sentence'][i]}")
    print()  # Empty line for spacing


In [None]:
model_id = "meta-llama/Llama-3.1-8B-Instruct"
pipe = pipeline(
    "text-generation",
    model=model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

In [None]:
def get_messages(sentence):
    messages = [
        {"role": "system", "content": "You are an expert in correcting typos in sentences."},
        {"role": "user", "content": """
Here are examples of sentences with typos; learn from them:

    sentence: by 2480 genetic engineers will havegcreated organisgs capable of surviving in space without life support
    corrected: by 2480 genetic engineers will have created organisms capable of surviving in space without life support

    sentence: by 2510 the local government will have been investing in green infrastructure for generations
    corrected: by 2510 the local government will have been investing in green infrastructure for generations

Now, please correct this sentence and output only the corrected version with no additional text:
    
{target_sentence}
        """.format(target_sentence=sentence)},
    ]
    return messages

In [None]:
def get_llm_sentence(sentence):
    messages = get_messages(sentence)
    outputs = pipe(
        messages,
        max_new_tokens=256,
        # pad_token_id=pipe.tokenizer.eos_token_id
    )
    llm_sentence = outputs[0]["generated_text"][-1]["content"]
    return llm_sentence

In [None]:
# get_llm_sentence("My name it John.")

In [None]:
def llm_postprocess(sentence):
    sentence = sentence.lower().strip()
    # remove all non a-z0-9 
    sentence = re.sub(r'[^a-z0-9\s]', '', sentence)
    return sentence

In [None]:
import difflib

def compute_accuracy_and_wrong_syllables(true_sentence, predicted_sentence):
    # Character-level accuracy using SequenceMatcher
    char_matcher = difflib.SequenceMatcher(None, true_sentence, predicted_sentence)
    accuracy = char_matcher.ratio()
    
    # Word-level wrong syllable count using SequenceMatcher on word lists
    true_words = true_sentence.split()
    predicted_words = predicted_sentence.split()
    word_matcher = difflib.SequenceMatcher(None, true_words, predicted_words)
    
    # Calculate wrong syllables based on insert, delete, and replace operations
    wrong_syllables = sum(1 for tag, _, _, _, _ in word_matcher.get_opcodes() if tag in ('insert', 'delete', 'replace'))
    
    return accuracy, wrong_syllables


In [None]:
from tqdm.notebook import tqdm

In [None]:
llm_accs = []
llm_ws = []
llm_sen = []
total=len(df)

for index, row in tqdm(df.iterrows(), total=total):
    should_print = index % 100 == 0
    predicted_sentence = row['Predicted Sentence']
    true_sentence = row['True Sentence']
    accuracy, wrong_syllables = compute_accuracy_and_wrong_syllables(true_sentence, predicted_sentence)
    if should_print:
        print(f"Index: {index} of {total}")
        print("CoAtNet", accuracy, wrong_syllables)
    
    llm_sentence = get_llm_sentence(predicted_sentence)
    llm_sentence = llm_postprocess(llm_sentence)
    accuracy, wrong_syllables = compute_accuracy_and_wrong_syllables(true_sentence, llm_sentence)
    if should_print:
        print("LLM", accuracy, wrong_syllables)
        print("==========")
    
    llm_sen.append(llm_sentence)
    llm_accs.append(accuracy)
    llm_ws.append(wrong_syllables)
    

In [None]:
df['LLM Sentence'] = llm_sen
df['LLM Accuracy'] = llm_accs
df['LLM Wrong syllables'] = llm_ws

In [None]:
# average accuracy
llm_avg_accuracy = sum(llm_accs) / len(llm_accs)
# sum of wrong syllables
llm_sum_wrong_syllables = sum(llm_ws)

print(f"LLM Average Accuracy: {llm_avg_accuracy}")
print(f"LLM Sum of Wrong Syllables: {llm_sum_wrong_syllables}")

- 1B
- NF 1
- LLM Average Accuracy: 0.9564564414385444
- LLM Sum of Wrong Syllables: 525
---
- 1B
- NF 5
- LLM Average Accuracy: 0.7679998762433198
- LLM Sum of Wrong Syllables: 2239
---
- 1B
- NF 6
- LLM Average Accuracy: 0.6005106454866648
- LLM Sum of Wrong Syllables: 2343

---
- 3B
- NF 1
- LLM Average Accuracy: 0.9935726024752972
- LLM Sum of Wrong Syllables: 220
---
- 3B
- NF 5
- LLM Average Accuracy: 0.8926553222468829
- LLM Sum of Wrong Syllables: 1701
---
- 3B
- NF 6
- LLM Average Accuracy: 0.6951608020737897
- LLM Sum of Wrong Syllables: 2622

---

- 8B
- NF 1
- LLM Average Accuracy: 0.9964405469123979
- LLM Sum of Wrong Syllables: 118
---
- 8B
- NF 5
- LLM Average Accuracy: 0.9221289253837536
- LLM Sum of Wrong Syllables: 1248
---
- 8B
- NF 6
- LLM Average Accuracy: 0.7523719200663457
- LLM Sum of Wrong Syllables: 2413


In [None]:
df.to_csv('results/llama3_1_8b/1.csv', index=False)

In [None]:
print("Done!")