In [None]:
!pip install transformers

In [None]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import time

In [None]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import time

# Define the model path for Pegasus Large
model_paths = {
    #  "T5 Large": "t5-large",
    # "LED Base": "allenai/led-base-16384",
    # "mBART Large": "facebook/mbart-large-cc25",
    #   "DialoGPT Large": "microsoft/DialoGPT-large",
    "BART CNN Samsum": "philschmid/bart-large-cnn-samsum",
    "BLOOM 560M": "bigscience/bloom-560m"
}

# Function to load model and tokenizer
def load_model(model_name):
    print(f"Loading {model_name} model...")
    # Use AutoModelForSeq2SeqLM specifically for Pegasus, which is a seq2seq model.
    model = AutoModelForSeq2SeqLM.from_pretrained(model_paths[model_name])
    tokenizer = AutoTokenizer.from_pretrained(model_paths[model_name], use_fast=False)  # Disable fast tokenizer for Pegasus and others
    return model, tokenizer

# Load all models
models = {name: load_model(name) for name in model_paths}

def create_prompt():
    """
    Creates a zero-shot prompt with the provided case text.
    
    Returns:
        str: The formatted prompt.
    """
    enter_case = """For context, this happened in 2020 during the pandemic, and I never knew the names of anyone here. 
Back in 2020, I got sick (not with covid) and I went to my doctor to get tested for strep. It was all going normal at first. 
I walked in, put on a mask, checked in, and sat down to wait. While I was on my phone, I see a male Karen, looking to be in his 40s or 50s maybe, 
come through the doors without a mask. The polite nurse tells him, "Sir please put a mask on," but he refuses. Instead, he starts rambling on about how he has breathing problems. 
Again, the nurse steps in and informs this guy that breathing trouble can be a sign of covid. He then proceeds to start violently screaming at the nurse and cussing her out. 
While all this is happening, I am sitting a short distance away. Being the only other in the room, the nurse gives me a desperate look, but socially awkward me just stayed where I was. 
About 5 minutes later, 2 other nurses get this guy kicked out and I never saw him come back in. So, was I the jerk for doing nothing for the poor nurse?
"""
    prompt = f'Summarize the case text using this template as accurately as possible while maintaining correct English grammar. Do not add extra information: ' \
             f'The active agent did action to passive agent which led to consequence. The active agent had good/bad/neutral moral intention, ' \
             f'however, the action violated ethical principle which caused ethical issue. Case text is as follows: "{enter_case}"'
    return prompt

def perform_inference(model, tokenizer, prompt):
    """
    Performs inference using the given model and prompt, setting the attention mask explicitly.
    
    Args:
        model: The pre-trained model to perform inference.
        tokenizer: The tokenizer to process the input.
        prompt (str): The formatted prompt string.
    
    Returns:
        str: The model's generated answer.
        float: Time taken for inference.
    """
    # Tokenize the input and set attention mask
    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)  # Max length increased
    attention_mask = inputs['attention_mask']  # Extract attention mask
    
    # Measure inference time
    start_time = time.time()
    output = model.generate(inputs['input_ids'], attention_mask=attention_mask, max_new_tokens=200)  # Generating up to 200 new tokens
    end_time = time.time()
    
    # Decode the generated output
    inference_time = end_time - start_time
    response = tokenizer.decode(output[0], skip_special_tokens=True)
    
    return response, inference_time

def run_tests(models):
    """
    Runs the inference tests on the given models and dataset with different prompt styles.
    
    Args:
        models (dict): Dictionary containing the models and tokenizers.
    """
    # Loop through each model
    for model_name, (model, tokenizer) in models.items():
        print(f"\nTesting with {model_name}...\n")
        
        # Create the prompt using the function
        prompt = create_prompt()
        response, time_taken = perform_inference(model, tokenizer, prompt)
        print(f"Zero-Shot Response: {response}")
        print(f"Zero-Shot Inference Time: {time_taken:.4f} seconds\n")

# Run tests
run_tests(models)