<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Mixed_Precision_Training_with_Automatic_Mixed_Precision_(AMP).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import random
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import precision_score, recall_score, f1_score
from nltk.corpus import wordnet
from torch.optim.lr_scheduler import LambdaLR
from torch.cuda.amp import autocast, GradScaler

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define a custom dataset
class TextDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=128, for_classification=False):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.for_classification = for_classification

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]["text"]
        encoding = self.tokenizer(text, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt')
        input_ids = encoding["input_ids"].squeeze()
        attention_mask = encoding["attention_mask"].squeeze()

        if self.for_classification:
            label = self.data[idx]["label"]
            return input_ids, attention_mask, label
        else:
            return input_ids, attention_mask

# Define FoundationModel class
class FoundationModel(nn.Module):
    def __init__(self, model_name="bert-base-uncased", dropout_rate=0.1):
        super(FoundationModel, self).__init__()
        self.model = BertModel.from_pretrained(model_name)
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids, attention_mask=attention_mask)
        return self.dropout(outputs.last_hidden_state)

    def encode_text(self, texts, max_length=128):
        encoding = self.tokenizer(texts, padding=True, truncation=True,
                                  max_length=max_length, return_tensors="pt")
        return encoding["input_ids"], encoding["attention_mask"]

# Define Adapter module
class Adapter(nn.Module):
    def __init__(self, input_dim, adapter_dim=64):
        super(Adapter, self).__init__()
        self.down_proj = nn.Linear(input_dim, adapter_dim)
        self.up_proj = nn.Linear(adapter_dim, input_dim)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(0.1)  # Add dropout to the adapter

    def forward(self, x):
        return x + self.dropout(self.up_proj(self.activation(self.down_proj(x))))

# Integrate Adapter into Foundation Model
class AdapterFoundationModel(FoundationModel):
    def __init__(self, model_name="bert-base-uncased", adapter_dim=64, dropout_rate=0.1):
        super().__init__(model_name, dropout_rate)
        for layer in self.model.encoder.layer:
            layer.adapter = Adapter(self.model.config.hidden_size, adapter_dim)

    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids, attention_mask=attention_mask)
        for layer in self.model.encoder.layer:
            layer.output = layer.adapter(layer.output)
        return outputs

# Define the MultiTaskAdapterFoundationModel class for multitask learning with adapters
class MultiTaskAdapterFoundationModel(AdapterFoundationModel):
    def __init__(self, model_name="bert-base-uncased", tasks=None, adapter_dim=64, dropout_rate=0.1):
        super().__init__(model_name, adapter_dim, dropout_rate)
        self.tasks = tasks or {}
        self.classifiers = nn.ModuleDict({
            task: nn.Linear(self.model.config.hidden_size, num_labels) for task, num_labels in self.tasks.items()
        })

    def forward(self, input_ids, attention_mask, task, labels=None):
        # Pass through the transformer with adapters
        outputs = self.model(input_ids, attention_mask=attention_mask)
        hidden_states = outputs.last_hidden_state
        logits = self.classifiers[task](hidden_states[:, 0, :])  # CLS token
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.classifiers[task].out_features), labels.view(-1))
        return loss, logits

    def add_task_tokens(self, texts, task):
        # Add task-specific tokens to text
        task_texts = [f"[TASK-{task}] {text}" for text in texts]
        return self.encode_text(task_texts)

