In [2]:
%cd ../

/mnt/storage7/gary/comp_mut


In [3]:
import json
from collections import defaultdict, Counter
import argparse
import os
from tqdm import tqdm
import sys
import csv
import pathogenprofiler as pp
import tbprofiler
from csv import DictReader
from collections import Counter
import requests
from contextlib import closing
import re
from python_scripts.utils import *

In [4]:
def get_vars_exclude(vars_exclude_file):

    # URL below is the results of all Fst = 1 variants from https://genomemedicine.biomedcentral.com/articles/10.1186/s13073-020-00817-3
    fst_results_url = 'https://raw.githubusercontent.com/GaryNapier/tb-lineages/main/fst_results_clean_fst_1_for_paper.csv'
    # See https://www.codegrepper.com/code-examples/python/how+to+read+a+csv+file+from+a+url+with+python for pulling data from url
    with closing(requests.get(fst_results_url, stream=True)) as r:
        f = (line.decode('utf-8') for line in r.iter_lines())
        fst_dict = csv_to_dict_multi(f)
    
    lin_specific_variants = []
    for gene in fst_dict:
        if gene in genes:
            for var in fst_dict[gene]:
                lin_specific_variants.append( tuple( [ gene, reformat_mutations(var['aa_pos']) ] ) )

    # Read in variants to be excluded
    vars_exclude = []
    for l in open(vars_exclude_file):
        vars_exclude.append(tuple(l.strip().split(',')))

    # Concat
    vars_exclude = vars_exclude + lin_specific_variants 
    return vars_exclude

In [5]:
def get_counts(in_dict, ref_dict, data_key):
    data_dict = {k:[] for k in in_dict.keys()}
    for mut in in_dict:
        for samp in in_dict[mut]:
            data_dict[mut].append(ref_dict[samp][data_key])

    data_counts = {k:[] for k in in_dict.keys()}
    for mut in data_dict:
        data_counts[mut] = dict(Counter(data_dict[mut]))
    return data_counts

In [6]:
def get_unique_mutations(in_dict, gene):
    # Pull all mutations from a certain gene and take unique
    mutations_set = []
    for samp in in_dict:
        mut_list = in_dict[samp]['mutations']
        mutations_set.append(mut['change'] for mut in mut_list if mut['gene'] == gene)
    # Convert to flat list 
    mutations_set = flat_list(mutations_set)
    # Sort
    mutations_set.sort()
    # Make a table (before converting to a set)
    table = Counter(mutations_set)
    # Get unique
    mutations_set = set(mutations_set)
    return (table, mutations_set)

In [7]:
def dr_filter(in_dict):

    # Test the proportion of sensitive DR types to all other DR types
    # If the majority are sensitive, then 'reject' (0) or 'accept' (1) 
    # Take a dictionary of DR type counts like this:
    # {'MDR': 15, 'Pre-MDR': 20, 'Pre-XDR': 8, 'XDR': 1, 'Other': 1}
    # If only 'Sensitive' in the dict the return accept = 0 
    # If 'Sensitive' is not in the counts then accept = 1
    # Else test the proportions

    is_dict(in_dict)

    if all('Sensitive' in x for x in in_dict):
        accept = 0
        return accept

    if 'Sensitive' not in in_dict:
        accept = 1
        return accept

    sum_non_sens = sum(in_dict[item] for item in in_dict if item != 'Sensitive')

    proportion = in_dict['Sensitive'] / (in_dict['Sensitive'] + sum_non_sens)

    if proportion >= 0.5: 
        accept = 0 
    else:
        accept = 1

    return accept

In [8]:
def lin_filter(in_dict):
    is_dict(in_dict)
    lin_filter_dict = {}
    for mutation in in_dict:
        if len(in_dict[mutation]) > 1:
            lin_filter_dict[mutation] = 1
        else:
            lin_filter_dict[mutation] = 0
    return lin_filter_dict

