# Data Split Creation

This notebook creates data splits used to evaluate gRNAde on biologically dissimilar clusters of RNAs.

**Workflow:**
1. Cluster RNA sample sequences into groups based on: 
    - Sequence identity -- CD-HIT (Fu et al., 2012) with identity threshold of 90%.
    - Structural similarity -- US-align with similarity threshold 0.45 (TODO).
2. Order the clusters based on some metric:
    - Avg. of intra-sequence avg. RMSD among available structures
    - Avg. of intra-sequence number of structures available
3. Training, validation, and test splits become progressively harder.
    - Top 100 samples from clusters with highest metric -- test set.
    - Next 100 samples from clusters with highest metric -- validation set.
    - All remaining samples -- training set.
    - For clusters with >20 samples within them -- training set.
    - Very large (> 1000 nts) or very small (< 10nts) RNAs -- training set.
4. If any samples were not assigned clusters, append them to the training set.

Note that we separate very large RNA samples (> 1000 nts) from clustering and directly add these to the training set, as it is unlikely that we want to redesign very large RNAs. Likewise for very short RNA samples (< 10 nts).

In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('../')

import os
import subprocess
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes, InsetPosition, mark_inset
import seaborn as sns

from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord

from src.data_utils import get_avg_rmsds

  from .autonotebook import tqdm as notebook_tqdm


In [15]:
def create_clusters_sequence_identity(
        input_sequences, 
        identity_threshold = 0.9,
        word_size = 2,
        input_file = "input",
        output_file = "output"
    ):
    # https://manpages.ubuntu.com/manpages/impish/man1/cd-hit-est.1.html
        
    # Write input sequences to the temporary input file
    SeqIO.write(input_sequences, input_file, "fasta")

    # Run CD-HIT-EST
    cmd = [
        "cd-hit-est",
        "-i", input_file,
        "-o", output_file,
        "-c", str(identity_threshold), # Sequence identity threshold (e.g., 90%)
        "-n", str(word_size)          # Word size for sequence comparisson, larger is better (default: 2)
    ]
    subprocess.run(cmd, check=True)

    # Read clustered sequences from the temporary output file
    clustered_sequences = list(SeqIO.parse(output_file, "fasta"))

    # Process the clustering output
    seq_idx_to_cluster = {}
    with open(output_file + ".clstr", "r") as f:
        current_cluster = None
        for line in f:
            if line.startswith(">"):
                current_cluster = int(line.strip().split(" ")[1])
            else:
                sequence_id = int(line.split(">")[1].split("...")[0])
                seq_idx_to_cluster[sequence_id] = current_cluster

    # Delete temporary files
    os.remove(input_file)
    os.remove(output_file)
    os.remove(output_file + ".clstr")

    return clustered_sequences, seq_idx_to_cluster

In [None]:
# Load data list
data_list = torch.load(os.path.join("../data/", "processed.pt"))
print(len(data_list))

# List of sample sequences (used to create .fasta input file)
seq_list = []
for idx, data in enumerate(data_list):
    seq = data["seq"]
    seq_list.append(SeqRecord(Seq(seq), id=str(idx)))  # the ID for each sequence is its index in data_list

# List of intra-sequence avg. RMSDs
rmsd_list = get_avg_rmsds(data_list)

# List of number of structures per sequence
count_list = [len(data["coords_list"]) for data in data_list]

assert len(data_list) == len(seq_list) == len(rmsd_list) == len(count_list)

In [16]:
import pickle

with open("/home/dnori/rna-design/src/data/rf2na_split_dataset.pickle", 'rb') as handle:
    dataset = pickle.load(handle)
    train_dict = dataset["train"]
    test_dict = dataset["test"]
    val_dict = dataset["val"]
    print(len(train_dict), len(test_dict), len(val_dict))

seq_list = []
pdb_id_list = []
idx = 0
idx_to_pdb = {}
for dataset in [train_dict, val_dict, test_dict]:
    for k,v in dataset.items():
        seq = v["rna_seq"]
        seq_list.append(SeqRecord(Seq(seq), id=str(idx)))
        pdb_id_list.append(k)
        idx_to_pdb[str(idx)] = k
        idx += 1

1111 16 124


In [17]:
# Cluster at 80% sequence identity (lowest currently possible)
clustered_sequences, seq_idx_to_cluster = create_clusters_sequence_identity(seq_list, identity_threshold=0.8, word_size=2)

Program: CD-HIT, V4.8.1 (+OpenMP), May 15 2023, 22:49:31
Command: cd-hit-est -i input -o output -c 0.9 -n 2

Started: Fri Jan 12 12:04:01 2024
                            Output                              
----------------------------------------------------------------
Your word length is 2, using 5 may be faster!
total seq: 1006
longest and shortest : 242 and 11
Total letters: 44876
Sequences have been sorted

Approximated minimal memory consumption:
Sequence        : 0M
Buffer          : 1 X 17M = 17M
Table           : 1 X 0M = 0M
Miscellaneous   : 0M
Total           : 17M

Table limit with the given memory limit:
Max number of representatives: 4000000
Max number of word counting entries: 97773433

