In [None]:
import baltic as bt
import pandas as pd
import numpy as np
import matplotlib as mpl
from matplotlib import pyplot as plt
from datetime import datetime as dt
from datetime import timedelta
import time
from io import StringIO
import altair as alt
from altair import datum
import arviz as az
from scipy.stats import gaussian_kde

alt.data_transformers.disable_max_rows()

In [None]:
def get_taxa_lines(tree_path):    
    # write out a temp tree file
#     temp_tree = tree_path.replace(".trees",".temp.tree")
#     with open(temp_tree, "w") as outfile: 
#         outfile.write("")

    lines_to_write = ""
    with open(trees, 'rU') as infile:
        for line in infile: ## iterate through each line
            if 'state' not in line.lower(): #going to grab all the interesting stuff in the .trees file prior to the newick tree strings
                lines_to_write = lines_to_write + line

    return(lines_to_write)

In [None]:
def get_burnin_value(tree_path, burnin_percent):
    with open(tree_path, 'rU') as infile:
        numtrees = 0
        for line in infile: ## iterate through each line
            if 'state' in line.lower(): #going to grab all the interesting stuff in the .trees file prior to the newick tree strings
                numtrees += 1
    
    burnin = numtrees * burnin_percent
    return(burnin)

In [None]:
#making decimal date from string dates adapted from stackoverflow (thank you coding geniuses)
def toYearFraction(date):
    def sinceEpoch(date): # returns seconds since epoch
        return time.mktime(date.timetuple())
    s = sinceEpoch

    year = date.year
    startOfThisYear = dt(year=year, month=1, day=1)
    startOfNextYear = dt(year=year+1, month=1, day=1)

    yearElapsed = s(date) - s(startOfThisYear)
    yearDuration = s(startOfNextYear) - s(startOfThisYear)
    fraction = yearElapsed/yearDuration

    return date.year + fraction

In [None]:
#need to convert the decimal dates back to calendar dates 
def convert_partial_year(number):

    year = int(number)
    d = timedelta(days=(number - year)*(365 + is_leap(year)))
    day_one = dt(year,1,1)
    date = d + day_one
    date = dt.strftime(date, '%Y-%m-%d')
    return date

In [None]:
def is_leap(number):
    if number == 2020:
        leap = 1
    else:
        leap = 0
    return leap

In [None]:
def convert_format(number):
    date = dt.strptime(number, '%Y-%m-%d')
    date = dt.strftime(date, '%Y-%m')
    return date

In [None]:
def enumerate_migration_events(tree, traitType):
        
    output_dict = {}
    migration_events_counter = 0
        
    for k in tree.Objects:
        trait = k.traits[traitType]
        parent_node = k.parent
        
        if traitType not in parent_node.traits:
            parent_trait = "root"
        
        # only write out migration events that are not from root to deme
        else:
            parent_trait = parent_node.traits[traitType]
        
            if trait != parent_trait:
                migration_events_counter += 1
                

                migration_event = parent_trait + "-to-" + trait
                migration_date = parent_node.absoluteTime                

                # write to output dictionary
                output_dict[migration_events_counter] = {"type":migration_event, "date":migration_date,
                                                        "parent_host":parent_trait,
                                                        "child_host": trait}
    
    return(output_dict)

### work on presistence times

In [None]:
#need to convert the decimal dates back to calendar dates 
def convert_persistence(number):

    
    d = timedelta(days=(number)*(365))
    
    return d.total_seconds()

