# ProteinMPNN Finetuning with Acidophilic Dataset - Simplified

This notebook demonstrates how to finetune ProteinMPNN using the acidophilic dataset using the modular functions from finetuning.py.

## 1. Setup Environment

First, let's make sure we have all the required packages installed and set up our working directory.

In [None]:
# Check if running in Colab and install dependencies if needed
import os
import sys

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    # Clone the repo if in Colab
    !git clone https://github.com/aaronmaiww/ProteinMPNN.git
    %cd ProteinMPNN
    !git checkout finetuning-modification
    
    # Install dependencies
    !pip install torch==2.0.1 numpy dateutil
else:
    # Add the current directory to the path if not in Colab
    sys.path.append(os.getcwd())

## 2. Download and Extract Acidophilic Dataset

Let's download and extract the acidophilic dataset.

In [None]:
# Create a temporary directory for the dataset
!mkdir -p /tmp/acidophilic_data

if IN_COLAB:
    # In Colab, you might need to download the dataset from a URL
    dataset_url = "YOUR_DATASET_URL_HERE"
    !wget -O /tmp/acidophilic_dataset_converted.zip {dataset_url}
    !unzip -q /tmp/acidophilic_dataset_converted.zip -d /tmp/acidophilic_data
else:
    # Locally, use the existing dataset
    if os.path.exists("acidophilic_dataset_converted.zip"):
        !unzip -q acidophilic_dataset_converted.zip -d /tmp/acidophilic_data
    else:
        print("Dataset not found locally. Please download it or provide a URL.")

In [None]:
# Check the dataset structure
!ls -la /tmp/acidophilic_data/acidophilic_dataset_for_aaron_converted

## 3. Import Required Functions

Now, let's import the functions we need from the training module.

In [None]:
# Import functions from finetuning.py
from training.finetuning import setup_acidophilic_finetuning, train_and_validate, run_acidophilic_finetuning
from training.utils import StructureDataset, StructureLoader
from training.model_utils import ProteinMPNN

import torch
import time
import numpy as np
import matplotlib.pyplot as plt

## 4. Define Training Parameters

Let's set up the training parameters for our finetuning process.

In [None]:
# Define training parameters as a class (for compatibility with the functions)
class Args:
    def __init__(self):
        self.path_for_training_data = "/tmp/acidophilic_data/acidophilic_dataset_for_aaron_converted"
        self.path_for_outputs = "./notebook_finetuning_%Y%m%d_%H%M%S"
        self.previous_checkpoint = ""  # Optional: path to previous checkpoint
        self.num_epochs = 5  # Set to a small number for demonstration
        self.save_model_every_n_epochs = 1
        self.reload_data_every_n_epochs = 2
        self.num_examples_per_epoch = 50  # Small number for demonstration
        self.batch_size = 1000
        self.max_protein_length = 1000
        self.hidden_dim = 128
        self.num_encoder_layers = 3
        self.num_decoder_layers = 3
        self.num_neighbors = 48
        self.dropout = 0.1
        self.backbone_noise = 0.2
        self.rescut = 3.5
        self.debug = True  # Set to True for faster debugging
        self.gradient_norm = -1.0
        self.mixed_precision = True
        self.verbose = True
        self.num_workers = 2  # Reduced for Jupyter

args = Args()

## 5. Training - Option 1: Using run_acidophilic_finetuning

The simplest way to train is to use the `run_acidophilic_finetuning` function which handles the entire process.

In [None]:
# Define a callback for logging and visualization (optional)
training_history = {
    'epochs': [],
    'train_loss': [],
    'train_accuracy': [],
    'train_perplexity': [],
    'valid_loss': [],
    'valid_accuracy': [],
    'valid_perplexity': []
}

def history_callback(epoch, train_metrics, valid_metrics, model, optimizer):
    training_history['epochs'].append(epoch + 1)
    training_history['train_loss'].append(train_metrics['loss'])
    training_history['train_accuracy'].append(train_metrics['accuracy'])
    training_history['train_perplexity'].append(train_metrics['perplexity'])
    training_history['valid_loss'].append(valid_metrics['loss'])
    training_history['valid_accuracy'].append(valid_metrics['accuracy'])
    training_history['valid_perplexity'].append(valid_metrics['perplexity'])
    
    # You could add more functionality here, like early stopping

# Run the training
results = run_acidophilic_finetuning(args, callbacks=[history_callback])

## 6. Training - Option 2: Step-by-Step Approach

Alternatively, you can use a more granular approach for more control over the process.

In [None]:
# Setup
setup = setup_acidophilic_finetuning(args)

model = setup['model']
optimizer = setup['optimizer']
scaler = setup['scaler']
train_loader = setup['train_loader']
valid_loader = setup['valid_loader']
loader_train = setup['loader_train']
loader_valid = setup['loader_valid']
device = setup['device']
base_folder = setup['base_folder']
logfile = setup['logfile']
total_step = setup['total_step']
epoch = setup['epoch']
params = setup['params']

# Create storage for metrics
step_by_step_history = {
    'epochs': [],
    'train_loss': [],
    'train_accuracy': [],
    'train_perplexity': [],
    'valid_loss': [],
    'valid_accuracy': [],
    'valid_perplexity': []
}