comparing sequences from          0  to       1006
.
     1006  finished        405  clusters

Approximated maximum memory consumption: 17M
writing new database
writing clustering information
program completed !

Total CPU time 0.12


In [18]:
result = {}
for idx, cluster in seq_idx_to_cluster.items():
    result[idx_to_pdb[str(idx)]] = cluster

In [20]:
with open("/home/dnori/rna-design/src/data/rna_seq_clusters.pkl", 'wb') as handle:
    pickle.dump(result, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [9]:
print(len(result.keys()))

1006


In [None]:
# Sanity check: it seems very short sequences (<10nt) are not being clustered.
# These will be added to the training set after initial splitting.
try:
    # Why does this fail? Guess: sequences are too short?
    assert len(seq_idx_to_cluster.keys()) == len(seq_list)
except:
    # Which sequence indices are not clustered? What are their corresponding sequences?
    idx_not_clustered = list(set(list(range(len(data_list)))) - set(seq_idx_to_cluster.keys()))
    print("Number of missing indices after clustering: ", len(idx_not_clustered))
    
    seq_lens = []
    for idx in idx_not_clustered:
        seq_lens.append(len(data_list[idx]["seq"]))
    print("Sequence lengths for missing indices:")
    print(f"    Distribution: {np.mean(seq_lens)} +- {np.std(seq_lens)}")
    print(f"    Max: {np.max(seq_lens)}, Min: {np.min(seq_lens)}")

In [None]:
# seq_idx_to_cluster: (index in data_list: cluster ID)
# (NEW) cluster_to_seq_idx_list: (cluster ID: list of indices in data_list)
cluster_to_seq_idx_list = {}
for seq_idx, cluster in seq_idx_to_cluster.items():
    # Sanity check to filter very large or very small RNAs
    if len(seq_list[seq_idx]) > 1000 or len(seq_list[seq_idx]) < 10 and seq_idx not in idx_not_clustered:
        idx_not_clustered.append(seq_idx)
        # print(f"Pruned idx {seq_idx} of length {len(seq_list[seq_idx])}.")
    else:
        if cluster in cluster_to_seq_idx_list.keys():
            cluster_to_seq_idx_list[cluster].append(seq_idx)
        else:
            cluster_to_seq_idx_list[cluster] = [seq_idx]
print("Number of unassigned indices (not clustered + too large + too small): ", len(idx_not_clustered))

In [None]:
# Cluster sizes: number of sequences in each cluster
cluster_ids = list(cluster_to_seq_idx_list.keys())
cluster_sizes = [len(list) for list in cluster_to_seq_idx_list.values()]

# Number of structures in each cluster (total and intra-sequence avg.)
total_structs_list = []
avg_structs_list = []
avg_rmsds_list = []
avg_seq_len_list = []
for cluster, seq_idx_list in cluster_to_seq_idx_list.items():
    count = []
    rmsds = []
    lens = []
    for seq_idx in seq_idx_list:
        count.append(count_list[seq_idx])
        rmsds.append(rmsd_list[seq_idx])
        lens.append(len(seq_list[seq_idx]))
    total_structs_list.append(np.sum(count))
    avg_structs_list.append(np.mean(count))
    avg_rmsds_list.append(np.mean(rmsds))
    avg_seq_len_list.append(np.mean(lens))

df = pd.DataFrame({
    'Cluster ID': cluster_ids,
    'Cluster size': cluster_sizes,
    'Total no. structures': total_structs_list,
    'Avg. sequence length': avg_seq_len_list,
    'Avg. intra-sequence no. structures': avg_structs_list,
    'Avg. intra-sequence avg. RMSD': avg_rmsds_list,
})
df

In [None]:
# RMSD Split

# Zip the two lists together
zipped = zip(cluster_ids, avg_rmsds_list)
# Sort the zipped list based on the values (descending order, highest first)
sorted_zipped = sorted(zipped, key=lambda x: x[1], reverse=True)
# Unzip the sorted list back into two separate lists
sorted_cluster_ids, sorted_avg_rmsds_list = zip(*sorted_zipped)

In [None]:
test_idx_list = []
val_idx_list = []
train_idx_list = []

for cluster in sorted_cluster_ids:
    seq_idx_list = cluster_to_seq_idx_list[cluster]
    cluster_size = len(seq_idx_list)

    # Test set
    if len(test_idx_list) < 100 and cluster_size < 25:
        test_idx_list += seq_idx_list
    
    # Validation set
    elif len(val_idx_list) < 100 and cluster_size < 25:
        val_idx_list += seq_idx_list
    
    # Training set
    else:
        train_idx_list += seq_idx_list

In [None]:
# Add all the sequences that were not assigned any clusters into the training set
try:
    assert len(test_idx_list) + len(val_idx_list) + len(train_idx_list) == len(data_list)
except:
    train_idx_list += idx_not_clustered
    assert len(test_idx_list) + len(val_idx_list) + len(train_idx_list) == len(data_list)

In [None]:
torch.save((train_idx_list, val_idx_list, test_idx_list), "../data/seqid_rmsd_split.pt")