# HMM Islands Modelling and Analysis

## 0. Prerequisites

### 0.1. Dependencies

In [None]:
import re
import pickle as pkl
import numpy as np

from operator import itemgetter
from collections import Counter

from scipy.stats import mannwhitneyu, chi2
from kagami.comm import l, paste, smap, pmap, unpack, fold
from kagami.dtypes import Table

### 0.2. Routines

In [None]:
def _unique(x, ignore = ''):
    ux = np.unique(x)
    return ux[ux != ignore]

In [None]:
def _ordunique(a):
    a = np.asarray(a)
    _, idx = np.unique(a, return_index = True)
    return a[np.sort(idx)]

### 0.3. Load merged data

In [None]:
otab = Table.loadhdf(
    '../data/temporal/BMT_filter_withoutAF_90percent.filtered.merged_table.hdf',
)

## 1. HMM Islands

### 1.1. Identify islands

In [None]:
def _hmm_islands(col, name):
    sfs = otab.ridx_.scaffold
    usf = _ordunique(sfs)
    
    _r = re.compile('(3{3,})')
    def _summ(s,lasti):
        stab = otab[sfs == s,col]
        itab = Table(-np.ones_like(stab, dtype = int), 
                     rownames = otab.rows_[sfs == s], colnames = [name])
        
        sv = stab.X_[:,0].astype(int)
        ss = paste(sv[sv != -1].astype(str))
        res = list(_r.finditer(ss))
        
        iv = -np.ones(len(ss), dtype = int)
        for i,r in enumerate(res): iv[r.start():r.end()] = i
        iv[iv != -1] += lasti
        itab.X_[sv != -1] = iv.reshape((-1,1))
        
        stab = stab[sv != -1]
        lns = smap(res, lambda x: len(x.group()))
        pos = smap(res, lambda x: np.sort(stab.ridx_.pos[slice(*x.span())]))
        return (itab, lns, pos), max(np.max(iv)+1, lasti)

    odct, lid = {}, 0
    for sf in usf:
        odct[sf],lid = _summ(sf,lid)
    return odct

In [None]:
bmhdct = _hmm_islands('BM_HMM_State', 'BM_HMM_Island')
mthdct = _hmm_islands('MT_HMM_State', 'MT_HMM_Island')

Save to file to avoid repeat running

In [None]:
oprefix = '../data/temporal/BMT_filter_withoutAF_90percent.filtered.'
# with open(oprefix + 'BM_hmmdct.pkl', 'wb') as f: pkl.dump(bmhdct, f)
# with open(oprefix + 'MT_hmmdct.pkl', 'wb') as f: pkl.dump(mthdct, f)

with open(oprefix + 'BM_hmmdct.pkl', 'rb') as f: bmhdct = pkl.load(f)
with open(oprefix + 'MT_hmmdct.pkl', 'rb') as f: mthdct = pkl.load(f)    

Insert to table

In [None]:
itabs = smap(bmhdct.values(), lambda x: x[0])
bmitab = fold(itabs, lambda x,y: x.append(y, axis = 0))
print(f'BM islands number = {_unique(bmitab.X_[:,0], -1).shape[0]}')

itabs = smap(mthdct.values(), lambda x: x[0])
mtitab = fold(itabs, lambda x,y: x.append(y, axis = 0))
print(f'MT islands number = {_unique(mtitab.X_[:,0], -1).shape[0]}')

In [None]:
otab = otab.append(bmitab[otab.rows_], axis = 1).append(mtitab[otab.rows_], axis = 1)

### 1.2. Island summary

In [None]:
iids = otab[:,['BM_HMM_Island','MT_HMM_Island']].X_.astype(int)

In [None]:
print(f'number of HGD SNPs in islands = {l(np.sum(iids != -1, axis = 0))}')