# Training loop
for e in range(args.num_epochs):
    total_step, train_metrics, valid_metrics = train_and_validate(
        model=model,
        optimizer=optimizer,
        scaler=scaler,
        loader_train=loader_train,
        loader_valid=loader_valid,
        device=device,
        epoch=epoch + e,
        total_step=total_step,
        base_folder=base_folder,
        logfile=logfile,
        args=args,
        mixed_precision=args.mixed_precision,
        gradient_norm=args.gradient_norm,
        save_checkpoints=True,
        callbacks=None,
        verbose=args.verbose
    )
    
    # Store metrics
    step_by_step_history['epochs'].append(epoch + e + 1)
    step_by_step_history['train_loss'].append(train_metrics['loss'])
    step_by_step_history['train_accuracy'].append(train_metrics['accuracy'])
    step_by_step_history['train_perplexity'].append(train_metrics['perplexity'])
    step_by_step_history['valid_loss'].append(valid_metrics['loss'])
    step_by_step_history['valid_accuracy'].append(valid_metrics['accuracy'])
    step_by_step_history['valid_perplexity'].append(valid_metrics['perplexity'])
    
    # Reload data if needed
    if (e + 1) % args.reload_data_every_n_epochs == 0 and (e + 1) < args.num_epochs:
        print("Reloading training data...")
        from training.finetuning import load_data_from_loader
        pdb_dict_train = load_data_from_loader(train_loader, args.max_protein_length, args.num_examples_per_epoch)
        dataset_train = StructureDataset(pdb_dict_train, truncate=None, max_length=args.max_protein_length)
        loader_train = StructureLoader(dataset_train, batch_size=args.batch_size)

## 7. Visualize Training Progress

Let's visualize the training progress using the metrics we collected.

In [None]:
# Choose which history to visualize
history = training_history  # or step_by_step_history

# Plot training progress
plt.figure(figsize=(12, 10))

# Plot perplexity
plt.subplot(2, 1, 1)
plt.plot(history['epochs'], history['train_perplexity'], 'b-', label='Training Perplexity')
plt.plot(history['epochs'], history['valid_perplexity'], 'r-', label='Validation Perplexity')
plt.xlabel('Epoch')
plt.ylabel('Perplexity')
plt.title('Training and Validation Perplexity')
plt.legend()
plt.grid(True)

# Plot accuracy
plt.subplot(2, 1, 2)
plt.plot(history['epochs'], history['train_accuracy'], 'b-', label='Training Accuracy')
plt.plot(history['epochs'], history['valid_accuracy'], 'r-', label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.savefig(base_folder + 'training_progress.png')
plt.show()

## 8. Evaluate the Model

Let's evaluate the trained model on the validation set.

In [None]:
# Load the final model
final_model_path = base_folder + 'model_weights/final_model.pt'
checkpoint = torch.load(final_model_path)

# Create a new model for evaluation
eval_model = ProteinMPNN(
    node_features=args.hidden_dim,
    edge_features=args.hidden_dim,
    hidden_dim=args.hidden_dim,
    num_encoder_layers=args.num_encoder_layers,
    num_decoder_layers=args.num_encoder_layers,
    k_neighbors=args.num_neighbors,
    dropout=0.0,  # No dropout during evaluation
    augment_eps=0.0  # No noise during evaluation
)
eval_model.to(device)
eval_model.load_state_dict(checkpoint['model_state_dict'])
eval_model.eval()

# Evaluate on validation set (you can reuse the train_and_validate function for this)
from training.model_utils import featurize, loss_nll
import torch.nn.functional as F

eval_model.eval()
with torch.no_grad():
    validation_sum, validation_weights = 0., 0.
    validation_acc = 0.
    for batch_idx, batch in enumerate(loader_valid):
        X, S, mask, lengths, chain_M, residue_idx, mask_self, chain_encoding_all = featurize(batch, device)
        log_probs = eval_model(X, S, mask, chain_M, residue_idx, chain_encoding_all)
        mask_for_loss = mask*chain_M
        loss, loss_av, true_false = loss_nll(S, log_probs, mask_for_loss)
        
        validation_sum += torch.sum(loss * mask_for_loss).cpu().data.numpy()
        validation_acc += torch.sum(true_false * mask_for_loss).cpu().data.numpy()
        validation_weights += torch.sum(mask_for_loss).cpu().data.numpy()
        
        print(f"Batch {batch_idx}: Accuracy = {torch.sum(true_false * mask_for_loss).cpu().data.numpy() / torch.sum(mask_for_loss).cpu().data.numpy():.4f}")
    
    # Calculate overall metrics
    validation_loss = validation_sum / validation_weights
    validation_accuracy = validation_acc / validation_weights
    validation_perplexity = np.exp(validation_loss)
    
    print("\nFinal Evaluation Results:")
    print(f"Loss: {validation_loss:.4f}")
    print(f"Accuracy: {validation_accuracy:.4f}")
    print(f"Perplexity: {validation_perplexity:.4f}")

## 9. Conclusion

In this notebook, we've demonstrated how to finetune ProteinMPNN using the acidophilic dataset with the modular functions from finetuning.py. This approach makes the code more reusable and maintainable.

The functions we used include:
- `setup_acidophilic_finetuning`: Sets up everything needed for training
- `train_and_validate`: Trains and validates the model for one epoch
- `run_acidophilic_finetuning`: Runs the complete finetuning process

These functions can be used in various contexts, including scripts, notebooks, and larger workflows.