In [1]:
# Uncomment and run to reload libs
# import importlib, pyutils; importlib.reload(pyutils)


import json
from itertools import combinations

import pandas as pd

from pyutils import (
    TOTAL_NUM_FILE,
    FIXATION_LINKAGE_FILE,
    ALL_MUT_SETS_FILE,
    ALL_AA_COMBO_FILE,
    MUT_NODE_FILE,
    MUT_FREQ_FILE,
)


In [2]:
total_num = pd.read_csv(TOTAL_NUM_FILE, index_col=0)
fixation_linkage: pd.DataFrame = pd.read_csv(FIXATION_LINKAGE_FILE)
fixation_linkage["date"] = pd.to_datetime(fixation_linkage["date"])
fixation_linkage = fixation_linkage.sort_values("date")

all_aa_combo: pd.DataFrame = pd.read_csv(ALL_AA_COMBO_FILE)
all_mut_sets: pd.DataFrame = pd.read_csv(ALL_MUT_SETS_FILE)
all_mut_sets["Date"] = pd.to_datetime(all_mut_sets["Date"])


In [3]:
mut_set_info = pd.DataFrame(
    all_mut_sets["Mut_set"].unique(),
    columns=["Mut_set"]
)
mut_set_info["Mut_set_id"] = mut_set_info.index

all_aa_combo = all_aa_combo.merge(mut_set_info, on="Mut_set")
all_mut_sets = all_mut_sets.merge(mut_set_info, on="Mut_set")


In [4]:
# fixation_period = []
all_fixed_mut = set()

start_date, prev_label = fixation_linkage.values[0]
end_date = start_date
# print(start_date, prev_label)
for c_date, label in fixation_linkage.values[1:]:
    mut_split = prev_label.split(", ")
    all_fixed_mut = all_fixed_mut.union(mut_split)

#     mut = pd.Series(mut_split)
#     mut_str_split = mut.str.split("_").str
#     protein = mut_str_split[0]
#     protein.name = "Protein"
#     aa_pos_state = mut_str_split[1].str.extract(r"(\d+\w+)")[0]
#     aa_pos_state.name = "AA_state"
#     aa_pos = mut_str_split[1].str.extract(r"(\d+)").astype(int)[0]
#     aa_pos.name = "Pos"
#     mut_set = pd.concat([protein, aa_pos_state, aa_pos], axis=1).sort_values("Pos")
#     mut_set = ",".join((mut_set["Protein"] + "_" + mut_set["AA_state"]).values)
#     if prev_label == label:
#         end_date = c_date
#     else:
#         (mut_set_id, ) = mut_set_info.loc[mut_set_info["Mut_set"] == mut_set, "Mut_set_id"].unique()
#         fixation_period.append((
#             start_date,
#             end_date,
#             mut_set_id
#         ))
#         start_date = c_date
#         # break
    prev_label = label

# (mut_set_id, ) = mut_set_info.loc[mut_set_info["Mut_set"] == mut_set, "Mut_set_id"].unique()
# fixation_period.append((
#     start_date,
#     end_date,
#     mut_set_id
# ))

all_fixed_mut_set = pd.DataFrame(list(all_fixed_mut), columns=["Mut_name"])
mut_str_split = all_fixed_mut_set["Mut_name"].str.split("_").str
all_fixed_mut_set["Protein"] = mut_str_split[0]
mut_aa: pd.DataFrame = mut_str_split[1].str.split("([A-Z]|ins)(\d+)(\w+)", expand=True)
all_fixed_mut_set["Pos"] = mut_aa[2].astype(int)
all_fixed_mut_set["To"] = mut_aa[3]
all_fixed_mut_set = all_fixed_mut_set.sort_values(["Protein", "Pos"]).reset_index(drop=True)

synthetic_aa_pos_state = all_aa_combo.loc[
    all_aa_combo["Mut_set_id"] == 0,
    ["Protein", "Pos", "To"]
].set_index(["Protein", "Pos"]).copy(deep=True)

for p, pos, aa_state in all_fixed_mut_set[["Protein", "Pos", "To"]].values:
    synthetic_aa_pos_state.loc[(p, pos), "To"] = aa_state


In [5]:
mut_set_aa_state = {}

mut_set_group: pd.DataFrame
for mut_set_id, mut_set_group in all_aa_combo.groupby("Mut_set_id", sort=True):
    aa_state = mut_set_group[["Protein", "Pos", "To"]].sort_values(["Protein", "Pos"])
    aa_state = aa_state.drop_duplicates(subset=["Protein", "Pos"], keep="first")
    mut_set_aa_state[mut_set_id] = aa_state["To"].values

# aa_pos = (aa_state["Protein"] + "_" + aa_state["Pos"].astype(str)).values
aa_pos = [tuple(i) for i in (aa_state[["Protein", "Pos"]]).values]


In [6]:
(ref_mut_id,) = mut_set_info.loc[mut_set_info["Mut_set"].isna(), "Mut_set_id"].unique()
ref_mut_set = mut_set_aa_state[ref_mut_id]

mut_set_diff = {}
mut_set_count = {}

for i, mut_set_group in all_mut_sets.groupby("Mut_set_id", sort=False):
    mut_set = mut_set_aa_state[i]
    aa_diff = []
    for pos, ref_aa, aa in zip(aa_pos, ref_mut_set, mut_set):
        if ref_aa != aa:
            aa_diff.append((*pos, aa))
    mut_set_diff[tuple(aa_diff)] = i
    mut_set_count[i] = mut_set_group


In [7]:
aa_diff = []
for pos, ref_aa, aa in zip(aa_pos, ref_mut_set, synthetic_aa_pos_state["To"].values):
    if ref_aa != aa:
        aa_diff.append((*pos, aa))

mut_node = []
mut_freq = []
prev_mut_list = []
ghost_mut_set_id = mut_set_info["Mut_set_id"].max() + 1
for n in range(len(aa_diff) + 1):
    mut_list = []
    for muts in combinations(aa_diff, n):
        mut_set_id = mut_set_diff.get(muts, ghost_mut_set_id)
        mut_set_group = mut_set_count.get(mut_set_id, ())
        if len(mut_set_group):
            date_num = mut_set_group["Date"].value_counts()
            for c_date, num in date_num.items():
                c_date_str = c_date.strftime("%Y-%m-%d")
                total = total_num.loc[c_date_str, "Total"]
                mut_freq.append({
                    "mut_set_id": mut_set_id,
                    "date": c_date_str,
                    "ratio": num / total
                })
        else:
            ghost_mut_set_id += 1  # the hypothetic mutant never appeared in the population

        parent_link = []
        for prev_muts, parent_id in prev_mut_list:
            (single_diff, *other_diff) = set(muts).difference(prev_muts)
            if len(other_diff) == 0:
                parent_link.extend((parent_id, *single_diff))
        mut_list.append([muts, mut_set_id])
        recon_muts = sorted(tuple(parent_link[i+1:i+4]) for i in range(len(parent_link))[::4])
        assert all(mut_1 == mut_2 for mut_1, mut_2 in zip(sorted(muts), recon_muts))
        mut_node.append([mut_set_id, *parent_link])
    print(len(mut_list))
    prev_mut_list = mut_list

mut_freq = pd.DataFrame.from_records(mut_freq)


1
13
78
286
715
1287
1716
1716
1287
715
286
78
13
1


In [8]:
mut_freq.to_csv(MUT_FREQ_FILE, index=False)

with open(MUT_NODE_FILE, "w") as f:
    json.dump(mut_node, f)
