<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Adding_Evaluation_Metrics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import precision_score, recall_score, f1_score

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define a custom dataset
class TextDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=128, for_classification=False):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.for_classification = for_classification

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]["text"]
        encoding = self.tokenizer(text, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt')
        input_ids = encoding["input_ids"].squeeze()
        attention_mask = encoding["attention_mask"].squeeze()

        if self.for_classification:
            label = self.data[idx]["label"]
            return input_ids, attention_mask, label
        else:
            return input_ids, attention_mask

# Define the FoundationModel class
class FoundationModel(nn.Module):
    def __init__(self, model_name="bert-base-uncased"):
        super(FoundationModel, self).__init__()
        self.model = BertModel.from_pretrained(model_name)
        self.tokenizer = BertTokenizer.from_pretrained(model_name)

    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids, attention_mask=attention_mask)
        return outputs.last_hidden_state

    def encode_text(self, texts, max_length=128):
        encoding = self.tokenizer(texts, padding=True, truncation=True,
                                  max_length=max_length, return_tensors="pt")
        return encoding["input_ids"], encoding["attention_mask"]

# Define the MultiTaskFoundationModel class for multitask learning
class MultiTaskFoundationModel(FoundationModel):
    def __init__(self, model_name="bert-base-uncased", tasks=None):
        super().__init__(model_name)
        self.tasks = tasks or {}
        self.classifiers = nn.ModuleDict({
            task: nn.Linear(self.model.config.hidden_size, num_labels) for task, num_labels in self.tasks.items()
        })

    def forward(self, input_ids, attention_mask, task, labels=None):
        # Pass through the transformer
        outputs = self.model(input_ids, attention_mask=attention_mask)
        hidden_states = outputs.last_hidden_state
        logits = self.classifiers[task](hidden_states[:, 0, :])  # CLS token
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.classifiers[task].out_features), labels.view(-1))
        return loss, logits

    def add_task_tokens(self, texts, task):
        # Add task-specific tokens to text
        task_texts = [f"[TASK-{task}] {text}" for text in texts]
        return self.encode_text(task_texts)

# Train the multitask model
def train_multitask_model(model, train_data, epochs=3, batch_size=32, learning_rate=5e-5):
    optimizer = AdamW(model.parameters(), lr=learning_rate)

    model.train()
    for epoch in range(epochs):
        for task, task_data in train_data.items():
            total_loss = 0
            train_dataset = TextDataset(task_data, model.tokenizer, for_classification=True)
            train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

            for batch in train_dataloader:
                optimizer.zero_grad()
                input_ids, attention_mask, labels = batch
                input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
                loss, _ = model(input_ids, attention_mask, task, labels=labels)
                total_loss += loss.item()
                loss.backward()
                optimizer.step()

            print(f"Epoch [{epoch + 1}/{epochs}], Task: {task}, Loss: {total_loss / len(train_dataloader)}")

# Evaluation function with metrics
def evaluate_with_metrics(model, test_data, task, batch_size=32):
    test_dataloader = DataLoader(test_data, batch_size=batch_size)
    model.eval()
    all_labels, all_preds = [], []

    with torch.no_grad():
        for batch in test_dataloader:
            input_ids, attention_mask, labels = batch
            input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)

            _, logits = model(input_ids, attention_mask, task)
            predictions = torch.argmax(logits, dim=-1)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(predictions.cpu().numpy())

    precision = precision_score(all_labels, all_preds, average='weighted')
    recall = recall_score(all_labels, all_preds, average='weighted')
    f1 = f1_score(all_labels, all_preds, average='weighted')

    print(f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1 Score: {f1:.2f}")
    return precision, recall, f1

# Example usage
# Assuming train_data and test_data are available as lists of dictionaries with "text" and "label" fields

train_data = {
    "task1": [{"text": "example sentence for task 1", "label": 0}],  # Replace with actual data
    "task2": [{"text": "example sentence for task 2", "label": 1}]   # Replace with actual data
}

tasks = {"task1": 2, "task2": 2}  # Define tasks with number of labels for each

# Initialize the multitask model
multitask_model = MultiTaskFoundationModel(model_name="bert-base-uncased", tasks=tasks).to(device)

# Train the multitask model
train_multitask_model(multitask_model, train_data)

# Test data
test_data_task1 = [{"text": "example test sentence for task 1", "label": 0}]  # Replace with actual data
test_dataset_task1 = TextDataset(test_data_task1, multitask_model.tokenizer, for_classification=True)

# Evaluate the multitask model on a specific task
evaluate_with_metrics(multitask_model, test_dataset_task1, task="task1")