# Train the multitask model with adapters, mixed precision, TensorBoard logging, and learning rate scheduler
def train_with_scheduler(model, train_data, epochs=5, batch_size=32, learning_rate=5e-5, log_dir="./logs", num_warmup_steps=500, num_training_steps=10000):
    optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)  # Add weight decay
    scheduler = get_scheduler(optimizer, num_warmup_steps, num_training_steps)
    writer = SummaryWriter(log_dir=log_dir)
    scaler = GradScaler()  # Initialize GradScaler

    # Create checkpoints directory if it doesn't exist
    checkpoint_dir = "./checkpoints"
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for task, task_data in train_data.items():
            train_dataset = TextDataset(task_data, model.tokenizer, for_classification=True)
            train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

            for batch_idx, batch in enumerate(train_dataloader):
                optimizer.zero_grad()
                input_ids, attention_mask, labels = batch
                input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)

                with autocast():  # Mixed precision context
                    loss, logits = model(input_ids, attention_mask, task, labels=labels)

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()

                total_loss += loss.item()

                # Log loss for each batch
                writer.add_scalar(f"Loss/train_{task}", loss.item(), epoch * len(train_dataloader) + batch_idx)

        # Print loss and some sample predictions
        print(f"Epoch [{epoch + 1}/{epochs}], Task: {task}, Loss: {total_loss / len(train_dataloader)}")
        print(f"Sample predictions: {logits[:5].cpu().detach().numpy()}")
        print(f"Actual labels: {labels[:5].cpu().numpy()}")

        # Save model checkpoint at each epoch
        torch.save(model.state_dict(), f"./checkpoints/model_epoch_{epoch+1}.pt")

    writer.close()

# Scheduler for learning rate
def get_scheduler(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))

    return LambdaLR(optimizer, lr_lambda)

# Evaluation function with metrics
def evaluate_with_metrics(model, test_data, task, batch_size=32):
    test_dataloader = DataLoader(test_data, batch_size=batch_size)
    model.eval()
    all_labels, all_preds = [], []

    with torch.no_grad():
        for batch in test_dataloader:
            input_ids, attention_mask, labels = batch
            input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)

            with autocast():  # Mixed precision context for evaluation
                _, logits = model(input_ids, attention_mask, task)

            predictions = torch.argmax(logits, dim=-1)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(predictions.cpu().numpy())

    precision = precision_score(all_labels, all_preds, average='weighted')
    recall = recall_score(all_labels, all_preds, average='weighted')
    f1 = f1_score(all_labels, all_preds, average='weighted')

    print(f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1 Score: {f1:.2f}")
    return precision, recall, f1

# Synonym replacement for data augmentation
def synonym_replacement(text, n=2):
    words = text.split()
    new_words = words.copy()
    random.shuffle(words)

    num_replaced = 0
    for word in words:
        synonyms = wordnet.synsets(word)
        if synonyms:
            synonym = synonyms[0].lemmas()[0].name()
            new_words = [synonym if w == word and num_replaced < n else w for w in new_words]
            num_replaced += 1
        if num_replaced >= n:
            break

    return " ".join(new_words)

# Ensure wordnet is downloaded
import nltk
nltk.download('wordnet')

# Example usage
# Assuming train_data and test_data are available as lists of dictionaries with "text" and "label" fields

train_data = {
    "task1": [{"text": "example sentence for task 1", "label": 0}],  # Replace with actual data
    "task2": [{"text": "example sentence for task 2", "label": 1}]   # Replace with actual data
}

tasks = {"task1": 2, "task2": 2}  # Define tasks with number of labels for each

# Initialize the multitask model with adapters
multitask_model = MultiTaskAdapterFoundationModel(model_name="bert-base-uncased", tasks=tasks).to(device)

# Train the multitask model with scheduler and logging
train_with_scheduler(multitask_model, train_data, epochs=5, batch_size=32, learning_rate=5e-5, num_warmup_steps=500, num_training_steps=10000)

# Test data
test_data_task1 = [{"text": "example test sentence for task 1", "label": 0}]  # Replace with actual data
test_dataset_task1 = TextDataset(test_data_task1, multitask_model.tokenizer, for_classification=True)

# Evaluate the multitask model on a specific task
evaluate_with_metrics(multitask_model, test_dataset_task1, task="task1")

# Example of synonym replacement
text = "The quick brown fox jumps over the lazy dog."
augmented_text = synonym_replacement(text)
print(f"Original: {text}")
print(f"Augmented: {augmented_text}")