<a href="https://colab.research.google.com/github/LunaSeline/AHMD/blob/main/Prompt_completion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install cmake (if not already installed)
!apt-get install -y cmake

# Install transformers (should be pre-installed, but this ensures it)
!pip install transformers

# Clone KenLM (if not already cloned) and build its binaries.
!git clone https://github.com/kpu/kenlm.git || echo "KenLM repository already exists."
!cd kenlm && mkdir -p build && cd build && cmake .. && make -j4
!pip install https://github.com/kpu/kenlm/archive/master.zip

In [None]:

import kenlm
kenlm_model = kenlm.Model('adl_model.klm')

file_path= 'adl_corpus.txt'
def load_corpus(file_path):
    """Load corpus from a text file."""
    with open(file_path, 'r', encoding='utf-8') as file:
        return [line.strip() for line in file.readlines()]

def kenlm_complete(prompt, max_words=10, corpus_file="adl_corpus.txt"):
    # Load corpus sentences from the file
    corpus_sentences = load_corpus(corpus_file)

    # Build a vocabulary from the corpus
    vocab = set()
    for sentence in corpus_sentences:
        for word in sentence.split():
            vocab.add(word)
    vocab = list(vocab)

    completion = prompt.strip()
    for _ in range(max_words):
        best_score = float('-inf')
        best_word = None

        # Evaluate candidates by appending each word from the vocabulary.
        for word in vocab:
            candidate = completion + ' ' + word
            score = kenlm_model.score(candidate, bos=False, eos=False)
            if score > best_score:
                best_score = score
                best_word = word

        if best_word is None:
            break
        # Avoid repeating words already present.
        if best_word in completion.split():
            break
        completion += ' ' + best_word

        # If punctuation is detected, stop early.
        if best_word.endswith(('.', '?', '!')):
            break
    return completion


from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# Initialize GPT-2 fallback using DistilGPT2.
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
distilgpt_model = AutoModelForCausalLM.from_pretrained("distilgpt2")
generator = pipeline("text-generation", model=distilgpt_model, tokenizer=tokenizer)

def complete_prompt(prompt):
    kenlm_result = kenlm_complete(prompt)
    # Score the generated text using KenLM.
    kenlm_score = kenlm_model.score(kenlm_result, bos=False, eos=False)

    # If KenLM didn't extend the prompt or produced a low-scoring output, use GPT-2 as fallback.
    if kenlm_result.strip() == prompt.strip() or kenlm_score < -10:
        gpt_result = generator(prompt, max_length=len(prompt.split()) + 20, do_sample=True, temperature=0.7)
        return gpt_result[0]['generated_text']
    else:
        return kenlm_result

test_prompt = "I want to eat "
completed_text = complete_prompt(test_prompt)
print("Completed Text:\n", completed_text)


from gtts import gTTS
from io import BytesIO
from IPython.display import Audio

# Generate speech and store it in a BytesIO buffer
tts = gTTS(text=" Did you mean: "+ completed_text, lang='hi')
audio_buffer = BytesIO()
tts.write_to_fp(audio_buffer)
audio_buffer.seek(0)  # Move to the beginning of the buffer

# Play the audio directly in Colab
Audio(audio_buffer.read(), autoplay=True)



from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
model_name = "facebook/mbart-large-50-many-to-many-mmt"
tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
model = MBartForConditionalGeneration.from_pretrained(model_name)
tokenizer.src_lang = "en_XX"
inputs = tokenizer(completed_text, return_tensors="pt")
generated_tokens = model.generate(**inputs, forced_bos_token_id=tokenizer.lang_code_to_id["hi_IN"])
translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
if isinstance(translation, list):
    translation = " ".join(translation)
    print(translation)

tts = gTTS(text="क्या आप ये बोलना चाहते है: "+translation, lang='hi')
audio_buffer = BytesIO()
tts.write_to_fp(audio_buffer)
audio_buffer.seek(0)

# Play the audio in Colab
Audio(audio_buffer.read(), autoplay=True)

Device set to use cpu


Completed Text:
 I want to eat breakfast?
मैं नाश्ता खाना चाहता हूँ?


In [None]:
import kenlm
import math
import random
import re
import time  # For time.sleep()
from transformers import pipeline, MBartForConditionalGeneration, MBart50TokenizerFast
from gtts import gTTS
from io import BytesIO
from IPython.display import Audio, display

#############################################
# Step 0: Load KenLM and Build Vocabulary
#############################################

# Load your KenLM model.
kenlm_model = kenlm.Model('adl_model.klm')

# Build the vocabulary from your corpus file.
file_path = 'adl_corpus.txt'
with open(file_path, 'r', encoding='utf-8') as file:
    corpus_sentences = [line.strip() for line in file.readlines()]

#############################################
# Step 1: KenLM Sampling Function
#############################################