In [None]:
#this is adapted from Bedford et al in nature where we start at a time and then walk backwards up the tree until the location changes
def estimate_persistence(tree, typeTrait):
        
    output_dict = {}
    persistence_counter = 0
        
    for k in tree.Objects:
        
        
        trait = k.traits[typeTrait]
        parent_node = k.parent
        if ('root' in parent_node.traits) or (parent_node.traits == {}) :
            parent_trait = "root"
        
        # only write out migration events that are not from root to deme
        elif k.branchType=='leaf':
            tip_date= k.absoluteTime
            tip_name = k.name
            confirm = True
            while confirm == True:
                try:
                    parent_trait = parent_node.traits[typeTrait]
                    
                    if trait != parent_trait:
                        persistence_counter += 1

                        migration_event = parent_trait + "-to-" + trait
                        migration_date = parent_node.absoluteTime
                        
                        persistence =  tip_date - migration_date

    
                        # write to output dictionary
                        output_dict[persistence_counter] = {"type":migration_event, "migration date":migration_date,"tip date": tip_date, "persistance": persistence, "tip_name" : tip_name,
                                                        "parent_host":parent_trait,
                                                        "child_host": trait}
                        confirm = False
                    else:    
                        parent_node = parent_node.parent    

                except:
                    break
                    

        else:
            continue
    
    return(output_dict)

In [None]:
#counts all migration events and records parent and child nodes
def run_mig_counts(all_trees, traitType):
    start_time = time.time()

    with open(all_trees, "r") as infile:

        tree_counter = 0
        trees_processed = 0
        migrations_dict = {}
        persistence_dict = {}

        for line in infile:
            if 'tree STATE_' in line:
                tree_counter += 1

                if tree_counter > burnin:
                    temp_tree = StringIO(taxa_lines + line)
                    tree = bt.loadNexus(temp_tree, tip_regex='_([0-9\-]+)$')
                    tree.setAbsoluteTime(2022.7438)
                    trees_processed += 1

                    # iterate through the tree and pull out all migration events
                    migrations_dict[tree_counter] = enumerate_migration_events(tree, traitType)
                    persistence_dict[tree_counter] =  estimate_persistence(tree, traitType)

    # print the amount of time this took
    total_time_seconds = time.time() - start_time
    total_time_minutes = total_time_seconds/60
    print("this took", total_time_seconds, "seconds (", total_time_minutes," minutes) to run on", trees_processed, "trees")
   
    """this will generate a multi-index dataframe from the migrations dictionary"""
    migrations_df = pd.DataFrame.from_dict({(i,j): migrations_dict[i][j] 
                           for i in migrations_dict.keys() 
                           for j in migrations_dict[i].keys()},
                       orient='index')

    migrations_df.reset_index(inplace=True)
    migrations_df.rename(columns={'level_0': 'tree_number', 'level_1': 'migration_event_number'}, inplace=True)
    
    persistence_df = pd.DataFrame.from_dict({(i,j): persistence_dict[i][j] 
                           for i in persistence_dict.keys() 
                           for j in persistence_dict[i].keys()},
                       orient='index')

    persistence_df.reset_index(inplace=True)
    persistence_df.rename(columns={'level_0': 'tree_number', 'level_1': 'migration_event_number'}, inplace=True)
    
    return(migrations_df, persistence_df)

## now regional

In [None]:
trees =  "../beast_results/contextual_trees_downsampled.trees"

In [None]:
all_trees = trees
burnin_percent = 0.6
taxa_lines = get_taxa_lines(all_trees)
burnin = get_burnin_value(all_trees, burnin_percent)
print(burnin)

In [None]:
migrations_df, persistence_df = run_mig_counts(all_trees, traitType = "typeTrait")
 

In [None]:
persistence_df


In [None]:
#persistence_df = persistence_df[(persistence_df.type == "no-to-yes") |(persistence_df.type == "yes-to-no")  ]
#migrations_df = migrations_df[(migrations_df.type == "no-to-yes") |(migrations_df.type == "yes-to-no")  ]


In [None]:
imports = migrations_df.groupby(["child_host", "tree_number"])['migration_event_number'].count().reset_index()
exports = migrations_df.groupby(["parent_host", "tree_number"])['migration_event_number'].count().reset_index()

imports.index = imports.child_host
exports.index = exports.parent_host

