<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Prefix_Tuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer

class PrefixTuning(nn.Module):
    def __init__(self, prefix_length=10, embedding_dim=768):
        super(PrefixTuning, self).__init__()
        self.prefix_tokens = nn.Parameter(torch.randn(1, prefix_length, embedding_dim))
        self.embedding_layer = nn.Embedding(30522, embedding_dim)  # Assuming BERT vocabulary size

    def forward(self, input_ids, attention_mask):
        batch_size = input_ids.size(0)
        prefix = self.prefix_tokens.expand(batch_size, -1, -1)

        # Get embeddings for input IDs
        input_embeddings = self.embedding_layer(input_ids)

        input_with_prefix = torch.cat((prefix, input_embeddings), dim=1)
        prefix_attention_mask = torch.ones(batch_size, prefix.size(1), device=input_ids.device)
        attention_with_prefix = torch.cat((prefix_attention_mask, attention_mask), dim=1)
        return input_with_prefix, attention_with_prefix

# Example usage
class PrefixTuningModel(nn.Module):
    def __init__(self, model, prefix_tuning):
        super(PrefixTuningModel, self).__init__()
        self.model = model
        self.prefix_tuning = prefix_tuning

    def forward(self, input_ids, attention_mask=None):
        input_with_prefix, attention_with_prefix = self.prefix_tuning(input_ids, attention_mask)
        outputs = self.model(inputs_embeds=input_with_prefix, attention_mask=attention_with_prefix)
        return outputs

# Load a pre-trained transformer model from Hugging Face
base_model = AutoModel.from_pretrained("bert-base-uncased")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# Initialize the PrefixTuning and PrefixTuningModel
prefix_tuning = PrefixTuning(prefix_length=10, embedding_dim=768)
prefix_model = PrefixTuningModel(base_model, prefix_tuning)

# Create a dummy input
input_text = "This is an example sentence."
inputs = tokenizer(input_text, return_tensors="pt")

# Forward pass through the PrefixTuningModel
outputs = prefix_model(inputs["input_ids"], inputs["attention_mask"])
print(outputs)