This script animates results of the simulation as spread over time.

Large chunks of it are from Gytis Dudas' (@evogytis) notebook curonia, part of baltic, which can be found here: https://github.com/evogytis/baltic

In [156]:
import matplotlib as mpl ## matplotlib should not be set to inline mode to accelerate animation rendering and save memory
mpl.use('Agg') ## recommended backend for animations
from matplotlib import pyplot as plt
from matplotlib.patches import Polygon
from matplotlib.collections import PatchCollection
from matplotlib.colors import LinearSegmentedColormap
from matplotlib import gridspec
import matplotlib.patheffects as path_effects
import matplotlib.animation as animation
from IPython.display import clear_output
from IPython.display import HTML

import requests
import json
from io import StringIO as sio

import numpy as np
import pandas as pd
from scipy.special import binom

import datetime as dt
import time

In [84]:
typeface='Helvetica Neue' ## set default matplotlib font and font size
mpl.rcParams['font.weight']=300
mpl.rcParams['axes.labelweight']=300
mpl.rcParams['font.family']=typeface
mpl.rcParams['font.size']=22

In [10]:
import sys

sys.path.insert(0,"../Simulation_scripts/no_interventions")

import Tree_simulator as cts
from collections import defaultdict
##(trans_dict, nodes, sampling_proportion, epidemic_len)

dropbox_path = "/Users/s1743989/VirusEvolution Dropbox/Verity Hill/Agent_based_model/Looping models/"
results = "Results/Fitted_runs/no_caps/2/log_files/"
log_file = "information_file_for_40.csv"

### Prepping maps and locations

In [157]:
#Some stuff about getting cases in each country which I think I'll leave for now
#Will need to have make the map here - just SLE
address='https://raw.githubusercontent.com/phylogeography/SpreaD3/master/data/geoJSON_maps/subregion/subregion_Western_Africa_subunits.json' ## address of example JSON

fetch_map = requests.get(address) ## fetch Central America geoJSON file from SpreaD3 repo as an example
json_map=json.load(sio(fetch_map.text)) ## import json
json.dump(json_map,open('./WA_map.geojson','w')) ## write to file locally


locations = ["bo", "bombali", "bonthe", 'kailahun', "kambia", 'kenema', 'koinadugu', 'kono', 'moyamba', 'portloko', 'pujehun', 'tonkolili', 'westernarearural', 'westernareaurban']



Done!


In [164]:
features=json_map['features']
location_points={} ## location points will be stored here
polygons={} ## polygons will be stored here

loc_name='name' ## key name for each feature

for loc in features: ## iterate through features (locations)
    
    poly = np.asarray(loc['geometry']['coordinates']) ## get coordinates
    #location=removeDiacritics(loc['properties'][loc_name]) ## standardised location name (remove diacritics)
    location = loc['properties'][loc_name]
#     print(location.encode().decode('utf-8'))

    polygons[location]=[]
    location_points[location]=[]
    if loc['geometry']['type']=='MultiPolygon': ## multiple parts detected
        for part in np.asarray(poly): ## iterate over each component polygon
            for coords in np.asarray(part): ## iterate over coordinates
                coords=np.array(coords)
                xs=coords[:,0] ## longitudes
                ys=coords[:,1] ## latitudes

                location_points[location].append(np.vstack(zip(xs,ys))) ## append coordinates to location's list of coordinates
    if loc['geometry']['type']=='Polygon': ## location is single part
        for coords in np.asarray(poly): ## iterate over coordinates
            coords=np.array(coords)
            xs=coords[:,0] ## longitudes
            ys=coords[:,1] ## latitudes

            location_points[location].append(np.vstack(zip(xs,ys))) ## append coordinates to location's list of coordinates

    complete_location=[]
    for part in location_points[location]: ## iterate over each component of a location
        complete_location.append(Polygon(part,True)) ## create a polygon for each component of a location

    polygons[location]=complete_location ## assign list of polygons to a location
#     elif location=='Isla Sala y Gomez': ## if location is Isla Sala y Gomez - print a geoJSON entry example
#         print('example geoJSON entry:\n%s\n\nnote that only the coordinate field is called\n'%(loc))

