# ProteinMPNN Finetuning with Acidophilic Dataset

This notebook demonstrates how to finetune ProteinMPNN using the acidophilic dataset. It includes all the necessary steps, from data preparation to model training and evaluation.

## 1. Setup Environment

First, let's make sure we have all the required packages installed and set up our working directory.

In [None]:
# Check if running in Colab and install dependencies if needed
import os
import sys

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    # Clone the repo if in Colab
    !git clone https://github.com/aaronmaiww/ProteinMPNN.git
    %cd ProteinMPNN
    !git checkout finetuning-modification
    
    # Install dependencies
    !pip install torch==2.0.1 numpy dateutil

In [None]:
# Set up paths and check directory structure
import os
import sys

# If not in Colab, assume we're in the ProteinMPNN directory
if not IN_COLAB:
    # Add the current directory to the path
    sys.path.append(os.getcwd())
else:
    # In Colab, we're already in the right directory
    pass

# Import necessary modules
from training.utils import PDB_dataset, StructureDataset, StructureLoader, worker_init_fn
from training.model_utils import featurize, loss_smoothed, loss_nll, get_std_opt, ProteinMPNN

# Display directory structure
!ls -la training/

## 2. Download and Extract Acidophilic Dataset

Now, let's download and extract the acidophilic dataset. This assumes the dataset is available in your GitHub repository or another accessible location.

In [None]:
# Create a temporary directory for the dataset
!mkdir -p /tmp/acidophilic_data

# For demonstration purposes - in a real notebook, you'd download from a URL
if IN_COLAB:
    # In Colab, you might need to download the dataset from a URL
    dataset_url = "YOUR_DATASET_URL_HERE"
    !wget -O /tmp/acidophilic_dataset_converted.zip {dataset_url}
    !unzip -q /tmp/acidophilic_dataset_converted.zip -d /tmp/acidophilic_data
else:
    # Locally, use the existing dataset
    if os.path.exists("acidophilic_dataset_converted.zip"):
        !unzip -q acidophilic_dataset_converted.zip -d /tmp/acidophilic_data
    else:
        print("Dataset not found locally. Please download it or provide a URL.")

In [None]:
# Check the dataset structure
!ls -la /tmp/acidophilic_data/acidophilic_dataset_for_aaron_converted

# Display a sample of the training proteins
!head -n 10 /tmp/acidophilic_data/acidophilic_dataset_for_aaron_converted/train_proteins.txt

## 3. Define Custom Functions for Acidophilic Dataset

Next, let's define the custom functions needed to work with the acidophilic dataset format.

In [None]:
import numpy as np
import torch
import random
import time

def build_custom_training_clusters(params, debug=False):
    """Build training clusters from protein lists instead of list.csv"""
    # Read protein lists
    with open(params['VAL'], 'r') as f:
        val_proteins = [line.strip() for line in f.readlines()]
    
    with open(params['TEST'], 'r') as f:
        test_proteins = [line.strip() for line in f.readlines()]
    
    train_file = os.path.join(params['DIR'], 'train_proteins.txt')
    with open(train_file, 'r') as f:
        train_proteins = [line.strip() for line in f.readlines()]
        
    print(f"Found {len(train_proteins)} train proteins, {len(val_proteins)} validation proteins, {len(test_proteins)} test proteins")
    
    # Create dictionaries in the expected format
    train = {}
    valid = {}
    test = {}
    
    # For train proteins, add them to train dict
    for protein in train_proteins:
        if debug and len(train) >= 20:  # Limit in debug mode
            break
        train[len(train)] = [[f"{protein}_A", protein]]
        
    # For validation proteins, add them to valid dict
    for protein in val_proteins:
        if debug and len(valid) >= 10:  # Limit in debug mode
            break
        valid[len(valid)] = [[f"{protein}_A", protein]]
        
    # For test proteins, add them to test dict
    for protein in test_proteins:
        if debug and len(test) >= 10:  # Limit in debug mode
            break
        test[len(test)] = [[f"{protein}_A", protein]]
        
    if debug:
        print("Train sample:", list(train.items())[:5])
        print("Valid sample:", list(valid.items())[:5])
        
    return train, valid, test

