In [1]:
import json
from multiprocessing import Pool, Manager, cpu_count

import pandas as pd

from utils import (
    logging,
    TARGET_PROTEIN,
    SAMPLE_END_DATE,
    MUTATION_PER_SEQ_FILE,
    TRAINING_DATA_FILE,
    DUMMY_SEQ_NAMES_FILE
)
from utils.miscellaneous import mut_seq_info, aa_per_seq


In [2]:
logging.info("Load data...")
df: pd.DataFrame = pd.read_feather(MUTATION_PER_SEQ_FILE)

df["Date"] = pd.to_datetime(df["Date"])
df = df[df["Date"] < SAMPLE_END_DATE]
logging.info(f"{mut_seq_info(df)} before {SAMPLE_END_DATE}")

df = df[df["Mutation"].str.contains(TARGET_PROTEIN)]
logging.info(f"{mut_seq_info(df)} are on {TARGET_PROTEIN}")

df = df[~df["Mutation"].str.contains("stop")]
logging.info(f"{mut_seq_info(df)} are not stop codon")

df = df[~df["Mutation"].str.contains("ins")]
logging.info(f"{mut_seq_info(df)} are not insertion")


In [3]:
mut_info = pd.DataFrame(
    df["Mutation"].unique(),
    columns=["Mutation"]
)
mut_info_split = mut_info["Mutation"].str.split("_").str
mut_aa = mut_info_split[1].str.split("(\w)(\d+)(\w+)", expand=True)
mut_info["From"] = mut_aa[1]
mut_info["Pos"] = mut_aa[2].astype(int)
mut_info["To"] = mut_aa[3]

all_pos = mut_info["Pos"].unique()
all_pos.sort()
all_pos = pd.Series(all_pos, name="Pos")
all_pos = pd.DataFrame({
    "Pos": all_pos.values,
    "Pos_id": all_pos.index
})
mut_info = mut_info.merge(all_pos, on="Pos")

pos_info = mut_info[["From", "Pos", "Pos_id"]]
pos_info = pos_info.drop_duplicates().reset_index(drop=True)
pos_info = pos_info.rename(columns={"From": "To"})
logging.info(f"{len(pos_info)} mutated pos")


In [4]:
training_data = df.merge(mut_info, on="Mutation")

with Pool(cpu_count()) as p, Manager() as manager:
    aa_info_dict = manager.dict()
    p.starmap(
        func=aa_per_seq,
        iterable=((*i, pos_info, aa_info_dict)
                  for i in training_data.groupby("Accession", sort=False))
    )
    all_mut_comb = tuple(tuple(i) for i in aa_info_dict.keys())
    logging.info(f"{len(all_mut_comb)} mutation combination")
    training_data = pd.concat(
        list(aa_info_dict.values())).reset_index(drop=True)
    logging.info(f"{len(training_data)} data points")


In [5]:
seq_info = pd.DataFrame(
    training_data["Accession"].unique(),
    columns=["Accession"]
)
seq_info["Seq_id"] = seq_info.index
training_data = training_data.merge(seq_info, on="Accession")


In [6]:
dummy_seqs = []
for pos, pos_group in mut_info.groupby("Pos"):
    unmutated = pos_group.iloc[0].to_dict()
    unmutated["To"] = unmutated["From"]
    unmutated["Mutation"] = f"{unmutated['Pos']}{unmutated['From']}"
    pos_group = pos_group.append(unmutated, ignore_index=True)
    dummy_seqs.append(pos_group)

dummy_seqs = pd.concat(dummy_seqs).reset_index(drop=True)


In [7]:
training_data = pd.concat([
    training_data,
    pd.DataFrame({
        "Accession": dummy_seqs["Mutation"].values,
        "Lineage": "None",
        "Date": SAMPLE_END_DATE,
        "Seq_id": dummy_seqs["Pos_id"].values + len(seq_info["Seq_id"]),
        "Pos": dummy_seqs["Pos"].values,
        "To": dummy_seqs["To"].values,
        "Pos_id": dummy_seqs["Pos_id"].values
    })
])
logging.info(f"{len(training_data)} after dummy added")


In [8]:
training_data = training_data.rename(columns={"To": "AA"})
aa_table = training_data["AA"].unique()
aa_table.sort()
aa_table = pd.Series(aa_table)
aa_table.index.name = "AA_idx"
aa_table = aa_table.reset_index(name="AA")

training_data = training_data.merge(aa_table, on="AA")


In [9]:
training_data.reset_index(drop=True)
training_data.to_feather(TRAINING_DATA_FILE)
logging.info(f"{TRAINING_DATA_FILE} saved!")


In [10]:
with open(DUMMY_SEQ_NAMES_FILE, "w") as f:
    json.dump(list(dummy_seqs["Mutation"].values), f)