#mig_dict = {}
mig_dict_imports={x:[] for x in imports["child_host"].unique()}
mig_dict_exports={x:[] for x in exports["parent_host"].unique()}

for items in imports.index.unique():
    mig_dict_imports[items].append(imports.loc[items,"migration_event_number"].values)
    mig_dict_exports[items].append(exports.loc[items,"migration_event_number"].values)

In [None]:

colors = {"yes":"#E67932",
          "no":"#5AA5AB",
          "other":"#511EA8", 
         }


In [None]:
persistence_df.type.value_counts()

In [None]:
persist = persistence_df.groupby(["child_host", "tree_number"])['persistance'].mean().reset_index()
persist['persistance'] = persist['persistance'].map(convert_persistence)
persist['persistance'] = persist['persistance'].div(86400) #calculating number of days from seconds
persist.index = persist.child_host


persist_dict={x:[] for x in persist["child_host"].unique()}

for items in persist.index.unique():
    persist_dict[items].append(persist.loc[items,"persistance"].values)
    


In [None]:
persist_dict.keys()

In [None]:
import matplotlib.pyplot as plt

fig, axs = plt.subplots(1, 3, figsize=(75, 18), facecolor='w')
#fig.suptitle('Publication Style Lettering', fontsize=30, fontweight='bold', x=0.05, y=1.05)
region_list = ["KC","other", "HCT"]
persistance_region_list = ["WA", "HCT"]


# First plot
pos = []
for index, place in enumerate(mig_dict_imports.keys()):


    violin_plot = axs[0].violinplot(mig_dict_imports[place], positions=[index], widths=0.8,
                                    showmedians=True, bw_method=0.75, showextrema=False)
    print(place, np.quantile(mig_dict_imports[place], 0.25))
    print( place, np.quantile(mig_dict_imports[place], 0.5))
    print(place, np.quantile(mig_dict_imports[place], 0.75))
    violin_plot["bodies"][0].set_facecolor(colors[place])
    violin_plot["bodies"][0].set_edgecolor(colors[place])
    violin_plot["bodies"][0].set_alpha(.8)
    violin_plot["cmedians"].set_edgecolor("black")
    pos.append(index)

divisions = list(mig_dict_imports.keys())
axs[0].set_xticks(pos)
axs[0].set_xticklabels(region_list,  fontsize=65,rotation=45, fontweight='bold')
axs[0].set_ylabel("Number of Introductions", fontsize=65, fontweight='bold')
axs[0].tick_params(axis='y', labelsize=70)

# Second plot
pos = []
for index, place in enumerate(mig_dict_exports.keys()):
    print(place, np.quantile(mig_dict_exports[place], 0.25))
    print( place, np.quantile(mig_dict_exports[place], 0.5))
    print(place, np.quantile(mig_dict_exports[place], 0.75))
    violin_plot = axs[1].violinplot(mig_dict_exports[place], positions=[index], widths=0.8,
                                    showmedians=True, bw_method=0.75, showextrema=False)
    violin_plot["bodies"][0].set_facecolor(colors[place])
    violin_plot["bodies"][0].set_edgecolor(colors[place])
    violin_plot["bodies"][0].set_alpha(.8)
    violin_plot["cmedians"].set_edgecolor("black")
    pos.append(index)

divisions = list(mig_dict_exports.keys())
axs[1].set_xticks(pos)
axs[1].set_xticklabels(region_list,  fontsize=65,rotation=45, fontweight='bold')
axs[1].set_ylabel("Number of Exportations", fontsize=65, fontweight='bold')
axs[1].tick_params(axis='y', labelsize=70)

# Third plot
pos = []
for index, place in enumerate(persist_dict.keys()):
    print(place, np.quantile(persist_dict[place], 0.25))
    print(place, np.quantile(persist_dict[place], 0.5))
    print(place, np.quantile(persist_dict[place], 0.75))
    violin_plot = axs[2].violinplot(persist_dict[place], positions=[index], widths=0.8,
                                    showmedians=True, bw_method=0.75, showextrema=False)
    violin_plot["bodies"][0].set_facecolor(colors[place])
    violin_plot["bodies"][0].set_edgecolor(colors[place])
    violin_plot["bodies"][0].set_alpha(.8)
    violin_plot["cmedians"].set_edgecolor("black")
    pos.append(index)
    
