In [None]:
import os
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForMaskedLM

  # ESM++ is a faithful implementation of ESMC (license) that allows for batching and standard Huggingface compatibility without requiring the ESM Python package. The small version corresponds to the 300 million parameter version of ESMC.
 # “This model is an approximate version of the ESM Cambrian 300M model, with fewer parameters, making it more suitable for local execution and fine-tuning.”
    
def getSequenceData(first_dir, file_name):
    """
    Read protein sequences and their labels from a file.
    
    Args:
        first_dir: Directory path
        file_name: File name
        
    Returns:
        Tuple[List[str], List[List[int]]]: Lists of sequences and labels
    """
    data, label = [], []
    path = os.path.join(first_dir, f"{file_name}.txt")

    with open(path) as f:
        for each in f:
            each = each.strip()
            if each[0] == '>':
                # Convert label string to list of integers
                label.append([int(char) for char in each[1:]])
            else:
                data.append(each)

    return data, label

def extract_features(model_name, sequences, pooling_type, device, batch_size=16):
    """
    Extract features from protein sequences using ESMplusplus model.
    
    Args:
        model_name: Name of the ESMplusplus model ('small' or 'large')
        sequences: List of protein sequences
        pooling_type: Type of pooling ('mean', 'max', or 'cls')
        device: Device to run the model on
        batch_size: Batch size for processing
        
    Returns:
        torch.Tensor: Extracted features tensor
    """

  
    # Load model and tokenizer
    print(f"Loading ESMplusplus_{model_name} model...")
    model = AutoModelForMaskedLM.from_pretrained(f'E:/ESMplusplus_{model_name}', trust_remote_code=True, local_files_only=True)
    tokenizer = model.tokenizer
    
    # Move model to specified device
    model = model.to(device)
    model.eval()
    
    # Initialize list to store features
    all_features = []
    
    # Process batches
    for i in tqdm(range(0, len(sequences), batch_size), desc=f"Extracting {pooling_type} features"):
        batch = sequences[i:i+batch_size]
        
        # Tokenize
        tokenized = tokenizer(batch, padding=True, return_tensors='pt')
        tokenized = {k: v.to(device) for k, v in tokenized.items()}
        
        # Get features
        with torch.no_grad():
            output = model(**tokenized)
            last_hidden_states = output.last_hidden_state
        
        # Apply pooling
        if pooling_type == 'mean':
            batch_features = last_hidden_states.mean(dim=1)
        elif pooling_type == 'max':
            batch_features = last_hidden_states.max(dim=1).values
        elif pooling_type == 'cls':
            batch_features = last_hidden_states[:, 0, :]
        
        # Move to CPU
        all_features.append(batch_features.cpu())
    
    # Combine features
    features_tensor = torch.cat(all_features, dim=0)
    
    # Print feature dimension information
    print(f"Extracted features dimension: {features_tensor.shape[1]}")
    
    return features_tensor

def main():
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Create base output directory
    base_output_dir = 'otherfeatures'
    os.makedirs(base_output_dir, exist_ok=True)
    
    # Data directory
    data_dir = 'dataset/pre'  # Adjust if needed
    
    # Load sequences from train, val, and test sets
    print("Loading datasets...")
    train_sequences, _ = getSequenceData(os.path.join(data_dir, 'train'), 'train')
    val_sequences, _ = getSequenceData(os.path.join(data_dir, 'val'), 'val')
    test_sequences, _ = getSequenceData(os.path.join(data_dir, 'test'), 'test')
    
    print(f"Loaded {len(train_sequences)} training sequences")
    print(f"Loaded {len(val_sequences)} validation sequences")
    print(f"Loaded {len(test_sequences)} test sequences")
    
    # Define models and pooling types
    models = ['small', 'large']
    pooling_types = ['mean', 'max', 'cls']
    
    # Process each combination
    for model_name in models:
        for pooling_type in pooling_types:
            feature_name = f'ESMplusplus_{model_name}_{pooling_type}'
            output_dir = os.path.join(base_output_dir, feature_name)
            os.makedirs(output_dir, exist_ok=True)
            
            print(f"\n=== Processing {feature_name} ===")
            
            # Extract and save features for each split
            print("Processing training set...")
            train_features = extract_features(model_name, train_sequences, pooling_type, device)
            print(f"Training features shape: {train_features.shape} - [samples × feature_dimension]")
            train_output_path = os.path.join(output_dir, f'train_{feature_name}.pt')
            torch.save(train_features, train_output_path)
            print(f"Saved training features to {train_output_path}")
            
            print("Processing validation set...")
            val_features = extract_features(model_name, val_sequences, pooling_type, device)
            print(f"Validation features shape: {val_features.shape} - [samples × feature_dimension]")
            val_output_path = os.path.join(output_dir, f'val_{feature_name}.pt')
            torch.save(val_features, val_output_path)
            print(f"Saved validation features to {val_output_path}")
            
            print("Processing test set...")
            test_features = extract_features(model_name, test_sequences, pooling_type, device)
            print(f"Test features shape: {test_features.shape} - [samples × feature_dimension]")
            test_output_path = os.path.join(output_dir, f'test_{feature_name}.pt')
            torch.save(test_features, test_output_path)
            print(f"Saved test features to {test_output_path}")

if __name__ == '__main__':
    main()