In [None]:
fsts = otab[:,['BM_Fst', 'MT_Fst']].X_
print(f'BM mean Fst in islands = {np.mean(fsts[iids[:,0]!=-1,0])}')
print(f'MT mean Fst in islands = {np.mean(fsts[iids[:,1]!=-1,1])}')

In [None]:
def _lens(idx):
    uid = _unique(idx, -1)
    return np.array(smap(uid, lambda x: np.sum(idx == x)))
lens = pmap(iids.T, _lens)
print(f'BM island SNPs mean = {np.mean(lens[0])}, std = {np.std(lens[0])}')
print(f'MT island SNPs mean = {np.mean(lens[1])}, std = {np.std(lens[1])}')

In [None]:
_size = lambda dct: np.hstack(smap(dct.values(), itemgetter(2), lambda x: smap(x, lambda p: np.max(p)-np.min(p)+1)))
bmhsize, mthsize = smap((bmhdct, mthdct), _size)
print(f'mean island size = {[np.mean(bmhsize), np.mean(mthsize)]}')

## 2. Outliers

### 2.1. Overall outliers

In [None]:
pvals = otab[:,['BM_Waples_Test_P', 'MT_Waples_Test_P']].X_.T
bmoutls, mtoutls = smap(pvals, lambda x: np.logical_and(~np.isnan(x), x < 0.01))
print(f'BM outliers = {np.sum(bmoutls)}')
print(f'MT outliers = {np.sum(mtoutls)}')

In [None]:
bmsgids, mtsgids = smap([bmoutls, mtoutls], lambda x: _unique(otab[x].ridx_.gene, ''))
print(f'BM genes with outlier = {len(bmsgids)}')
print(f'MT genes with outlier = {len(mtsgids)}')

### 2.2. Reversal outliers

In [None]:
afs = otab[:,['B_AFs','M_AFs','T_AFs']].X_
majors = afs[:,0] > 0.5
afs[majors] = 1- afs[majors]

In [None]:
rloci = np.logical_or(
    np.logical_and(afs[:,0] < afs[:,1], afs[:,1] > afs[:,2]),
    np.logical_and(afs[:,0] > afs[:,1], afs[:,1] < afs[:,2]),    
)
print(f'reversal loci = {np.sum(rloci)}')

dloci = np.logical_or(
    np.logical_and(afs[:,0] < afs[:,1], afs[:,1] < afs[:,2]),
    np.logical_and(afs[:,0] > afs[:,1], afs[:,1] > afs[:,2]),
)
print(f'directional loci = {np.sum(dloci)}')

In [None]:
outls = np.logical_and(bmoutls, mtoutls)

orloci = np.logical_and(rloci, outls)
print(f'reversal outlier loci = {np.sum(orloci)}')
odloci = np.logical_and(dloci, outls)
print(f'directional outlier loci = {np.sum(odloci)}')

In [None]:
iolaps = np.logical_and.reduce(otab[:,['BM_HMM_Island', 'MT_HMM_Island']].X_.astype(int).T != -1)

iorloci = np.logical_and.reduce([rloci, iolaps, outls])
print(f'reversal outlier loci in island overlapping regions = {np.sum(iorloci)}')
iodloci = np.logical_and.reduce([dloci, iolaps, outls])
print(f'directional outlier loci in island overlapping regions = {np.sum(iodloci)}')

In [None]:
iids = otab[:,['BM_HMM_Island', 'MT_HMM_Island']].X_.astype(int).T

bmiids, mtiids = smap(iids, lambda x: _unique(x[iorloci], -1))
print(f'BM islands with reversal outlier loci = {len(bmiids)}')
print(f'MT islands with reversal outlier loci = {len(mtiids)}')

bmgids, mtgids = smap(
    zip(iids, [bmiids,mtiids]), 
    unpack(lambda ids,uid: [_unique(otab.ridx_.gene[ids == i], '') for i in uid]),
    lambda x: np.unique(fold(x, np.union1d)),
)
print(f'BM islands with reversal outlier loci contain genes = {len(bmgids)}')
print(f'MT islands with reversal outlier loci contain genes = {len(mtgids)}')

