In [None]:
import pandas as pd
import numpy as np
import subprocess
from sklearn.manifold import TSNE
from Bio.Cluster import kcluster
import matplotlib.pyplot as plt
from collections import Counter
import glob
import os

## Prepare for analysis

In [None]:
input="TEVp-240412"
input_dataframe=pd.read_csv(f"output/{input}/opt_binders/all.csv")
input_dataframe

In [None]:
# Functions
# add scaffold name column
def add_scaffold_name_column(filtered, prefix):
    filtered["scaffold_name"] = ""
    
    for index, row in filtered.iterrows():
        path = row["model_path"]
        prefix = prefix
        file_name = path.split("/")[-1]
        parts = file_name.split(prefix)[-1].split("_")
        
        if len(parts) >= 5:
            result = f"{parts[0]}_{parts[1]}"
        else:
            result = parts[0].split(".")[0]
        
        filtered.at[index, "scaffold_name"] = result
    
    return filtered

def repeat_rows_by_column_value(df, column_name, number):
    unique_values = df[column_name].unique()
    repeated_rows = []

    for value in unique_values:
        subset = df[df[column_name] == value]
        num_repeats = min(number, subset.shape[0])
        repeated_rows.extend([subset.iloc[i, :] for i in range(num_repeats)])

    repeated_df = pd.DataFrame(repeated_rows)
    return repeated_df

#best_binders=add_scaffold_name_column(best_binders, input+"_")

## Filter dataframe

In [None]:
#filtered = input_dataframe[(input_dataframe["rmsd"]<1)&(input_dataframe["plddt"]>0.85)&input_dataframe["i_pae"]<6]
filtered =input_dataframe[(input_dataframe["plddt"]>0.90)&(input_dataframe["i_pae"]<8)&(input_dataframe["rmsd"]<1.5)]

filtered = filtered.sort_values(by='plddt', ascending=False).drop_duplicates("model_path").reset_index()#.head(9900)
filtered=add_scaffold_name_column(filtered, input+"_")
filtered

In [None]:
# Calculate statistics on scaffolds
scaffold_counts = filtered["scaffold_name"].value_counts()
total_unique_scaffolds = len(scaffold_counts)
total_scaffold_instances = scaffold_counts.sum()

print("Total unique scaffolds:", total_unique_scaffolds)
print("Total scaffold instances:", total_scaffold_instances)
print("\nScaffold counts:")
print(scaffold_counts)

In [None]:
designs_per_scaffold = 200

filtered=repeat_rows_by_column_value(filtered, "scaffold_name", designs_per_scaffold)
#folder=f"output/{input}/opt_binders/test"
#os.makedirs(folder, exist_ok=True)
#for path in filtered["model_path"]:
#    !cp $path $folder


In [None]:
# good_scaffolds = ["2_","55_","61_","54_","24_","13_","11_"]

# # Filter the DataFrame
# #filtered = filtered[filtered['scaffold_name'].isin(good_scaffolds)]
# filtered = filtered[~filtered['scaffold_name'].isin(good_scaffolds)]
# filtered
#for  scaff in good_scaffolds:
#    filtered=filtered[scaff]

# bad_scaffolds = ["model_30","model_42", "model_1-lcb3"]

# Filter the DataFrame
# good = filtered[~filtered['scaffold_name'].isin(bad_scaffolds)]
# good

In [None]:
filtered["seq_split"] = filtered["seq"].apply(lambda x: x.split("/")[-1])
#filtered["seq_split"].to_list()

## Testing cluster sequences

In [None]:
seqs=filtered["seq_split"].to_list()

In [None]:
def length_statistics(input_list):
    length_stats = {}
    
    for item in input_list:
        item_length = len(item)
        length_stats[item_length] = length_stats.get(item_length, 0) + 1
    
    return length_stats

seqs_len=length_statistics(seqs)
print(seqs_len)

In [None]:
num_clusters=25

seqs=filtered["seq_split"].to_list()
#matrix = np.asarray([np.frombuffer(seq.encode(), dtype=np.uint8) for seq in seqs])
max_length = max(len(seq) for seq in seqs)
padded_seqs = [seq.ljust(max_length, 'N') for seq in seqs]
matrix = np.asarray([np.frombuffer(seq.encode(), dtype=np.uint8) for seq in padded_seqs])
clusterid, error, nfound = kcluster(matrix, nclusters=num_clusters)