print('polygons loaded:\n%s'%(polygons.keys()))


SLE_polygons = polygons["Sierra Leone"]

polygons loaded:
dict_keys(['Benin', 'Burkina Faso', "C�te d'Ivoire", 'Cape Verde', 'Ghana', 'Guinea', 'Gambia', 'Guinea-Bissau', 'Liberia', 'Mali', 'Mauritania', 'Niger', 'Nigeria', 'Senegal', 'Sierra Leone', 'Togo'])




In [167]:
with open("pop_centroids") as f:
    next(f)
    for l in f:
        toks = l.strip("\n").split("\t")
        print(toks)

FileNotFoundError: [Errno 2] No such file or directory: 'pop_centroids'

### Prepping tree

In [96]:
transm_dict = defaultdict(list)
nodes = []
times = []
child_dict = defaultdict(list)
location_dict = {}

with open(dropbox_path + results + log_file) as f:
    next(f)
    for l in f:
        toks = l.strip("\n").split(",")
        transm_dict[toks[0]] = [toks[1], int(toks[5]), int(toks[6])]
        
        nodes.append(toks[0])
        
        times.append(int(toks[7]))
        
        location_dict[toks[0]] = toks[4]
        
        child_dict[toks[1]].append(toks[0])

epidemic_len = max(times)

newick_tree, tree, R0 = cts.simulate_tree(transm_dict, child_dict, nodes, 0.16, epidemic_len)

objects = []

for i in tree.final_nodes:
    person = i.subtree.person.id
    i.location = location_dict[person]
    objects.append(i)

for i in tree.tips:
    i.location = location_dict[i.id]
    objects.append(i)
    


### Setting variables - need to do max cases here somewhere

In [138]:
def convert_date(x,start,end):
    """ Converts calendar dates between given formats """
    return dt.datetime.strftime(dt.datetime.strptime(x,start),end)

def decimal_date(input_date):
    start_year = dt.datetime.date(dt.datetime.strptime(str(input_date.year) +"-01-01", "%Y-%m-%d"))

    whole_diff = input_date-start_year

    day_diff = (whole_diff.days)/365

    decimal = input_date.year + day_diff

    return decimal



In [76]:
travelers=[x for x in objects if x.node_parent and x.location!=x.node_parent.location]

In [123]:
dates = []

start_date = dt.datetime.date(dt.datetime.strptime("2014-05-05", "%Y-%m-%d"))

for i in times:

    days = dt.timedelta(i)
    date = days + start_date
    
    dates.append(date)

In [125]:
epi_weeks = []

date = start_date

epi_weeks.append(date)

len_epidemic_weeks = round(max(times)/7)

print(len_epidemic_weeks)

for i in range(len_epidemic_weeks):
    change = dt.timedelta(7)
    
    new_date = date+change
    
    epi_weeks.append(new_date)
    
    date = new_date
    
    
print(epi_weeks)

12
[datetime.date(2014, 5, 5), datetime.date(2014, 5, 12), datetime.date(2014, 5, 19), datetime.date(2014, 5, 26), datetime.date(2014, 6, 2), datetime.date(2014, 6, 9), datetime.date(2014, 6, 16), datetime.date(2014, 6, 23), datetime.date(2014, 6, 30), datetime.date(2014, 7, 7), datetime.date(2014, 7, 14), datetime.date(2014, 7, 21), datetime.date(2014, 7, 28)]


In [140]:
for i in objects:
    day = tree.heights[i]*365
    i.absolute_time = decimal_date(start_date + dt.timedelta(day))
    
    

## Animation

In [121]:
t0 = time.time()
smooth = 2 #Number of gridpoints between each epiweek
dpi = 50
Bezier_smooth = 5
tracking_length = 21
depth = tracking_length/365.0

###
loc_trait = 'location.states'
#dates2 bit for only running animation for some dates for debugging
###



Nframes = len_epidemic_weeks*smooth 

animation_duration = 70
fps = int((Nframes)/animation_duration)

# height_normalisation needed - I think there's a bunch Bezier curve things I'll need to get out

