In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from scipy.stats import spearmanr
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from transformers import AutoModelForMaskedLM, AutoTokenizer
from torch.utils.data import Dataset, DataLoader, TensorDataset
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Configuration
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# 1. Correct ESM Implementation for Mutation Scoring
class ESMPredictor:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
        self.model = AutoModelForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D").to(device)
        self.model.eval()
        
    def get_mutation_score(self, sequence_wt, mutation):
        wt, pos, mt = mutation[0], int(mutation[1:-1]), mutation[-1]
        
        # Create masked sequence
        masked_seq = sequence_wt[:pos-1] + self.tokenizer.mask_token + sequence_wt[pos:]
        
        # Tokenize
        inputs = self.tokenizer(masked_seq, return_tensors="pt").to(device)
        mask_token_index = torch.where(inputs["input_ids"] == self.tokenizer.mask_token_id)[1].item()
        
        # Get logits
        with torch.no_grad():
            outputs = self.model(**inputs)
            logits = outputs.logits
        
        # Get token IDs
        wt_token = self.tokenizer.convert_tokens_to_ids(wt)
        mt_token = self.tokenizer.convert_tokens_to_ids(mt)
        
        # Calculate probabilities
        wt_prob = F.softmax(logits[0, mask_token_index], dim=-1)[wt_token].item()
        mt_prob = F.softmax(logits[0, mask_token_index], dim=-1)[mt_token].item()
        
        return mt_prob - wt_prob  # Positive means more likely

# 2. Feature Engineering
class FeatureExtractor:
    def __init__(self, sequence_wt):
        self.sequence_wt = sequence_wt
        self.esm_predictor = ESMPredictor()
        self.amino_acids = 'ACDEFGHIKLMNPQRSTVWY'
        self.aa_to_idx = {aa: i for i, aa in enumerate(self.amino_acids)}
        
    def get_features(self, df):
        features = []
        for mut in df['mutant']:
            wt, pos, mt = mut[0], int(mut[1:-1]), mut[-1]
            
            # Positional features
            pos_features = [
                pos / len(self.sequence_wt),
                self.aa_to_idx[wt] / len(self.amino_acids),
                self.aa_to_idx[mt] / len(self.amino_acids)
            ]
            
            # ESM score
            try:
                esm_score = self.esm_predictor.get_mutation_score(self.sequence_wt, mut)
            except:
                esm_score = 0  # Fallback if scoring fails
            
            features.append(pos_features + [esm_score])
        
        return np.array(features)

# 3. Model Architecture
class FitnessPredictor(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.fc1 = nn.Linear(input_size, 256)
        self.bn1 = nn.BatchNorm1d(256)
        self.fc2 = nn.Linear(256, 128)
        self.bn2 = nn.BatchNorm1d(128)
        self.fc3 = nn.Linear(128, 1)
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, x):
        x = F.leaky_relu(self.bn1(self.fc1(x)), 0.1)
        x = self.dropout(F.leaky_relu(self.bn2(self.fc2(x)), 0.1))
        return self.fc3(x).squeeze()

