Skip to content

Commit

Permalink
Merge pull request #75 from VDBWRAIR/plot_muts2
Browse files Browse the repository at this point in the history
working on graphic to show full dates
  • Loading branch information
averagehat committed Jan 5, 2016
2 parents 5960e7f + 01d1f7c commit 1a4f3e9
Show file tree
Hide file tree
Showing 7 changed files with 5,111 additions and 35 deletions.
114 changes: 79 additions & 35 deletions bio_bits/plot_muts.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
import matplotlib.pyplot as plt
import docopt, schema
from operator import itemgetter as get
import csv
import dateparser
import datetime
from time import mktime
try:
#below import is necessary for some reason
from scipy.stats import poisson
Expand All @@ -33,75 +37,115 @@
years = range_regex.range_regex.regex_for_range(1900, 2015)
year_regex = re.compile(years)
hamming = compose(sum, partial(map, operator.ne))
timestamp = lambda x: mktime(x.timetuple())
legend = {"queries": 'r', "references": 'b', "interval": 'g'}
'''it seems like pdist gives results that are too small to be useful?'''
#def pdist(s1, s2):
# assert len(s1) == len(s2), "All sequences must be the same length! %s %s" % (s1, s2)
# return hamming(s1, s2)/float(len(s1))
class InvalidFastaIdentifier(Exception): pass
def extract_year(header):
#s = header[-4:]
if header.count('/') > 3: s = header.split('/')[3]
else: s = header.split('_')[-1]
def extract_date(fasta_id):
''' fasta id '''
_e = InvalidFastaIdentifier("Could retrieve date from {0}".format(fasta_id))
if '____' not in fasta_id:
raise _e
s = fasta_id.split('____')[-1]
try:
return int(year_regex.search(s).group())
dt = dateparser.parse(s)
return dt
except Exception as e:
raise InvalidFastaIdentifier("Could retrieve year from {0}".format(header))
# had to add 2015 to A/England/50220895/
raise _e


def get_seqs_and_years(fn):
def get_seqs_and_dates(fn):
fasta = SeqIO.parse(fn, format="fasta")
info = [ (str(seq.seq), seq.id) for seq in fasta]
seqs, ids = zip(*info)
years = map(extract_year, ids)
return seqs, years

info = [(str(seq.seq), seq.id, seq.description) for seq in fasta]
seqs, ids, descriptions = zip(*info)
dates = map(extract_date, ids)
return seqs, dates, ids

def process(refs_fn, query_fn, save_path=None):
ref_seqs, ref_years = zip(*sorted(zip(*get_seqs_and_years(refs_fn)), key=get(1)))
ref_seqs, ref_dates, ref_names = zip(*sorted(zip(*get_seqs_and_dates(refs_fn)), key=get(1)))
#assert len(ref_seqs) > 1, "Need more than 1 reference sequence"
ref_seqs = map(str.upper, ref_seqs)
super_ref_seq, super_ref_year = ref_seqs[0], ref_years[0]
super_ref_seq, super_ref_date, super_ref_name = ref_seqs[0], ref_dates[0], ref_names[0]
print(super_ref_name)
print(super_ref_date)
get_mutations = partial(hamming, super_ref_seq)
def get_relative_info(seqs, years):
muts = map(get_mutations, seqs)
dists = [yr - super_ref_year for yr in years]
return muts, dists
ref_muts, ref_dists = get_relative_info(ref_seqs[1:], ref_years[1:])
query_muts, query_dists = get_relative_info(*get_seqs_and_years(query_fn))
do_plot(ref_dists, ref_muts, query_dists, query_muts, save_path)
def get_relative_info(seqs, dates, names):
muts = map(get_mutations, seqs)
dists = [(yr - super_ref_date).days for yr in dates]
return muts, dists, names
ref_muts, ref_dists, ref_names = get_relative_info(ref_seqs, ref_dates, ref_names)
query_muts, query_dists, query_names = get_relative_info(*get_seqs_and_dates(query_fn))
do_plot(ref_dists, ref_muts, ref_names, query_dists, query_muts, query_names, save_path)
#map(compose(print, '{0}\t{1}'.format ), ref_dists, ref_muts)

def do_plot(x1, y1, x2, y2, save_path=None):

def do_plot(x1, y1, ref_names, x2, y2, query_names, save_path=None):
'''
:param iterable x1: reference dates distances
:param iterable y1: reference p-distances
:param iterable x2: query dates diferences
:param iterable y2: query p-distances
:param str save_path: path to save image or None to open GTK if installed
'''
assert len(x1) > 0, "No reference dates to use"
assert len(y2) > 0, "No reference p-distances to use"
assert len(x2) > 0, "No query dates to use"
assert len(y2) > 0, "No query p-distances to use"
fig = plt.figure()
ax = plt.subplot(111)
# from matplotlib.dates import YearLocator, MonthLocator, DateFormatter
# years = YearLocator() # every year
# months = MonthLocator() # every month
# yearsFmt = DateFormatter('%Y')
# ax.xaxis.set_major_locator(years)
# ax.xaxis.set_major_formatter(yearsFmt)
# ax.xaxis.set_minor_locator(months)
max_x = max(max(x1), max(x2))
#legend_info = [mpatches.Patch(label=n, color=c) for n, c in legend.items()]
""" http://stackoverflow.com/questions/4700614/how-to-put-the-legend-out-of-the-plot"""
plot_muts(ax, x1, y1, label='references', color=legend['references'], polyfit=True, max_x=max_x, dist=None)
plot_muts(ax, x2, y2, label='queries', color=legend['queries'], dist=None)

ref_info = zip(ref_names, x1, y1)
query_info = zip(query_names, x2, y2)
all_info = sorted(ref_info + query_info, key=lambda x: x[2], reverse=True)

if save_path:
fh = open(save_path+'.csv', 'wb')
else:
fh = sys.stdout
fh.write('name,dates,p-dist\n')
outcsv = csv.writer(fh)
map(outcsv.writerow, all_info)

plot_muts(ax, x1, y1, plotkwargs=dict(label='references', color=legend['references'], marker='s'), polyfit=True, max_x=max_x, dist=None)
plot_muts(ax, x2, y2, plotkwargs=dict(label='queries', color=legend['queries']), dist=None)
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
#ax.legend(handles=legend_info, loc='center left', bbox_to_anchor=(1, 0.5))
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.xlabel("Years since Base reference")
plt.xlabel("days since Base reference")
plt.ylabel("p-distance")
if save_path:
plt.savefig(save_path)
else: plt.show()

def plot_muts(ax, x, y, color, label=None, dist=DISTRIBUTION, polyfit=False, max_x=None):
#problem was didn't account for +b
'''if norm distribution, probably have to scale (via passing loc= and scale=)'''
ax.scatter(x, y, color=color, label=label)
def plot_muts(ax, x, y, dist=DISTRIBUTION, polyfit=False, max_x=None, plotkwargs=dict(marker='o')):
'''
Plot x and y
if norm distribution, probably have to scale (via passing loc= and scale=)
problem was didn't account for +b
'''
if max_x and isinstance(max_x, datetime.datetime):
max_x = timestamp(max_x)
ax.scatter(x, y, **plotkwargs)#color=color, label=label, marker=marker)
if polyfit:
''' this forces a polyfit with y-intercept at zero, necessary because
we necessarily start with 0 mutations from the query at year 0.'''
we necessarily start with 0 mutations from the query at date 0.'''
x = np.array(x)[:,np.newaxis]
m, _, _, _ = np.linalg.lstsq(x, y)
x, y = np.linspace(0,max_x, 100), m*np.linspace(0,max_x, 100)
ax.plot(x, y, color='y', label='Best Fit')
x, y = np.linspace(0,max_x,100), m*np.linspace(0,max_x,100)
ax.plot(x, y, color='y', label='Best Fit', linewidth=2)
if dist:
"""see http://stackoverflow.com/a/14814711/3757222"""
R = dist.interval(0.95, y)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ funcy
tabulate
range_regex
matplotlib
dateparser

0 comments on commit 1a4f3e9

Please sign in to comment.