def kenlm_sampling(prompt, max_words=10, temperature=1.0, top_k=5):
    """
    Generate a raw completion by word-by-word sampling using KenLM scores.
    """
    # Build vocabulary from corpus.
    vocab = set()
    for sentence in corpus_sentences:
        for word in sentence.split():
            vocab.add(word)
    vocab = list(vocab)

    output = prompt.strip()
    for _ in range(max_words):
        scores = []
        # Score each candidate word.
        for word in vocab:
            candidate = output + ' ' + word
            score_val = kenlm_model.score(candidate, bos=False, eos=False)
            scores.append(score_val)

        # Convert scores to probabilities via softmax (with temperature).
        max_score = max(scores)
        adjusted_scores = [(s - max_score) / temperature for s in scores]
        exp_scores = [math.exp(s) for s in adjusted_scores]
        sum_exp = sum(exp_scores)
        probs = [s / sum_exp for s in exp_scores]

        # Limit to top_k words.
        top_indices = sorted(range(len(probs)), key=lambda i: probs[i], reverse=True)[:top_k]
        top_words = [vocab[i] for i in top_indices]
        top_probs = [probs[i] for i in top_indices]

        # Sample one word from top candidates.
        chosen = random.choices(top_words, weights=top_probs, k=1)[0]
        output += ' ' + chosen

        # Stop early if punctuation is appended.
        if chosen.endswith(('.', '?', '!')):
            break
    return output

#############################################
# New Step: Remove Repeated Words
#############################################

def remove_repeated_words(text):
    """
    Remove consecutive duplicate words from the text.
    For example, transforms "my my my" into "my".
    """
    # The regex matches one word followed by at least one repetition of that word.
    # The \b indicates word boundary and the re.IGNORECASE handles case-insensitive matching.
    pattern = re.compile(r'\b(\w+)(\s+\1\b)+', re.IGNORECASE)
    return pattern.sub(r'\1', text)

#############################################
# Step 2: Neural Rewriter (Paraphraser)
#############################################

paraphraser = pipeline("text2text-generation",
                       model="Vamsi/T5_Paraphrase_Paws",
                       tokenizer="Vamsi/T5_Paraphrase_Paws")

def neural_rewrite(text, max_length=50, num_beams=5):
    """
    Rewrite the given text using a neural paraphraser.
    """
    rewriting_prompt = f"paraphrase: {text} </s>"
    rewritten = paraphraser(rewriting_prompt, max_length=max_length, num_beams=num_beams)
    return rewritten[0]['generated_text']

#############################################
# Step 3: Grammar Correction
#############################################

grammar_corrector = pipeline("text2text-generation",
                             model="prithivida/grammar_error_correcter_v1",
                             tokenizer="prithivida/grammar_error_correcter_v1")

def grammar_correct(text, max_length=60):
    """
    Correct the grammar of the given text.
    """
    corrected = grammar_corrector(text, max_length=max_length)
    return corrected[0]['generated_text']

#############################################
# Step 4: Ensure Single Sentence
#############################################

def ensure_single_sentence(text):
    """
    Reduce multi-sentence text down to the first sentence.
    """
    sentences = re.split(r'[.!?]', text)
    for sent in sentences:
        sent = sent.strip()
        if sent:
            return sent + '.'
    return text.strip()

#############################################
# Step 5: Generate Multiple Options
#############################################

def generate_multiple_options(prompt, n_options=3, max_words=10, temperature=1.0, top_k=5,
                              paraphrase_max_length=50, paraphrase_num_beams=5, correct_max_length=60):
    """
    Generate candidate outputs by combining KenLM sampling, neural rewrite, grammar correction,
    and then post-process to remove repetitions and ensure a single sentence.
    """
    options = []
    for _ in range(n_options):
        raw_text = kenlm_sampling(prompt, max_words=max_words, temperature=temperature, top_k=top_k)
        rewritten_text = neural_rewrite(raw_text, max_length=paraphrase_max_length, num_beams=paraphrase_num_beams)
        corrected_text = grammar_correct(rewritten_text, max_length=correct_max_length)
        # Remove repeated words.
        cleaned_text = remove_repeated_words(corrected_text)
        single_sentence = ensure_single_sentence(cleaned_text)
        options.append(single_sentence)
    return options

#############################################
# Step 6: Translation Setup for MBart
#############################################

model_name = "facebook/mbart-large-50-many-to-many-mmt"
tokenizer_mbart = MBart50TokenizerFast.from_pretrained(model_name)
model_mbart = MBartForConditionalGeneration.from_pretrained(model_name)
tokenizer_mbart.src_lang = "en_XX"

#############################################
# Step 7: Interactive Completion with Audio for Every Option
#############################################

