In [27]:
import pandas as pd
from pathlib import Path
from tqdm.notebook import tqdm
from collections import defaultdict

In [40]:
splits = ["train", "dev", "test"]
splits_dfs = defaultdict(list)

for split in splits:
    print(f"Split: {split}")
    for path in tqdm(Path(f"data/{split}").iterdir()):
        if not path.is_file():
            continue

        splits_dfs[split].append(pd.read_csv(path))
    splits_dfs[split] = pd.concat(splits_dfs[split])

Split: train


0it [00:00, ?it/s]

Split: dev


0it [00:00, ?it/s]

Split: test


0it [00:00, ?it/s]

In [41]:
# Use top 64 classes from train, similar top classes in dev/test
n_classes = 64
top_n_train = splits_dfs["train"]["family_accession"].value_counts().head(n_classes)
top_n_dev = splits_dfs["dev"]["family_accession"].value_counts().head(2 + n_classes)
top_n_test = splits_dfs["test"]["family_accession"].value_counts().head(2 + n_classes)
print(len(top_n_train.index.intersection(top_n_dev.index)))
print(len(top_n_train.index.intersection(top_n_test.index)))

64
64


In [44]:
rare_acids = ["X", "U", "B", "O", "Z"]


def preprocess_df(df, max_seq_len=256, pad_token="-"):
    # filter for top n classes in train/dev, truncate and pad
    df = df[df["family_accession"].isin(top_n_train.index)]
    df = df[~df["sequence"].str.contains("|".join(rare_acids), regex=True)]
    df["sequence"] = df["sequence"].str.slice(0, max_seq_len)
    df["sequence"] = df["sequence"].str.pad(
        width=max_seq_len, side="right", fillchar=pad_token
    )
    return df


for split in splits:
    splits_dfs[split] = preprocess_df(splits_dfs[split])
    display(splits_dfs[split].info())

<class 'pandas.core.frame.DataFrame'>
Index: 66226 entries, 8 to 13353
Data columns (total 5 columns):
 #   Column            Non-Null Count  Dtype 
---  ------            --------------  ----- 
 0   family_id         66226 non-null  object
 1   sequence_name     66226 non-null  object
 2   family_accession  66226 non-null  object
 3   aligned_sequence  66226 non-null  object
 4   sequence          66226 non-null  object
dtypes: object(5)
memory usage: 3.0+ MB


None

<class 'pandas.core.frame.DataFrame'>
Index: 8249 entries, 29 to 12755
Data columns (total 5 columns):
 #   Column            Non-Null Count  Dtype 
---  ------            --------------  ----- 
 0   family_id         8249 non-null   object
 1   sequence_name     8249 non-null   object
 2   family_accession  8249 non-null   object
 3   aligned_sequence  8249 non-null   object
 4   sequence          8249 non-null   object
dtypes: object(5)
memory usage: 386.7+ KB


None

<class 'pandas.core.frame.DataFrame'>
Index: 8242 entries, 1 to 12754
Data columns (total 5 columns):
 #   Column            Non-Null Count  Dtype 
---  ------            --------------  ----- 
 0   family_id         8242 non-null   object
 1   sequence_name     8242 non-null   object
 2   family_accession  8242 non-null   object
 3   aligned_sequence  8242 non-null   object
 4   sequence          8242 non-null   object
dtypes: object(5)
memory usage: 386.3+ KB


None

In [45]:
total = sum([len(df) for df in splits_dfs.values()])
for split in splits:
    print(f"Split proportion {len(splits_dfs[split]) / total * 100:.2f}")

Split proportion 80.06
Split proportion 9.97
Split proportion 9.96


In [46]:
save_path = "data/processed_data"
for split, df in splits_dfs.items():
    df.to_parquet(f"{save_path}/{split}.parquet")