print(pos)

countries = list(persist_dict.keys())
axs[2].set_xticks(pos)
axs[2].set_xticklabels(region_list, fontsize=65,rotation=45, fontweight='bold')
axs[2].set_ylabel("Average Persistence times (days)", fontsize=65, fontweight='bold')
axs[2].tick_params(axis='y', labelsize=70)
axs[2].set_ylim(bottom=0)


# Add publication style lettering in the upper left corner
fig.text(0.1, 0.93, "B", fontsize=70, fontweight='bold')
fig.text(0.37, 0.93, "C", fontsize=70, fontweight='bold')
fig.text(0.66, 0.93, "D", fontsize=70, fontweight='bold')

plt.savefig('../figures/imports_exports_persistence.png', dpi=300, bbox_inches='tight')
plt.show()


In [None]:
migrations_df['calendar_date'] = migrations_df.date.map(convert_partial_year)
migrations_df['year-month'] = migrations_df['calendar_date'].map(convert_format)

In [None]:

def return_proportions_dataframe(input_df, time_unit):
    output_df = pd.DataFrame()
    #north_kc = ["South_King_County-to-North_King_County", "none-to-North_King_County"]
    #south_kc = ["North_King_County-to-South_King_County", "none-to-South_King_County"]
    
    for tree_number in set(input_df['tree_number'].tolist()):
        local_df1 = input_df[input_df['tree_number'] == tree_number]
        
        for v in list(set(input_df['type'].tolist())):
            local_df = local_df1[local_df1['type'] == v]
            total_transitions = len(local_df)

            for item in set(input_df[time_unit].tolist()):
                local_df2 = local_df[local_df[time_unit] == item]
                transitions_in_time_unit = len(local_df2)
                
                               
                if total_transitions != 0 :
                    prop_transitions_in_time_unit = transitions_in_time_unit/total_transitions
                else:
                    prop_transitions_in_time_unit = 0
                    

                to_add = pd.DataFrame({"migration_direction":[v],"time_unit":[item],"tree_number":[tree_number], 
                                       "total_transitions":[total_transitions],
                                       "transitions_in_time_interval":[transitions_in_time_unit],
                                      "proportion_transitions_in_time_interval":[prop_transitions_in_time_unit]})
                output_df = output_df.append(to_add)
            
    return(output_df)

In [None]:
start_time = time.time()

mig = return_proportions_dataframe(migrations_df, "year-month")

total_time_seconds = time.time() - start_time
total_time_minutes = total_time_seconds/60
print(total_time_minutes)

mig.head()

In [None]:
mig.groupby(["migration_direction", "time_unit"])["total_transitions"].count()

In [None]:
mig.migration_direction[mig.migration_direction == "no-to-yes"] = "KC to HCT"
mig.migration_direction[mig.migration_direction == "yes-to-no"] = "HCT to KC"


In [None]:
error_bars = alt.Chart(mig).mark_errorbar(extent='ci').encode(
  x=alt.X('time_unit:T',axis=alt.Axis(title="", grid=True,tickCount = "month",  format="%B %Y")),
  y=alt.Y('transitions_in_time_interval:Q', title = "Number of migration events ", axis=alt.Axis( grid=False)), 
  color = alt.Color("migration_direction:N" )
).properties(
    width=800,
    height=300
)

points = alt.Chart(mig).mark_point(filled=True,  opacity = 1, width = 5).encode(
  x=alt.X('time_unit:T'),
  y=alt.Y('transitions_in_time_interval:Q', aggregate='mean'),
    color = alt.Color("migration_direction:N")
).properties(
    width=800,
    height=300
)  

