In [None]:
# Importing necessary libraries
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, Trainer, TrainingArguments
from datasets import load_dataset

def load_model_and_tokenizer(model_name="t5-large"):
    """Load the model and tokenizer for the specified model name."""
    try:
        model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        return model, tokenizer
    except Exception as e:
        print(f"Error loading model and tokenizer: {e}")
        raise

def freeze_model_parameters(model):
    """Freeze all parameters in the model except for prompt embeddings."""
    for param in model.parameters():
        param.requires_grad = False

def initialize_prompt_embedding(prompt_length, d_model):
    """Initialize learnable prompt embeddings."""
    return torch.nn.Parameter(torch.randn(prompt_length, d_model))

def set_input_embeddings(model, prompt_embedding):
    """Set the model's input embeddings by concatenating prompt and original embeddings."""
    original_input_embeddings_weight = model.get_input_embeddings().weight
    model.set_input_embeddings(torch.nn.Embedding.from_pretrained(
        torch.cat([prompt_embedding, original_input_embeddings_weight], dim=0),
        freeze=False
    ))

def preprocess_data(dataset, tokenizer, task):
    """Preprocess the dataset for a specific NLP task and create in-context prompts."""
    def create_in_context_prompt(example):
        try:
            if task == "QA":
                in_context_examples = f"Question: {example['question']}\nAnswer: {example['answers']['text'][0]}"
                inputs = f"{in_context_examples}\nQuestion: {example['question']}\nAnswer:"
                label = example['answers']['text'][0]
            elif task == "NLI":
                in_context_examples = f"Premise: {example['premise']}\nHypothesis: {example['hypothesis']}\nLabel: {example['label']}"
                inputs = f"{in_context_examples}\nPremise: {example['premise']}\nHypothesis: {example['hypothesis']}\nLabel:"
                label = example['label']
            elif task == "Classification":
                in_context_examples = f"Question: {example['text']}\nCategory: {example['label-coarse']}"
                inputs = f"{in_context_examples}\nQuestion: {example['text']}\nCategory:"
                label = example['label-coarse']
            else:
                raise ValueError("Unsupported task type specified.")
            
            # Tokenize inputs and labels
            tokenized_inputs = tokenizer(inputs, return_tensors="pt", padding=True, truncation=True, max_length=512)
            with tokenizer.as_target_tokenizer():
                labels = tokenizer(str(label), return_tensors="pt", padding=True, truncation=True, max_length=512)
            tokenized_inputs['labels'] = labels['input_ids']
            return tokenized_inputs
        except Exception as e:
            print(f"Error creating in-context prompt: {e}")
            raise

    return dataset.map(lambda x: create_in_context_prompt(x), batched=False)

def train_model(model, tokenized_datasets):
    """Train the model using the provided tokenized datasets."""
    training_args = TrainingArguments(
        output_dir='./results',
        evaluation_strategy="epoch",
        learning_rate=5e-5,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        num_train_epochs=3,
        weight_decay=0.01,
    )
    
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_datasets["SQuAD"],
        eval_dataset=tokenized_datasets["MultiNLI"],
    )

    # Train the model
    trainer.train()
    # Evaluate the model
    eval_results = trainer.evaluate()
    print(f"Evaluation results: {eval_results}")

def main():
    model_name = "t5-large"
    prompt_length = 20

    # Load model and tokenizer
    model, tokenizer = load_model_and_tokenizer(model_name)

    # Freeze model parameters
    freeze_model_parameters(model)

    # Initialize and set prompt embeddings
    prompt_embedding = initialize_prompt_embedding(prompt_length, model.config.d_model)
    set_input_embeddings(model, prompt_embedding)

    # Load datasets
    datasets = {
        "SQuAD": load_dataset("squad"),
        "MultiNLI": load_dataset("multi_nli"),
        "TREC": load_dataset("trec")
    }

    # Preprocess datasets
    tokenized_datasets = {
        "SQuAD": preprocess_data(datasets["SQuAD"]['train'], tokenizer, task="QA"),
        "MultiNLI": preprocess_data(datasets["MultiNLI"]['train'], tokenizer, task="NLI"),
        "TREC": preprocess_data(datasets["TREC"]['train'], tokenizer, task="Classification")
    }

    # Train and evaluate the model
    train_model(model, tokenized_datasets)

if __name__ == "__main__":
    main()