def custom_loader_pdb(item, params):
    """Custom loader for acidophilic dataset structure"""
    pdbid, chid = item[0].split('_')
    
    # Determine if it's a train, validation, or test protein
    dataset_type = None
    train_file = os.path.join(params['DIR'], 'train_proteins.txt')
    val_file = os.path.join(params['DIR'], 'validation_proteins.txt')
    test_file = os.path.join(params['DIR'], 'test_proteins.txt')
    
    with open(train_file, 'r') as f:
        train_proteins = [line.strip() for line in f.readlines()]
    with open(val_file, 'r') as f:
        val_proteins = [line.strip() for line in f.readlines()]
    with open(test_file, 'r') as f:
        test_proteins = [line.strip() for line in f.readlines()]
    
    if pdbid in train_proteins:
        dataset_type = "train"
    elif pdbid in val_proteins:
        dataset_type = "validation"
    elif pdbid in test_proteins:
        dataset_type = "test"
    else:
        print(f"Protein {pdbid} not found in any dataset!")
        return {'seq': np.zeros(5)}
    
    # Extract first two characters for directory structure
    first_two = pdbid[:2]
    
    PREFIX = f"{params['DIR']}/{dataset_type}/pdb/{first_two}/{pdbid}/{pdbid}"
    
    # Check if file exists
    if not os.path.isfile(f"{PREFIX}.pt"):
        print(f"File not found: {PREFIX}.pt")
        return {'seq': np.zeros(5)}
    
    # Load metadata
    meta = torch.load(f"{PREFIX}.pt")
    
    # Check if chain file exists
    chain_file = f"{PREFIX}_{chid}.pt"
    if not os.path.isfile(chain_file):
        print(f"Chain file not found: {chain_file}")
        return {'seq': np.zeros(5)}
    
    # Load chain data
    chain = torch.load(chain_file)
    L = len(chain['seq'])
    
    # Return data in the expected format
    return {'seq': chain['seq'],
            'xyz': chain['xyz'],
            'idx': torch.zeros(L).int(),
            'masked': torch.Tensor([0]).int(),
            'label': item[0]}

def load_data_from_loader(data_loader, max_length, num_examples):
    """Load data directly from loader_pdb output and convert to expected format"""
    pdb_dict_list = []
    count = 0
    
    for batch in data_loader:
        # Handle the batch format - extract values from lists
        # batch contains lists because DataLoader wraps everything in batches
        
        # Extract single values from batch lists
        label = batch['label'][0] if isinstance(batch['label'], list) else batch['label']
        seq = batch['seq'][0] if isinstance(batch['seq'], list) else batch['seq']
        xyz = batch['xyz']
        idx = batch['idx'][0] if batch['idx'].dim() > 1 else batch['idx']
        masked = batch['masked'][0] if batch['masked'].dim() > 1 else batch['masked']
        
        print(f"DEBUG: Processing {label}, seq_len={len(seq)}, xyz_shape={xyz.shape}")
        
        if len(seq) <= max_length:
            # Convert to format expected by featurize
            chain_id = label.split('_')[-1]  # Extract chain ID (e.g., 'A')
            
            # Remove His-tags from sequence
            sequence = seq
            if sequence[-6:] == "HHHHHH":
                sequence = sequence[:-6]
            if sequence[0:6] == "HHHHHH":
                sequence = sequence[6:]
            # Add other His-tag removal patterns if needed
            
            if len(sequence) < 4:
                continue
                
            # Get coordinates and ensure proper shape
            # Handle various possible shapes of coordinates
            if xyz.dim() == 4 and xyz.shape[0] == 1 and xyz.shape[2] == 4:  # [1, L, 4, 3]
                all_atoms = xyz.squeeze(0)
            elif xyz.dim() == 5 and xyz.shape[0] == 1 and xyz.shape[1] == 1:  # [1, 1, L, 4, 3]
                all_atoms = xyz.squeeze(0).squeeze(0)
            else:
                print(f"Unable to handle coordinate shape: {xyz.shape}")
                continue
            
            print(f"DEBUG: {label} - reshaped coords to shape: {all_atoms.shape}, seq length: {len(sequence)}")
            
            # Adjust coordinates to match sequence length after His-tag removal
            if len(all_atoms) > len(sequence):
                print(f"DEBUG: Trimming coordinates from {len(all_atoms)} to {len(sequence)}")
                all_atoms = all_atoms[:len(sequence)]
            elif len(all_atoms) < len(sequence):
                print(f"Warning: coordinates shorter than sequence for {label}: coords={len(all_atoms)}, seq={len(sequence)}")
                continue
            
            # Create coordinate dictionary
            coords_dict = {
                f'N_chain_{chain_id}': all_atoms[:,0,:].tolist(),
                f'CA_chain_{chain_id}': all_atoms[:,1,:].tolist(),
                f'C_chain_{chain_id}': all_atoms[:,2,:].tolist(),
                f'O_chain_{chain_id}': all_atoms[:,3,:].tolist(),
            }
            
            # Determine masking
            if masked.dim() > 0:
                masked_values = masked.tolist()
            else:
                masked_values = [masked.item()]
            
            if 0 in masked_values:
                masked_list = [chain_id]
                visible_list = []
            else:
                masked_list = []
                visible_list = [chain_id]
            
            # Create final structure in format expected by featurize
            converted_item = {
                f'seq_chain_{chain_id}': sequence,
                f'coords_chain_{chain_id}': coords_dict,
                'masked_list': masked_list,
                'visible_list': visible_list,
                'num_of_chains': 1,
                'seq': sequence,
                'name': label
            }
            
            pdb_dict_list.append(converted_item)
            count += 1
            print(f"DEBUG: Successfully converted {label}")
            
            if count >= num_examples:
                break
    
    return pdb_dict_list

