In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import colors
from tqdm import tqdm
import glob
import re
import sklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from sklearn.model_selection import GridSearchCV

import seaborn as sns
import tensorflow as tf
import keras
from keras.layers import Input, Dense, LeakyReLU
from keras.models import Model, Sequential, load_model
from keras.callbacks import ModelCheckpoint, EarlyStopping

import os
import gzip
import sys

import matplotlib as mpl
mpl.rcParams.update(mpl.rcParamsDefault)

2024-02-04 16:32:27.723570: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-04 16:32:27.723604: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-04 16:32:27.723634: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-02-04 16:32:27.731457: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2, in other operations, rebuild TensorFlow with the appropriate compiler flags.


# LCC Data Import

In [2]:
# Import lcc data files for wt protein and assign to variable (wt_windowsize)
lccdata_folder = 'lccdata_files'

wt_files=glob.glob(f'{lccdata_folder}/w*.lccdata')
wt_files.sort(key=lambda x: int(re.search(r'\d+', x).group()))  # Sort files based on numerical part

window_range = list(range(2,51))
wt_dict = {}
for window, file in zip(window_range, wt_files):
    wt_dict[window] = pd.DataFrame(np.loadtxt(file)).drop(columns=0) 
    
print(wt_files)

['lccdata_files/wildtype_2.lccdata', 'lccdata_files/wildtype_3.lccdata', 'lccdata_files/wildtype_4.lccdata', 'lccdata_files/wildtype_5.lccdata', 'lccdata_files/wildtype_6.lccdata', 'lccdata_files/wildtype_7.lccdata', 'lccdata_files/wildtype_8.lccdata', 'lccdata_files/wildtype_9.lccdata', 'lccdata_files/wildtype_10.lccdata', 'lccdata_files/wildtype_11.lccdata', 'lccdata_files/wildtype_12.lccdata', 'lccdata_files/wildtype_13.lccdata', 'lccdata_files/wildtype_14.lccdata', 'lccdata_files/wildtype_15.lccdata', 'lccdata_files/wildtype_16.lccdata', 'lccdata_files/wildtype_17.lccdata', 'lccdata_files/wildtype_18.lccdata', 'lccdata_files/wildtype_19.lccdata', 'lccdata_files/wildtype_20.lccdata', 'lccdata_files/wildtype_21.lccdata', 'lccdata_files/wildtype_22.lccdata', 'lccdata_files/wildtype_23.lccdata', 'lccdata_files/wildtype_24.lccdata', 'lccdata_files/wildtype_25.lccdata', 'lccdata_files/wildtype_26.lccdata', 'lccdata_files/wildtype_27.lccdata', 'lccdata_files/wildtype_28.lccdata', 'lccdata

In [3]:
# Import lcc data files for mutant protein and assign to variable (mutant_windowsize)
m_files = glob.glob(f'{lccdata_folder}/m*.lccdata')
m_files.sort(key=lambda x: [int(part) if part.isdigit() else part for part in re.split(r'(\d+)', x)])

window_range = list(range(2, 51))
D132H_dict = {}
for window, file in zip(window_range, m_files):
    D132H_dict[window] = pd.DataFrame(np.loadtxt(file)).drop(columns=0)
    
print(m_files)

['lccdata_files/myc_091-160_D132-H_2.lccdata', 'lccdata_files/myc_091-160_D132-H_3.lccdata', 'lccdata_files/myc_091-160_D132-H_4.lccdata', 'lccdata_files/myc_091-160_D132-H_5.lccdata', 'lccdata_files/myc_091-160_D132-H_6.lccdata', 'lccdata_files/myc_091-160_D132-H_7.lccdata', 'lccdata_files/myc_091-160_D132-H_8.lccdata', 'lccdata_files/myc_091-160_D132-H_9.lccdata', 'lccdata_files/myc_091-160_D132-H_10.lccdata', 'lccdata_files/myc_091-160_D132-H_11.lccdata', 'lccdata_files/myc_091-160_D132-H_12.lccdata', 'lccdata_files/myc_091-160_D132-H_13.lccdata', 'lccdata_files/myc_091-160_D132-H_14.lccdata', 'lccdata_files/myc_091-160_D132-H_15.lccdata', 'lccdata_files/myc_091-160_D132-H_16.lccdata', 'lccdata_files/myc_091-160_D132-H_17.lccdata', 'lccdata_files/myc_091-160_D132-H_18.lccdata', 'lccdata_files/myc_091-160_D132-H_19.lccdata', 'lccdata_files/myc_091-160_D132-H_20.lccdata', 'lccdata_files/myc_091-160_D132-H_21.lccdata', 'lccdata_files/myc_091-160_D132-H_22.lccdata', 'lccdata_files/myc_0

# 1.1 RF Model and Feature Importance

In [4]:
# Build and train random forest classifier with selected hyperparameters + extract feature importance

def RFC(window, wt, mutant):

    # Add labels
    wt_label = np.zeros(len(wt))
    mutant_label = np.ones(len(mutant))
    
    # Concatenate data frames and label arrays
    X_train_full = pd.concat([wt, mutant])
    y_train_full = np.concatenate((wt_label, mutant_label))
    
    # Separate training and validation sets and print relevant shapes
    X_train, X_valid, y_train, y_valid = train_test_split(X_train_full, y_train_full, stratify=y_train_full, test_size=0.2)
    
    # Build random forest classifier
    rnd_clf = RandomForestClassifier(n_estimators = 400, max_leaf_nodes = 32, n_jobs=-1)
    
    # Fit rf model
    rnd_clf.fit(X_train, y_train)
    
    # Perform predictions on validation set
    y_pred_rf = rnd_clf.predict(X_valid)
    
    # Evaluate model accuracy
    accuracy = accuracy_score(y_valid, y_pred_rf)
    
    confusion = confusion_matrix(y_valid, y_pred_rf)
    
    # Extract importance of each position
    position = []
    importance = []
    
    for name, score in zip(X_train.columns, rnd_clf.feature_importances_):
        position.append(name)
        importance.append(score)
    
    upper_limit = 70 + 1 - window # max protein length + 1
    
    x = np.arange(1 + window / 2 ,upper_limit + window / 2) + 90
    
    rf_pos_importance = pd.DataFrame({'Position': x, 'Importance:': importance})
    
    # Create directory if it does not exist
    directory = 'position_importance_folder'
    if not os.path.exists(directory):
        os.makedirs(directory)
    
    rf_pos_importance.to_csv(directory + '/pos_imp' + str(window)+'.csv')
    
    return accuracy, confusion, rf_pos_importance

In [7]:
# Initializing dictionaries
acc_dict = {}
rf_dict = {}
conf_mat = {}
n_sim = 10

# Running random forest analysis for all ws
for window in tqdm(window_range):
        
    acc_dict[window], conf_mat[window], rf_dict[window] = RFC(window, wt_dict[window], D132H_dict[window])

100%|███████████████████████████████████████████| 49/49 [11:50<00:00, 14.51s/it]


In [11]:
# New function to calculate adjusted importance
def calculate_adjusted_importance(acc_dict, rf_dict, max_positions=68):
    adjusted_importances = {}
    global_importances = []
    for window, data in rf_dict.items():
        accuracy = acc_dict[window]
        num_positions = 70 - window
        adjustment_factor = accuracy * (num_positions / max_positions)
        
        # Adjust importance scores
        adjusted = data.copy()
        adjusted['Adjusted Importance:'] = adjusted['Importance:'] * adjustment_factor
        
        # Store adjusted importances
        adjusted_importances[window] = adjusted
        global_importances.extend(adjusted['Importance:'].tolist())
    
    # Determine global threshold from unadjusted importances
    global_threshold = np.quantile(global_importances, 0.95)  # Adjust quantile as needed
    
    return adjusted_importances, global_threshold

In [16]:
import os

def filter_lcdata_with_preview(window, adjusted_rf_data, wt, D132H, global_threshold):
    '''Filters data using adjusted importance scores and a global threshold. Saves files in specific folders if positions meet the threshold.'''
    
    # Filter based on global threshold applied to unadjusted importance
    rf_filtered = adjusted_rf_data[adjusted_rf_data['Importance:'] >= global_threshold]
    
    if not rf_filtered.empty:
        # Print preview of positions above the threshold
        print(f"Window Size {window}, Positions above threshold:")
        print(rf_filtered['Position'])
        
        # Get indices to keep
        idx_keep = rf_filtered.index.values.tolist()
        index_keep = [int(i) for i in idx_keep]
        
        # Filter data based on selected indices
        filtered_wt = wt.iloc[:, index_keep]
        filtered_D132H = D132H.iloc[:, index_keep]
        
        # Create directories for filtered data if they do not exist
        wt_directory = 'wt_filtered'
        d132h_directory = 'D132H_filtered'
        
        if not os.path.exists(wt_directory):
            os.makedirs(wt_directory)
        if not os.path.exists(d132h_directory):
            os.makedirs(d132h_directory)
        
        # Define file paths
        wt_file_path = os.path.join(wt_directory, f"wt_filtered_{window}.lccdata")
        d132h_file_path = os.path.join(d132h_directory, f"D132H_filtered_{window}.lccdata")
        
        # Save filtered data to csv in the respective directories
        filtered_wt.to_csv(wt_file_path, sep='\t', mode='w')
        filtered_D132H.to_csv(d132h_file_path, sep='\t', mode='w')
        
        return filtered_wt, filtered_D132H, index_keep
    else:
        # If there are no positions above the threshold, print a message and skip file generation
        print(f"No positions above the threshold for window size {window}. No file generated.")
        return None, None, []

In [17]:
# Calculate adjusted importances and global threshold
adjusted_importances, global_threshold = calculate_adjusted_importance(acc_dict, rf_dict)

# Select all window sizes from 2 to 51
selected_windows = list(range(2, 52))

# Loop through all selected window sizes
for window in selected_windows:
    # Ensure that window is in adjusted_importances and original data dictionaries before proceeding
    if window in adjusted_importances and window in wt_dict and window in D132H_dict:
        filt_wt, filt_m, idx = filter_lcdata_with_preview(window, adjusted_importances[window], wt_dict[window], D132H_dict[window], global_threshold)
    else:
        print(f"Data for window size {window} is not available.")

Window Size 2, Positions above threshold:
27    119.0
39    131.0
Name: Position, dtype: float64
Window Size 3, Positions above threshold:
7      99.5
26    118.5
Name: Position, dtype: float64
Window Size 4, Positions above threshold:
6    99.0
Name: Position, dtype: float64
No positions above the threshold for window size 5. No file generated.
Window Size 6, Positions above threshold:
11    105.0
21    115.0
22    116.0
Name: Position, dtype: float64
Window Size 7, Positions above threshold:
11    105.5
21    115.5
Name: Position, dtype: float64
Window Size 8, Positions above threshold:
20    115.0
Name: Position, dtype: float64
No positions above the threshold for window size 9. No file generated.
Window Size 10, Positions above threshold:
31    127.0
Name: Position, dtype: float64
Window Size 11, Positions above threshold:
7     103.5
17    113.5
30    126.5
Name: Position, dtype: float64
Window Size 12, Positions above threshold:
16    113.0
29    126.0
Name: Position, dtype: floa

In [18]:
def count_positions_from_filtered_files(wt_directory='wt_filtered', d132h_directory='D132H_filtered', window_range=range(2, 52)):
    counts = {}
    total_positions = 0
    for window in window_range:
        try:
            wt_file_path = os.path.join(wt_directory, f"wt_filtered_{window}.lccdata")
            # Assuming the structure is similar for both types, so we read one for the count
            data = pd.read_csv(wt_file_path, sep='\t')
            count = len(data.columns) - 1  # Exclude any index column
            counts[window] = count
            total_positions += count
        except FileNotFoundError:
            # If the file doesn't exist, assume no positions were above threshold for this window
            counts[window] = 0
    return counts, total_positions


In [19]:
# Call the function after your existing code has executed
counts, total_positions_saved = count_positions_from_filtered_files()

# Print the summary
for window, count in counts.items():
    print(f"Window Size {window}: {count} Positions above threshold")

print(f"\n# Positions saved in total: {total_positions_saved}")


Window Size 2: 2 Positions above threshold
Window Size 3: 2 Positions above threshold
Window Size 4: 1 Positions above threshold
Window Size 5: 0 Positions above threshold
Window Size 6: 3 Positions above threshold
Window Size 7: 2 Positions above threshold
Window Size 8: 1 Positions above threshold
Window Size 9: 0 Positions above threshold
Window Size 10: 1 Positions above threshold
Window Size 11: 3 Positions above threshold
Window Size 12: 2 Positions above threshold
Window Size 13: 2 Positions above threshold
Window Size 14: 3 Positions above threshold
Window Size 15: 4 Positions above threshold
Window Size 16: 4 Positions above threshold
Window Size 17: 3 Positions above threshold
Window Size 18: 2 Positions above threshold
Window Size 19: 0 Positions above threshold
Window Size 20: 2 Positions above threshold
Window Size 21: 3 Positions above threshold
Window Size 22: 2 Positions above threshold
Window Size 23: 2 Positions above threshold
Window Size 24: 1 Positions above thresh