<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Meta_Learning_for_Fast_Adaptation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import LongformerModel, LongformerTokenizer
from transformers import AdamW as TransformersAdamW
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import precision_score, recall_score, f1_score
from sklearn.model_selection import train_test_split
from nltk.corpus import wordnet
import nltk

# Download NLTK wordnet data
nltk.download('wordnet')

# Device configuration
device = torch.device("cpu")  # Switch to CPU to reduce memory usage

# Define a custom dataset
class TextDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=64, for_classification=False):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.for_classification = for_classification

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        text = item["text"]
        encoding = self.tokenizer(text, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt')
        input_ids = encoding["input_ids"].squeeze()
        attention_mask = encoding["attention_mask"].squeeze()

        if self.for_classification:
            label = item["label"]
            return input_ids, attention_mask, label
        else:
            return input_ids, attention_mask

# Define the MAML model class
class MAMLModel(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids, attention_mask=attention_mask)
        return outputs.last_hidden_state[:, 0, :]  # Use CLS token

    def clone_parameters(self):
        return {name: param.clone() for name, param in self.named_parameters()}

    def fast_adapt(self, support_data, query_data, optimizer, n_steps=5, lr_inner=0.01):
        original_params = self.clone_parameters()
        for _ in range(n_steps):
            support_input, support_attention, support_target = support_data
            optimizer.zero_grad()
            logits = self(support_input, support_attention)
            loss = F.cross_entropy(logits, support_target)
            loss.backward()

            for name, param in self.named_parameters():
                if param.grad is not None:  # Check for None gradients
                    param.data -= lr_inner * param.grad
            optimizer.zero_grad()

        query_input, query_attention, query_target = query_data
        query_logits = self(query_input, query_attention)
        query_loss = F.cross_entropy(query_logits, query_target)

        for name, param in self.named_parameters():
            param.data = original_params[name]  # Restore original parameters

        return query_loss

# Synonym replacement for data augmentation
def synonym_replacement(text, n=2):
    words = text.split()
    new_words = words.copy()
    random.shuffle(words)

    num_replaced = 0
    for word in words:
        synonyms = wordnet.synsets(word)
        if synonyms:
            synonym = synonyms[0].lemmas()[0].name()
            new_words = [synonym if w == word and num_replaced < n else w for w in new_words]
            num_replaced += 1
        if num_replaced >= n:
            break

    return " ".join(new_words)

# Initialize tokenizer
tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")

# Augmenting the dataset with more examples and synonym replacement
texts = [
    {"text": "The quick brown fox jumps over the lazy dog.", "label": 0},
    {"text": "A journey of a thousand miles begins with a single step.", "label": 0},
    {"text": "To be or not to be, that is the question.", "label": 0},
    {"text": "All that glitters is not gold.", "label": 0},
    {"text": "The early bird catches the worm.", "label": 1},
    {"text": "A picture is worth a thousand words.", "label": 1},
    {"text": "Better late than never.", "label": 1},
    {"text": "Actions speak louder than words.", "label": 1}
]

# Augmenting data with synonyms
augmented_texts = []
for text in texts:
    for _ in range(3):  # Create 3 augmented versions of each sentence
        augmented_text = synonym_replacement(text["text"])
        augmented_texts.append({"text": augmented_text, "label": text["label"]})
texts.extend(augmented_texts)

# Shuffle the data to ensure randomness
random.shuffle(texts)

# Split data into training and validation sets
train_data, val_data = train_test_split(texts, test_size=0.2, random_state=42)

# Create datasets and dataloaders
train_dataset = TextDataset(train_data, tokenizer, for_classification=True)
val_dataset = TextDataset(val_data, tokenizer, for_classification=True)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False)

# Initialize Longformer model
base_model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
maml_model = MAMLModel(base_model).to(device)
optimizer = TransformersAdamW(maml_model.parameters(), lr=5e-5)

# Train the MAML model
for epoch in range(3):  # Adjust number of epochs as needed
    for support_batch, query_batch in zip(train_dataloader, val_dataloader):
        support_input, support_attention, support_target = support_batch
        query_input, query_attention, query_target = query_batch

        support_input, support_attention, support_target = support_input.to(device), support_attention.to(device), support_target.to(device)
        query_input, query_attention, query_target = query_input.to(device), query_attention.to(device), query_target.to(device)

        support_data = (support_input, support_attention, support_target)
        query_data = (query_input, query_attention, query_target)

        query_loss = maml_model.fast_adapt(support_data, query_data, optimizer)
        print(f"Epoch {epoch + 1}, Query Loss: {query_loss.item()}")

In [None]:
pip install optuna

In [None]:
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import LongformerModel, LongformerTokenizer, AdamW
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import precision_score, recall_score, f1_score
from sklearn.model_selection import train_test_split
from nltk.corpus import wordnet
from torch.optim.lr_scheduler import LambdaLR
import torch.nn.utils.prune as prune
import optuna

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define a custom dataset
class TextDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=2048, for_classification=False):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.for_classification = for_classification

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        text = item["text"]
        encoding = self.tokenizer(text, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt')
        input_ids = encoding["input_ids"].squeeze()
        attention_mask = encoding["attention_mask"].squeeze()

        if self.for_classification:
            label = item["label"]
            return input_ids, attention_mask, label
        else:
            return input_ids, attention_mask

# Define LongformerFoundationModel class
class LongformerFoundationModel(nn.Module):
    def __init__(self, model_name="allenai/longformer-base-4096"):
        super(LongformerFoundationModel, self).__init__()
        self.model = LongformerModel.from_pretrained(model_name)
        self.tokenizer = LongformerTokenizer.from_pretrained(model_name)

    def forward(self, input_ids, attention_mask, labels=None):
        # Use sliding window attention for long sequences
        outputs = self.model(input_ids, attention_mask=attention_mask, global_attention_mask=(attention_mask == 1))
        return outputs

# Define Adapter class
class Adapter(nn.Module):
    def __init__(self, input_dim, adapter_dim=64):
        super(Adapter, self).__init__()
        self.down_proj = nn.Linear(input_dim, adapter_dim)
        self.up_proj = nn.Linear(adapter_dim, input_dim)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        return x + self.dropout(self.up_proj(self.activation(self.down_proj(x))))

# Define MultiTaskAdapterFoundationModel class
class MultiTaskAdapterFoundationModel(LongformerFoundationModel):
    def __init__(self, model_name="allenai/longformer-base-4096", tasks=None, adapter_dim=64):
        super().__init__(model_name)
        self.tasks = tasks or {}
        self.classifiers = nn.ModuleDict({
            task: nn.Linear(self.model.config.hidden_size, num_labels) for task, num_labels in self.tasks.items()
        })

    def forward(self, input_ids, attention_mask, task, labels=None):
        outputs = self.model(input_ids, attention_mask=attention_mask, global_attention_mask=(attention_mask == 1))
        hidden_states = outputs.last_hidden_state
        logits = self.classifiers[task](hidden_states[:, 0, :])  # CLS token
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.classifiers[task].out_features), labels.view(-1))
        return loss, logits

# Scheduler for learning rate
def get_scheduler(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
    return LambdaLR(optimizer, lr_lambda)

# Train the multitask model with gradient accumulation
def train_with_scheduler(model, train_data, epochs=5, batch_size=4, accumulation_steps=8, learning_rate=5e-5, log_dir="./logs", num_warmup_steps=500, num_training_steps=10000):
    optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    scheduler = get_scheduler(optimizer, num_warmup_steps, num_training_steps)
    writer = SummaryWriter(log_dir=log_dir)
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        optimizer.zero_grad()
        for task, task_data in train_data.items():
            tokenizer = model.tokenizer
            train_dataset = task_data
            train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
            for batch_idx, batch in enumerate(train_dataloader):
                input_ids, attention_mask, labels = batch
                input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)

                loss, _ = model(input_ids, attention_mask, task, labels=labels)

                loss = loss / accumulation_steps
                loss.backward()

                if (batch_idx + 1) % accumulation_steps == 0:
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()

                total_loss += loss.item() * accumulation_steps

        writer.add_scalar("Loss/train", total_loss / len(train_dataloader), epoch)
        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {total_loss / len(train_dataloader)}")
    writer.close()

# Evaluation function
def evaluate_with_metrics(model, test_data, task, batch_size=32):
    test_dataloader = DataLoader(test_data, batch_size=batch_size)
    model.eval()
    all_labels, all_preds = [], []
    with torch.no_grad():
        for batch in test_dataloader:
            input_ids, attention_mask, labels = batch
            input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
            _, logits = model(input_ids, attention_mask, task)
            predictions = torch.argmax(logits, dim=-1)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(predictions.cpu().numpy())
    precision = precision_score(all_labels, all_preds, average='weighted')
    recall = recall_score(all_labels, all_preds, average='weighted')
    f1 = f1_score(all_preds, all_labels, average='weighted')
    print(f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1 Score: {f1:.2f}")
    return precision, recall, f1

# Synonym replacement for data augmentation
def synonym_replacement(text, n=2):
    words = text.split()
    new_words = words.copy()
    random.shuffle(words)

    num_replaced = 0
    for word in words:
        synonyms = wordnet.synsets(word)
        if synonyms:
            synonym = synonyms[0].lemmas()[0].name()
            new_words = [synonym if w == word and num_replaced < n else w for w in new_words]
            num_replaced += 1
        if num_replaced >= n:
            break

    return " ".join(new_words)

# Ensure wordnet is downloaded
import nltk
nltk.download('wordnet')

# Initialize tokenizer
tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")

# Augmenting the dataset with more examples and synonym replacement
texts = [
    {"text": "The quick brown fox jumps over the lazy dog.", "label": 0},
    {"text": "A journey of a thousand miles begins with a single step.", "label": 0},
    {"text": "To be or not to be, that is the question.", "label": 0},
    {"text": "All that glitters is not gold.", "label": 0},
    {"text": "The early bird catches the worm.", "label": 1},
    {"text": "A picture is worth a thousand words.", "label": 1},
    {"text": "Better late than never.", "label": 1},
    {"text": "Actions speak louder than words.", "label": 1}
]

# Augmenting data with synonyms
augmented_texts = []
for text in texts:
    for _ in range(3):  # Create 3 augmented versions of each sentence
        augmented_text = synonym_replacement(text["text"])
        augmented_texts.append({"text": augmented_text, "label": text["label"]})
texts.extend(augmented_texts)

# Shuffle the data to ensure randomness
random.shuffle(texts)

# Split data into training and validation sets again
train_data, val_data = train_test_split(texts, test_size=0.2, random_state=42)

# Create datasets and dataloaders
train_dataset = TextDataset(train_data, tokenizer, for_classification=True)
val_dataset = TextDataset(val_data, tokenizer, for_classification=True)
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=2, shuffle=False)

# Train model again with the augmented dataset
model = MultiTaskAdapterFoundationModel(model_name="allenai/longformer-base-4096", tasks={"classification": 3, "sentiment": 2}).to(device)
train_with_scheduler(model, {"classification": train_dataset, "sentiment": train_dataset}, epochs=5, batch_size=4, accumulation_steps=8, learning_rate=5e-5, num_warmup_steps=100, num_training_steps=1000)

# Evaluate model
for task, dataset in {"classification": val_dataset, "sentiment": val_dataset}.items():
    print(f"Evaluating task: {task}")
    evaluate_with_metrics(model, dataset, task)

# Define MAML Model
class MAMLModel(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, input_ids, attention_mask):
        return self.model(input_ids, attention_mask).logits

    def clone_parameters(self):
        return {name: param.clone() for name, param in self.model.named_parameters()}

    def fast_adapt(self, support_data, query_data, optimizer, n_steps=5, lr_inner=0.01):
        # Support step
        original_params = self.clone_parameters()
        for _ in range(n_steps):
            support_input, support_attention, support_target = support_data
            optimizer.zero_grad()
            logits = self(support_input, support_attention)
            loss = F.cross_entropy(logits, support_target)
            loss.backward()
            for name, param in self.model.named_parameters():
                param.data -= lr_inner * param.grad
            optimizer.zero_grad()  # Reset optimizer after inner-loop update

        # Query step
        query_input, query_attention, query_target = query_data
        query_logits = self(query_input, query_attention)
        query_loss = F.cross_entropy(query_logits, query_target)

        # Restore original parameters
        for name, param in self.model.named_parameters():
            param.data = original_params[name]

        return query_loss

# Example usage of MAML Model
if __name__ == "__main__":
    model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
    maml_model = MAMLModel(model).to(device)
    optimizer = AdamW(maml_model.parameters(), lr=5e-5)

    # Assuming `support_data` and `query_data` are DataLoader objects
    for epoch in range(3):
        for support_batch, query_batch in zip(train_dataloader, val_dataloader):
            support_input, support_attention, support_target = support_batch
            query_input, query_attention, query_target = query_batch

            support_input, support_attention, support_target = support_input.to(device), support_attention.to(device), support_target.to(device)
            query_input, query_attention, query_target = query_input.to(device), query_attention.to(device), query_target.to(device)

            support_data = (support_input, support_attention, support_target)
            query_data = (query_input, query_attention, query_target)

            query_loss = maml_model.fast_adapt(support_data, query_data, optimizer)
            print(f"Epoch {epoch + 1}, Query Loss: {query_loss.item()}")

In [None]:
pip install fastapi

In [None]:
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import LongformerModel, LongformerTokenizer, AdamW
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import precision_score, recall_score, f1_score
from sklearn.model_selection import train_test_split
from nltk.corpus import wordnet
from torch.optim.lr_scheduler import LambdaLR
import torch.nn.utils.prune as prune
import optuna
import pandas as pd
from fastapi import FastAPI, Request

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Ensure wordnet is downloaded
import nltk
nltk.download('wordnet')

# Define a custom dataset
class TextDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=2048, for_classification=False):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.for_classification = for_classification

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        text = item["text"]
        encoding = self.tokenizer(text, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt')
        input_ids = encoding["input_ids"].squeeze()
        attention_mask = encoding["attention_mask"].squeeze()

        if self.for_classification:
            label = item["label"]
            return input_ids, attention_mask, label
        else:
            return input_ids, attention_mask

# Define LongformerFoundationModel class
class LongformerFoundationModel(nn.Module):
    def __init__(self, model_name="allenai/longformer-base-4096"):
        super(LongformerFoundationModel, self).__init__()
        self.model = LongformerModel.from_pretrained(model_name)
        self.tokenizer = LongformerTokenizer.from_pretrained(model_name)

    def forward(self, input_ids, attention_mask, labels=None):
        # Use sliding window attention for long sequences
        outputs = self.model(input_ids, attention_mask=attention_mask, global_attention_mask=(attention_mask == 1))
        return outputs

# Define MultiTaskAdapterFoundationModel class
class MultiTaskAdapterFoundationModel(LongformerFoundationModel):
    def __init__(self, model_name="allenai/longformer-base-4096", tasks=None, adapter_dim=64):
        super().__init__(model_name)
        self.tasks = tasks or {}
        self.classifiers = nn.ModuleDict({
            task: nn.Linear(self.model.config.hidden_size, num_labels) for task, num_labels in self.tasks.items()
        })

    def forward(self, input_ids, attention_mask, task, labels=None):
        outputs = self.model(input_ids, attention_mask=attention_mask, global_attention_mask=(attention_mask == 1))
        hidden_states = outputs.last_hidden_state
        logits = self.classifiers[task](hidden_states[:, 0, :])  # CLS token
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.classifiers[task].out_features), labels.view(-1))
        return loss, logits

# Scheduler for learning rate
def get_scheduler(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
    return LambdaLR(optimizer, lr_lambda)

# Train the multitask model with gradient accumulation
def train_with_scheduler(model, train_data, epochs=5, batch_size=4, accumulation_steps=8, learning_rate=5e-5, log_dir="./logs", num_warmup_steps=500, num_training_steps=10000):
    optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    scheduler = get_scheduler(optimizer, num_warmup_steps, num_training_steps)
    writer = SummaryWriter(log_dir=log_dir)
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        optimizer.zero_grad()
        for task, task_data in train_data.items():
            tokenizer = model.tokenizer
            train_dataset = task_data
            train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
            for batch_idx, batch in enumerate(train_dataloader):
                input_ids, attention_mask, labels = batch
                input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)

                loss, _ = model(input_ids, attention_mask, task, labels=labels)

                loss = loss / accumulation_steps
                loss.backward()

                if (batch_idx + 1) % accumulation_steps == 0:
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()

                total_loss += loss.item() * accumulation_steps

        writer.add_scalar("Loss/train", total_loss / len(train_dataloader), epoch)
        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {total_loss / len(train_dataloader)}")
    writer.close()

# Evaluation function
def evaluate_with_metrics(model, test_data, task, batch_size=32):
    test_dataloader = DataLoader(test_data, batch_size=batch_size)
    model.eval()
    all_labels, all_preds = [], []
    with torch.no_grad():
        for batch in test_dataloader:
            input_ids, attention_mask, labels = batch
            input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
            _, logits = model(input_ids, attention_mask, task)
            predictions = torch.argmax(logits, dim=-1)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(predictions.cpu().numpy())
    precision = precision_score(all_labels, all_preds, average='weighted')
    recall = recall_score(all_labels, all_preds, average='weighted')
    f1 = f1_score(all_preds, all_labels, average='weighted')
    print(f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1 Score: {f1:.2f}")
    return precision, recall, f1

# Synonym replacement for data augmentation
def synonym_replacement(text, n=2):
    words = text.split()
    new_words = words.copy()
    random.shuffle(words)

    num_replaced = 0
    for word in words:
        synonyms = wordnet.synsets(word)
        if synonyms:
            synonym = synonyms[0].lemmas()[0].name()
            new_words = [synonym if w == word and num_replaced < n else w for w in new_words]
            num_replaced += 1
        if num_replaced >= n:
            break

    return " ".join(new_words)

# Load dataset dynamically from CSV
def load_dataset(file_path, tokenizer, max_length=2048, for_classification=False):
    df = pd.read_csv(file_path)
    dataset = []
    for _, row in df.iterrows():
        data = {
            "text": row["text"],
            "label": row["label"] if for_classification else None
        }
        dataset.append(data)
    return TextDataset(dataset, tokenizer, max_length=max_length, for_classification=for_classification)

# Optuna for hyperparameter tuning
def objective(trial):
    learning_rate = trial.suggest_loguniform("learning_rate", 1e-6, 1e-4)
    adapter_dim = trial.suggest_int("adapter_dim", 16, 128, step=16)
    batch_size = trial.suggest_categorical("batch_size", [4, 8, 16])

    # Create model and dataset
    model = MultiTaskAdapterFoundationModel(model_name="allenai/longformer-base-4096", tasks={"classification": 3}, adapter_dim=adapter_dim).to(device)
    train_data, val_data = load_dataset("train.csv", tokenizer, for_classification=True), load_dataset("val.csv", tokenizer, for_classification=True)

    # Train
    train_with_scheduler(model, {"classification": train_data}, epochs=3, batch_size=batch_size, learning_rate=learning_rate, accumulation_steps=2, num_training_steps=1000)

    # Evaluate
    precision, recall, f1 = evaluate_with_metrics(model, val_data, "classification")
    return f1  # Optimize F1 score

# Run Optuna study
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=10)
print("Best trial:", study.best_trial)

# FastAPI deployment
app = FastAPI()

# Load trained model and tokenizer
model.eval()

@app.post("/predict")
async def predict(request: Request):
    data = await request.json()
    text = data["text"]

    # Tokenize input
    inputs = tokenizer(text, padding="max_length", truncation=True, max_length=2048, return_tensors="pt").to(device)
    input_ids, attention_mask = inputs["input_ids"], inputs["attention_mask"]

    # Predict
    _, logits = model(input_ids, attention_mask, task="classification")
    predictions = torch.argmax(logits, dim=-1).cpu().numpy()

    return {"predictions": predictions.tolist()}

# To run the FastAPI app, use:
# uvicorn script_name:app

In [None]:
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import LongformerModel, LongformerTokenizer, AdamW
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import precision_score, recall_score, f1_score
from sklearn.model_selection import train_test_split
from nltk.corpus import wordnet
from torch.optim.lr_scheduler import LambdaLR
import torch.nn.utils.prune as prune
import optuna
import pandas as pd
from fastapi import FastAPI, Request

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Ensure wordnet is downloaded
import nltk
nltk.download('wordnet')

# Define a custom dataset
class TextDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=2048, for_classification=False):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.for_classification = for_classification

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        text = item["text"]
        encoding = self.tokenizer(text, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt')
        input_ids = encoding["input_ids"].squeeze()
        attention_mask = encoding["attention_mask"].squeeze()

        if self.for_classification:
            label = item["label"]
            return input_ids, attention_mask, label
        else:
            return input_ids, attention_mask

# Define LongformerFoundationModel class
class LongformerFoundationModel(nn.Module):
    def __init__(self, model_name="allenai/longformer-base-4096"):
        super(LongformerFoundationModel, self).__init__()
        self.model = LongformerModel.from_pretrained(model_name)
        self.tokenizer = LongformerTokenizer.from_pretrained(model_name)

    def forward(self, input_ids, attention_mask, labels=None):
        # Use sliding window attention for long sequences
        outputs = self.model(input_ids, attention_mask=attention_mask, global_attention_mask=(attention_mask == 1))
        return outputs

# Define MultiTaskAdapterFoundationModel class
class MultiTaskAdapterFoundationModel(LongformerFoundationModel):
    def __init__(self, model_name="allenai/longformer-base-4096", tasks=None, adapter_dim=64):
        super().__init__(model_name)
        self.tasks = tasks or {}
        self.classifiers = nn.ModuleDict({
            task: nn.Linear(self.model.config.hidden_size, num_labels) for task, num_labels in self.tasks.items()
        })

    def forward(self, input_ids, attention_mask, task, labels=None):
        outputs = self.model(input_ids, attention_mask=attention_mask, global_attention_mask=(attention_mask == 1))
        hidden_states = outputs.last_hidden_state
        logits = self.classifiers[task](hidden_states[:, 0, :])  # CLS token
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.classifiers[task].out_features), labels.view(-1))
        return loss, logits

# Scheduler for learning rate
def get_scheduler(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
    return LambdaLR(optimizer, lr_lambda)

# Train the multitask model with gradient accumulation
def train_with_scheduler(model, train_data, epochs=5, batch_size=4, accumulation_steps=8, learning_rate=5e-5, log_dir="./logs", num_warmup_steps=500, num_training_steps=10000):
    optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    scheduler = get_scheduler(optimizer, num_warmup_steps, num_training_steps)
    writer = SummaryWriter(log_dir=log_dir)
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        optimizer.zero_grad()
        for task, task_data in train_data.items():
            tokenizer = model.tokenizer
            train_dataset = task_data
            train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
            for batch_idx, batch in enumerate(train_dataloader):
                input_ids, attention_mask, labels = batch
                input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)

                loss, _ = model(input_ids, attention_mask, task, labels=labels)

                loss = loss / accumulation_steps
                loss.backward()

                if (batch_idx + 1) % accumulation_steps == 0:
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()

                total_loss += loss.item() * accumulation_steps

        writer.add_scalar("Loss/train", total_loss / len(train_dataloader), epoch)
        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {total_loss / len(train_dataloader)}")
    writer.close()

# Evaluation function
def evaluate_with_metrics(model, test_data, task, batch_size=32):
    test_dataloader = DataLoader(test_data, batch_size=batch_size)
    model.eval()
    all_labels, all_preds = [], []
    with torch.no_grad():
        for batch in test_dataloader:
            input_ids, attention_mask, labels = batch
            input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
            _, logits = model(input_ids, attention_mask, task)
            predictions = torch.argmax(logits, dim=-1)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(predictions.cpu().numpy())
    precision = precision_score(all_labels, all_preds, average='weighted')
    recall = recall_score(all_labels, all_preds, average='weighted')
    f1 = f1_score(all_preds, all_labels, average='weighted')
    print(f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1 Score: {f1:.2f}")
    return precision, recall, f1

# Synonym replacement for data augmentation
def synonym_replacement(text, n=2):
    words = text.split()
    new_words = words.copy()
    random.shuffle(words)

    num_replaced = 0
    for word in words:
        synonyms = wordnet.synsets(word)
        if synonyms:
            synonym = synonyms[0].lemmas()[0].name()
            new_words = [synonym if w == word and num_replaced < n else w for w in new_words]
            num_replaced += 1
        if num_replaced >= n:
            break

    return " ".join(new_words)

# Load dataset dynamically from CSV
def load_dataset(file_path, tokenizer, max_length=2048, for_classification=False):
    df = pd.read_csv(file_path)
    dataset = []
    for _, row in df.iterrows():
        data = {
            "text": row["text"],
            "label": row["label"] if for_classification else None
        }
        dataset.append(data)
    return TextDataset(dataset, tokenizer, max_length=max_length, for_classification=for_classification)

# Optuna for hyperparameter tuning
def objective(trial):
    learning_rate = trial.suggest_loguniform("learning_rate", 1e-6, 1e-4)
    adapter_dim = trial.suggest_int("adapter_dim", 16, 128, step=16)
    batch_size = trial.suggest_categorical("batch_size", [4, 8, 16])

    # Create model and dataset
    model = MultiTaskAdapterFoundationModel(model_name="allenai/longformer-base-4096", tasks={"classification": 3}, adapter_dim=adapter_dim).to(device)
    train_data, val_data = load_dataset("train.csv", tokenizer, for_classification=True), load_dataset("val.csv", tokenizer, for_classification=True)

    # Train
    train_with_scheduler(model, {"classification": train_data}, epochs=3, batch_size=batch_size, learning_rate=learning_rate, accumulation_steps=2, num_training_steps=1000)

    # Evaluate
    precision, recall, f1 = evaluate_with_metrics(model, val_data, "classification")
    return f1  # Optimize F1 score

# Run Optuna study
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=10)
print("Best trial:", study.best_trial)

# FastAPI deployment
app = FastAPI()

# Load trained model and tokenizer
model.eval()

@app.post("/predict")
async def predict(request: Request):
    data = await request.json()
    text = data["text"]

    # Tokenize input
    inputs = tokenizer(text, padding="max_length", truncation=True, max_length=2048, return_tensors="pt").to(device)
    input_ids, attention_mask = inputs["input_ids"], inputs["attention_mask"]

    # Predict
    _, logits = model(input_ids, attention_mask, task="classification")
    predictions = torch.argmax(logits, dim=-1).cpu().numpy()

    return {"predictions": predictions.tolist()}

# To run the FastAPI app, use:
# uvicorn script_name:app --reload