## An example geography with two islands

Let's evaluate the accuracy of the waypoints model using two island simulation and hex-grid waypoints.

We'll use the real times of ancestral nodes in the graph so that we only have to infer the location of ancestors.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt

import sys
sys.path.insert(0, "..")

import tsdate
from tspyro.diffusion import make_hex_grid
import tspyro

import tskit
import pyslim

import numpy as np
from sklearn.metrics import mean_squared_log_error, mean_squared_error

import pyro
import torch

# Draw islands
Create the island model, output greyscale for SLiMGui to read in.

In [None]:
# Global Variables for Islands Geography
bounds = dict(west=0, east=4.0, south=0, north=2.24)
island_center = [(1.33, 1.2), (2.66, 1.2)]
island_radius = [0.33, 0.33]

In [None]:
circle1 = plt.Circle(island_center[0], island_radius[0] * 2, color='black')
circle2 = plt.Circle(island_center[1], island_radius[1] * 2, color='black')

fig, ax = plt.subplots() # note we must use plt.subplots, not plt.subplot
ax.set_axis_off()
# (or if you have an existing figure)
# fig = plt.gcf()
# ax = fig.gca()
plt.xlim(0, 4)
plt.ylim(0, 2.4)

ax.add_patch(circle1)
ax.add_patch(circle2)
# plt.axis("off")

fig.savefig("two_islands.png", bbox_inches='tight', pad_inches=0)
from PIL import Image
Image.open('two_islands.png').convert('L').save('two_islands.png')

In [None]:
ts = pyslim.load("examples/two_islands.trees").simplify()
recap_ts = ts.recapitate(recombination_rate=1e-8, Ne=50).simplify()

In [None]:
# Get the real locations of nodes

lat_long = []
for node in recap_ts.nodes():
    if node.individual != -1:
        ind = recap_ts.individual(node.individual)
    
        lat_long.append([ind.location[0], ind.location[1]])
    else:
        print(node)
lat_long = np.array(lat_long)

In [None]:
# Plot the real locations of nodes
plt.scatter(lat_long[:,0], lat_long[:,1], s=0.1)
plt.xlim(0, 4)
plt.ylim(0, 2)

In [None]:
# Match waypoints to model from above
grid_radius = 0.1

def on_land(x, y):
    result = torch.tensor(False)
    for (x0, y0), r in zip(island_center, island_radius):
        result = result | (r > (x - x0) ** 2 + (y - y0) ** 2)
    return result

In [None]:
grid = make_hex_grid(**bounds, radius=grid_radius, predicate=on_land)
waypoints = grid["waypoints"]
transition = grid["transition"]

In [None]:
plt.title("Probability of transitioning from a given point")
plt.scatter(waypoints[:, 0], waypoints[:, 1], c=transition[0])
plt.xlim(bounds["west"], bounds["east"])
plt.ylim(bounds["south"], bounds["north"])
# Add islands in background
circle1 = plt.Circle(island_center[0], island_radius[0] * 2, color='black', alpha=0.1, zorder=-1)
circle2 = plt.Circle(island_center[1], island_radius[1] * 2, color='black', alpha=0.1, zorder=-1)
plt.gca().add_patch(circle1)
plt.gca().add_patch(circle2)

# plt.axis("equal")
plt.colorbar()
plt.tight_layout()

And note the spread of lineages is roughly covered by the waypoints

In [None]:
plt.scatter(waypoints[:, 0], waypoints[:, 1])
plt.xlim(bounds["west"], bounds["east"])
plt.ylim(bounds["south"], bounds["north"])
circle1 = plt.Circle(island_center[0], island_radius[0] * 2, color='black', alpha=0.1, zorder=-1)
circle2 = plt.Circle(island_center[1], island_radius[1] * 2, color='black', alpha=0.1, zorder=-1)
plt.gca().add_patch(circle1)
plt.gca().add_patch(circle2)
plt.tight_layout()
plt.scatter(lat_long[:,0], lat_long[:,1], s=0.1, label="real location of ancestors")
plt.legend()

For fun, let's find a lineage that "jumps" between islands.
But curiously, there don't seem to be any.

In [None]:
# NOTE: could tune distance between two islands until see jumpes
migrants = set()
for tree in recap_ts.trees():
    for node in recap_ts.samples():
        locs = []
        while node != -1 and node < ts.num_nodes:
            locs.append(lat_long[node])
            node = tree.parent(node)
        locs = np.array(locs)
        if np.any(locs[:,0] < 2) and np.any(locs[:,0] > 2):
            migrants.add(node)
print(migrants)

# Testing the four combinations of location models and migration likelihoods

In [None]:
from tspyro import models

In [None]:
real_locations_internal = lat_long[recap_ts.num_samples:,]

In [None]:
# Create the priors for dates
priors = tsdate.build_prior_grid(recap_ts, Ne=10000, approximate_priors=True, timepoints=100, progress=True)

In [None]:
node_times = torch.as_tensor(recap_ts.tables.nodes.time, dtype=torch.get_default_dtype())
leaf_location=torch.as_tensor(lat_long[:recap_ts.num_samples,:], dtype=torch.get_default_dtype())

In [None]:
avg_location = models.NaiveModel(
    recap_ts, Ne=200, prior=priors).get_ancestral_geography(recap_ts, leaf_location)

In [None]:
avg_msle = np.sqrt(mean_squared_error(real_locations_internal, avg_location))#[:-(recap_ts.num_nodes-ts.num_nodes)]))
print("The accuracy to beat is {:.5f}".format(avg_msle))