# Apply t-SNE to the matrix to reduce the dimensionality and visualize the sequences.
tsne = TSNE(n_components=2, random_state=42)
embedded_matrix = tsne.fit_transform(matrix)

# Create a scatter plot of the embedded points and label them with cluster IDs.
plt.figure(figsize=(10, 6))
for cluster in range(num_clusters):
    cluster_points = embedded_matrix[clusterid == cluster]
    plt.scatter(cluster_points[:, 0], cluster_points[:, 1], label=f"Cluster {cluster}")

plt.title(f"t-SNE Visualization of {input} best protein sequences")
plt.xlabel("t-SNE Dimension 1")
plt.ylabel("t-SNE Dimension 2")
plt.legend()
#plt.savefig(f"output/{input}/filtered_sequences/tsne_binders.png")
plt.show()


# Print the number of sequences in each cluster.
cluster_counts = Counter(clusterid)
sorted_cluster_counts = dict(sorted(cluster_counts.items()))
for cluster, count in sorted_cluster_counts.items():
    print(f"Cluster {cluster}: {count} sequences")

# Add cluster id to dataframe
filtered["clusterid"]=clusterid
#filtered.to_csv(f"output/{input}/filtered_sequences/2_filtered_binders_clus.csv", index=False)

# Calculate average cluster metrics
average_metrics_by_cluster = filtered.groupby('clusterid').mean()
#average_metrics_by_cluster.to_csv(f"output/{input}/filtered_sequences/2_cluster_average.csv", index=False)
average_metrics_by_cluster

## Prepare metrics command

In [None]:
save_path = f"output/{input}/opt_binders/filtered.csv" # Save filtered
metric_path = f"output/{input}/opt_binders/metrics.csv" # Save filtered with metrics

# Make filtered dataframe or append new sequences to the old one
if os.path.exists(save_path):
    print("reading existant dataframe...")
    existing_dataframe = pd.read_csv(save_path)
    filtered_new = filtered[~filtered["model_path"].isin(existing_dataframe["model_path"])]
    print(f"existing dataframe of len: {len(existing_dataframe)}, new filtered: {len(filtered_new)}")
    existing_dataframe = pd.concat([existing_dataframe, filtered_new], ignore_index=True)
    print(f"final length: {len(existing_dataframe)}")
    existing_dataframe = existing_dataframe.sort_values(by='plddt', ascending=False)
    #drop duplicates
    existing_dataframe.to_csv(save_path, index=False)
    existing_dataframe.to_csv(metric_path, index=False)

else:
    filtered.to_csv(save_path, index=False)
    filtered.to_csv(metric_path, index=False)
    existing_dataframe=filtered

In [None]:
print(save_path)

In [None]:
#test_path="/home/tsatler/RFdif/ClusterProteinDesign/scripts/binder_design/output/{input}/opt_binders/binders/test"
#os.makedirs(test_path, exist_ok=True)

#for pdb in filtered["model_path"].head(10):
#    !cp $pdb $test_path

## Prepare input files for analysis script

In [None]:
save_directory = f"output/{input}/opt_binders/analysis_input"

if not os.path.exists(save_directory):
    os.makedirs(save_directory)

batch_size = 1000

# Split the model_paths into batches
model_paths = existing_dataframe["model_path"]
batches = [model_paths[i:i + batch_size] for i in range(0, len(model_paths), batch_size)]

# Save each batch as a separate TXT file
for i, batch in enumerate(batches):
    save_path = os.path.join(save_directory, "model_paths_" + str(i) + ".txt")
    with open(save_path, "w") as file:
        file.write("\n".join(batch))

## Run analysis script

In [None]:
input_files=glob.glob(f"{save_directory}/*txt")
array_limit=300//len(input_files)
target_chain="A"
binder_chain="B"
xml_file="helper_scripts/metrics_calc.xml"

commands=[]

for input_file in input_files:
    with open(input_file, "r") as file:
        lines = file.readlines()
    array_number = len(lines)-1

    bash_arguments=f"--output=/dev/null --array=0-{array_number}%{array_limit}"
    script_arguments=f"{input_file} {target_chain} {binder_chain} {metric_path} {xml_file}"

    command = f"sbatch {bash_arguments} helper_scripts/binder_analysis.sh {script_arguments}"
    print(command)
    commands.append(command)

print(f"This will run {len(commands)} array scripts")

In [None]:
# Run the array bash script
for command in commands:
    subprocess.run(command, shell=True)

In [None]:
!squeue --me