## 4. Train the Model

Now, let's set up the training parameters and start the finetuning process.

In [None]:
# Define training parameters
class Args:
    def __init__(self):
        self.path_for_training_data = "/tmp/acidophilic_data/acidophilic_dataset_for_aaron_converted"
        self.path_for_outputs = "./notebook_finetuning_%Y%m%d_%H%M%S"
        self.previous_checkpoint = ""  # Optional: path to previous checkpoint
        self.num_epochs = 5  # Set to a small number for demonstration
        self.save_model_every_n_epochs = 1
        self.reload_data_every_n_epochs = 2
        self.num_examples_per_epoch = 50  # Small number for demonstration
        self.batch_size = 1000
        self.max_protein_length = 1000
        self.hidden_dim = 128
        self.num_encoder_layers = 3
        self.num_decoder_layers = 3
        self.num_neighbors = 48
        self.dropout = 0.1
        self.backbone_noise = 0.2
        self.rescut = 3.5
        self.debug = True  # Set to True for faster debugging
        self.gradient_norm = -1.0
        self.mixed_precision = True

args = Args()

In [None]:
# Initialize training
import json
import time
import sys
import glob
import shutil
import warnings
import torch.nn as nn
import torch.nn.functional as F

# Set device
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
print(f"Using device: {device}")

# Initialize mixed precision if needed
scaler = torch.cuda.amp.GradScaler() if args.mixed_precision else None

# Set up output folder
base_folder = time.strftime(args.path_for_outputs, time.localtime())
if base_folder[-1] != '/':
    base_folder += '/'
if not os.path.exists(base_folder):
    os.makedirs(base_folder)
subfolders = ['model_weights']
for subfolder in subfolders:
    if not os.path.exists(base_folder + subfolder):
        os.makedirs(base_folder + subfolder)

# Set up logging
logfile = base_folder + 'log.txt'
with open(logfile, 'w') as f:
    f.write('Epoch\tTrain\tValidation\n')

# Set up data parameters
data_path = args.path_for_training_data
params = {
    "VAL": f"{data_path}/validation_proteins.txt",
    "TEST": f"{data_path}/test_proteins.txt",
    "DIR": f"{data_path}",
    "DATCUT": "2030-Jan-01",
    "RESCUT": args.rescut,
    "HOMO": 0.70
}

# DataLoader parameters
LOAD_PARAM = {
    'batch_size': 1,
    'shuffle': True,
    'pin_memory': False,
    'num_workers': 2  # Reduced for Jupyter
}

In [None]:
# Build training clusters and create datasets
print("Building training clusters...")
train, valid, test = build_custom_training_clusters(params, args.debug)

print("\nCreating datasets...")
train_set = PDB_dataset(list(train.keys()), custom_loader_pdb, train, params)
valid_set = PDB_dataset(list(valid.keys()), custom_loader_pdb, valid, params)

# Create data loaders
train_loader = torch.utils.data.DataLoader(train_set, worker_init_fn=worker_init_fn, **LOAD_PARAM)
valid_loader = torch.utils.data.DataLoader(valid_set, worker_init_fn=worker_init_fn, **LOAD_PARAM)