def interactive_completion(prompt, n_options=3, max_words=10, temperature=1.0, top_k=5,
                           paraphrase_max_length=50, paraphrase_num_beams=5, correct_max_length=60):
    """
    Generate candidate outputs (each cleaned and reduced to a single sentence), display the text,
    produce TTS for each candidate and its translation (each followed by a 2-second pause),
    and let the user select one or request more options.
    """
    while True:
        candidates = generate_multiple_options(prompt, n_options=n_options, max_words=max_words,
                                                 temperature=temperature, top_k=top_k,
                                                 paraphrase_max_length=paraphrase_max_length,
                                                 paraphrase_num_beams=paraphrase_num_beams,
                                                 correct_max_length=correct_max_length)
        print("\nCandidate Options (single sentence outputs):\n")
        for idx, option in enumerate(candidates):
            print(f"Option {idx+1}: {option}\n")

            # Produce TTS for the candidate.
            tts_option = gTTS(text="Did you mean: " + option, lang='hi')
            audio_buf_option = BytesIO()
            tts_option.write_to_fp(audio_buf_option)
            audio_buf_option.seek(0)
            print(f"Playing Speech for Option {idx+1}:")
            display(Audio(audio_buf_option.read(), autoplay=True))
            time.sleep(2)  # 2-second gap

            # Translate candidate using MBart.
            inputs_option = tokenizer_mbart(option, return_tensors="pt")
            generated_tokens_option = model_mbart.generate(**inputs_option,
                                                           forced_bos_token_id=tokenizer_mbart.lang_code_to_id["hi_IN"])
            translation_option = tokenizer_mbart.batch_decode(generated_tokens_option, skip_special_tokens=True)
            if isinstance(translation_option, list):
                translation_option = " ".join(translation_option)
            print(f"Translated Option {idx+1}: {translation_option}\n")

            # Produce TTS for the translation.
            tts_translation_option = gTTS(text="क्या आप ये बोलना चाहते हैं: " + translation_option, lang='hi')
            audio_buf_translation = BytesIO()
            tts_translation_option.write_to_fp(audio_buf_translation)
            audio_buf_translation.seek(0)
            print(f"Playing Translated Speech for Option {idx+1}:")
            display(Audio(audio_buf_translation.read(), autoplay=True))
            time.sleep(2)  # 2-second gap

        user_input = input("Enter the option number you prefer (e.g., 1) or type 'n' for more options: ").strip().lower()
        if user_input == 'n':
            print("Generating additional options...\n")
            continue
        try:
            choice = int(user_input)
            if 1 <= choice <= len(candidates):
                return candidates[choice-1]
            else:
                print("Invalid option number, please try again.\n")
        except ValueError:
            print("Invalid input, please enter a valid number or 'n'.\n")

#############################################
# Step 8: Main Execution and Final Audio Output
#############################################

base_prompt = "I want to"
final_output = interactive_completion(base_prompt, n_options=3, max_words=10, temperature=1.0, top_k=5,
                                      paraphrase_max_length=50, paraphrase_num_beams=5, correct_max_length=60)
print("\nFinal accepted output:")
print(final_output)

# Produce TTS for the final accepted output.
tts_final = gTTS(text="Did you mean: " + final_output, lang='hi')
audio_buf_final = BytesIO()
tts_final.write_to_fp(audio_buf_final)
audio_buf_final.seek(0)
print("\nPlaying audio for the final accepted output:")
display(Audio(audio_buf_final.read(), autoplay=True))
time.sleep(2)  # 2-second gap

# Translate the final accepted text.
inputs_final = tokenizer_mbart(final_output, return_tensors="pt")
generated_tokens_final = model_mbart.generate(**inputs_final, forced_bos_token_id=tokenizer_mbart.lang_code_to_id["hi_IN"])
translation_final = tokenizer_mbart.batch_decode(generated_tokens_final, skip_special_tokens=True)
if isinstance(translation_final, list):
    translation_final = " ".join(translation_final)
print("\nTranslated final text:")
print(translation_final)

tts_trans_final = gTTS(text="क्या आप ये बोलना चाहते हैं: " + translation_final, lang='hi')
audio_buf_trans_final = BytesIO()
tts_trans_final.write_to_fp(audio_buf_trans_final)
audio_buf_trans_final.seek(0)
print("\nPlaying translated audio for the final accepted output:")
display(Audio(audio_buf_trans_final.read(), autoplay=True))


Device set to use cpu
Device set to use cpu



Candidate Options (single sentence outputs):

Option 1: I want to text for a bit.

Playing Speech for Option 1:


Translated Option 1: मैं कुछ समय के लिए पाठ करना चाहता हूँ।

Playing Translated Speech for Option 1:


Option 2: Do I want to go to the post office.

Playing Speech for Option 2:


Translated Option 2: क्या मैं डाकघर जाना चाहता हूँ।

Playing Translated Speech for Option 2:


Option 3: I want to get ready for bed.

Playing Speech for Option 3:


Translated Option 3: मैं बिस्तर के लिए तैयार होना चाहता हूँ।

Playing Translated Speech for Option 3:


Enter the option number you prefer (e.g., 1) or type 'n' for more options: 3

Final accepted output:
I want to get ready for bed.

Playing audio for the final accepted output:



Translated final text:
मैं बिस्तर के लिए तैयार होना चाहता हूँ।

Playing translated audio for the final accepted output:
