# AAV Dataset Tutorial: Embedding and Feature Extraction with PLMFit

## Introduction

This notebook demonstrates how to use `plmfit` as a library to process protein sequences from the AAV (Adeno-Associated Virus) dataset. We'll cover:
1. Loading and preparing the AAV dataset
2. Extracting protein sequence embeddings using pre-trained language models
3. Training a downstream model on these embeddings for property prediction

The AAV dataset contains engineered variants of the adeno-associated virus (AAV) capsid protein, which is widely used in gene therapy. By analyzing these variants, we can predict their properties and understand sequence-function relationships.

### 1. Setup

### Importing Required Libraries

We'll start by importing the necessary modules and setting up our environment. The key components we'll use are:

- `extract_embeddings`: For generating protein sequence embeddings
- `feature_extraction`: For training downstream prediction models
- `Logger`: For tracking experiments and results
- `utils`: For data loading and helper functions
- `DefaultArgs`: For managing configuration parameters

In [None]:
import sys
import os
# Add the project root to the Python path
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '..')))

from plmfit.functions.extract_embeddings import extract_embeddings
from plmfit.functions.feature_extraction import feature_extraction
from plmfit.logger import Logger
from plmfit.shared_utils import utils
from plmfit.args_parser import DefaultArgs

### 2. Data Loading and Preparation

### Understanding the AAV Dataset
The AAV dataset contains engineered variants of the AAV capsid protein, each with:

A unique sequence identifier
The protein sequence
Associated functional measurements
Let's load and examine the dataset:

In [None]:
data_path = '../tutorial/aav_data_sample.csv'
aav_sample_df = utils.load_dataset(data_path)
print(f'Dataset shape: {aav_sample_df.shape}')
aav_sample_df.head()

### 3. Embedding Extraction

### Why Extract Embeddings?
Protein language models (like ESM) learn rich, contextual representations of protein sequences during pre-training. These embeddings capture:
- Evolutionary information
- Structural properties
- Functional characteristics

We'll use the ESM-2 model with 8M parameters, which provides is the smallest model available.

In [None]:
# Initialize default arguments
args = DefaultArgs()

# Configure embedding extraction parameters
args.plm = "esm2_t6_8M_UR50D"  # Small but effective model
args.layer = "quarter1"  # Early layer for general features
args.reduction = "mean"  # Average over sequence positions
args.batch_size = 32  # Smaller batch size for memory efficiency
args.output_dir = "../tutorial"
args.experiment_dir = f"{args.output_dir}/aav_tutorial"
args.experiment_name = f"aav_sample_{args.plm}_{args.layer}_{args.reduction}"

# Initialize logger for tracking
logger = Logger(experiment_name=args.experiment_name, base_dir=args.experiment_dir)

embeddings_path = f"{args.experiment_dir}/{args.experiment_name}.pt"

print("Starting embedding extraction...")
extract_embeddings(args, logger, data=aav_sample_df)
print("Embedding extraction complete.")

### What's happening:

The model processes each protein sequence through the ESM-2 network
For each sequence, it extracts the hidden representations from the specified layer
These representations are averaged across the sequence length
The resulting embeddings are saved to disk

### 4. Feature Extraction

Now that we have rich protein representations, we'll train a simple neural network to predict protein properties from these embeddings.

In [None]:
# Configure model architecture and training parameters
args = DefaultArgs()
args.head_config = {
    "architecture_parameters": {
        "network_type": "mlp",  # Multi-layer perceptron
        "output_dim": 1,  # Single output for regression
        "hidden_dim": 128,  # Hidden layer size
        "task": "regression",  # Predicting continuous values
        "hidden_activation": "relu",  # Non-linearity
        "hidden_dropout": 0.25,  # Avoid overfitting
    },
    "training_parameters": {
        "learning_rate": 0.00005,  # Small learning rate
        "epochs": 200,  # Training iterations
        "batch_size": 64,  # Number of samples per batch
        "loss_f": "mse",  # Mean squared error loss
        "optimizer": "adam",  # Adaptive learning rate optimizer
        "val_split": 0.2,  # 20% of data for validation
        "weight_decay": 0.01,  # L2 regularization
        "early_stopping": 30,  # Stop if no improvement for 30 epochs
    },
}

# Set up paths and run feature extraction
args.split = "sampled"  # Use the sampled split (PLMFit will generate it in the background)
args.ray_tuning = "False"  # Disable hyperparameter tuning for this example
args.embeddings_path = embeddings_path  # Path to our saved embeddings

print("Starting feature extraction and model training...")
feature_extraction(args, logger, data=aav_sample_df, head_config=args.head_config)
print("Feature extraction and training complete!")

### Understanding the Model Architecture
The downstream model consists of:
- Input Layer: Takes the 320-dimensional protein embeddings
- Hidden Layer: 128 neurons with ReLU activation
- Dropout: 25% dropout for regularization
- Output Layer: Single neuron for regression prediction

### Training Process
The model is trained using the Adam optimizer with a small learning rate
Early stopping prevents overfitting by monitoring the validation loss
The best model weights are saved based on validation performance

### Next Steps
After running this pipeline, you might want to:

Visualize the learned embeddings using t-SNE or UMAP
Interpret which sequence features contribute most to the predictions
Try different model architectures or hyperparameters
Fine-tune the model using PLMFit and parameter-efficient fine-tuning methods like LoRA

### Troubleshooting
If you encounter any issues:

Ensure all file paths are correct
Check that you have sufficient GPU memory
Verify that all required packages are installed
Reduce the batch size if you run into memory errors