print(f"Training set size: {len(train_set)}")
print(f"Validation set size: {len(valid_set)}")

In [None]:
# Create the model
model = ProteinMPNN(node_features=args.hidden_dim, 
                edge_features=args.hidden_dim, 
                hidden_dim=args.hidden_dim, 
                num_encoder_layers=args.num_encoder_layers, 
                num_decoder_layers=args.num_encoder_layers, 
                k_neighbors=args.num_neighbors, 
                dropout=args.dropout, 
                augment_eps=args.backbone_noise)
model.to(device)

# Load pre-trained weights if specified
if args.previous_checkpoint:
    print(f"Loading weights from {args.previous_checkpoint}")
    checkpoint = torch.load(args.previous_checkpoint)
    total_step = checkpoint['step']
    epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['model_state_dict'])
    
    optimizer = get_std_opt(model.parameters(), args.hidden_dim, total_step)
    if 'optimizer_state_dict' in checkpoint:
        optimizer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
else:
    total_step = 0
    epoch = 0
    optimizer = get_std_opt(model.parameters(), args.hidden_dim, total_step)

print("Model created and initialized.")

In [None]:
# Load training and validation data
print("Loading training data...")
pdb_dict_train = load_data_from_loader(train_loader, args.max_protein_length, args.num_examples_per_epoch)
print(f"Loaded {len(pdb_dict_train)} training examples")

print("\nLoading validation data...")
pdb_dict_valid = load_data_from_loader(valid_loader, args.max_protein_length, args.num_examples_per_epoch)
print(f"Loaded {len(pdb_dict_valid)} validation examples")

# Create structure datasets
dataset_train = StructureDataset(pdb_dict_train, truncate=None, max_length=args.max_protein_length)
dataset_valid = StructureDataset(pdb_dict_valid, truncate=None, max_length=args.max_protein_length)

# Create structure loaders
loader_train = StructureLoader(dataset_train, batch_size=args.batch_size)
loader_valid = StructureLoader(dataset_valid, batch_size=args.batch_size)