### 2.3. Stats

Test BM island sizes significantly larger than MT

In [None]:
_neuc = lambda x: np.max(x)-np.min(x)+1

def _island_size(dval):
    uids = np.unique(dval[0].X_[:,0])
    uids = uids[uids != -1]
    poss = dval[2]
    assert len(uids) == len(poss)
    return {i: _neuc(p) for i,p in zip(uids,poss)}

bmldcts, mtldcts = smap((bmhdct, mthdct), lambda x: smap(x.values(), _island_size))
bmldct = {} 
for d in bmldcts: bmldct.update(d)
mtldct = {} 
for d in mtldcts: mtldct.update(d)

In [None]:
bmilens, mtilens = smap(
    zip([bmiids,mtiids], [bmldct,mtldct]),
    unpack(lambda i,d: np.array(itemgetter(*i)(d))),
)
print(f'Length of BM islands with reversal outlier loci = {np.sum(bmilens)}')
print(f'Length of MT islands with reversal outlier loci = {np.sum(mtilens)}')

In [None]:
_mwu = lambda x,y: mannwhitneyu(x, y, alternative = 'greater')
print(f'BM islands size larger than MT = {_mwu(bmilens, mtilens,)}')

Test number of reversal outlier significantly larger than expected

In [None]:
N = otab.nrow
bcnts = otab[:,['B_Counts0','B_Counts1']].X_.reshape((N,1,2))
mcnts = otab[:,['M_Counts0','M_Counts1']].X_.reshape((N,1,2))
tcnts = otab[:,['T_Counts0','T_Counts1']].X_.reshape((N,1,2))
cnts = np.hstack([bcnts, mcnts, tcnts])

def _chi2(counts):
    a1, b1, a2, b2 = counts.T
    s1, s2 = a1 + b1, a2 + b2
    stats = (s1+s2)/(s1*s2) * np.power(a1*b2-a2*b1,2)/((a1+a2)*(b1+b2))
    pvals = 1 - chi2.cdf(stats, 1)
    return pvals

def _perm(_):
    pcnt = np.apply_along_axis(np.random.permutation, 1, cnts) + 1 # avoid zero division
    bmpvals = _chi2(pcnt[:,:2,:].reshape((N,4)))
    mtpvals = _chi2(pcnt[:,1:,:].reshape((N,4)))
    
    pouts = np.logical_and(bmpvals < 0.01, mtpvals < 0.01)
    
    lcnts = pcnt[:,:,0]
    rlocs = np.logical_or(
        np.logical_and.reduce([lcnts[:,0] < lcnts[:,1], lcnts[:,1] > lcnts[:,2], pouts]),
        np.logical_and.reduce([lcnts[:,0] > lcnts[:,1], lcnts[:,1] < lcnts[:,2], pouts]),
    )
    return np.sum(rlocs), np.hstack([lcnts, bmpvals.reshape((-1,1)), mtpvals.reshape((-1,1))])

nperm = 100
perms = pmap(range(nperm), _perm)

In [None]:
mrevs = np.array(smap(perms, itemgetter(0)))
pperc = np.sum(mrevs >= np.sum(orloci)) / nperm
print(f'on {nperm} permutations p-value = {pperc} ({int(np.round(np.mean(mrevs)))} vs {np.sum(orloci)})')

## 3. Island Overlapping

### 3.1. Identify Overlaps

In [None]:
bmiids, mtiids = otab[:,['BM_HMM_Island', 'MT_HMM_Island']].X_.astype(int).T
obmuiid, omtuiid = smap((bmiids, mtiids), lambda x: _unique(x[iorloci], -1))

In [None]:
oiolaps = np.logical_and(
    np.logical_or.reduce([bmiids == i for i in obmuiid]),
    np.logical_or.reduce([mtiids == i for i in omtuiid]),
)
assert np.all(oiolaps[iorloci] == True)

