In [None]:

from peft import PromptTuningConfig, TaskType, PromptTuningInit, get_peft_model,PeftModel
import config

class Modelinference:
    def __init__(self, model, tokenizer, output_dir_name):
        self.output_directory = output_dir_name
        self.foundational_model = model
        self.tokenizer = tokenizer
        

    #this function returns the outputs from the model received, and inputs.
    def get_outputs(self,model,inputs,do_sample=True,num_beams=None,num_return_sequences = 3):
        """
        Generates multiple sequences of text using the provided model and inputs.

        Args:
            model: The model used for generation.
            inputs (dict): Input tensors including 'input_ids' and 'attention_mask'.
            do_sample (bool, optional): Whether to use sampling during generation (default: True).
            num_beams (int, optional): Number of beams for beam search. Overrides `do_sample`.
            num_return_sequences (int, optional): Number of sequences to generate per input (default: 3).

        Returns:
            torch.Tensor: Tensor containing generated sequences.
        """

        outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_length=30,
            early_stopping=False, #if num_beams is None else True, #The model can stop before reach the max_length
            temperature= 1,
            num_beams=1 if num_beams is None else num_beams,
            do_sample=num_beams is None and do_sample,
            num_return_sequences=num_return_sequences,
        )
        return outputs
    
    def inference (self,input_prompt,num_return_sequences = 3, num_beams = 5 ):
        """
        Generate text sequences based on an input prompt using a pretrained model saved in the directory.

        Args:
            input_prompt (str): The input prompt text to generate sequences from.

        Returns:
            list: List of generated text sequences as decoded by the tokenizer, without special tokens.
        """

        loaded_model_prompt = PeftModel.from_pretrained(self.foundational_model,
                                         self.output_directory,
                                         device_map='auto',
                                         is_trainable=False)
        
        input_prompt_tok = self.tokenizer(input_prompt, return_tensors="pt")
        loaded_model_prompt_outputs = self.get_outputs(loaded_model_prompt, input_prompt_tok,num_beams = num_beams,num_return_sequences = num_return_sequences)
        result = self.tokenizer.batch_decode(loaded_model_prompt_outputs, skip_special_tokens=True)

        return result
    


In [None]:
from training_helper import setup_model

device,tokenizer, model = setup_model(config.MODEL_PATH)
output_directory = './peft_outputs'

input_prompt = "Her decision is rational love. <|perturb|> [negation] Her decision is [BLANK] love"

# Define your foundational model, tokenizer, and tokenized datasets
foundational_model = model
model_name = "uw-hai/polyjuice"

# Initialize the ModelTrainer class
inf_class =  Modelinference(foundational_model, tokenizer, output_directory)

sequence = inf_class.inference (input_prompt)