# 4. Active Learning Pipeline
class ActiveLearningPipeline:
    def __init__(self, sequence_wt):
        self.sequence_wt = sequence_wt
        self.feature_extractor = FeatureExtractor(sequence_wt)
        self.queried_mutants = set()
        
    def load_data(self):
        self.train_df = pd.read_csv('train.csv')
        self.train_df['sequence'] = self.train_df['mutant'].apply(
            lambda x: get_mutated_sequence(x, self.sequence_wt))
        
        self.test_df = pd.read_csv('test.csv')
        self.test_df['sequence'] = self.test_df['mutant'].apply(
            lambda x: get_mutated_sequence(x, self.sequence_wt))
        
        if os.path.exists('queried_mutants.txt'):
            with open('queried_mutants.txt', 'r') as f:
                self.queried_mutants.update(line.strip() for line in f)
    
    def train_model(self):
        X = self.feature_extractor.get_features(self.train_df)
        y = self.train_df['DMS_score'].values
        
        # Train/validation split
        X_train, X_val, y_train, y_val = train_test_split(
            X, y, test_size=0.2, random_state=SEED)
        
        # Feature scaling
        self.scaler = StandardScaler()
        X_train = self.scaler.fit_transform(X_train)
        X_val = self.scaler.transform(X_val)
        
        # Initialize model
        model = FitnessPredictor(X_train.shape[1]).to(device)
        optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
        criterion = nn.HuberLoss()
        
        # Training loop
        best_corr = -1
        for epoch in range(100):
            model.train()
            for i in range(0, len(X_train), 32):
                batch_X = torch.FloatTensor(X_train[i:i+32]).to(device)
                batch_y = torch.FloatTensor(y_train[i:i+32]).to(device)
                
                optimizer.zero_grad()
                outputs = model(batch_X)
                loss = criterion(outputs, batch_y)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
            
            # Validation
            model.eval()
            with torch.no_grad():
                val_preds = model(torch.FloatTensor(X_val).to(device)).cpu().numpy()
                corr = spearmanr(y_val, val_preds)[0]
                
                if corr > best_corr:
                    best_corr = corr
                    torch.save(model.state_dict(), 'best_model_b.pth')
            
            print(f"Epoch {epoch+1}: Val Corr = {corr:.4f}, Best = {best_corr:.4f}")
        
        model.load_state_dict(torch.load('best_model_b.pth'))
        return model
    
    def select_queries(self, model, n_queries=100):
        X_test = self.scaler.transform(
            self.feature_extractor.get_features(self.test_df))
        
        with torch.no_grad():
            test_preds = model(torch.FloatTensor(X_test).to(device)).cpu().numpy()
        
        candidates = []
        for i, (mut, pred) in enumerate(zip(self.test_df['mutant'], test_preds)):
            if mut not in self.queried_mutants:
                esm_score = X_test[i, -1]  # ESM score is the last feature
                combined_score = 0.7 * pred + 0.3 * esm_score
                candidates.append((mut, combined_score))
        
        # Select top candidates
        candidates.sort(key=lambda x: x[1], reverse=True)
        selected = [mut for mut, score in candidates[:n_queries]]
        
        # Update queried mutants
        self.queried_mutants.update(selected)
        with open('query.txt', 'w') as f:
            for mut in selected:
                f.write(f"{mut}\n")
        with open('queried_mutants.txt', 'w') as f:
            for mut in self.queried_mutants:
                f.write(f"{mut}\n")
        
        return selected
    
    def run_cycle(self):
        print('Loading data...')
        self.load_data()
        print('Training Model...')
        model = self.train_model()
        print('Selecting Queries...')
        queries = self.select_queries(model)
        
        # Save predictions
        X_test = self.scaler.transform(
            self.feature_extractor.get_features(self.test_df))
        with torch.no_grad():
            test_preds = model(torch.FloatTensor(X_test).to(device)).cpu().numpy()
        
        # pd.DataFrame({
        #     'mutant': self.test_df['mutant'],
        #     'DMS_score_predicted': test_preds
        # }).to_csv('predictions.csv', index=False)

# Helper function
def get_mutated_sequence(mut, sequence_wt):
    wt, pos, mt = mut[0], int(mut[1:-1]), mut[-1]
    return sequence_wt[:pos-1] + mt + sequence_wt[pos:]

# Main execution
if __name__ == "__main__":
    with open('sequence.fasta', 'r') as f:
        sequence_wt = f.readlines()[1].strip()
    
    pipeline = ActiveLearningPipeline(sequence_wt)
    pipeline.run_cycle()

cuda
Loading data...
Training Model...
Epoch 1: Val Corr = 0.1257, Best = 0.1257
Epoch 2: Val Corr = 0.1777, Best = 0.1777
Epoch 3: Val Corr = 0.1150, Best = 0.1777
Epoch 4: Val Corr = 0.2927, Best = 0.2927
Epoch 5: Val Corr = 0.1896, Best = 0.2927
Epoch 6: Val Corr = 0.2973, Best = 0.2973
Epoch 7: Val Corr = 0.2901, Best = 0.2973
Epoch 8: Val Corr = 0.2985, Best = 0.2985
Epoch 9: Val Corr = 0.3034, Best = 0.3034
Epoch 10: Val Corr = 0.2999, Best = 0.3034
Epoch 11: Val Corr = 0.3060, Best = 0.3060
Epoch 12: Val Corr = 0.2956, Best = 0.3060
Epoch 13: Val Corr = 0.2749, Best = 0.3060
Epoch 14: Val Corr = 0.3194, Best = 0.3194
Epoch 15: Val Corr = 0.2822, Best = 0.3194
Epoch 16: Val Corr = 0.2771, Best = 0.3194
Epoch 17: Val Corr = 0.3101, Best = 0.3194
Epoch 18: Val Corr = 0.2842, Best = 0.3194
Epoch 19: Val Corr = 0.2900, Best = 0.3194
Epoch 20: Val Corr = 0.2963, Best = 0.3194
Epoch 21: Val Corr = 0.3291, Best = 0.3291
Epoch 22: Val Corr = 0.3084, Best = 0.3291
Epoch 23: Val Corr = 0.2