In [147]:
def animate(frame):
    tr=(frame%smooth)/float(smooth) ## tr is a fraction of smoothing
    
    t=int(frame/smooth) ## t is index of time slice

    #### Primary plotting (map) ##I'LL NEED TO WORK THIS OUT
    ax1.lines=[line for line in ax1.lines if '_border' in line.get_label()] ## reset lines (except borders) and texts in the plot
    ax1.texts=[]
    ##
    
    if len(dates)-1>t: ## get epi week of next frame
        next_time=decimal_date(epi_weeks[t+1])
    else:
        next_time=decimal_date(epi_weeks[t])
    
    current_time=decimal_date(epi_weeks[t]) ## get epi week of current frame

    delta_time=next_time-current_time ## find interval step size

    ax1.text(0.05,0.1,'Epi week: %s\nDecimal time: %.3f'%(convert_date(epi_weeks[t],'%Y-%m-%d','%Y-%b-%d'),decimal_date(epi_weeks[t])+(delta_time*tr)),size=40,transform=ax1.transAxes) ## add text to indicate current time point
    
    #ax1.text(0.05,0.0,'@evogytis',size=28,ha='left',va='bottom',transform=ax1.transAxes)
    
    exists=[k for k in objects if k.node_parent and k.node_parent.absolute_time<=current_time+(delta_time*tr)<=k.absolute_time] ## identify lineages that exist at current timeslice

    lineage_locations=[c.location for c in exists]## identify locations where lineages are present
    lineage_counts = Counter(lineage_locations)
    presence = lineage_counts.keys()

    ###COME BACK TO THIS
    circle=[c.set_radius(0) for c in ax1.patches if '_circle' in c.get_label()] ## reset circle sizes
    ####
    
    for region in presence: ## iterate through every region where a lineage exists
        size=lineage_counts[region] ## count how many other lineages there are

        circle=[c for c in ax1.patches if c.get_label()=='%s_circle'%(region)][0] ## fetch circle at the location

        circle.set_radius(0.02+size*0.003) ## update its radius

    
    cur_slice=current_time+(delta_time*tr)

    for lineage in travelers: ## iterate through travelling lineages
        transition_time=(lineage.absolute_time+lineage.node_parent.absolute_time)/2.0 ## branch begins travelling mid-branch

        if cur_slice-depth<transition_time<cur_slice+depth: ## make sure transition is within period of animation
            frac=1-(transition_time-cur_slice)/float(depth) ## frac will go from 0.0 to 2.0

            origin=lineage.node_parent.location ## fetch locations
            destination=lineage.location

            point_a=pop_centres[origin] ## find coordinates of start and end locations
            begin_x,begin_y=point_a
            point_b=pop_centres[destination]
            end_x,end_y=point_b

            fc='k' ## colour line black unless it's travelling between countries

            #origin_country=location_to_country[origin] ## get countries for start and end points
            #destination_country=location_to_country[destination]
            
#             if origin_country!=destination_country: ## if lineage travelling internationally - colour it by origin country
#                 countryColour=colours[origin_country]
#                 fc=countryColour(0.6)

            distance=math.sqrt(math.pow(begin_x-end_x,2)+math.pow(begin_y-end_y,2)) ## calculate distance between locations

            ####
            normalized_height=height_normalization(cur_slice) ## normalize time of lineage
            ####
            
            adjust_d=-1+(1-normalized_height)+1/float(distance)**0.15+0.5 ## adjust Bezier line control point distance
            
            ####
            n=Bezier_control(pointA,pointB,adjust_d) ## find the coordinates of a point n that is at a distance adjust_d, perpendicular to the mid-point between points A and B
            ####
            
            bezier_start=frac-0.5 ## Bezier line begins at half a fraction along the path
            bezier_end=frac

            if bezier_start<0.0: ## if Bezier line begins outside the interval - make sure it stays within interval
                bezier_start=0.0
            if bezier_end>1.0:
                bezier_end=1.0

            bezier_line=Bezier([pointA,n,pointB],bezier_start,bezier_end,num=Bezier_smooth) ## get Bezier line points

            if bezier_start<1.0: ## only plot if line begins before destination
                for q in range(len(bezier_line)-1): ## iterate through Bezier line segments with fading alpha and reducing width
                    x1,y1=bezier_line[q]
                    x2,y2=bezier_line[q+1]

                    segL=(q+1)/float(len(bezier_line)) ## fraction along length of Bezier line
                    
                    ax1.plot([x1,x2],[y1,y2],lw=7*segL,alpha=1,color=fc,zorder=99,solid_capstyle='round') ## plot actual lineage

                    ax1.plot([x1,x2],[y1,y2],lw=10*segL,alpha=1,color='w',zorder=98,solid_capstyle='round') ## plot underlying white background to help lineages stand out

    
    #locations currently district list
    for i,loc in enumerate(locations):  ##plot new districts
        #some bits taken out here that I think I don't need
        
        if len(epi_weeks)-1>t:
            next_cases=cases_by_location[loc][epi_weeks[t+1]]
        else:
            next_cases=cases_by_location[loc][epi_weeks[t]]

        cur_cases=cases_by_location[loc][epi_weeks[t]]

        #country_max=1+float(maxByCountryTemporal[country]) ## get the maximum number of cases seen in the country at any point
        max_cases = 1 + max_cases
        
        interpolate=1+cur_cases+(next_cases-cur_cases)*tr ## interpolate between current and next cases (add one so that single cases show up after log normalization)     

        ####
        c=country_colour(np.log10(interpolate)/np.log10(country_max))
        #####
        
        polygons=[p for p in ax1.patches if p.get_label()=='%s_polygon'%(loc)]
        for polygon in polygons:
            polygon.set_facecolor(c) ## change the colour of locations based on cases
    
    frame+=1 ## next frame
    
    update=10 ## update progress bar every X frames
    
    
    
    
    
    

In [None]:
### Not sure I'm going to use this, I think I'm more interested in just lineage mvmt rather than the tree too
    
    #### Secondary plotting (tree)
    Ls2=[x for x in ax2.lines if 'Colour' not in str(x.get_label())] ## fetch all the lines with labels in tree plot
    partials=[x for x in ax2.lines if 'partial' in str(x.get_label())]
    finished_lines=[x for x in ax2.lines if 'finished' in str(x.get_label())]
    finished_points=[x for x in ax2.collections if 'finished' in str(x.get_label())]
    
    finished_labels=[str(x.get_label()) for x in finished_lines]+[str(x.get_label()) for x in finished_points]
    partial_labels=[str(x.get_label()) for x in partials]
    
    if frame%update==0: ## progress bar
        clear_output()
        timeElapsed=(time.time() - t0)/60.0
        progress=int((frame*(50/float(Nframes))))
        percentage=frame/float(Nframes)*100
        rate=timeElapsed/float(frame)
        ETA=rate*(Nframes-frame)
        sys.stdout.write("[%-50s] %6.2f%%  frame: %5d %10s  time: %5.2f min  ETA: %5.2f min (%6.5f s/operation) %s %s %s" % ('='*progress,percentage,frame,dates2[t],timeElapsed,ETA,rate,len(partials),len(finished_lines),len(finished_points)))
        sys.stdout.flush()

        
    ####
    ## COMMENT this bit out if you don't want the tree to appear out of the time arrow
    ####
    for ap in ll.Objects:
        idx='%s'%(ap.index)
        xp=ap.parent.absoluteTime

        x=ap.absoluteTime
        y=ap.y

        location=ap.traits[locTrait]
        country=location_to_country[location]
        cmap=colours[country]
        c=cmap(normalized_coords[location])
        
        if xp<=cur_slice<x: ## branch is intersected
            if 'partial_%s'%(idx) in partial_labels: ## if branch was drawn before
                l=[w for w in partials if 'partial_%s'%(idx)==str(w.get_label())][-1]
                l.set_data([xp,cur_slice],[y,y])
            else: ## branch is intersected, but not drawn before
                ax2.plot([xp,cur_slice],[y,y],lw=branchWidth,color=c,zorder=99,label='partial_%s'%(ap.index))
                
        if x<=cur_slice: ## time arrow passed branch - add it to finished class
            if 'partial_%s'%(idx) in partial_labels:
                l=[w for w in partials if 'partial_%s'%(idx)==str(w.get_label())][-1]
                l.set_data([xp,x],[y,y])
                l.set_label('finished_%s'%(idx))
                finished_labels.append('finished_%s'%(idx))
                
            if 'finished_%s'%(idx) not in finished_labels:
                ax2.plot([xp,x],[y,y],lw=branchWidth,color=c,zorder=99,label='finished_%s'%(ap.index))
                
            if 'partial_%s'%(idx) in partial_labels or 'finished_%s'%(idx) not in finished_labels:
                if isinstance(ap,bt.leaf):
                    ax2.scatter(x,y,s=tipSize,facecolor=c,edgecolor='none',zorder=102,label='finished_%s'%(ap.index))
                    ax2.scatter(x,y,s=tipSize+30,facecolor='k',edgecolor='none',zorder=101,label='finished_%s'%(ap.index))
                elif isinstance(ap,bt.node):
                    yl=ap.children[0].y
                    yr=ap.children[-1].y
                    ax2.plot([x,x],[yl,yr],lw=branchWidth,color=c,zorder=99,label='finished_%s'%(ap.index))
    ####
    ## COMMENT this bit out if you don't want the tree to appear out of the time arrow
    ####
                
    for l in Ls2:
        if 'time' in l.get_label():
            l.set_data([cur_slice,cur_slice],[0,1]) ## adjust time arrow
            
        #### 
        ## UNCOMMENT this bit if you'd like lineages to be coloured over time
        ####
#         else:
#             ## fetch all line data
#             d_xs,d_ys=l.get_data()
            
#             ## extract x coordinate
#             start,end=d_xs
            
#             ## if time arrow passed end point of line - delete line
#             if end<cur_slice:
#                 ax2.lines.remove(l)
                
#             ## if time arrow passed start of line - adjust start of line
#             elif start<cur_slice:
#                 l.set_data([cur_slice,end],d_ys)
    
#     ## iterate over collections (scatter points) in tree plot
#     Ps2=[x for x in ax2.collections if 'Colour' not in str(x.get_label())]
    
#     for p in Ps2:
#         ## fetch coordinates
#         coords=p.get_offsets()
#         ## only alter points with 1 coordinate
#         if len(coords)==1:
#             ## remove black and white point if time arrow has passed
#             if coords[0][0]<=float(cur_slice):
#                 ax2.collections.remove(p)
        #### 
        ## UNCOMMENT this bit if you'd like lineages to be coloured over time
        ####
    
    ### Tertiary plotting (cases)
    Ls3=[x for x in ax3.lines if 'Colour' not in str(x.get_label())] ## fetch all the lines with labels in cases plot
    
    for l in Ls3:
        if 'time' in l.get_label():
            l.set_data([cur_slice,cur_slice],[0,1]) ## adjust time arrow
        else:
            d=l.get_xydata() ## fetch all line data
            
            for e in range(len(d)-1): ## iterate over points
                x_now=d[:,0][e] ## get coordinates of current and next positions
                x_nex=d[:,0][e+1]

                y_now=d[:,1][e]
                y_nex=d[:,1][e+1]
                
                if x_now<cur_slice: ## if beginning of line passed time arrow
                    d[:,0][e]=cur_slice # adjust coordinate so it's sitting on top of time arrow
                    d[:,1][e]=y_now+((y_nex-y_now)/(x_nex-x_now))*(cur_slice-x_now) 

In [None]:


## This part will initialise the map, case numbers, and tree (in grey, if so set up)
plt.clf() 
plt.cla()
plt.figure(figsize=(32,18),facecolor='w') ## start figure

gs = gridspec.GridSpec(2, 2,width_ratios=[18,14],height_ratios=[14,4],hspace=0.05555,wspace=0.05882) ## define subplots

ax1 = plt.subplot(gs[0:, 0]) ## ax1 is map
ax2 = plt.subplot(gs[0, 1]) ## ax2 is tree
ax3 = plt.subplot(gs[1, 1]) ## ax3 is cases

for l,local_border in enumerate(global_border): ## plot the international borders
    ax1.plot(column(local_border,0),column(local_border,1),lw=5,color='w',zorder=96,label='%d_border_bg'%(l))
    ax1.plot(column(local_border,0),column(local_border,1),lw=2,color='k',zorder=97,label='%d_border'%(l))
    
for i,loc in enumerate(locations): ## iterate over locations, plot the initial setup
    country=location_to_country[loc]
    countryColour=colours[country]
    
    c=countryColour(0) ## zero cases colour

    if country in required_countries:
        N_lineages=plt.Circle(popCentres[loc],radius=0,label='%s_circle'%(loc),facecolor='indianred',edgecolor='k',lw=1,zorder=100) ## add circle that tracks the number of lineages at location with radius 0 to begin with
        ax1.add_patch(N_lineages)

        for part in location_points[loc]: ## plot every part of each location (islands, etc)
            poly=plt.Polygon(part,facecolor=c,edgecolor='grey',lw=1,label='%s_polygon'%(loc),closed=True,zorder=95)
            ax1.add_patch(poly)

ax1.spines['top'].set_visible(False) ## remove borders and axis labels
ax1.spines['right'].set_visible(False)
ax1.spines['left'].set_visible(False)
ax1.spines['bottom'].set_visible(False)
ax1.tick_params(size=0)
ax1.set_xticklabels([])
ax1.set_yticklabels([])

ax1.set_ylim(ylimits) ## set plot limits
ax1.set_xlim(xlimits)

xlabels=['2013-%02d-01'%x for x in range(12,13)] ## setup time labels
xlabels+=['2014-%02d-01'%x for x in range(1,13)]
xlabels+=['2015-%02d-01'%x for x in range(1,13)]
xlabels+=['2016-%02d-01'%x for x in range(1,3)]

################
## Secondary plot begins - CASES
################
for c,country in enumerate(required_countries): ## iterate through countries
    greyColour=mpl.cm.Greys
    countryColour=colours[country]
    xs=[decimalDate(x) for x in dates] ## get time points based on epiweeks
    ys=[sum([cases_byLocation[loc][epiweek] for loc in locations if location_to_country[loc]==country]) for epiweek in dates] ## get cases in country at each epiweek
    
    grey_colour=greyColour((required_countries.index(country)+1)/float(len(required_countries)+2))
    
    ax3.plot(xs,ys,lw=3.3,color=grey_colour,zorder=2,label='BW') ## plot the same cases, one in full colour and one in grey on top to obscure colour
    ax3.plot(xs,ys,lw=3,color=countryColour(0.6),zorder=1,label='Colour')
    
ax3.axvline(decimalDate(dates[0]),color='k',lw=3,label='time',zorder=100) ## add time arrow to indicate current time

ax3.set_xticks([decimalDate(x)+1/24.0 for x in xlabels]) ## add ticks, tick labels and month markers
ax3.set_xticklabels([convertDate(x,'%Y-%m-%d','%b\n%Y') if x.split('-')[1]=='01' else convertDate(x,'%Y-%m-%d','%b') for x in xlabels])
[ax3.axvspan(decimalDate(xlabels[x]),decimalDate(xlabels[x])+1/12.,facecolor='k',edgecolor='none',alpha=0.04) for x in range(0,len(xlabels),2)]

ax3.xaxis.tick_bottom() ## make cases plot pretty
ax3.yaxis.tick_left()
ax3.spines['top'].set_visible(False)
ax3.spines['right'].set_visible(False)

ax3.yaxis.set_major_locator(mpl.ticker.MultipleLocator(100))
ax3.set_xlim(decimalDate('2013-12-01'),decimalDate(dates2[-1]))
ax3.set_ylim(0,700)

ax3.tick_params(which='both',direction='out')
ax3.tick_params(axis='x',size=0,labelsize=18)
ax3.tick_params(axis='y',which='major',size=8,labelsize=30)
ax3.tick_params(axis='y',which='minor',size=5)
ax3.set_xticklabels([])
################
## Secondary plot ends - CASES
################


################
## Tertiary plot begins - TREE
################
tipSize=20
branchWidth=2

posteriorCutoff=0.0

####
## UNCOMMENT if you'd like the tree to be plotted in grey initially and get coloured over time
####
## iterate over objects in tree
# for k in ll.Objects:
#     location=k.traits[locTrait]
#     country=location_to_country[location]
#     cmap=colours[country]
#     c=cmap(normalized_coords[location])
    
#     countryColour=mpl.cm.Greys
#     grey_colour=countryColour((required_countries.index(country)+1)/float(len(required_countries)+2))
    
#     y=k.y
#     yp=k.parent.y
    
#     x=k.absoluteTime
#     xp=k.parent.absoluteTime
    
#     if isinstance(k,bt.leaf):
#         ## plot BW tree on top
#         ax2.scatter(x,y,s=tipSize,facecolor=grey_colour,edgecolor='none',zorder=102,label='LeafBW_%d'%(k.index))
#         ax2.scatter(x,y,s=tipSize+30,facecolor='k',edgecolor='k',zorder=100,label='Colour')
#         ax2.plot([xp,x],[y,y],color=grey_colour,lw=branchWidth,zorder=99,label='LeafBranchBW_%d'%(k.index))
        
#         ## plot colour tree underneath
#         ax2.scatter(x,y,s=tipSize,facecolor=c,edgecolor='none',zorder=101,label='LeafColour_%d'%(k.index))
#         ax2.plot([xp,x],[y,y],color=c,lw=branchWidth,zorder=98,label='LeafBranchColour_%d'%(k.index))
        
#     elif isinstance(k,bt.node):
#         yl=k.children[0].y
#         yr=k.children[-1].y
        
#         if xp==0.0:
#             xp=x

#         ls='-'
#         if k.traits['posterior']<posteriorCutoff:
#             ls='--'
            
#         ax2.plot([xp,x],[y,y],color=grey_colour,lw=branchWidth,ls=ls,zorder=99,label='NodeBranchBW_%d'%(k.index))
#         ax2.plot([x,x],[yl,yr],color=grey_colour,lw=branchWidth,ls=ls,zorder=99,label='NodeHbarBW_%d'%(k.index))
        
#         ax2.plot([xp,x],[y,y],color=c,lw=branchWidth,ls=ls,zorder=98,label='NodeBranchColour_%d'%(k.index))
#         ax2.plot([x,x],[yl,yr],color=c,lw=branchWidth,ls=ls,zorder=98,label='NodeHbarColour_%d'%(k.index))
####
## UNCOMMENT if you'd like the tree to be plotted in grey initially and get coloured over time
####

ax2.axvline(decimalDate(dates[0]),color='k',lw=3,label='time',zorder=200) ## add time arrow to indicate current time

ax2.set_xticks([decimalDate(x)+1/24.0 for x in xlabels]) ## add ticks, tick labels and month markers
ax2.set_xticklabels([convertDate(x,'%Y-%m-%d','%b\n%Y') if x.split('-')[1]=='01' else convertDate(x,'%Y-%m-%d','%b') for x in xlabels])
[ax2.axvspan(decimalDate(xlabels[x]),decimalDate(xlabels[x])+1/12.,facecolor='k',edgecolor='none',alpha=0.04) for x in range(0,len(xlabels),2)]

ax2.xaxis.tick_bottom() ## make tree plot pretty
ax2.yaxis.tick_left()
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)
ax2.spines['left'].set_visible(False)
ax2.spines['bottom'].set_visible(False)

ax2.tick_params(axis='x',size=0)
ax2.tick_params(axis='y',size=0)
ax2.set_xticklabels([])
ax2.set_yticklabels([])

ax2.set_xlim(decimalDate('2013-12-01'),decimalDate(dates2[-1]))
ax2.set_ylim(-5,len(ll.Objects)/2.0+6)
################
## Tertiary plot ends - TREE
################

for i in range(0,Nframes): ## iterate through each frame
    animate(i) ## animate will modify the map, tree and cases
    plt.savefig(local_output+'EBOV_animation/ani_frame_%05d.png'%(i), format='png',bbox_inches='tight',dpi=dpi) ## save individual frames for stitching up using 3rd party software (e.g. FFMpeg)
    
print '\n\nDONE!'

## Expect a HUGE slow down around August 2014 (about 0.02 s/frame) due to lots of EBOV movement 
print '\nTime taken: %.2f minutes'%((time.time() - t0)/60.0)

fps=int((Nframes)/animation_duration)
print 'Recommended fps to get animation %d seconds long: %d'%(animation_duration,fps)
plt.show()