band4 = alt.Chart(mig).mark_line( interpolate='monotone', opacity = 0.5).encode(
    x=alt.X('time_unit:T'),
    y=alt.Y('mean(transitions_in_time_interval)'), 
    color =alt.Color('migration_direction:N')).properties(
    width=800,
    height=300
)

error_bars  + points+band4

In [None]:
mig.migration_direction.value_counts()

In [None]:
domain = ['WA to HCT', 'HCT to WA']
range_ = ['#9461bd',"#2ca02c"]
error_bars = alt.Chart(mig).mark_errorbar(extent='ci').encode(
  x=alt.X('time_unit:T',axis=alt.Axis(title="", grid=False,tickCount = "month",  format="%B %Y")),
  y=alt.Y('proportion_transitions_in_time_interval:Q', title = "Proportion of all migration events ", axis=alt.Axis( grid=False, format='%')), 
  color = alt.Color("migration_direction:N",  legend=alt.Legend(title = "Migration Direction", orient = "left", offset = -245, labelFontSize = 16, titleFontSize = 16) )
).properties(
    width=800,
    height=300
)#.transform_filter((datum.migration_direction == "HCT to WA") | (datum.migration_direction == "WA to HCT"))

points = alt.Chart(mig).mark_point(filled=True,  opacity = 1, width = 5).encode(
  x=alt.X('time_unit:T'),
  y=alt.Y('proportion_transitions_in_time_interval:Q', aggregate='mean'),
    color = alt.Color("migration_direction:N")
).properties(
    width=800,
    height=300
)#.transform_filter((datum.migration_direction == "HCT to WA") | (datum.migration_direction == "WA to HCT"))
 

lineplot4 =  alt.Chart(mig).mark_line(interpolate='monotone', opacity = 0.35).encode(
    x=alt.X('time_unit:T'),
    y=alt.Y('mean(proportion_transitions_in_time_interval)'),
    color=alt.Color('migration_direction:N')).properties(
    width=800,
    height=300
)#.transform_filter((datum.migration_direction == "HCT to WA") | (datum.migration_direction == "WA to HCT"))


ave = error_bars + points +lineplot4
ave.configure_axis(
    labelFontSize=16,
    titleFontSize=14
)

In [None]:
mig_short = mig[mig.time_unit > "2019-12"]

In [None]:
domain = ['KC to HCT', 'HCT to KC']
range_ = ['#4A8CC2',"#E29D39"]
error_bars = alt.Chart(mig_short).mark_errorbar(extent='ci').encode(
  x=alt.X('time_unit:T',  axis=alt.Axis(title="", grid=False,tickCount = "month", format="%B %Y")),
  y=alt.Y('proportion_transitions_in_time_interval:Q', title = "Proportion of all migration events ", axis=alt.Axis( grid=False, format='%')), 
  color = alt.Color("migration_direction:N", scale=alt.Scale(domain = domain, range = range_),  legend=alt.Legend(title = "Migration Direction", orient = "left", offset = -245, labelFontSize = 16, titleFontSize = 16) )
).properties(
    width=800,
    height=300
).transform_filter((datum.migration_direction == "HCT to KC") | (datum.migration_direction == "KC to HCT"))

points = alt.Chart(mig_short).mark_point(filled=True,  opacity = 1, width = 5).encode(
  x=alt.X('time_unit:T'),
  y=alt.Y('proportion_transitions_in_time_interval:Q', aggregate='mean'),
    color = alt.Color("migration_direction:N")
).properties(
    width=800,
    height=300
).transform_filter((datum.migration_direction == "HCT to KC") | (datum.migration_direction == "KC to HCT"))
 

lineplot4 =  alt.Chart(mig_short).mark_line(interpolate='monotone', opacity = 0.35).encode(
    x=alt.X('time_unit:T'),
    y=alt.Y('mean(proportion_transitions_in_time_interval)'),
    color=alt.Color('migration_direction:N')).properties(
    width=800,
    height=300
).transform_filter((datum.migration_direction == "HCT to KC") | (datum.migration_direction == "KC to HCT"))


