# Data Split Creation

This notebook creates data splits used to evaluate gRNAde on biologically dissimilar clusters of RNAs.

**Workflow:**
1. Cluster RNA sample sequences into groups based on sequence identity -- CD-HIT (Fu et al., 2012) with identity threshold of 90%.
2. Order the clusters based on a metric (median intra-sequence RMSD among available structures within the cluster).
3. Training, validation, and test splits become progressively harder.
    - Top 100 samples from clusters with highest metric -- test set (only clusters with 1 unique sequence/samples).
    - Next 100 samples from clusters with highest metric -- validation set (clusters with less than 5 unique sequences/samples).
    - All remaining samples -- training set.
    - Very large (> 1000 nts) RNAs -- training set.
4. If any samples were not assigned clusters, append them to the training set.

Note that we separate very large RNA samples (> 1000 nts) from clustering and directly add these to the training set, as it is unlikely that we want to redesign very large RNAs. We do not process very short RNA samples (< 10 nts).

In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../')

import dotenv
dotenv.load_dotenv("/home/ckj24/rna-inverse-folding/.env")

In [None]:
import os
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm

import matplotlib.pyplot as plt
import seaborn as sns

import lovely_tensors as lt
lt.monkey_patch()

In [None]:
DATA_PATH = os.environ.get("DATA_PATH")
DATA_PATH

In [None]:
seq_to_data = torch.load(os.path.join(DATA_PATH, "processed.pt"))
df = pd.read_csv(os.path.join(DATA_PATH, "processed_df.csv"))
# df["cluster_structsim0.45"] = df["cluster_structsim0.45"].fillna(-1)
# df["cluster_seqid0.8"] = df["cluster_seqid0.8"].fillna(-1)
df

In [None]:
# Number of clusters
len(df["cluster_seqid0.8"].unique())

In [None]:
cluster_to_seq_idx_list = {}     # cluster -> list of seq idx in seq_to_data dict
excluded_idx = []                # indexes which will be excluded from val/test set

unclustered_id = int(df["cluster_seqid0.8"].max()) + 1  # some clusters were not assigned

for seq_idx, data in enumerate(tqdm(seq_to_data.values())):
    # exclude very long sequences which are generally ribosomal RNAs
    if len(data["sequence"]) > 1000:
        excluded_idx.append(seq_idx)
    
    else:
        cluster = int(data["cluster_seqid0.8"])
        if cluster == -1:
            # assign unclustered sequences to a new cluster
            cluster = unclustered_id
            unclustered_id += 1

        if cluster not in cluster_to_seq_idx_list:
            cluster_to_seq_idx_list[cluster] = []
        cluster_to_seq_idx_list[cluster].append(seq_idx)

print(f"Number of excluded sequences (> 1000 nts): {len(excluded_idx)}")

In [None]:
# Cluster sizes: number of sequences in each cluster
cluster_ids = list(cluster_to_seq_idx_list.keys())
cluster_sizes = [len(list) for list in cluster_to_seq_idx_list.values()]

# Median RMSD for each cluster (in same order as cluster_id list)
cluster_median_rmsds = []
for cluster, seq_idx_list in cluster_to_seq_idx_list.items():
    rmsds = []
    for seq_idx in seq_idx_list:
        sequence = list(seq_to_data.keys())[seq_idx]  # ugh, how longwinded...
        _rmsds = list(seq_to_data[sequence]["rmsds_list"].values())
        if len(_rmsds) > 0:
            rmsds += _rmsds
    if len(rmsds) > 0:
        cluster_median_rmsds.append(np.median(rmsds))
    else:
        cluster_median_rmsds.append(0.0)

df_split = pd.DataFrame({
    'Cluster ID': cluster_ids,
    'Cluster size': cluster_sizes,
    'Median intra-sequence RMSD': cluster_median_rmsds,
})
# Sort df_split by median intra-sequence RMSD
df_split = df_split.sort_values(by="Median intra-sequence RMSD", ascending=False)
df_split

In [None]:
# RMSD split preparation

# Initialize lists for test, validation, and training set indexes
test_idx_list = []
val_idx_list = []
train_idx_list = []

# Training, validation, and test splits become progressively harder.
#     - Top 100 samples from clusters with highest metric -- test set (only cluster_size = 1).
#     - Next 100 samples from clusters with highest metric -- validation set (cluster_size < 5).
#     - All remaining samples -- training set.
#     - Very large (> 1000 nts) RNAs -- training set.

for _, (cluster, cluster_size, median_rmsd) in df_split.iterrows():
    seq_idx_list = cluster_to_seq_idx_list[cluster]
    assert cluster_size == len(seq_idx_list)

    # Test set
    if len(test_idx_list) < 100 and cluster_size == 1:
        test_idx_list += seq_idx_list
    
    # Validation set
    elif len(val_idx_list) < 100 and cluster_size < 5:
        val_idx_list += seq_idx_list
    
    # Training set
    else:
        train_idx_list += seq_idx_list

# Add all the sequences that were not assigned any clusters into the training set
try:
    assert len(test_idx_list) + len(val_idx_list) + len(train_idx_list) == len(list(seq_to_data.keys()))
except:
    train_idx_list += excluded_idx
    assert len(test_idx_list) + len(val_idx_list) + len(train_idx_list) == len(list(seq_to_data.keys()))

In [None]:
# Save split
torch.save(
    (train_idx_list, val_idx_list, test_idx_list), 
    os.path.join(DATA_PATH, "seqid_split.pt")
)

In [None]:
df["split"] = "train"
df.loc[val_idx_list, "split"] = "val"
df.loc[test_idx_list, "split"] = "test"

for split in ["train", "val", "test"]:
    print(f"Split: {split}")
    print(f"Average median RMSD: {df.loc[df.split == split]['median_rmsd'].mean():.2f} +- {df.loc[df.split == split]['median_rmsd'].std():.2f}")
    print(f"Median number of structures: {df.loc[df.split == split]['num_structures'].median():.2f}")
    df.loc[df.split == split].hist(column=["length", "num_structures", "mean_rmsd", "median_rmsd"], figsize=(10, 10), bins=100)
    plt.show()