In [None]:
sfs = otab.ridx_.scaffold
usf = _ordunique(sfs)
olp = oiolaps.astype(int) - 1

def _smx(us, mi = 0, rl = []): 
    sf, us = us[0], us[1:]
    mids = olp[sfs == sf]
    
    inblk = False
    blkid = 0
    for i,v in enumerate(mids):
        if v == -1: 
            if inblk: inblk = False; blkid += 1
        else:
            if not inblk: inblk = True
            mids[i] = blkid
    mids[mids != -1] += mi
    
    mi += np.unique(mids[mids != -1]).shape[0]
    rl += [mids]
    return rl if len(us) == 0 else _smx(us, mi, rl)

olpids = np.hstack(_smx(usf, mi = 0, rl = []))
uolpid = _unique(olpids, -1)
print(f'{len(uolpid)} overlapping regions found')

In [None]:
for i in uolpid:
    if not any(iorloci[olpids == i] == True): olpids[olpids == i] = -1    
uolpid = _unique(olpids, -1)
print(f'{len(uolpid)} overlapping regions with outliers')

In [None]:
otab = otab.append(Table(olpids.reshape(-1,1), colnames = ['Island_Overlaps']), axis = 1)

### 3.2. Overlap summary

In [None]:
olaps = np.zeros_like(oiolaps, dtype = bool)
for i in uolpid: olaps[olpids == i] = True

In [None]:
olpbmiids = np.unique(bmiids[olaps])
print(f'BM islands with reversal outliers overlapping with MT islands = {olpbmiids.shape[0]}')
olpbmlens = np.sum(itemgetter(*olpbmiids)(bmldct))
print(f'Length of overlappign BM islands = {olpbmlens}')

olpmtiids = np.unique(mtiids[olaps])
print(f'MT islands with reversal outliers overlapping with BT islands = {olpmtiids.shape[0]}')
olpmtlens = np.sum(itemgetter(*olpmtiids)(mtldct))
print(f'Length of overlappign MT islands = {olpmtlens}')

In [None]:
_size = lambda x: np.max(x)-np.min(x)+1
pos = otab.ridx_.pos
olplens = np.array([_size(pos[olpids == i]) for i in uolpid])
print(f'Length of overlapping region = {np.sum(olplens)}')

In [None]:
olpgids = _unique(otab.ridx_.gene[olaps], '')
print(f'Genes in overlapping region = {olpgids.shape[0]}')

In [None]:
olpbmgids = _unique(otab.ridx_.gene[np.logical_or.reduce([bmiids == i for i in olpbmiids])], '')
print(f'Genes in overlapped BM islands = {olpbmgids.shape[0]}')

In [None]:
hhkgids = np.setdiff1d(olpbmgids, olpgids) 
print(f'Hitchhiking genes = {hhkgids.shape[0]}') # Life, the Universe and Everything

In [None]:
gcnts = np.array(smap(uolpid, lambda i: len(_unique(otab.ridx_.gene[olpids == i], ''))))
for g,c in Counter(gcnts).most_common(): print(f'{c} overlap region(s) contain {g} genes')

In [None]:
olpbmids = np.hstack([_unique(bmiids[olpids == oi], -1) for oi in uolpid])
obmsizes = np.array([np.sum(bmiids == i) for i in olpbmids])

olpmtids = np.hstack([_unique(mtiids[olpids == oi], -1) for oi in uolpid])
omtsizes = np.array([np.sum(mtiids == i) for i in olpmtids])

olpsizes = np.vstack([obmsizes, omtsizes])
ccs = np.array([np.sum(olpids == oi) for oi in uolpid]).astype(float) / np.min(olpsizes, axis = 0)

print(f'{np.min(ccs)} |-- {np.mean(ccs)} | {np.median(ccs)}  --| {np.max(ccs)}')