ave = error_bars + points +lineplot4
ave.configure_axis(
    labelFontSize=16,
    titleFontSize=14
)

## repeat for temporal subsampling

In [None]:
trees =  "../beast_results/temporal_925_skyline.align_925_temporal.trees"

In [None]:
all_trees = trees
burnin_percent = 0.4
taxa_lines = get_taxa_lines(all_trees)
burnin = get_burnin_value(all_trees, burnin_percent)
print(burnin)

In [None]:
migrations_df, persistence_df = run_mig_counts(all_trees, traitType = "max")


In [None]:
imports = migrations_df.groupby(["child_host", "tree_number"])['migration_event_number'].count().reset_index()
exports = migrations_df.groupby(["parent_host", "tree_number"])['migration_event_number'].count().reset_index()

imports.index = imports.child_host
exports.index = exports.parent_host

#mig_dict = {}
mig_dict_imports={x:[] for x in imports["child_host"].unique()}
mig_dict_exports={x:[] for x in exports["parent_host"].unique()}

for items in imports.index.unique():
    mig_dict_imports[items].append(imports.loc[items,"migration_event_number"].values)
    mig_dict_exports[items].append(exports.loc[items,"migration_event_number"].values)

In [None]:
persist = persistence_df.groupby(["child_host", "tree_number"])['persistance'].mean().reset_index()
persist['persistance'] = persist['persistance'].map(convert_persistence)
persist['persistance'] = persist['persistance'].div(86400) #calculating number of days from seconds
persist.index = persist.child_host


persist_dict={x:[] for x in persist["child_host"].unique()}

for items in persist.index.unique():
    persist_dict[items].append(persist.loc[items,"persistance"].values)
    


In [None]:
import matplotlib.pyplot as plt

fig, axs = plt.subplots(1, 3, figsize=(75, 18), facecolor='w')
#fig.suptitle('Publication Style Lettering', fontsize=30, fontweight='bold', x=0.05, y=1.05)
region_list = ["WA", "HCT"]

# First plot
pos = []
for index, place in enumerate(mig_dict_imports.keys()):


    violin_plot = axs[0].violinplot(mig_dict_imports[place], positions=[index], widths=0.8,
                                    showmedians=True, bw_method=0.2, showextrema=False)
    violin_plot["bodies"][0].set_facecolor(colors[place])
    violin_plot["bodies"][0].set_edgecolor(colors[place])
    violin_plot["bodies"][0].set_alpha(.8)
    violin_plot["cmedians"].set_edgecolor("black")
    pos.append(index)

divisions = list(mig_dict_imports.keys())
axs[0].set_xticks(pos)
axs[0].set_xticklabels(region_list,  fontsize=65,rotation=45, fontweight='bold')
axs[0].set_ylabel("Number of Introductions", fontsize=65, fontweight='bold')
axs[0].tick_params(axis='y', labelsize=70)

# Second plot
pos = []
for index, place in enumerate(mig_dict_exports.keys()):
    print(place, np.quantile(mig_dict_exports[place], 0.25))
    print( place, np.quantile(mig_dict_exports[place], 0.5))
    print(place, np.quantile(mig_dict_exports[place], 0.75))
    violin_plot = axs[1].violinplot(mig_dict_exports[place], positions=[index], widths=0.8,
                                    showmedians=True, bw_method=0.75, showextrema=False)
    violin_plot["bodies"][0].set_facecolor(colors[place])
    violin_plot["bodies"][0].set_edgecolor(colors[place])
    violin_plot["bodies"][0].set_alpha(.8)
    violin_plot["cmedians"].set_edgecolor("black")
    pos.append(index)

divisions = list(mig_dict_exports.keys())
axs[1].set_xticks(pos)
axs[1].set_xticklabels(region_list,  fontsize=65,rotation=45, fontweight='bold')
axs[1].set_ylabel("Number of Exportations", fontsize=65, fontweight='bold')
axs[1].tick_params(axis='y', labelsize=70)

