<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Modular_Head_Architectures_for_Task_Specific_Outputs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import random  # Add this import statement
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel, BertTokenizer, AdamW
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from nltk.corpus import wordnet
import nltk

# Download NLTK wordnet data
nltk.download('wordnet')

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the ModularFoundationModel class
class ModularFoundationModel(nn.Module):
    def __init__(self, model_name="bert-base-uncased", num_classes=2):
        super(ModularFoundationModel, self).__init__()
        self.core_model = BertModel.from_pretrained(model_name)
        self.classification_head = nn.Linear(self.core_model.config.hidden_size, num_classes)
        self.qa_head = nn.Linear(self.core_model.config.hidden_size, 2)  # Start and end logits
        self.summarization_head = nn.Linear(self.core_model.config.hidden_size, self.core_model.config.vocab_size)  # Summarization

    def forward(self, input_ids, attention_mask, task_type):
        core_outputs = self.core_model(input_ids, attention_mask=attention_mask)
        if task_type == "classification":
            return self.classification_head(core_outputs.last_hidden_state[:, 0, :])  # [CLS] token
        elif task_type == "question_answering":
            return self.qa_head(core_outputs.last_hidden_state)
        elif task_type == "summarization":
            return self.summarization_head(core_outputs.last_hidden_state)
        else:
            raise ValueError("Unsupported task type: {}".format(task_type))

# Define a custom dataset
class TextDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=128, for_classification=False):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.for_classification = for_classification

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        text = item["text"]
        encoding = self.tokenizer(text, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt')
        input_ids = encoding["input_ids"].squeeze()
        attention_mask = encoding["attention_mask"].squeeze()

        if self.for_classification:
            label = item["label"]
            return input_ids, attention_mask, label
        else:
            return input_ids, attention_mask

# Synonym replacement for data augmentation
def synonym_replacement(text, n=2):
    words = text.split()
    new_words = words.copy()
    random.shuffle(words)  # Ensure random module is used here

    num_replaced = 0
    for word in words:
        synonyms = wordnet.synsets(word)
        if synonyms:
            synonym = synonyms[0].lemmas()[0].name()
            new_words = [synonym if w == word and num_replaced < n else w for w in new_words]
            num_replaced += 1
        if num_replaced >= n:
            break

    return " ".join(new_words)

# Initialize tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# Augmenting the dataset with more examples and synonym replacement
texts = [
    {"text": "The quick brown fox jumps over the lazy dog.", "label": 0},
    {"text": "A journey of a thousand miles begins with a single step.", "label": 0},
    {"text": "To be or not to be, that is the question.", "label": 0},
    {"text": "All that glitters is not gold.", "label": 0},
    {"text": "The early bird catches the worm.", "label": 1},
    {"text": "A picture is worth a thousand words.", "label": 1},
    {"text": "Better late than never.", "label": 1},
    {"text": "Actions speak louder than words.", "label": 1}
]

# Augmenting data with synonyms
augmented_texts = []
for text in texts:
    for _ in range(3):  # Create 3 augmented versions of each sentence
        augmented_text = synonym_replacement(text["text"])
        augmented_texts.append({"text": augmented_text, "label": text["label"]})
texts.extend(augmented_texts)

# Shuffle the data to ensure randomness
random.shuffle(texts)

# Split data into training and validation sets
train_data, val_data = train_test_split(texts, test_size=0.2, random_state=42)

# Create datasets and dataloaders
train_dataset = TextDataset(train_data, tokenizer, for_classification=True)
val_dataset = TextDataset(val_data, tokenizer, for_classification=True)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False)

# Initialize model
model = ModularFoundationModel(model_name="bert-base-uncased", num_classes=2).to(device)
optimizer = AdamW(model.parameters(), lr=5e-5)

# Define the prompt generation function
def generate_prompt(text, task_description):
    prompt = f"{task_description}: {text}"
    return tokenizer(prompt, return_tensors="pt")

# Define your datasets for each task
summarization_data = [
    {"text": "Example summarization text.", "label": 0},
    # Add your summarization data here
]

qa_data = [
    {"text": "Example QA text.", "label": [0, 1]},  # Labels for start and end positions
    # Add your QA data here
]

classification_data = [
    {"text": "Example classification text.", "label": 0},
    # Add your classification data here
]

# Create datasets for each task
summarization_dataset = TextDataset(summarization_data, tokenizer, for_classification=True)
qa_dataset = TextDataset(qa_data, tokenizer, for_classification=True)
classification_dataset = TextDataset(classification_data, tokenizer, for_classification=True)

# Create data loaders for each task
summarization_loader = DataLoader(summarization_dataset, batch_size=1, shuffle=True)
qa_loader = DataLoader(qa_dataset, batch_size=1, shuffle=True)
classification_loader = DataLoader(classification_dataset, batch_size=1, shuffle=True)

# Training loop for multi-task learning
task_dataloaders = {
    "summarization": summarization_loader,
    "question_answering": qa_loader,
    "classification": classification_loader
}

for epoch in range(3):  # Adjust number of epochs as needed
    for task, dataloader in task_dataloaders.items():
        for batch in dataloader:
            optimizer.zero_grad()
            input_ids, attention_mask, labels = batch

            for i in range(input_ids.size(0)):
                input_text = tokenizer.decode(input_ids[i], skip_special_tokens=True)
                prompted_inputs = generate_prompt(input_text, task)
                input_ids = prompted_inputs['input_ids'].to(device)
                attention_mask = prompted_inputs['attention_mask'].to(device)

                logits = model(input_ids, attention_mask, task_type=task)
                if task == "classification":
                    loss = F.cross_entropy(logits, labels[i].unsqueeze(0))
                elif task == "question_answering":
                    start_logits, end_logits = logits.split(1, dim=-1)
                    label_tensor = torch.tensor(labels[i]).to(device)
                    if len(label_tensor) != 2:
                        print(f"Invalid label tensor length for QA task: {len(label_tensor)}")
                        continue
                    start_loss = F.cross_entropy(start_logits.squeeze(-1), label_tensor[0].unsqueeze(0))
                    end_loss = F.cross_entropy(end_logits.squeeze(-1), label_tensor[1].unsqueeze(0))
                    loss = (start_loss + end_loss) / 2
                elif task == "summarization":
                    # For summarization, we need to ensure the batch size matches.
                    labels_summarization = labels[i].view(-1)
                    logits_summarization = logits.view(-1, logits.size(-1))
                    if logits_summarization.size(0) != labels_summarization.size(0):
                        labels_summarization = labels_summarization.expand(logits_summarization.size(0))
                    loss = F.cross_entropy(logits_summarization, labels_summarization)

                loss.backward()
                optimizer.step()

            print(f"Task: {task}, Epoch: {epoch + 1}, Loss: {loss.item()}")