In [6]:
import requests
import re
import json
import math

import tsinfer
from ete3 import Tree, TreeStyle, TextFace

In [7]:
google_files = {
    "focal10":"1Pw5s3-_p0Ap8MUs6F2OUHRUXwphTs7pc",
    "filtered": "1jjRkDBO2QmAmsCZ6cgMjMY6g17pecp2G",
    "full":"1NZAp11dUnfq9VmXs4gN5qZ3d6wJiHt-Y", 
    "allsnps": "1AF4lZHOO2IrtZJVZGPe4QaEKlk6ttq7B",
    }

In [23]:
for dataset, doc in google_files.items():
    a=[]
    url = "https://drive.google.com/uc?export=download&id={}".format(doc)
    print(url)
    chr = None
    l = 0
    r = requests.get(url, stream=True)
    with tsinfer.SampleData(path="{}.samples".format(dataset)) as sample_data:
        for line in r.iter_lines(decode_unicode=True):
            # filter out keep-alive new lines
            if line:
                l += 1
                data = line.split() #split on whitespace
                if l==1: #header line
                    for name in data[1:]: #first is position
                        sample_data.add_sample({'name':name.strip("\"")})
                else:
                    #match e.g. chr20:33896756-33896757
                    snp_position = re.match(r'([^:]+):(\d+)-(\d+)', data[1])
                    #check we have a sensible position
                    assert snp_position, "SNP position is {}".format(data[1])
                    #check all are on the same chromosome
                    assert (chr == None) or (chr == snp_position.group(1)), \
                        "Different chromosomes, {} vs {}".format(chr, snp_position.group(1))
                    chr = snp_position.group(1)
                    #check these are single SNPs
                    assert int(snp_position.group(2))+1 == int(snp_position.group(3)), \
                        "start+1 != end ({} vs {})".format(snp_position.group(2), snp_position.group(3))
                    #read in
                    a.append(int(snp_position.group(2)))
                    if all([0<=int(i)<=1 for i in data[2:]]):
                        sample_data.add_site(int(snp_position.group(2)), ["0", "1"], [int(i) for i in data[2:]])
                    else:
                        print("Problem in line {} of {} (SNP {} {})".format(l, doc, data[0], data[1]))
    
    
    inferred_ts = tsinfer.infer(sample_data)
    print("Generated {} trees from {} to {} for {} SNPs".format(inferred_ts.num_trees, min(a), max(a), l-1))
    inferred_ts.dump("{}.trees".format(dataset))
    
    node_labels={n.id:json.loads(n.metadata)['name'] for n in inferred_ts.nodes() if n.is_sample()}
    for tree in inferred_ts.trees():
        low, high = tree.interval
        if low <= 33952619 < high:
            #only output the tree at the locus of interest
            identifier = "{}-{}_{}".format(low, high, dataset)
            svg = tree.draw(path=identifier + ".svg", format="svg", width = 1000*math.log(inferred_ts.num_samples), height=1000, node_labels=node_labels)
            visualize_tree(tree,identifier)
            
            #with open(identifier + ".nwk", "wt") as f:
            #    print(tree.newick(), file=f)

https://drive.google.com/uc?export=download&id=1Pw5s3-_p0Ap8MUs6F2OUHRUXwphTs7pc
Generated 64 trees from 33887954 to 34025982 for 222 SNPs
https://drive.google.com/uc?export=download&id=1jjRkDBO2QmAmsCZ6cgMjMY6g17pecp2G


KeyboardInterrupt: 

In [21]:
def visualize_tree(tree,tree_name):
    import matplotlib as plt

    first_tree = Tree(tree.newick())

    for node in first_tree.traverse():
        # Hide node circles
        node.img_style['size'] = 0
       
    ts = TreeStyle()
    ts.mode = "c"
    ts.arc_start = 0 # 0 degrees = 3 o'clock
    ts.arc_span = 360
    ts.show_leaf_name = True
    
    ts.show_scale = False

    # Draw Tree
    first_tree.render(tree_name+".png", dpi=300, w=2400, tree_style=ts)