In [None]:
# Training loop
print("Starting training...")
for e in range(args.num_epochs):
    t0 = time.time()
    e = epoch + e
    model.train()
    train_sum, train_weights = 0., 0.
    train_acc = 0.
    
    print(f"Epoch {e+1}/{epoch + args.num_epochs}")
    
    # Training
    for batch_idx, batch in enumerate(loader_train):
        print(f"  Training batch {batch_idx+1}/{len(loader_train)}")
        X, S, mask, lengths, chain_M, residue_idx, mask_self, chain_encoding_all = featurize(batch, device)
        optimizer.zero_grad()
        mask_for_loss = mask*chain_M
        
        if args.mixed_precision:
            with torch.cuda.amp.autocast():
                log_probs = model(X, S, mask, chain_M, residue_idx, chain_encoding_all)
                _, loss_av_smoothed = loss_smoothed(S, log_probs, mask_for_loss)
    
            scaler.scale(loss_av_smoothed).backward()
              
            if args.gradient_norm > 0.0:
                total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_norm)

            scaler.step(optimizer)
            scaler.update()
        else:
            log_probs = model(X, S, mask, chain_M, residue_idx, chain_encoding_all)
            _, loss_av_smoothed = loss_smoothed(S, log_probs, mask_for_loss)
            loss_av_smoothed.backward()

            if args.gradient_norm > 0.0:
                total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_norm)

            optimizer.step()
        
        # Calculate loss and accuracy
        loss, loss_av, true_false = loss_nll(S, log_probs, mask_for_loss)
        train_sum += torch.sum(loss * mask_for_loss).cpu().data.numpy()
        train_acc += torch.sum(true_false * mask_for_loss).cpu().data.numpy()
        train_weights += torch.sum(mask_for_loss).cpu().data.numpy()
        
        total_step += 1
    
    # Validation
    print("  Running validation...")
    model.eval()
    with torch.no_grad():
        validation_sum, validation_weights = 0., 0.
        validation_acc = 0.
        for batch_idx, batch in enumerate(loader_valid):
            X, S, mask, lengths, chain_M, residue_idx, mask_self, chain_encoding_all = featurize(batch, device)
            log_probs = model(X, S, mask, chain_M, residue_idx, chain_encoding_all)
            mask_for_loss = mask*chain_M
            loss, loss_av, true_false = loss_nll(S, log_probs, mask_for_loss)
            
            validation_sum += torch.sum(loss * mask_for_loss).cpu().data.numpy()
            validation_acc += torch.sum(true_false * mask_for_loss).cpu().data.numpy()
            validation_weights += torch.sum(mask_for_loss).cpu().data.numpy()
    
    # Calculate metrics
    train_loss = train_sum / train_weights if train_weights > 0 else float('inf')
    train_accuracy = train_acc / train_weights if train_weights > 0 else 0
    train_perplexity = np.exp(train_loss)
    validation_loss = validation_sum / validation_weights if validation_weights > 0 else float('inf')
    validation_accuracy = validation_acc / validation_weights if validation_weights > 0 else 0
    validation_perplexity = np.exp(validation_loss)
    
    # Format metrics for printing
    train_perplexity_ = np.format_float_positional(np.float32(train_perplexity), unique=False, precision=3)     
    validation_perplexity_ = np.format_float_positional(np.float32(validation_perplexity), unique=False, precision=3)
    train_accuracy_ = np.format_float_positional(np.float32(train_accuracy), unique=False, precision=3)
    validation_accuracy_ = np.format_float_positional(np.float32(validation_accuracy), unique=False, precision=3)

    # Calculate time and log results
    t1 = time.time()
    dt = np.format_float_positional(np.float32(t1-t0), unique=False, precision=1) 
    
    # Log results
    log_message = f'epoch: {e+1}, step: {total_step}, time: {dt}, train: {train_perplexity_}, valid: {validation_perplexity_}, train_acc: {train_accuracy_}, valid_acc: {validation_accuracy_}'
    with open(logfile, 'a') as f:
        f.write(f'{log_message}\n')
    print(log_message)
    
    # Save checkpoint
    checkpoint_filename_last = base_folder+'model_weights/epoch_last.pt'
    torch.save({
                'epoch': e+1,
                'step': total_step,
                'num_edges' : args.num_neighbors,
                'noise_level': args.backbone_noise,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.optimizer.state_dict(),
                }, checkpoint_filename_last)

    # Save periodic checkpoints
    if (e+1) % args.save_model_every_n_epochs == 0:
        checkpoint_filename = base_folder+f'model_weights/epoch{e+1}_step{total_step}.pt'
        torch.save({
                'epoch': e+1,
                'step': total_step,
                'num_edges' : args.num_neighbors,
                'noise_level': args.backbone_noise, 
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.optimizer.state_dict(),
                }, checkpoint_filename)

    # Reload data periodically if needed
    if (e+1) % args.reload_data_every_n_epochs == 0 and e+1 < epoch + args.num_epochs:
        print("Reloading training data...")
        pdb_dict_train = load_data_from_loader(train_loader, args.max_protein_length, args.num_examples_per_epoch)
        dataset_train = StructureDataset(pdb_dict_train, truncate=None, max_length=args.max_protein_length)
        loader_train = StructureLoader(dataset_train, batch_size=args.batch_size)

print("\nTraining completed!")

## 5. Evaluate the Model

Let's evaluate the trained model on the validation set.

In [None]:
# Load the latest model for evaluation
latest_checkpoint = torch.load(checkpoint_filename_last)
model.load_state_dict(latest_checkpoint['model_state_dict'])
model.eval()

# Run a comprehensive evaluation on the validation set
print("Running comprehensive evaluation on validation set...")
with torch.no_grad():
    all_losses = []
    all_accuracies = []
    all_perplexities = []
    
    for batch_idx, batch in enumerate(loader_valid):
        X, S, mask, lengths, chain_M, residue_idx, mask_self, chain_encoding_all = featurize(batch, device)
        log_probs = model(X, S, mask, chain_M, residue_idx, chain_encoding_all)
        mask_for_loss = mask*chain_M
        loss, loss_av, true_false = loss_nll(S, log_probs, mask_for_loss)
        
        # Calculate metrics per batch
        batch_loss = torch.sum(loss * mask_for_loss).cpu().data.numpy() / torch.sum(mask_for_loss).cpu().data.numpy()
        batch_accuracy = torch.sum(true_false * mask_for_loss).cpu().data.numpy() / torch.sum(mask_for_loss).cpu().data.numpy()
        batch_perplexity = np.exp(batch_loss)
        
        all_losses.append(batch_loss)
        all_accuracies.append(batch_accuracy)
        all_perplexities.append(batch_perplexity)
        
        print(f"Batch {batch_idx}: Loss = {batch_loss:.4f}, Accuracy = {batch_accuracy:.4f}, Perplexity = {batch_perplexity:.4f}")
    
    # Calculate overall metrics
    avg_loss = np.mean(all_losses)
    avg_accuracy = np.mean(all_accuracies)
    avg_perplexity = np.mean(all_perplexities)
    
    print("\nEvaluation Results:")
    print(f"Average Loss: {avg_loss:.4f}")
    print(f"Average Accuracy: {avg_accuracy:.4f}")
    print(f"Average Perplexity: {avg_perplexity:.4f}")