In [9]:
def dst_filter(in_dict):

    if all(x == 'NA' for x in in_dict.keys()):
        accept = 1
        return accept

    if all(x == '1' for x in in_dict.keys()):
        accept = 1
        return accept

    if all(x == '0' for x in in_dict.keys()):
        accept = 0
        return accept

    # Avoid key errors - add keys if not there and only have combination of '0' & 'NA' or '1' & 'NA'
    if '0' not in in_dict.keys() and 'NA' in in_dict.keys():
        in_dict['0'] = 0

    if '1' not in in_dict.keys() and 'NA' in in_dict.keys():
        in_dict['1'] = 0

    is_dict(in_dict)
    proportion = in_dict['1'] / (in_dict['0'] + in_dict['1'])

    if proportion < 0.5:
        accept = 0
    else:
        accept = 1

    return accept

In [10]:
def invert_tb_dict(in_dict):
    inv_dict = {}
    for samp in in_dict:
        for var in in_dict[samp]['mutations']:
            if var['change'] not in inv_dict:
                inv_dict[var['change']] = [samp]
            else:
                inv_dict[var['change']].append(samp)
    return inv_dict

In [11]:
def tb_data(filename, suffix, samps_list, genes, vars_exclude):
    # Read in all the json data for samples with given genes only (>0.7 freq), non-syn
    all_data = defaultdict(list)
    meta_dict = defaultdict(dict)

    for samp in tqdm(samps_list):
        # Load file for each samp
        file = "%s/%s%s" % (filename, samp, suffix)
        if os.path.isfile(file):
            # Open the json file for the sample, skip if can't find
            tmp_data = json.load(open(file))
            # Remove mixed samps
            if ";" in tmp_data['sublin']: continue
            # Get metadata
            meta_dict[samp] = {
                'wgs_id':samp,
                # 'inh_dst':meta_dict[samp]['isoniazid'],
                'main_lin':tmp_data['main_lin'],
                'sublin':tmp_data['sublin'],
                # 'country_code':meta_dict[samp]['country_code'],
                'drtype':standardise_drtype[tmp_data['drtype']]}

            # Loop over all variants of interest
            for var in tmp_data["dr_variants"] + tmp_data["other_variants"]:
                if var['gene'] not in genes: continue
                if var['freq'] < 0.7: continue
                if var['type'] == 'synonymous_variant': continue
                # Store key as tuple of gene and mutation and append the sample. Ignore if in exclude list.
                key = (var['gene'], var['change'])
                if key in vars_exclude: continue
                all_data[key].append(samp)

    # Remove single samples
    for var in list(all_data):
        if len(all_data[var]) < 2:
            del all_data[var]

    return (all_data, meta_dict)

In [12]:
def merge_metadata(meta_dict, meta_dict_csv, drug_of_interest):
    # Merge metadata from metadata csv file and from json data (tb_data() function)
    # Need country code and DST for drug of interest

    for samp in meta_dict:
        meta_dict[samp]['dst'] = meta_dict_csv[samp][drug_of_interest]
        meta_dict[samp]['country_code'] = meta_dict_csv[samp]['country_code']
    return meta_dict

In [13]:
def get_all_counts(all_vars_dict, variant_list, meta_dict):
    
    # Get number of samples by DR type, lineage and DST for a list of mutations of interest 
    
    # all_vars_dict:
    # Dictionary 
    # Keys = tuple of gene and variant, e.g. "('ahpC', 'c.-51G>A')"
    # Values = list of samples, e.g. "['ERR1465765','SAMEA3715554','ERR1193661',...etc]"
    
    # variant_list:
    # List of tuples of gene-variant pairs, e.g.:
    # [('ahpC', 'c.-51G>A'),
    # ('ahpC', 'c.-72C>T'),...etc]
    
    # meta_dict:
    # Dictionary of metadata, e.g. 
    #     {'SAMEA2534433': {'wgs_id': 'SAMEA2534433',
    #               'main_lin': 'lineage2',
    #               'sublin': 'lineage2.2.1',
    #               'drtype': 'Other',
    #               'dst': '0',
    #               'country_code': 'af'},
    
    # Return:
    # Dict for each variant of interest with keys and gene-variant pairs and values as counts
    #     {('ahpC', 'c.-51G>A'): {'MDR-TB': 16,
    #    'Pre-MDR-TB': 18,
    #    'Pre-XDR-TB': 10,
    #    'Other': 1},... etc}
    
    
    # Subset the data to the gene-variants of interest data to just the ahpC unknown
    variants_samps = {var: all_vars_dict[var] for var in variant_list}
    
    # ahpc_drtype_counts = get_counts(unknown_ahpc_samps_dict, ahpc_dict, 'drtype')
    drtype_counts = get_counts(variants_samps, meta_dict, 'drtype')
    
    # Lineage counts - mutation is only from one (or two) lineages
    lin_counts = get_counts(variants_samps, meta_dict, 'sublin')
    
    # Aggregate lin counts
    for mut in lin_counts:
        lin_counts[mut] = resolve_lineages(lin_counts[mut])
        
    # DST counts
    dst_counts = get_counts(variants_samps, meta_dict, 'dst')
    
    # Put together
    out_dict = {'drtype_counts': drtype_counts, 'lin_counts': lin_counts, 'dst_counts': dst_counts}
    
    return out_dict

In [14]:
# def main(args):

# mutations_file = args.mutations_file
# metadata_file = args.metadata_file
# id_key = args.id_key

# mutations_file = "metadata/novel_ahpc_mutations.txt"
# metadata_file = "../metadata/tb_data_18_02_2021.csv"
# id_key = "wgs_id"
# mutaions_key = "wgs_id" # column name of mutation names e.g. "c.-101A>G"

ahpc_glm_results_file = "metadata/ahpc_model_results.csv"
metadata_file = "../metadata/tb_data_18_02_2021.csv"
tbdb_file = "../tbdb/tbdb.csv"
drtypes_file = "../pipeline/db/dr_types.json"
# tbprofiler_results_dir = '/mnt/storage7/jody/tb_ena/tbprofiler/freebayes/results/'
# tbprofiler_results_dir = '/mnt/storage7/jody/tb_ena/tbprofiler/gatk/results'
tbprofiler_results_dir = '/mnt/storage7/jody/tb_ena/tbprofiler/freebayes/results/'
# metadata_id_key = "wgs_id"
suffix = ".results.json"
genes = ('ahpC', 'katG', 'fabG1')
vars_exclude_file = 'metadata/var_exclude_katg_comp_mut.csv'
drug_of_interest = 'isoniazid'

In [15]:
# -------------
# READ IN DATA
# -------------

# ----------------------------------------------------
# NEED TO CHANGE VAR NAMES TO SOMETHING GENERALISABLE
# ----------------------------------------------------

# Read in ahpc GLM results file
with open(ahpc_glm_results_file, 'r') as f:
    ahpc_glm_dict = csv_to_dict(f)

# Convert to list of tuples
ahpc_glm_list = [('ahpC', var) for var in list(ahpc_glm_dict)]

# ----------------------------------------------------
# NEED TO CHANGE VAR NAMES TO SOMETHING GENERALISABLE
# ----------------------------------------------------

# Read in metadata
with open(metadata_file) as mf:
    meta_dict_csv = csv_to_dict(mf)

# Pull samples
samples = list(meta_dict_csv.keys())

# Read in tbdb file
with open(tbdb_file, 'r') as f:
    tbdb_dict = csv_to_dict_multi(f)

# Read in DR types from json
standardise_drtype = json.load(open(drtypes_file))

# Read in variants to exclude
vars_exclude = get_vars_exclude(vars_exclude_file)

# Read in all the json data for (samples with) ahpC/katG only (>0.7 freq) and the metadata for those samples
all_data, meta_dict = tb_data(tbprofiler_results_dir, suffix, samples, genes, vars_exclude)


100%|██████████| 32735/32735 [00:08<00:00, 3718.54it/s]


In [23]:
# ---------
# TESTING
# ---------

# file = "%s/%s%s" % (tbprofiler_results_dir, 'SAMEA2534433', suffix)
# data = json.load(open(file))

# DR TYPES - how to load?


[('ahpC', 'c.-51G>A'),
 ('ahpC', 'c.-72C>T'),
 ('ahpC', 'c.-76T>A'),
 ('ahpC', 'p.Pro44Arg')]

In [17]:
# -------------
# Wrangle data 
# -------------

meta_dict = merge_metadata(meta_dict, meta_dict_csv, drug_of_interest)

In [18]:
# ---------------------------------------------------------------------
# Classify UNKNOWN ahpC and filter 
# GLM is only first step in identifying 'interesting' ahpC mutations
# Need to check against tbprofiler results for each mutation 
# e.g. if the mutation is lineage specific, then filter out
# ---------------------------------------------------------------------

# Get counts of samples by DR type, lin and DST for each mutation
ahpc_glm_counts = get_all_counts(all_data, ahpc_glm_list, meta_dict)





In [21]:

def filter_counts(counts_dict):
    # counts_dict:
    # Output from get_all_counts
    # Dictionary of count type, gene-variant pair and dict of counts:
    #     {'drtype_counts': {('ahpC', 'c.-51G>A'): {'MDR-TB': 16,
    #    'Pre-MDR-TB': 18,
    #    'Pre-XDR-TB': 10,
    #    'Other': 1},...},
    #  'lin_counts': {('ahpC', 'c.-51G>A'): {'lineage2.2.1': 19,
    #    'lineage4.6.1.2': 2,...},
    #  'dst_counts': {('ahpC', 'c.-51G>A'): {'1': 29, 'NA': 15, '0': 1},...}}
    
        
    # DR type filter
    dr_filter_dict = {}
    for mutation in counts_dict['drtype_counts']:
        dr_filter_dict[mutation] = dr_filter(counts_dict['drtype_counts'][mutation])
        
    # Lineage filter
    lin_filter_dict = lin_filter(counts_dict['lin_counts'])

    # DST filter
    dst_filter_dict = {}
    for mutation in counts_dict['dst_counts']:
        dst_filter_dict[mutation] = dst_filter(counts_dict['dst_counts'][mutation])
        
    # Put together
    out_dict = {'drtype_filter': dr_filter_dict, 'lin_filter': lin_filter_dict, 'dst_filter': dst_filter_dict}
    
    return out_dict

    
    
    # Put together and add up scores
    # Remove from ahpc mutations list if 0 


filter_counts(ahpc_glm_counts)


{'drtype_filter': {('ahpC', 'c.-51G>A'): 1,
  ('ahpC', 'c.-72C>T'): 1,
  ('ahpC', 'c.-76T>A'): 1,
  ('ahpC', 'p.Pro44Arg'): 1},
 'lin_filter': {('ahpC', 'c.-51G>A'): 1,
  ('ahpC', 'c.-72C>T'): 1,
  ('ahpC', 'c.-76T>A'): 0,
  ('ahpC', 'p.Pro44Arg'): 0},
 'dst_filter': {('ahpC', 'c.-51G>A'): 1,
  ('ahpC', 'c.-72C>T'): 1,
  ('ahpC', 'c.-76T>A'): 1,
  ('ahpC', 'p.Pro44Arg'): 1}}

{'gyrB': [{'Gene': 'gyrB',
   'Mutation': 'p.Glu540Asp',
   'Drug': 'moxifloxacin',
   'Confers': 'resistance',
   'Interaction': '',
   'Literature': '10.1128/AAC.00825-17;10.1128/JCM.06860-11'},
  {'Gene': 'gyrB',
   'Mutation': 'p.Ala504Thr',
   'Drug': 'ciprofloxacin',
   'Confers': 'resistance',
   'Interaction': '',
   'Literature': ''},
  {'Gene': 'gyrB',
   'Mutation': 'p.Ala504Val',
   'Drug': 'ciprofloxacin',
   'Confers': 'resistance',
   'Interaction': '',
   'Literature': ''},
  {'Gene': 'gyrB',
   'Mutation': 'p.Arg446Cys',
   'Drug': 'ciprofloxacin',
   'Confers': 'resistance',
   'Interaction': '',
   'Literature': ''},
  {'Gene': 'gyrB',
   'Mutation': 'p.Arg446His',
   'Drug': 'ciprofloxacin',
   'Confers': 'resistance',
   'Interaction': '',
   'Literature': ''},
  {'Gene': 'gyrB',
   'Mutation': 'p.Arg446Leu',
   'Drug': 'ciprofloxacin',
   'Confers': 'resistance',
   'Interaction': '',
   'Literature': ''},
  {'Gene': 'gyrB',
   'Mutation': 'p.Asn499Asp',
   'Drug'