In [1]:
import os
os.chdir("/n/data1/hms/dbmi/oconnor/lab/shz311/pangenome/graph_var")

from graph_var.utils import load_graph_from_pkl, merge_dicts, log_action
from graph_var.graph import PangenomeGraph
from graph_var.evaluating_functions import *
import pandas as pd

In [2]:
version = 'v1'
ref_name = 'CHM13'

graph_obj_dir = f"/n/data1/hms/dbmi/oconnor/lab/shz311/pangenome/Graph_objs_{version}{'_chm13' if ref_name == 'CHM13' else ''}"
raw_vcf_dir = f"/n/data1/hms/dbmi/oconnor/lab/shz311/pangenome/VCFs_chr"
graph_vcf_dir = f"/n/data1/hms/dbmi/oconnor/lab/shz311/pangenome/VCFs_{version}{'_chm13' if ref_name == 'CHM13' else ''}"

ref_tree_dir = f"/n/data1/hms/dbmi/oconnor/lab/shz311/pangenome/Data/reference_tree_gfa_{version}{'_chm13' if ref_name == 'CHM13' else ''}"
gfa_dir = f"/n/data1/hms/dbmi/oconnor/lab/shz311/pangenome/Data/chromosome_gfa_{version}{'_chm13' if ref_name == 'CHM13' else ''}"
snarl_dir = f"/n/data1/hms/dbmi/oconnor/lab/shz311/pangenome/Data/chr_snarls_{version}{'_chm13' if ref_name == 'CHM13' else ''}"
bubble_summary_dir = f"/n/data1/hms/dbmi/oconnor/lab/shz311/pangenome/Bubble_summary_{version}{'_chm13' if ref_name == 'CHM13' else ''}"

var_summary_dir = f"/n/data1/hms/dbmi/oconnor/lab/shz311/pangenome/Stats_chr_{version}{'_chm13' if ref_name == 'CHM13' else ''}"
data_vis_dir = f"/n/data1/hms/dbmi/oconnor/lab/shz311/pangenome/Data_visualization_{version}{'_chm13' if ref_name == 'CHM13' else ''}"

region_dir = f"/n/data1/hms/dbmi/oconnor/lab/shz311/pangenome/Region_files"

mode = 'AT'
exclude_terminus = True

In [7]:
chr_set = 'Y'

if chr_set == 'autosome':
    chr_list = list(range(1, 23))
elif chr_set == 'X':
    chr_list = ['X']
elif chr_set == 'Y':
    chr_list = ['Y']
else:
    raise ValueError("chr_set must be one of ['autosome', 'X', 'Y']")

num_chr = len(chr_list)

Part 1: Summary for all variant edges

In [8]:
def variant_edges_summary_from_dict(var_list: list, var_dict: dict):
    summary_dict = dict()
    for edge in sorted(list(var_list)):
        summary_dict[var_dict[edge]] = summary_dict.get(var_dict[edge], 0) + 1
    summary_dict['Total'] = len(var_list)
    return summary_dict

def prepare_dataframe_dict(var_dict):
    return pd.DataFrame({
        "Variant Type": ['SNP', 'MNP', 'Insertion', 'Deletion', 'Replacement', 'Inversion', 'Duplication', 'Total'],
        "Count": [
                  var_dict.get('SNP', 0),
                  var_dict.get('MNP', 0),
                  var_dict.get('INS', 0),
                  var_dict.get('DEL', 0),
                  var_dict.get('REP', 0),
                  var_dict.get('INV', 0),
                  var_dict.get('DUP', 0),
                  var_dict.get('Total', 0),
        ]
    })

def comprehensive_summary(graph_vcf_input: Union[str, pd.DataFrame], 
                          chrom_label: str, 
                          tandem_repeat: bool=False):
    variant_sets = defaultdict(set)
    var_dict = dict()

    if isinstance(graph_vcf_input, str):
        row_iter = read_vcf_line_by_line(graph_vcf_input)
    elif isinstance(graph_vcf_input, pd.DataFrame):
        row_iter = graph_vcf_input.to_dict(orient="records")
    else:
        raise ValueError("Input must be either a path to VCF or a pandas DataFrame.")

    for row in row_iter:
        edge = row['ID']
        info_dict = {k: v for k, v in (field.split('=') for field in row['INFO'].split(';') if '=' in field)}
        var_type = info_dict['VT']
        nearly_identical = int(info_dict['NIA'])
        allele_count = int(info_dict['AC']) if var_type == 'INV' else min(int(info_dict['RC']), int(info_dict['AC']))
        ref_allele = row['REF'] if row['REF'] != '.' else info_dict['NR']
        alt_allele = row['ALT']
        allele_length = len(ref_allele) + len(alt_allele)
        on_linear = (int(info_dict['DR'].split(',')[1]) == 0)

        is_repeat = info_dict['TR_MOTIF'] != '.'

        if tandem_repeat and not is_repeat:
            continue

        var_dict[edge] = var_type
        variant_sets['All'].add(edge)
        variant_sets['Linear' if on_linear else 'Off_Linear'].add(edge)
        variant_sets['Small' if allele_length < 50 or nearly_identical else 'Large'].add(edge)
        variant_sets['Common' if allele_count >= 5 else 'Uncommon'].add(edge)

    # Generate combinations
    for a in ['Linear', 'Off_Linear']:
        for b in ['Small', 'Large']:
            variant_sets[f'{a}_{b}'] = variant_sets[a].intersection(variant_sets[b])
        for c in ['Common', 'Uncommon']:
            variant_sets[f'{a}_{c}'] = variant_sets[a].intersection(variant_sets[c])
    for b in ['Small', 'Large']:
        for c in ['Common', 'Uncommon']:
            variant_sets[f'{b}_{c}'] = variant_sets[b].intersection(variant_sets[c])
            for a in ['Linear', 'Off_Linear']:
                variant_sets[f'{a}_{b}_{c}'] = variant_sets[a].intersection(variant_sets[f'{b}_{c}'])

    # Construct DataFrames
    all_var_df = prepare_dataframe_dict(variant_edges_summary_from_dict(variant_sets['All'], var_dict))
    summary_dict = {'CHROM': [chrom_label] * len(all_var_df), 'Variant Type': all_var_df['Variant Type']}
    for key, variant_set in variant_sets.items():
        df = prepare_dataframe_dict(variant_edges_summary_from_dict(variant_set, var_dict))
        summary_dict[key+"_Variants"] = df['Count']

    return pd.DataFrame(summary_dict, columns=[
        "CHROM",
        "Variant Type",
        "All_Variants",
        "Linear_Variants",
        "Off_Linear_Variants",
        "Small_Variants",
        "Large_Variants",
        "Common_Variants",
        "Uncommon_Variants",
        "Linear_Small_Variants",
        "Linear_Large_Variants",
        "Off_Linear_Small_Variants",
        "Off_Linear_Large_Variants",
        "Linear_Common_Variants",
        "Linear_Uncommon_Variants",
        "Off_Linear_Common_Variants",
        "Off_Linear_Uncommon_Variants",
        "Small_Common_Variants",
        "Small_Uncommon_Variants",
        "Large_Common_Variants",
        "Large_Uncommon_Variants",
        "Linear_Small_Common_Variants",
        "Linear_Large_Common_Variants",
        "Linear_Small_Uncommon_Variants",
        "Linear_Large_Uncommon_Variants",
        "Off_Linear_Small_Common_Variants",
        "Off_Linear_Large_Common_Variants",
        "Off_Linear_Small_Uncommon_Variants",
        "Off_Linear_Large_Uncommon_Variants"
        ])

In [9]:
for i in tqdm(chr_list):
    graph_vcf_path = f"{graph_vcf_dir}/graph_chr{i}{'_no_terminus' if exclude_terminus else ''}.vcf"
    comp_var_summary_df = comprehensive_summary(graph_vcf_path, f"chr{i}")
    comp_var_summary_df.to_csv(f"{var_summary_dir}/comprehensive_variant_summary_for_chr{i}.csv", index=False)

100%|██████████| 1/1 [00:02<00:00,  2.49s/it]


Hatch-marked for repeated variants

In [10]:
for i in tqdm(chr_list):
    graph_vcf_path = f"{graph_vcf_dir}/graph_chr{i}{'_no_terminus' if exclude_terminus else ''}.vcf"
    comp_var_summary_df = comprehensive_summary(graph_vcf_path, f"chr{i}", tandem_repeat=True)
    comp_var_summary_df.to_csv(f"{var_summary_dir}/comprehensive_repeated_variant_summary_for_chr{i}.csv", index=False)

100%|██████████| 1/1 [00:01<00:00,  1.70s/it]


Part 2: Summary for variant edges based on annotated regions

In [5]:
# other_difficult_bed = "/n/data1/hms/dbmi/oconnor/lab/shz311/pangenome/Region_files/GRCh38_subtract_difficult_segdup.bed"
easy_bed = f"{region_dir}/GRCh38_notinalldifficultregions.bed"
segdup_bed = f"{region_dir}/GRCh38_segdups.bed"

TRregion_bed = f"{region_dir}/adotto_TRregions_v1.2.1.types.bed"

In [6]:
def process_raw_vcf(vcf_df):
    def split_vcf_row_minimal(row):
        ACs = row['INFO'].split(';')[0].split('=')[1].split(',')
        alts = row['ALT'].split(',')
        return [{key if key != 'INFO' else 'AC': row[key] if key != 'ALT' and key != 'INFO' else alt if key == 'ALT' else AC \
            for key in row.keys()} for alt, AC in zip(alts, ACs)]

    split_rows = []
    for i in range(len(vcf_df)):
        split_rows.extend(split_vcf_row_minimal(vcf_df.iloc[i]))

    splited_vcf_df = pd.DataFrame(split_rows)
    return splited_vcf_df[['#CHROM', 'POS', 'REF', 'ALT', 'AC']]

def process_ourvcf(graph_vcf_df):
    def find_type(x):
        info_dict = {attr.split('=')[0]: attr.split('=')[1] for attr in x.split(';') if '=' in attr}
        return info_dict['VT']
    graph_vcf_df['Variant_Type'] = graph_vcf_df['INFO'].apply(find_type)
    def if_linear(x):
        info_dict = {attr.split('=')[0]: attr.split('=')[1] for attr in x.split(';') if '=' in attr}
        return int(info_dict['DR'].split(',')[1]) == 0
    graph_vcf_df['Linear'] = graph_vcf_df['INFO'].apply(if_linear)
    def find_ac(x):
        info_dict = {attr.split('=')[0]: attr.split('=')[1] for attr in x.split(';') if '=' in attr}
        ac = min(int(info_dict['AC']), int(info_dict['RC'])) if info_dict['VT'] != 'INV' else int(info_dict['AC'])
        return ac
    #graph_vcf_df['AC'] = graph_vcf_df.apply(lambda x: sum([int(AC) for sample in sample_cols for AC in x[sample].split('|')]), axis=1)
    graph_vcf_df['AC'] = graph_vcf_df['INFO'].apply(find_ac)
    return graph_vcf_df[['#CHROM', 'POS', 'REF', 'ALT', 'AC', 'Linear', 'Variant_Type']]

In [8]:
def get_df_from_vcf_filtered_by_region(chr_id, raw_vcf=True, graph_vcf=True):
    rawvcf_path = f"{raw_vcf_dir}/hprc-v1.1-mc-grch38.raw_{chr_id}.vcf"
    vcfwave_path = f"{raw_vcf_dir}/hprc-v1.1-mc-grch38.vcfbub.a100k.wave_{chr_id}.vcf"
    graph_vcf_path = f"{graph_vcf_dir}/graph_{chr_id}{'_no_terminus' if exclude_terminus else ''}.vcf"

    easy_region = get_interval_tree_from_bed(easy_bed, chr_id)
    segdup_region = get_interval_tree_from_bed(segdup_bed, chr_id)
    # diff_region = get_interval_tree_from_bed(other_difficult_bed, chr_id)

    if raw_vcf:
        raw_vcf_df = read_vcf_to_dataframe(rawvcf_path)
        raw_vcf_df = process_raw_vcf(raw_vcf_df)

        split_raw_vcf_df = raw_vcf_df

        raw_easy_bool = split_raw_vcf_df['POS'].apply(lambda x: len(easy_region[int(x)]) > 0)
        raw_segdup_bool = split_raw_vcf_df['POS'].apply(lambda x: len(segdup_region[int(x)]) > 0)
        raw_diff_bool = ~raw_easy_bool & ~raw_segdup_bool 

        easy_raw_vcf_df = split_raw_vcf_df[raw_easy_bool]
        segdup_raw_vcf_df = split_raw_vcf_df[raw_segdup_bool]
        diff_raw_vcf_df = split_raw_vcf_df[raw_diff_bool]

        wave_vcf_df = read_vcf_to_dataframe(vcfwave_path)
        wave_vcf_df = process_raw_vcf(wave_vcf_df)

        split_wave_vcf_df = wave_vcf_df

        wave_easy_bool = split_wave_vcf_df['POS'].apply(lambda x: len(easy_region[int(x)]) > 0)
        wave_segdup_bool = split_wave_vcf_df['POS'].apply(lambda x: len(segdup_region[int(x)]) > 0)
        wave_diff_bool = ~wave_easy_bool & ~wave_segdup_bool 

        easy_wave_vcf_df = split_wave_vcf_df[wave_easy_bool]
        segdup_wave_vcf_df = split_wave_vcf_df[wave_segdup_bool]
        diff_wave_vcf_df = split_wave_vcf_df[wave_diff_bool]
    else:
        easy_raw_vcf_df = None
        segdup_raw_vcf_df = None
        diff_raw_vcf_df = None

        easy_wave_vcf_df = None
        segdup_wave_vcf_df = None
        diff_wave_vcf_df = None

    if graph_vcf:
        graph_vcf_df = read_vcf_to_dataframe(graph_vcf_path)
        graph_vcf_df = process_ourvcf(graph_vcf_df)
        simple_graph_vcf_df = graph_vcf_df

        graph_easy_bool = simple_graph_vcf_df['POS'].apply(lambda x: len(easy_region[int(x)]) > 0)
        graph_segdup_bool = simple_graph_vcf_df['POS'].apply(lambda x: len(segdup_region[int(x)]) > 0)
        graph_diff_bool = ~graph_easy_bool & ~graph_segdup_bool

        easy_graph_vcf_df = simple_graph_vcf_df[graph_easy_bool]
        segdup_graph_vcf_df = simple_graph_vcf_df[graph_segdup_bool]
        diff_graph_vcf_df = simple_graph_vcf_df[graph_diff_bool]
    else:
        easy_graph_vcf_df = None
        segdup_graph_vcf_df = None
        diff_graph_vcf_df = None
    
    return {
        "easy_raw": easy_raw_vcf_df,
        "segdup_raw": segdup_raw_vcf_df,
        "diff_raw": diff_raw_vcf_df,
        "easy_wave": easy_wave_vcf_df,
        "segdup_wave": segdup_wave_vcf_df,
        "diff_wave": diff_wave_vcf_df,
        "easy_graph": easy_graph_vcf_df,
        "segdup_graph": segdup_graph_vcf_df,
        "diff_graph": diff_graph_vcf_df,
    }

def get_df_from_vcf(chr_id, not_in_easy_region=False):
    rawvcf_path = f"{raw_vcf_dir}/hprc-v1.1-mc-grch38.raw_{chr_id}.vcf"
    vcfwave_path = f"{raw_vcf_dir}/hprc-v1.1-mc-grch38.vcfbub.a100k.wave_{chr_id}.vcf"
    graph_vcf_path = f"{graph_vcf_dir}/graph_{chr_id}{'_no_terminus' if exclude_terminus else ''}.vcf"

    raw_vcf_df = read_vcf_to_dataframe(rawvcf_path)
    wave_vcf_df = read_vcf_to_dataframe(vcfwave_path)
    graph_vcf_df = read_vcf_to_dataframe(graph_vcf_path)

    if not_in_easy_region:
        easy_region = get_interval_tree_from_bed(easy_bed, chr_id)

    raw_vcf_df = process_raw_vcf(raw_vcf_df)
    split_raw_vcf_df = raw_vcf_df

    wave_vcf_df = process_raw_vcf(wave_vcf_df)
    split_wave_vcf_df = wave_vcf_df

    if not_in_easy_region:
        split_raw_vcf_df = split_raw_vcf_df[split_raw_vcf_df['POS'].apply(lambda x: len(easy_region[int(x)]) == 0)]
        split_wave_vcf_df = split_wave_vcf_df[split_wave_vcf_df['POS'].apply(lambda x: len(easy_region[int(x)]) == 0)]
        

    graph_vcf_df = process_ourvcf(graph_vcf_df)
    simple_graph_vcf_df = graph_vcf_df

    if not_in_easy_region:
        easy_region = get_interval_tree_from_bed(easy_bed, chr_id)
        simple_graph_vcf_df = simple_graph_vcf_df[simple_graph_vcf_df['POS'].apply(lambda x: len(easy_region[int(x)]) == 0)]

    return simple_graph_vcf_df, split_raw_vcf_df, split_wave_vcf_df

def get_vcf_df_separated_by_regions(chr_id):
    # vcfwave_path = f"{raw_vcf_dir}/hprc-v1.1-mc-grch38.vcfbub.a100k.wave_{chr_id}.vcf"
    graph_vcf_path = f"{graph_vcf_dir}/graph_{chr_id}{'_no_terminus' if exclude_terminus else ''}.vcf"

    # raw_vcf_df = read_vcf_to_dataframe(vcfwave_path)
    graph_vcf_df = read_vcf_to_dataframe(graph_vcf_path)
    # graph_vcf_df = process_ourvcf(graph_vcf_df)
    simple_graph_vcf_df = graph_vcf_df

    region_intervaltree_dict = get_interval_trees_from_bed(TRregion_bed, chr_id)
    segdup_region = get_interval_tree_from_bed(segdup_bed, chr_id)

    TR_region_dict = {}
    for region, tree in region_intervaltree_dict.items():
        TR_region_dict[region+"_segdup"] = simple_graph_vcf_df[simple_graph_vcf_df['POS'].apply(lambda x: len(tree[int(x)]) > 0 and len(segdup_region[int(x)]) > 0)] 
        TR_region_dict[region+"_nonsegdup"] = simple_graph_vcf_df[simple_graph_vcf_df['POS'].apply(lambda x: len(tree[int(x)]) > 0 and len(segdup_region[int(x)]) == 0)] 

    TR_region_dict['Unspecified_segdup'] = simple_graph_vcf_df[simple_graph_vcf_df['POS'].apply(lambda x: len([interval for tree in region_intervaltree_dict.values() for interval in tree[int(x)]]) == 0 and len(segdup_region[int(x)]) > 0)]
    TR_region_dict['Unspecified_nonsegdup'] = simple_graph_vcf_df[simple_graph_vcf_df['POS'].apply(lambda x: len([interval for tree in region_intervaltree_dict.values() for interval in tree[int(x)]]) == 0 and len(segdup_region[int(x)]) == 0)]
    return TR_region_dict

Region based variant type summary

In [9]:
def region_df_var_summary(region_df):
    summary_dict = dict()
    for i in range(len(region_df)):
        summary_dict[region_df.iloc[i]['Variant_Type']] = summary_dict.get(region_df.iloc[i]['Variant_Type'], 0) + 1
    summary_dict['Total'] = len(region_df)
    return summary_dict

def prepare_dataframe_dict(var_dict):
    return pd.DataFrame({
        "Variant Type": ['SNP', 'MNP', 'Insertion', 'Deletion', 'Replacement', 'Inversion', 'Duplication', 'Total'],
        "Count": [
                  var_dict.get('SNP', 0),
                  var_dict.get('MNP', 0),
                  var_dict.get('INS', 0),
                  var_dict.get('DEL', 0),
                  var_dict.get('REP', 0),
                  var_dict.get('INV', 0),
                  var_dict.get('DUP', 0),
                  var_dict.get('Total', 0),
        ]
    })

In [31]:
easy_list = []
segdup_list = []
diff_list = []

for i in tqdm(chr_list):
    var_dicts = get_df_from_vcf_filtered_by_region(f"chr{i}", raw_vcf=False)

    easy_var_summary = prepare_dataframe_dict(region_df_var_summary(var_dicts['easy_graph']))
    segdup_var_summary = prepare_dataframe_dict(region_df_var_summary(var_dicts['segdup_graph']))
    diff_var_summary = prepare_dataframe_dict(region_df_var_summary(var_dicts['diff_graph']))

    easy_list += easy_var_summary['Count'].to_list()
    segdup_list += segdup_var_summary['Count'].to_list()
    diff_list += diff_var_summary['Count'].to_list()

100%|██████████| 1/1 [00:09<00:00,  9.01s/it]


In [32]:
var_region_df = pd.DataFrame({
    'CHROM': [f'chr{int(i/8) + 1}' for i in range(8*num_chr)],
    'Variant_type': ['SNP', 'MNP', 'Insertion', 'Deletion', 'Replacement', 'Inversion', 'Duplication', 'Total']*num_chr,
    'Easy_region': easy_list,
    'Segdup_region': segdup_list,
    'Difficult_region': diff_list,
})

# var_region_df = pd.DataFrame({
#     'CHROM': [f'chr{int(i/8) + 1}' for i in range(8)],
#     'Variant_type': ['SNP', 'MNP', 'Insertion', 'Deletion', 'Replacement', 'Inversion', 'Repeat', 'Total'],
#     'Easy_region': easy_list,
#     'Segdup_region': segdup_list,
#     'Difficult_region': diff_list,
# })

all_df = var_region_df.drop(columns=['CHROM']).groupby('Variant_type', as_index=False).sum()
all_df = all_df.set_index('Variant_type').reindex(['SNP', 'MNP', 'Insertion', 'Deletion', 'Replacement', 'Inversion', 'Duplication', 'Total']).reset_index()
all_df['CHROM'] = 'all'

var_region_concated_df = pd.concat([var_region_df, all_df])

In [33]:
var_region_concated_df.to_csv(f"{data_vis_dir}/variant_summary_by_region_{chr_set}.csv", index=False)

Region based SNP comparison between ourvcf and vcfwave/rawvcf

In [34]:
complex_count_list_raw = []
complex_count_list_wave = []
nonlinear_count_list = []

shared_count_list_raw = []
shared_count_list_wave = []

our_count_list_raw = []
our_count_list_wave = []
raw_count_list = []
wave_count_list = []

for i in tqdm(chr_list):
    var_dicts = get_df_from_vcf_filtered_by_region(f"chr{i}")
    def snp_raw(x):
        return len(x['REF']) == 1 and len(x['ALT']) == 1
    def snp_graph(x):
        return x['Variant_Type'] == 'SNP'
    snp_dicts = {k:v[v.apply(snp_raw, axis=1)] if k.split('_')[1] != 'graph' else v[v.apply(snp_graph, axis=1)] for k, v in var_dicts.items()}

    easy_linear = snp_dicts['easy_graph'][snp_dicts['easy_graph']['Linear'].apply(lambda x: x == 1)]
    easy_nonlinear = snp_dicts['easy_graph'][snp_dicts['easy_graph']['Linear'].apply(lambda x: x == 0)]
    easy_simple_raw = pd.merge(snp_dicts['easy_graph'], snp_dicts['easy_raw'], how='inner', on=['#CHROM', 'POS', 'REF', 'ALT'])
    easy_simple_wave = pd.merge(snp_dicts['easy_graph'], snp_dicts['easy_wave'], how='inner', on=['#CHROM', 'POS', 'REF', 'ALT'])
    
    easy_simple_count_raw = len(easy_simple_raw)
    easy_complex_count_raw = len(easy_linear) - easy_simple_count_raw
    easy_simple_count_wave = len(easy_simple_wave)
    easy_complex_count_wave = len(easy_linear) - easy_simple_count_wave
    easy_nonlinear_count = len(easy_nonlinear)

    easy_shared_count_raw = easy_simple_count_raw
    easy_our_count_raw = len(snp_dicts['easy_graph']) - easy_shared_count_raw
    easy_raw_count = len(snp_dicts['easy_raw']) - easy_shared_count_raw
    easy_shared_count_wave = easy_simple_count_wave
    easy_our_count_wave = len(snp_dicts['easy_graph']) - easy_shared_count_wave
    easy_wave_count = len(snp_dicts['easy_wave']) - easy_shared_count_wave

    segdup_linear = snp_dicts['segdup_graph'][snp_dicts['segdup_graph']['Linear'].apply(lambda x: x == 1)]
    segdup_nonlinear = snp_dicts['segdup_graph'][snp_dicts['segdup_graph']['Linear'].apply(lambda x: x == 0)]
    segdup_simple_raw = pd.merge(snp_dicts['segdup_graph'], snp_dicts['segdup_raw'], how='inner', on=['#CHROM', 'POS', 'REF', 'ALT'])
    segdup_simple_wave = pd.merge(snp_dicts['segdup_graph'], snp_dicts['segdup_wave'], how='inner', on=['#CHROM', 'POS', 'REF', 'ALT'])
    
    segdup_simple_count_raw = len(segdup_simple_raw)
    segdup_complex_count_raw = len(segdup_linear) - segdup_simple_count_raw
    segdup_simple_count_wave = len(segdup_simple_wave)
    segdup_complex_count_wave = len(segdup_linear) - segdup_simple_count_wave
    segdup_nonlinear_count = len(segdup_nonlinear)

    segdup_shared_count_raw = segdup_simple_count_raw
    segdup_our_count_raw = len(snp_dicts['segdup_graph']) - segdup_shared_count_raw
    segdup_raw_count = len(snp_dicts['segdup_raw']) - segdup_shared_count_raw
    segdup_shared_count_wave = segdup_simple_count_wave
    segdup_our_count_wave = len(snp_dicts['segdup_graph']) - segdup_shared_count_wave
    segdup_wave_count = len(snp_dicts['segdup_wave']) - segdup_shared_count_wave

    diff_linear = snp_dicts['diff_graph'][snp_dicts['diff_graph']['Linear'].apply(lambda x: x == 1)]
    diff_nonlinear = snp_dicts['diff_graph'][snp_dicts['diff_graph']['Linear'].apply(lambda x: x == 0)]
    diff_simple_raw = pd.merge(snp_dicts['diff_graph'], snp_dicts['diff_raw'], how='inner', on=['#CHROM', 'POS', 'REF', 'ALT'])
    diff_simple_wave = pd.merge(snp_dicts['diff_graph'], snp_dicts['diff_wave'], how='inner', on=['#CHROM', 'POS', 'REF', 'ALT'])
    
    diff_simple_count_raw = len(diff_simple_raw)
    diff_complex_count_raw = len(diff_linear) - diff_simple_count_raw
    diff_simple_count_wave = len(diff_simple_wave)
    diff_complex_count_wave = len(diff_linear) - diff_simple_count_wave
    diff_nonlinear_count = len(diff_nonlinear)

    diff_shared_count_raw = diff_simple_count_raw
    diff_our_count_raw = len(snp_dicts['diff_graph']) - diff_shared_count_raw
    diff_raw_count = len(snp_dicts['diff_raw']) - diff_shared_count_raw
    diff_shared_count_wave = diff_simple_count_wave
    diff_our_count_wave = len(snp_dicts['diff_graph']) - diff_shared_count_wave
    diff_wave_count = len(snp_dicts['diff_wave']) - diff_shared_count_wave


    complex_counts_raw = [easy_complex_count_raw, segdup_complex_count_raw, diff_complex_count_raw]
    complex_counts_wave = [easy_complex_count_wave, segdup_complex_count_wave, diff_complex_count_wave]
    nonlinear_counts = [easy_nonlinear_count, segdup_nonlinear_count, diff_nonlinear_count]

    complex_count_list_raw += complex_counts_raw
    complex_count_list_wave += complex_counts_wave
    nonlinear_count_list += nonlinear_counts

    shared_counts_raw = [easy_shared_count_raw, segdup_shared_count_raw, diff_shared_count_raw]
    our_counts_raw = [easy_our_count_raw, segdup_our_count_raw, diff_our_count_raw]
    raw_counts = [easy_raw_count, segdup_raw_count, diff_raw_count]

    shared_counts_wave = [easy_shared_count_wave, segdup_shared_count_wave, diff_shared_count_wave]
    our_counts_wave = [easy_our_count_wave, segdup_our_count_wave, diff_our_count_wave]
    wave_counts = [easy_wave_count, segdup_wave_count, diff_wave_count]

    shared_count_list_raw += shared_counts_raw
    our_count_list_raw += our_counts_raw
    raw_count_list += raw_counts

    shared_count_list_wave += shared_counts_wave
    our_count_list_wave += our_counts_wave
    wave_count_list += wave_counts

100%|██████████| 1/1 [00:22<00:00, 22.96s/it]


In [35]:
region_df = pd.DataFrame({
    'CHROM': [f'chr{chr_list[int(i/3)]}' for i in range(3*num_chr)],
    'Region': ['easy', 'segdup', 'difficult']*num_chr,
    'Shared_raw': shared_count_list_raw,
    'Shared_wave': shared_count_list_wave,
    'Ourvcf_only_raw': our_count_list_raw,
    'Ourvcf_only_wave': our_count_list_wave,
    'Rawvcf_only': raw_count_list,
    'Vcfwave_only': wave_count_list,
    'Ourvcf_linear_raw': complex_count_list_raw,
    'Ourvcf_linear_wave': complex_count_list_wave,
    'Ourvcf_offlinear': nonlinear_count_list
})

all_df = region_df.drop(columns=['CHROM']).groupby('Region', as_index=False).sum()
all_df = all_df.set_index('Region').reindex(['easy', 'segdup', 'difficult']).reset_index()
all_df['CHROM'] = 'all'

region_concated_df = pd.concat([region_df, all_df])

In [36]:
region_concated_df.to_csv(f"{data_vis_dir}/snp_region_summary_ourvcf_vs_vcfwave_{chr_set}.csv", index=False)

In [37]:
region_concated_df

Unnamed: 0,CHROM,Region,Shared_raw,Shared_wave,Ourvcf_only_raw,Ourvcf_only_wave,Rawvcf_only,Vcfwave_only,Ourvcf_linear_raw,Ourvcf_linear_wave,Ourvcf_offlinear
0,chrY,easy,3196,3204,37,29,0,6,9,1,28
1,chrY,segdup,12832,12226,1506,2112,0,122,258,864,1248
2,chrY,difficult,11435,8565,13196,16066,0,2819,1544,4414,11652
0,all,easy,3196,3204,37,29,0,6,9,1,28
1,all,segdup,12832,12226,1506,2112,0,122,258,864,1248
2,all,difficult,11435,8565,13196,16066,0,2819,1544,4414,11652


Variant summary based on allele count

In [38]:
not_in_easy_region=True

In [39]:
simple_count_list_raw = []
complex_count_list_raw = []
simple_count_list_wave = []
complex_count_list_wave = []

our_count_list_raw = []
our_count_list_wave = []
nonlinear_count_list = []

def snp_list_by_ac(chr_id):
    graph_vcf, raw_vcf, wave_vcf = get_df_from_vcf(chr_id, not_in_easy_region)
    
    graph_vcf = graph_vcf[graph_vcf['Variant_Type'] == 'SNP']

    singleton = graph_vcf[graph_vcf['AC'] == 1]
    singleton_linear = singleton[singleton['Linear'] == 1]
    singleton_nonlinear = singleton[singleton['Linear'] == 0]

    two2four = graph_vcf[(graph_vcf['AC'] >=2) & (graph_vcf['AC'] <=4)]
    two2four_linear = two2four[two2four['Linear'] == 1]
    two2four_nonlinear = two2four[two2four['Linear'] == 0]

    five2eighteen = graph_vcf[(graph_vcf['AC'] >=5) & (graph_vcf['AC'] <=18)]
    five2eighteen_linear = five2eighteen[five2eighteen['Linear'] == 1]
    five2eighteen_nonlinear =  five2eighteen[five2eighteen['Linear'] == 0]

    above19 = graph_vcf[graph_vcf['AC'] >=19]
    above19_linear = above19[above19['Linear'] == 1]
    above19_nonlinear = above19[above19['Linear'] == 0]

    singleton_simple_raw = pd.merge(singleton_linear, raw_vcf, on=['#CHROM', 'POS', 'REF', 'ALT'])
    two2four_simple_raw = pd.merge(two2four_linear, raw_vcf, on=['#CHROM', 'POS', 'REF', 'ALT'])
    five2eighteen_simple_raw = pd.merge(five2eighteen_linear, raw_vcf, on=['#CHROM', 'POS', 'REF', 'ALT'])
    above19_simple_raw = pd.merge(above19_linear, raw_vcf, on=['#CHROM', 'POS', 'REF', 'ALT'])
    singleton_simple_wave = pd.merge(singleton_linear, wave_vcf, on=['#CHROM', 'POS', 'REF', 'ALT'])
    two2four_simple_wave = pd.merge(two2four_linear, wave_vcf, on=['#CHROM', 'POS', 'REF', 'ALT'])
    five2eighteen_simple_wave = pd.merge(five2eighteen_linear, wave_vcf, on=['#CHROM', 'POS', 'REF', 'ALT'])
    above19_simple_wave = pd.merge(above19_linear, wave_vcf, on=['#CHROM', 'POS', 'REF', 'ALT'])

    singleton_simple_count_raw = len(singleton_simple_raw)
    singleton_our_count_raw = len(singleton) - singleton_simple_count_raw
    singleton_complex_count_raw = len(singleton_linear) - singleton_simple_count_raw
    singleton_simple_count_wave = len(singleton_simple_wave)
    singleton_our_count_wave = len(singleton) - singleton_simple_count_wave
    singleton_complex_count_wave = len(singleton_linear) - singleton_simple_count_wave
    singleton_nonlinear_count = len(singleton_nonlinear)

    two2four_simple_count_raw = len(two2four_simple_raw)
    two2four_our_count_raw = len(two2four) - two2four_simple_count_raw
    two2four_complex_count_raw = len(two2four_linear) - two2four_simple_count_raw
    two2four_simple_count_wave = len(two2four_simple_wave)
    two2four_our_count_wave = len(two2four) - two2four_simple_count_wave
    two2four_complex_count_wave = len(two2four_linear) - two2four_simple_count_wave
    two2four_nonlinear_count = len(two2four_nonlinear)

    five2eighteen_simple_count_raw = len(five2eighteen_simple_raw)
    five2eighteen_our_count_raw = len(five2eighteen) - five2eighteen_simple_count_raw
    five2eighteen_complex_count_raw = len(five2eighteen_linear) - five2eighteen_simple_count_raw
    five2eighteen_simple_count_wave = len(five2eighteen_simple_wave)
    five2eighteen_our_count_wave = len(five2eighteen) - five2eighteen_simple_count_wave
    five2eighteen_complex_count_wave = len(five2eighteen_linear) - five2eighteen_simple_count_wave
    five2eighteen_nonlinear_count = len(five2eighteen_nonlinear)

    above19_simple_count_raw = len(above19_simple_raw)
    above19_our_count_raw = len(above19) - above19_simple_count_raw
    above19_complex_count_raw = len(above19_linear) - above19_simple_count_raw
    above19_simple_count_wave = len(above19_simple_wave)
    above19_our_count_wave = len(above19) - above19_simple_count_wave
    above19_complex_count_wave = len(above19_linear) - above19_simple_count_wave
    above19_nonlinear_count = len(above19_nonlinear)

    simple_counts_raw = [singleton_simple_count_raw, two2four_simple_count_raw, 
                         five2eighteen_simple_count_raw, above19_simple_count_raw]
    simple_counts_wave = [singleton_simple_count_wave, two2four_simple_count_wave, 
                         five2eighteen_simple_count_wave, above19_simple_count_wave]
    complex_counts_raw = [singleton_complex_count_raw, two2four_complex_count_raw, 
                          five2eighteen_complex_count_raw, above19_complex_count_raw]
    complex_counts_wave = [singleton_complex_count_wave, two2four_complex_count_wave, 
                          five2eighteen_complex_count_wave, above19_complex_count_wave]
    our_counts_raw = [singleton_our_count_raw, two2four_our_count_raw,
                      five2eighteen_our_count_raw, above19_our_count_raw]
    our_counts_wave = [singleton_our_count_wave, two2four_our_count_wave,
                      five2eighteen_our_count_wave, above19_our_count_wave]
    nonlinear_counts = [singleton_nonlinear_count, two2four_nonlinear_count, five2eighteen_nonlinear_count, above19_nonlinear_count]

    return simple_counts_raw, simple_counts_wave, complex_counts_raw, complex_counts_wave, our_counts_raw, our_counts_wave, nonlinear_counts

for i in tqdm(chr_list):
    simple_counts_raw, simple_counts_wave, complex_counts_raw, complex_counts_wave, our_counts_raw, our_counts_wave, nonlinear_counts = snp_list_by_ac(f"chr{i}")
    simple_count_list_raw += simple_counts_raw
    complex_count_list_raw += complex_counts_raw
    simple_count_list_wave += simple_counts_wave
    complex_count_list_wave += complex_counts_wave
    our_count_list_raw += our_counts_raw
    our_count_list_wave += our_counts_wave
    nonlinear_count_list += nonlinear_counts

100%|██████████| 1/1 [00:21<00:00, 21.61s/it]


In [40]:
range_df = pd.DataFrame({
    'CHROM': [f'chr{chr_list[int(i/4)]}' for i in range(4*num_chr)],
    'Range': ["1", "2-4", "5-18", "19+"]*num_chr,
    'Shared_raw': simple_count_list_raw,
    'Shared_wave': simple_count_list_wave,
    'Ourvcf_raw': our_count_list_raw,
    'Ourvcf_wave': our_count_list_wave,
    'Ourvcf_linear_raw': complex_count_list_raw,
    'Ourvcf_linear_wave': complex_count_list_wave,
    'Ourvcf_offlinear': nonlinear_count_list,
})

all_df = range_df.drop(columns=['CHROM']).groupby('Range', as_index=False).sum()
all_df = all_df.set_index('Range').reindex(["1", "2-4", "5-18", "19+"]).reset_index()
all_df['CHROM'] = 'all'

range_concated_df = pd.concat([range_df, all_df])

In [41]:
range_concated_df.to_csv(f"{data_vis_dir}/snp_ac_range_summary_ourvcf_vs_vcfwave_{chr_set}.csv", index=False)

In [42]:
range_concated_df

Unnamed: 0,CHROM,Range,Shared_raw,Shared_wave,Ourvcf_raw,Ourvcf_wave,Ourvcf_linear_raw,Ourvcf_linear_wave,Ourvcf_offlinear
0,chrY,1,10134,8339,7566,9361,635,2430,6931
1,chrY,2-4,10739,9375,5979,7343,988,2352,4991
2,chrY,5-18,3394,3077,1157,1474,179,496,978
3,chrY,19+,0,0,0,0,0,0,0
0,all,1,10134,8339,7566,9361,635,2430,6931
1,all,2-4,10739,9375,5979,7343,988,2352,4991
2,all,5-18,3394,3077,1157,1474,179,496,978
3,all,19+,0,0,0,0,0,0,0


Annotate variants based on the TR region bed file

In [7]:
import pandas as pd
from functools import reduce

def merge_dfs(dfs):
    int_cols = dfs[0].select_dtypes(include='int').columns
    str_cols = dfs[0].select_dtypes(exclude='int').columns[1:]

    summed_ints = reduce(lambda x, y: x + y, [df[int_cols] for df in dfs])
    result = pd.concat([dfs[0][str_cols].reset_index(drop=True), summed_ints.reset_index(drop=True)], axis=1)
    return result

In [8]:
chr_list = list(range(1, 23))

In [9]:
TR_regions_dicts = defaultdict(list)

for i in tqdm(chr_list):
    TR_regions_dict = get_vcf_df_separated_by_regions(f"chr{i}")
    TR_regions_dict = {region: comprehensive_summary(df, f"chr{i}")
                       for region, df in TR_regions_dict.items()}
    
    for region, var_df in TR_regions_dict.items():
        TR_regions_dicts[region].append(var_df)

TR_regions_dict_all = {region: merge_dfs(var_dfs) for region, var_dfs in TR_regions_dicts.items()}

  0%|          | 0/22 [00:00<?, ?it/s]

100%|██████████| 22/22 [2:41:08<00:00, 439.47s/it]  


In [10]:
file_path = f"{var_summary_dir}/TR_regions_var_comprehensive_summary_all_chrs_{chr_set}.xlsx"

sheets = TR_regions_dict_all

with pd.ExcelWriter(file_path, engine="openpyxl") as writer:
    for sheet, df in sheets.items():
        df.to_excel(writer, sheet_name=sheet, index=False)

Part 3: Triallelic/Multiallelic bubble analysis

1. Generating table for mapping bubble id to within variant edges

In [None]:
for chr in chr_list:
    print(f"Processing chr{chr}...")
    #gfa_path = f"{gfa_dir}/chr{chr}.gfa"
    snarl_path = f"{snarl_dir}/chr{chr}.snarls"
    #graph_path = f"{graph_obj_dir}/chr{chr}.pkl"
    graph_vcf_path = f"{graph_vcf_dir}/graph_chr{chr}.vcf"


    write_bubble_summary_result(chr_name=f"chr{chr}",
                                snarl_path=snarl_path,
                                vcf_path=graph_vcf_path,
                                output_dir=bubble_summary_dir
                                )

Processing chr1...
Assigning node to bubbles...
Conducting bubble summary...
Writing bubble summary to CSV...
Processing chr2...
Assigning node to bubbles...
Conducting bubble summary...
Writing bubble summary to CSV...
Processing chr3...
Assigning node to bubbles...
Conducting bubble summary...
Writing bubble summary to CSV...
Processing chr4...
Assigning node to bubbles...
Conducting bubble summary...
Writing bubble summary to CSV...
Processing chr5...
Assigning node to bubbles...
Conducting bubble summary...
Writing bubble summary to CSV...
Processing chr6...
Assigning node to bubbles...
Conducting bubble summary...
Writing bubble summary to CSV...
Processing chr7...
Assigning node to bubbles...
Conducting bubble summary...
Writing bubble summary to CSV...
Processing chr8...
Assigning node to bubbles...
Conducting bubble summary...
Writing bubble summary to CSV...
Processing chr9...
Assigning node to bubbles...
Conducting bubble summary...
Writing bubble summary to CSV...
Processing

2. Generating table for mapping number of variants to number of alleles

In [25]:
chr_list = list(range(1,23))

In [35]:
for chr in tqdm(chr_list):
    summary_path = f"{bubble_summary_dir}/bubble_variant_counts_chr{chr}_AT.tsv"
    #graph_path = f"{graph_obj_dir}/chr{chr}.pkl"
    graph_vcf_path = f"{graph_vcf_dir}/graph_chr{chr}{'_no_terminus' if exclude_terminus else ''}.vcf"
    vcf_path = f"{raw_vcf_dir}/hprc-v1.1-mc-grch38.raw_chr{chr}.vcf"
    vcfwave_path = f"{raw_vcf_dir}/hprc-v1.1-mc-grch38.vcfbub.a100k.wave_chr{chr}.vcf"

    AT_csv = pd.read_csv(summary_path, sep='\t')
    
    vcf_dict = {extract_bubble_ids(row['ID'], symbol=True):
                    [row['POS'], row['REF'], row['ALT'], 
                     get_info_dict(row['INFO'])]
                for row in read_vcf_line_by_line(graph_vcf_path)}
    
    AT_bubble_level_dict = {}
    AT_allele_count_dict = {}
    AT_allele_len_dict = {}
    
    for row in AT_csv.itertuples():
        bubble_id = tuple(sorted(eval(row.Bubble)))
        variants = eval(row.Within)
        var_len_list = []

        AT_bubble_level_dict[bubble_id] = int(row.Level)
        AT_allele_count_dict[bubble_id] = len(variants)

        for var in variants:
            # ref = vcf_dict[var][1] if vcf_dict[var][1] != '.' else vcf_dict[var][3]['NR']
            alt = vcf_dict[var][2]
            length = len(alt) if alt != '.' else 0
            var_len_list.append(length)
        AT_allele_len_dict[bubble_id] = ((sum(var_len_list) / len(var_len_list)) if len(var_len_list)!= 0 else 0, 
                                         max(var_len_list) if len(var_len_list)!= 0 else 0)

    vcf_allele_count_dict = {tuple(sorted(extract_bubble_ids(row['ID']))): 
                             len(row['ALT'].split(',')) 
                             for row in read_vcf_line_by_line(vcf_path)}
    vcf_allele_length_dict = {tuple(sorted(extract_bubble_ids(row['ID']))): 
                             (sum(map(lambda x: len(x), row['ALT'].split(',')))/len(row['ALT'].split(',')), 
                              max(map(lambda x: len(x), row['ALT'].split(','))))
                             for row in read_vcf_line_by_line(vcf_path)}
    

    vcfwave_csv = read_vcf_to_dataframe(vcfwave_path)
    vcfwave_allele_count_dict = vcfwave_csv['ID'].apply(lambda x: 
                tuple(sorted(extract_bubble_ids(x.split('_')[0])))).value_counts().to_dict()

    bubble_allelic_records = [{'Bubble':tuple(sorted(bubble, key=lambda x: (len(x), x))), 
                               'Level':AT_bubble_level_dict[bubble], 
                               'num_variants':AT_allele_count_dict[bubble], 
                               'avg_allele_length':AT_allele_len_dict[bubble][0],
                               'max_allele_length':AT_allele_len_dict[bubble][1],
                               'raw_vcf_allele_count': vcf_allele_count_dict.get(bubble, '.'), 
                               'vcfwave_allele_count':vcfwave_allele_count_dict.get(bubble, '.'),
                               'raw_vcf_avg_allele_length': vcf_allele_length_dict.get(bubble, ('.', '.'))[0], 
                               'raw_vcf_max_allele_length': vcf_allele_length_dict.get(bubble, ('.', '.'))[1], } 
                               for bubble in AT_allele_count_dict.keys()]
    bubble_allelic_df = pd.DataFrame(bubble_allelic_records)
    bubble_allelic_df.to_csv(f"{bubble_summary_dir}/bubble_allele_summary_chr{chr}.tsv", sep='\t')

100%|██████████| 22/22 [1:18:32<00:00, 214.22s/it]


3. Generating table for extracting triallelic variants

In [None]:
triallelic_var_dfs = [] 

for i in tqdm(chr_list):
    graph_vcf_path = f"{graph_vcf_dir}/graph_chr{i}{'_no_terminus' if exclude_terminus else ''}.vcf"
    graph_vcf_df = read_vcf_to_dataframe(graph_vcf_path)

    bubble_summary_path = f"{bubble_summary_dir}/bubble_variant_counts_chr{i}_AT.tsv"
    bubble_summary_df = pd.read_csv(bubble_summary_path, sep = '\t')

    triallelic_vars = {var for j in range(len(bubble_summary_df)) 
                       if int(bubble_summary_df.iloc[j]['Level']) == 0 and len(eval(bubble_summary_df.iloc[j]['Within'])) == 2 
                       for var in eval(bubble_summary_df.iloc[j]['Within'])}
    triallelic_var_df = graph_vcf_df[graph_vcf_df['ID'].apply(lambda x: extract_bubble_ids(x, symbol=True)).isin(triallelic_vars)]

    triallelic_var_df.to_csv(f"{bubble_summary_dir}/triallelic_variants_chr{i}.tsv", sep='\t')

Processing chr1...
Processing chr2...
Processing chr3...
Processing chr4...
Processing chr5...
Processing chr6...
Processing chr7...
Processing chr8...
Processing chr9...
Processing chr10...
Processing chr11...
Processing chr12...
Processing chr13...
Processing chr14...
Processing chr15...
Processing chr16...
Processing chr17...
Processing chr18...
Processing chr19...
Processing chr20...
Processing chr21...
Processing chr22...


In [6]:
triallelic_var_summary = {}

for i in chr_list:
    triallelic_var_path = f"{bubble_summary_dir}/triallelic_variants_chr{i}.tsv"
    triallelic_var_df = pd.read_csv(triallelic_var_path, sep='\t')
    for j in range(len(triallelic_var_df)):
        info_dict = get_info_dict(triallelic_var_df.iloc[j]['INFO'])
        VT = info_dict['VT']
        triallelic_var_summary[VT] = triallelic_var_summary.get(VT, 0) + 1

print(triallelic_var_summary)
print(sum(triallelic_var_summary.values()))

{'deletion': 356663, 'insertion': 392492, 'snp': 726627, 'mnp': 53508, 'replacement': 17726}
1547016


In [3]:
import pandas as pd

off_linear_count = 0
for i in chr_list:
    triallelic_path = f"{bubble_summary_dir}/triallelic_variants_chr{i}.tsv"
    triallelic_df = pd.read_csv(triallelic_path, sep='\t')

    off_linear = (triallelic_df['REF'] == '.').sum()
    off_linear_count += off_linear
print(off_linear_count / 3472802)

0.035435075192884594


In [4]:
off_linear_count

123059

4. Categorize triallelic bubbles

In [7]:
for i in tqdm(chr_list):
    AT_csv = pd.read_csv(f"{bubble_summary_dir}/bubble_variant_counts_chr{i}_AT.tsv", sep='\t')
    gfa_path = f"{gfa_dir}/chr{i}.gfa"
    pkl_path = f"{graph_obj_dir}/chr{i}.pkl"

    if os.path.exists(pkl_path):
        G = load_graph_from_pkl(pkl_path, compressed=False)
    else:
        G = PangenomeGraph.from_gfa_line_by_line(gfa_path, compressed=False)
        write_dfs_tree_to_gfa(G, f"{ref_tree_dir}/chr{i}_ref_tree.gfa")

    bubble_candidate = AT_csv.apply(lambda x: len(eval(x['Within'])) == 2 and int(x['Level']) == 0, axis=1)

    bubble_ids = AT_csv['Bubble'][bubble_candidate].apply(lambda x: tuple(map(lambda y: y, eval(x))))
    bubble_vars = AT_csv['Within'][bubble_candidate].apply(lambda x: list(map(lambda y: y, eval(x))))

    bubble_var_dict = dict(zip(bubble_ids, bubble_vars))

    failed_bubble = []
    error_dict = defaultdict(list)
    bubble_type_dict = {}

    for bubble, bubble_var in zip(bubble_ids, bubble_vars):
        try:
            bubble_type_dict[bubble] = G.classify_triallelic_bubble([bubble[0], bubble[1]], bubble_var)
        except Exception as e:
            failed_bubble.append(bubble)
            error_dict['Error'].append(e)
            error_dict['Bubble'].append(bubble)
    # print(len(bubble_ids), len(failed_bubble))
    # print(failed_bubble[:10])
    
    bubble_triallelic_df = pd.DataFrame({'Bubble': list(bubble_type_dict.keys()), 'Bubble_type': list(bubble_type_dict.values())})
    bubble_triallelic_df.to_csv(f"{bubble_summary_dir}/bubble_triallelic_chr{i}.tsv", sep='\t')

100%|██████████| 3/3 [05:30<00:00, 110.02s/it]


5. Classify superbubbles

In [8]:
chr_list = list(range(1,23))

In [4]:
for i in tqdm(chr_list):
    AT_path = f"{bubble_summary_dir}/bubble_variant_counts_chr{i}_AT.tsv"
    vcf_path = f"{graph_vcf_dir}/graph_chr{i}{'_no_terminus' if exclude_terminus else ''}.vcf"
    # ref_tree_path = f"{ref_tree_dir}/chr{i}_ref_tree.gfa"
    gfa_path = f"{gfa_dir}/chr{i}.gfa"
    pkl_path = f"{graph_obj_dir}/chr{i}.pkl"

    if os.path.exists(pkl_path):
        G = load_graph_from_pkl(pkl_path, compressed=False)
    else:
        G = PangenomeGraph.from_gfa_line_by_line(gfa_path, compressed=False)

    AT_df = pd.read_csv(AT_path, sep='\t')
    AT_df = AT_df[AT_df['Level'] == 0]

    # vcf_df = read_vcf_to_dataframe(vcf_path)

    parent_dict = {v[:-2]: u[:-2] for u, v in G.reference_tree.edges}

    vcf_dict = {tuple(sorted(map(lambda x: x[:-2], extract_bubble_ids(row['ID'], symbol=True)))):
                    [row['POS'], row['REF'], row['ALT'], 
                     get_info_dict(row['INFO'])]
                for row in read_vcf_line_by_line(vcf_path)}

    bubble_pos_list = [] #start node
    bubble_list = []
    bubble_type_list = []
    bubble_max_allele_len_list = []
    for row in AT_df.itertuples():
        bubble_id = eval(row.Bubble)
        variants = eval(row.Within)
        variants = {tuple(sorted(map(lambda x: x[:-2], variant))) for variant in variants}
        var_len_list = []
        pos_list = []

        bubble_list.append(bubble_id)

        for var in variants:
            pos = vcf_dict[var][0]
            ref = vcf_dict[var][1] if vcf_dict[var][1] != '.' else vcf_dict[var][3]['NR']
            alt = vcf_dict[var][2]
            length = (len(ref) if ref != '.' else 0) + (len(alt) if alt != '.' else 0)
            var_len_list.append(length)
            pos_list.append(pos)
        bubble_max_allele_len_list.append(max(var_len_list) if len(var_len_list) != 0 else -1)
        bubble_pos_list.append(min(pos_list) if len(pos_list) != 0 else -1)

        if parent_dict[bubble_id[0]] == bubble_id[1] or parent_dict[bubble_id[1]] == bubble_id[0]:
            bubble_type_list.append('insertion')
        elif tuple(sorted(bubble_id)) in variants:
            bubble_type_list.append('deletion')
        else:
            bubble_type_list.append('neither')

    bubble_type_df = pd.DataFrame({"POS": bubble_pos_list,
                                   "Bubble": bubble_list,
                                   "Type": bubble_type_list,
                                   "Max_allele_length": bubble_max_allele_len_list})
    bubble_type_df.to_csv(f"{bubble_summary_dir}/superbubble_type_chr{i}.tsv", sep='\t', index=False)

  0%|          | 0/22 [00:00<?, ?it/s]

100%|██████████| 22/22 [2:31:02<00:00, 411.92s/it]  
