<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Few_Shot_Learning_with_Prompt_Engineering.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import LongformerModel, LongformerTokenizer
from transformers import AdamW as TransformersAdamW
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import precision_score, recall_score, f1_score
from sklearn.model_selection import train_test_split
from nltk.corpus import wordnet
import nltk

# Download NLTK wordnet data
nltk.download('wordnet')

# Device configuration
device = torch.device("cpu")  # Switch to CPU to reduce memory usage

# Define a custom dataset
class TextDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=64, for_classification=False):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.for_classification = for_classification

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        text = item["text"]
        encoding = self.tokenizer(text, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt')
        input_ids = encoding["input_ids"].squeeze()
        attention_mask = encoding["attention_mask"].squeeze()

        if self.for_classification:
            label = item["label"]
            return input_ids, attention_mask, label
        else:
            return input_ids, attention_mask

# Define the MAML model class
class MAMLModel(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids, attention_mask=attention_mask)
        return outputs.last_hidden_state[:, 0, :]  # Use CLS token

    def clone_parameters(self):
        return {name: param.clone() for name, param in self.named_parameters()}

    def fast_adapt(self, support_data, query_data, optimizer, n_steps=5, lr_inner=0.01):
        original_params = self.clone_parameters()
        for _ in range(n_steps):
            support_input, support_attention, support_target = support_data
            optimizer.zero_grad()
            logits = self(support_input, support_attention)
            loss = F.cross_entropy(logits, support_target)
            loss.backward()

            for name, param in self.named_parameters():
                if param.grad is not None:  # Check for None gradients
                    param.data -= lr_inner * param.grad
            optimizer.zero_grad()

        query_input, query_attention, query_target = query_data
        query_logits = self(query_input, query_attention)
        query_loss = F.cross_entropy(query_logits, query_target)

        for name, param in self.named_parameters():
            param.data = original_params[name]  # Restore original parameters

        return query_loss

# Synonym replacement for data augmentation
def synonym_replacement(text, n=2):
    words = text.split()
    new_words = words.copy()
    random.shuffle(words)

    num_replaced = 0
    for word in words:
        synonyms = wordnet.synsets(word)
        if synonyms:
            synonym = synonyms[0].lemmas()[0].name()
            new_words = [synonym if w == word and num_replaced < n else w for w in new_words]
            num_replaced += 1
        if num_replaced >= n:
            break

    return " ".join(new_words)

# Initialize tokenizer
tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")

# Augmenting the dataset with more examples and synonym replacement
texts = [
    {"text": "The quick brown fox jumps over the lazy dog.", "label": 0},
    {"text": "A journey of a thousand miles begins with a single step.", "label": 0},
    {"text": "To be or not to be, that is the question.", "label": 0},
    {"text": "All that glitters is not gold.", "label": 0},
    {"text": "The early bird catches the worm.", "label": 1},
    {"text": "A picture is worth a thousand words.", "label": 1},
    {"text": "Better late than never.", "label": 1},
    {"text": "Actions speak louder than words.", "label": 1}
]

# Augmenting data with synonyms
augmented_texts = []
for text in texts:
    for _ in range(3):  # Create 3 augmented versions of each sentence
        augmented_text = synonym_replacement(text["text"])
        augmented_texts.append({"text": augmented_text, "label": text["label"]})
texts.extend(augmented_texts)

# Shuffle the data to ensure randomness
random.shuffle(texts)

# Split data into training and validation sets
train_data, val_data = train_test_split(texts, test_size=0.2, random_state=42)

# Create datasets and dataloaders
train_dataset = TextDataset(train_data, tokenizer, for_classification=True)
val_dataset = TextDataset(val_data, tokenizer, for_classification=True)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False)

# Initialize Longformer model
base_model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
maml_model = MAMLModel(base_model).to(device)
optimizer = TransformersAdamW(maml_model.parameters(), lr=5e-5)

# Define the prompt generation function
def generate_prompt(text, task_description):
    prompt = f"{task_description}: {text}"
    return tokenizer(prompt, return_tensors="pt")

# Example usage
prompted_inputs = generate_prompt("This is an example text.", "Classify sentiment")

# Make sure to pass the input tensors to the model
input_ids = prompted_inputs['input_ids'].to(device)
attention_mask = prompted_inputs['attention_mask'].to(device)

# Get the logits from the model
logits = maml_model(input_ids, attention_mask)

# Train the MAML model
for epoch in range(3):  # Adjust number of epochs as needed
    for support_batch, query_batch in zip(train_dataloader, val_dataloader):
        support_input, support_attention, support_target = support_batch
        query_input, query_attention, query_target = query_batch

        support_input, support_attention, support_target = support_input.to(device), support_attention.to(device), support_target.to(device)
        query_input, query_attention, query_target = query_input.to(device), query_attention.to(device), query_target.to(device)

        support_data = (support_input, support_attention, support_target)
        query_data = (query_input, query_attention, query_target)

        query_loss = maml_model.fast_adapt(support_data, query_data, optimizer)
        print(f"Epoch {epoch + 1}, Query Loss: {query_loss.item()}")