## 6. Visualize Training Progress

Let's visualize the training progress using the log file.

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import re

# Parse log file
epochs = []
train_perplexities = []
valid_perplexities = []
train_accuracies = []
valid_accuracies = []

with open(logfile, 'r') as f:
    lines = f.readlines()
    for line in lines[1:]:  # Skip header
        match = re.match(r'epoch: (\d+), step: (\d+), time: ([\d.]+), train: ([\d.]+), valid: ([\d.]+), train_acc: ([\d.]+), valid_acc: ([\d.]+)', line)
        if match:
            epoch, _, _, train_perp, valid_perp, train_acc, valid_acc = match.groups()
            epochs.append(int(epoch))
            train_perplexities.append(float(train_perp))
            valid_perplexities.append(float(valid_perp))
            train_accuracies.append(float(train_acc))
            valid_accuracies.append(float(valid_acc))

# Plot training progress
plt.figure(figsize=(12, 10))

# Plot perplexity
plt.subplot(2, 1, 1)
plt.plot(epochs, train_perplexities, 'b-', label='Training Perplexity')
plt.plot(epochs, valid_perplexities, 'r-', label='Validation Perplexity')
plt.xlabel('Epoch')
plt.ylabel('Perplexity')
plt.title('Training and Validation Perplexity')
plt.legend()
plt.grid(True)

# Plot accuracy
plt.subplot(2, 1, 2)
plt.plot(epochs, train_accuracies, 'b-', label='Training Accuracy')
plt.plot(epochs, valid_accuracies, 'r-', label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.savefig(base_folder + 'training_progress.png')
plt.show()

## 7. Save Model for Later Use

Let's save the final model in a format that can be easily loaded for inference.

In [None]:
# Save the final model
final_model_path = base_folder + 'model_weights/final_model.pt'
torch.save({
    'epoch': epoch + args.num_epochs,
    'step': total_step,
    'num_edges': args.num_neighbors,
    'noise_level': args.backbone_noise,
    'model_state_dict': model.state_dict(),
}, final_model_path)

print(f"Final model saved to: {final_model_path}")

## 8. Using the Trained Model for Inference

Here's a simple example of how to use the trained model for inference.

In [None]:
# Load the trained model
from training.model_utils import ProteinMPNN

# Create a new model with the same architecture
inference_model = ProteinMPNN(node_features=args.hidden_dim, 
                          edge_features=args.hidden_dim, 
                          hidden_dim=args.hidden_dim, 
                          num_encoder_layers=args.num_encoder_layers, 
                          num_decoder_layers=args.num_encoder_layers, 
                          k_neighbors=args.num_neighbors, 
                          dropout=0.0,  # No dropout during inference 
                          augment_eps=0.0)  # No noise during inference
inference_model.to(device)

# Load the weights from the trained model
checkpoint = torch.load(final_model_path)
inference_model.load_state_dict(checkpoint['model_state_dict'])
inference_model.eval()

print("Model loaded for inference.")

# You can now use the model for inference with your custom inputs
# Example: inference_model(X, S, mask, chain_M, residue_idx, chain_encoding_all)

## 9. Conclusion

In this notebook, we've demonstrated how to finetune ProteinMPNN using the acidophilic dataset. We covered:

1. Setting up the environment
2. Loading and preprocessing the acidophilic dataset
3. Implementing custom functions for the dataset
4. Training the model
5. Evaluating the model performance
6. Visualizing the training progress
7. Saving and loading the model for inference

This finetuned model can now be used for protein sequence design tasks specifically tailored to acidophilic proteins.