In [1]:
import json
from multiprocessing import Pool, cpu_count, Manager, managers

import pandas as pd

from utils import (
    logging,
    TARGET_PROTEIN,
    SAMPLE_END_DATE,
    MUTATION_PER_SEQ_FILE,
    TRAINING_DATA_FILE,
    DUMMY_SEQ_NAMES_FILE
)
from utils.miscellaneous import mut_seq_info, aa_per_seq


In [2]:
logging.info("Load data...")
df: pd.DataFrame = pd.read_feather(MUTATION_PER_SEQ_FILE)

df["Date"] = pd.to_datetime(df["Date"])
df = df[df["Date"] < SAMPLE_END_DATE]
logging.info(f"{mut_seq_info(df)} before {SAMPLE_END_DATE}")

df = df[df["Mutation"].str.contains(TARGET_PROTEIN)]
logging.info(f"{mut_seq_info(df)} are on {TARGET_PROTEIN}")

df = df[~df["Mutation"].str.contains("stop")]
logging.info(f"{mut_seq_info(df)} are not stop codon")

df = df[~df["Mutation"].str.contains("ins")]
logging.info(f"{mut_seq_info(df)} are not insertion")


In [3]:
mut_info = pd.DataFrame(
    df["Mutation"].unique(),
    columns=["Mutation"]
)
mut_info_split = mut_info["Mutation"].str.split("_").str
mut_aa = mut_info_split[1].str.split("(\w)(\d+)(\w+)", expand=True)
mut_info["From"] = mut_aa[1]
mut_info["Pos"] = mut_aa[2].astype(int)
mut_info["To"] = mut_aa[3]

all_pos = mut_info["Pos"].unique()
all_pos.sort()
all_pos = pd.Series(all_pos, name="Pos")
all_pos = pd.DataFrame({
    "Pos": all_pos.values,
    "Pos_id": all_pos.index
})
mut_info = mut_info.merge(all_pos, on="Pos")

pos_info = mut_info[["From", "Pos", "Pos_id"]]
pos_info = pos_info.drop_duplicates().reset_index(drop=True)
pos_info = pos_info.rename(columns={"From": "To"})
logging.info(f"{len(pos_info)} mutated pos")


In [4]:
training_data: pd.DataFrame = df.merge(mut_info, on="Mutation")

manager: managers.SyncManager
with Pool(cpu_count()) as p, Manager() as manager:
    all_mut_comb = manager.dict()
    mut_sets = p.starmap(
        func=aa_per_seq,
        iterable=((*i, pos_info, all_mut_comb)
                  for i in training_data.groupby("Accession", sort=False))
    )
    mut_sets = pd.DataFrame.from_records(mut_sets)
    logging.info(f"{len(all_mut_comb)} mutation combination")
    all_mut_comb = pd.concat(list(all_mut_comb.values()))
    all_mut_comb = all_mut_comb.reset_index(drop=True)
    logging.info(f"{len(all_mut_comb)} data points")


In [5]:
mut_group: pd.DataFrame
selected_mut_sets = []
for mut, mut_group in mut_sets.groupby("Mut_set", sort=False):
    mut_group = mut_group.sort_values("Date")
    selected_mut_sets.append(mut_group.iloc[0].to_dict())

selected_mut_sets = pd.DataFrame.from_records(selected_mut_sets)

training_data = all_mut_comb.merge(selected_mut_sets, on="Mut_set")
training_data = training_data.drop("Mut_set", axis=1)

seq_info = pd.DataFrame(
    training_data["Accession"].unique(),
    columns=["Accession"]
)
seq_info["Seq_id"] = seq_info.index
training_data = training_data.merge(seq_info, on="Accession")


In [6]:
pos_group: pd.DataFrame
dummy_seqs = []
for pos, pos_group in mut_info.groupby("Pos"):
    unmutated = pos_group.iloc[0].to_dict()
    unmutated["To"] = unmutated["From"]
    unmutated["Mutation"] = f"{unmutated['Pos']}{unmutated['From']}"
    pos_group = pos_group.append(unmutated, ignore_index=True)
    dummy_seqs.append(pos_group)

dummy_seqs: pd.DataFrame = pd.concat(dummy_seqs)
dummy_seqs = dummy_seqs.reset_index(drop=True)


In [7]:
training_data = pd.concat([
    training_data,
    pd.DataFrame({
        "Accession": dummy_seqs["Mutation"].values,
        "Lineage": "None",
        "Date": SAMPLE_END_DATE,
        "Seq_id": dummy_seqs["Pos_id"].values + len(seq_info["Seq_id"]),
        "Pos": dummy_seqs["Pos"].values,
        "To": dummy_seqs["To"].values,
        "Pos_id": dummy_seqs["Pos_id"].values
    })
])
logging.info(f"{len(training_data)} after dummy added")


In [8]:
training_data = training_data.rename(columns={"To": "AA"})
aa_table = training_data["AA"].unique()
aa_table.sort()
aa_table = pd.Series(aa_table)
aa_table.index.name = "AA_idx"
aa_table = aa_table.reset_index(name="AA")

training_data = training_data.merge(aa_table, on="AA")


In [9]:
training_data.reset_index(drop=True)
training_data.to_feather(TRAINING_DATA_FILE)
logging.info(f"{TRAINING_DATA_FILE} saved!")

with open(DUMMY_SEQ_NAMES_FILE, "w") as f:
    json.dump(list(dummy_seqs["Mutation"].values), f)
