## Python code to run EEMS (some code borrowed from https://github.com/jhmarcus/feems-analysis/)

In [1]:
import numpy as np
import fiona
from shapely.geometry import Polygon, Point, shape, MultiPoint
from shapely.affinity import translate
import networkx as nx
import scipy.sparse as sp
from copy import deepcopy
import subprocess
import os

# feems
from feems import SpatialGraph, Viz, Objective
from feems.joint_ver import Joint_SpatialGraph, Joint_Objective, loss_wrapper
from feems.sim import setup_graph, setup_graph_long_range, simulate_genotypes, simulate_genotypes_w_admixture
from feems.spatial_graph import query_node_attributes
from feems.objective import comp_mats
from feems.cross_validation import run_cv
from feems.helper_funcs import * 

In [16]:
def prepare_input(coord, ggrid, translated, buffer=0, outer=None):
    """Prepares the graph input files for feems
    Adapted from Ben Peters eems-around-the-world repo
    Arguments
    ---------
    coord : ndarray
        n x 2 matrix of sample coordinates
    ggrid : str
        path to global grid shape file
    transform : bool
        to translate x coordinates
    buffer : float
        buffer on the convex hull of sample pts
    outer : ndarray
        q x 2 matrix of coordinates of outer polygon
    Returns
    -------
    res : tuple
        tuple of outer, edges, grid
    """
    # no outer so construct with buffer
    if outer is None:
        points = MultiPoint([(x, y) for x,y in coord])
        xy = points.convex_hull.buffer(buffer).exterior.xy
        outer = np.array([xy[0].tolist(), xy[1].tolist()]).T

    if translated:
        outer[:,0] = outer[:,0] + 360.0

    # intersect outer with discrete global grid
    bpoly = Polygon(outer)
    bpoly2 = translate(bpoly, xoff=-360.0)
    tiles2 = load_tiles(ggrid)
    tiles3 = [t for t in tiles2 if bpoly.intersects(t) or bpoly2.intersects(t)]
    pts, rev_pts, e = create_tile_dict(tiles3, bpoly)

    # construct grid array
    grid = []
    for i, v in rev_pts.items():
        grid.append((v[0], v[1]))
    grid = np.array(grid)

    assert grid.shape[0] != 0, "grid is empty changing translation"

    # un-translate
    if translated:
        pts = [Point(rev_pts[p][0] - 360.0, rev_pts[p][1]) for p in range(len(rev_pts))]
        grid[:,0] = grid[:,0] - 360.0
        outer[:,0] = outer[:,0] - 360.0

    # construct edge array
    edges = np.array(list(e))
    ipmap = get_closest_point_to_sample(pts, coord)
    
    res = (outer, edges, grid, ipmap)
    return(res)

def load_tiles(s):
    tiles = fiona.collection(s)
    return [shape(t['geometry']) for t in tiles]

def get_closest_point_to_sample(points, samples):
    usamples = unique2d(samples)
    dists = dict((tuple(s), np.argmin([Point(s).distance(Point(p))
                                       for p in points])) for s in usamples)

    res = [dists[tuple(s)] for s in samples]

    return np.array(res)

def create_tile_dict(tiles, bpoly):
    pts = dict() #dict saving ids
    rev_pts = dict()
    edges = set()
    pts_in = dict() #dict saving which points are in region

    for c, poly in enumerate(tiles):
        x, y = poly.exterior.xy
        points = zip(np.round(x, 3), np.round(y, 3))
        points = [wrap_america(p) for p in points]
        for p in points:
            if p not in pts_in:
                pts_in[p] = bpoly.intersects(Point(p))  # check if point is in region
                if pts_in[p]:
                    pts[p] = len(pts)  # if so, give id
                    rev_pts[len(rev_pts)] = p

        for i in range(3):
            pi, pj = points[i], points[i + 1]
            if pts_in[pi] and pts_in[pj]:
                if pts[pi] < pts[pj]:
                    edges.add((pts[pi] + 1, pts[pj] + 1))
                else:
                    edges.add((pts[pj] + 1, pts[pi] + 1))

        #if c % 100 == 0:
        #    print(c, len(tiles))

    pts = [Point(rev_pts[p]) for p in range(len(rev_pts))]
    return pts, rev_pts, edges

def wrap_america(tile):
    tile = Point(tile)
    if np.max(tile.xy[0]) < -40 or \
            np.min(tile.xy[0]) < -40:
        tile = translate(tile, xoff=360.)

    return tile.xy[0][0], tile.xy[1][0]

def unique2d(a):
    x, y = a.T
    b = x + y * 1.0j
    idx = np.unique(b, return_index=True)[1]
    return a[idx]

In [41]:
# n_rows, n_columns = 4, 4
# graph_def, coord_def, grid_def, edge_def = setup_graph(n_rows=n_rows, n_columns=n_columns, corridor_w=1., barrier_w=1, barrier_prob=1.0, corridor_left_prob=1., corridor_right_prob=1., barrier_startpt=5, barrier_endpt=10, n_samples_per_node=30)
# outer, edges, demes, ipmap = prepare_input(coord_def, "/Users/vivaswat/feems/feems/data/grid_100.shp", False, 0, None)
# np.savetxt('/Users/vivaswat/feems/feems/data/sims/infiles/simgrd.demes',grid_def,fmt='%f')
# np.savetxt('/Users/vivaswat/feems/feems/data/sims/infiles/simgrd.edges',edge_def,fmt='%d')
# np.savetxt('/Users/vivaswat/feems/feems/data/sims/infiles/simgrd.coord',coord_def,fmt='%f')
# np.savetxt('/Users/vivaswat/feems/feems/data/sims/infiles/simgrd.outer',[[-0.1,-0.1],[n_columns+0.1,-0.1],[n_columns+0.1,n_rows+0.1],[-0.1,n_rows+0.1],[-0.1,-0.1]],fmt='%f')
# eems.plots(c("chain19"), "/Users/vivaswat/feems/feems/data/sims/outfiles/simgrd/simgrd", longlat = TRUE, add.abline=T,add.r.squared=T,add.grid=T,add.demes=T)

In [2]:
os.chdir("/Users/vivaswat/feems/feems/data/sims/paramfiles")
# for n in range(12):
    # gen_test_1e = np.loadtxt("/Users/vivaswat/feems/docs/notebooks/results/gentest{}.csv".format(n+11),delimiter=',') ##simulate_genotypes(graph_def, target_n_snps=1500, n_print=600, mu=1)
    # np.savetxt('/Users/vivaswat/feems/feems/data/sims/infiles/simgrd{}.diffs'.format(n+11),squareform(pdist(gen_test_1e,"sqeuclidean"))/gen_test_1e.shape[1],fmt='%f')

for n in range(11,23):
    subprocess.call("cp ../infiles/simgrd{}.diffs ../infiles/simgrd.diffs".format(n),shell=True)
    with open('parsimgrd-chain1.ini',"r") as file:
         data = file.readlines()
    data[1] = 'mcmcpath = /Users/vivaswat/feems/feems/data/sims/outfiles/simgrd/chain{}\n'.format(n)
    with open('parsimgrd-chain1.ini', 'w') as file:
        file.writelines(data)
    subprocess.call("~/Documents/eems/runeems_snps/src/runeems_snps --params parsimgrd-chain1.ini",shell=True,stdout=subprocess.DEVNULL,stderr=subprocess.STDOUT)