In [None]:
from tspyro.models import NaiveModel
class ConditionedModel(NaiveModel):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        ts = kwargs["ts"]
        self.internal_times = torch.as_tensor(ts.tables.nodes.time[ts.num_samples:],
                                             dtype=torch.get_default_dtype())
        
    def forward(self, *args, **kwargs):
        with pyro.condition(data={#"migration_scale": torch.tensor(0.05),
                                 "internal_time": self.internal_times}):
            return super().forward(*args, **kwargs)

In [None]:
# Create grid for WayPoint Model
grid_radius = 0.1
grid = make_hex_grid(**bounds, radius=grid_radius, predicate=on_land)
waypoints = grid["waypoints"]
transition = grid["transition"]


In [None]:
steps = 2000
log_every = 100

# Method 1: euclidean_migration and mean_field_location 
_, tspyro_location_1, migration_scale_1, guide_1, losses_1 = models.fit_guide(
    recap_ts,
    leaf_location,
    priors,
    migration_likelihood=models.euclidean_migration,
    location_model=models.mean_field_location, steps=steps, log_every=log_every,
    Model=ConditionedModel,
)
accuracy_1 = np.sqrt(mean_squared_error(real_locations_internal, tspyro_location_1[recap_ts.num_samples:ts.num_nodes]))

# Method 2: euclidean_migration and ReparamLocation
_, tspyro_location_2, migration_scale_2, guide_2, losses_2 = models.fit_guide(
    recap_ts,
    leaf_location,
    priors,
    migration_likelihood=models.euclidean_migration,
    location_model=models.ReparamLocation(recap_ts, leaf_location[:ts.num_samples]),
    steps=steps, log_every=log_every,
    Model=ConditionedModel,
)
accuracy_2 = np.sqrt(mean_squared_error(real_locations_internal, tspyro_location_2[recap_ts.num_samples:ts.num_nodes]))

waypoints_steps = 200
log_every = 10
# Method 3: WayPointMigration and mean_field_location
_, tspyro_location_3, migration_scale_3, guide_3, losses_3 = models.fit_guide(
    recap_ts,
    leaf_location,
    priors,
    migration_likelihood=models.WayPointMigration(transition, waypoints, grid_radius),
    location_model=models.mean_field_location,
    steps=waypoints_steps, log_every=log_every,
    Model=ConditionedModel,
)
accuracy_3 = np.sqrt(mean_squared_error(real_locations_internal, tspyro_location_3[recap_ts.num_samples:ts.num_nodes]))

# Method 4: WayPointMigration and ReparamLocation
_, tspyro_location_4, migration_scale_4, guide_4, losses_4 = models.fit_guide(
    recap_ts,
    leaf_location,
    priors,
    migration_likelihood=models.WayPointMigration(transition, waypoints, grid_radius),
    location_model=models.ReparamLocation(recap_ts, leaf_location[:ts.num_samples]),
    steps=waypoints_steps, log_every=log_every,
    Model=ConditionedModel,
)
accuracy_4 = np.sqrt(mean_squared_error(real_locations_internal, tspyro_location_4[recap_ts.num_samples:ts.num_nodes]))

In [None]:
print(accuracy_1, accuracy_2, accuracy_3, accuracy_4)

In [None]:
# with a learning rate of 0.005 (reparam model)
plt.plot(losses_1)
plt.plot(losses_2)
plt.plot(losses_3)
plt.plot(losses_4)

In [None]:
plt.scatter(lat_long[:,0], lat_long[:,1], s=1, label="real location of ancestors")
plt.scatter(avg_location[:,0],
            avg_location[:,1], s=1, label="average of children location (initalization)")
plt.legend()

In [None]:
plt.scatter(lat_long[:,0], lat_long[:,1], s=1, label="real location of ancestors")
plt.scatter(tspyro_location_1[recap_ts.num_samples:ts.num_nodes][:,0],
            tspyro_location_1[recap_ts.num_samples:ts.num_nodes][:,1], s=1,
            label="inferred location of model 1")
plt.scatter(waypoints[:, 0], waypoints[:, 1])

circle1 = plt.Circle(island_center[0], island_radius[0] * 2, color='black', alpha=0.1, zorder=-1)
circle2 = plt.Circle(island_center[1], island_radius[1] * 2, color='black', alpha=0.1, zorder=-1)
plt.gca().add_patch(circle1)
plt.gca().add_patch(circle2)

plt.legend()

In [None]:
from tspyro import viz

In [None]:
fig, ax = plt.subplots()
viz.plot_diff(ts, lat_long, np.concatenate([np.array(leaf_location), avg_location]), waypoints=waypoints, title="Average of children", ax=ax)

In [None]:
fig, ax = plt.subplots()
viz.plot_diff(ts, lat_long, tspyro_location_1, waypoints=waypoints, title="Waypoint", ax=ax)

In [None]:
fig, ax = plt.subplots()
viz.plot_diff(ts, lat_long, tspyro_location_2, waypoints=waypoints, title="Waypoint", ax=ax)

In [None]:
fig, ax = plt.subplots()
viz.plot_diff(ts, lat_long, tspyro_location_3, waypoints=waypoints, title="Waypoint", ax=ax)

In [None]:
fig, ax = plt.subplots()
viz.plot_diff(ts, lat_long, tspyro_location_4, waypoints=waypoints, title="Waypoint", ax=ax)