# Third plot
pos = []
for index, place in enumerate(persist_dict.keys()):
    print(place, np.quantile(persist_dict[place], 0.25))
    print(place, np.quantile(persist_dict[place], 0.5))
    print(place, np.quantile(persist_dict[place], 0.75))
    violin_plot = axs[2].violinplot(persist_dict[place], positions=[index], widths=0.8,
                                    showmedians=True, bw_method=0.75, showextrema=False)
    violin_plot["bodies"][0].set_facecolor(colors[place])
    violin_plot["bodies"][0].set_edgecolor(colors[place])
    violin_plot["bodies"][0].set_alpha(.8)
    violin_plot["cmedians"].set_edgecolor("black")
    pos.append(index)

countries = list(persist_dict.keys())
axs[2].set_xticks(pos)
axs[2].set_xticklabels(region_list, fontsize=65,rotation=45, fontweight='bold')
axs[2].set_ylabel("Average Persistence times (days)", fontsize=65, fontweight='bold')
axs[2].tick_params(axis='y', labelsize=70)
axs[2].set_ylim(bottom=0)


# Add publication style lettering in the upper left corner
fig.text(0.1, 0.93, "B", fontsize=70, fontweight='bold')
fig.text(0.37, 0.93, "C", fontsize=70, fontweight='bold')
fig.text(0.66, 0.93, "D", fontsize=70, fontweight='bold')

plt.savefig('../figures/temporal_925_imports_exports_persistence.png', dpi=300, bbox_inches='tight')
plt.show()


In [None]:
migrations_df['calendar_date'] = migrations_df.date.map(convert_partial_year)
migrations_df['year-month'] = migrations_df['calendar_date'].map(convert_format)

In [None]:
start_time = time.time()

mig = return_proportions_dataframe(migrations_df, "year-month")

total_time_seconds = time.time() - start_time
total_time_minutes = total_time_seconds/60
print(total_time_minutes)

mig.head()

In [None]:
mig.migration_direction[mig.migration_direction == "no-to-yes"] = "WA to HCT"
mig.migration_direction[mig.migration_direction == "yes-to-no"] = "HCT to WA"

In [None]:
error_bars = alt.Chart(mig).mark_errorbar(extent='ci').encode(
  x=alt.X('time_unit:O', title = "date"),
  y=alt.Y('proportion_transitions_in_time_interval:Q'), color = alt.Color("migration_direction:N")
).properties(
    width=800,
    height=300
)

points = alt.Chart(mig).mark_point(filled=True,  opacity = 1, width = 5).encode(
  x=alt.X('time_unit:O'),
  y=alt.Y('proportion_transitions_in_time_interval:Q', aggregate='mean'),
    color = alt.Color("migration_direction:N")
).properties(
    width=800,
    height=300
)  

lineplot4 =  alt.Chart(mig).mark_line(interpolate='monotone', opacity = 0.35).encode(
    x=alt.X('time_unit:O'),
    y=alt.Y('mean(proportion_transitions_in_time_interval)'),
    color=alt.Color('migration_direction:N')).properties(
    width=800,
    height=300
)

ave = error_bars + points +lineplot4
ave

In [None]:
lineplot4 =  alt.Chart(mig, width = 750).mark_line(interpolate='monotone').encode(
    x=alt.X('time_unit:T'),
    y=alt.Y('mean(transitions_in_time_interval)'),
    color=alt.Color('migration_direction:N')).properties(
    width=800,
    height=300
)

band4 = alt.Chart(mig).mark_errorband(extent='ci', interpolate='monotone').encode(
    x=alt.X('time_unit:T'),
    y=alt.Y('transitions_in_time_interval'), 
    color =alt.Color('migration_direction:N')).properties(
    width=800,
    height=300
)

lineplot4 +band4 