In [None]:
import hail as hl
from hail.linalg import BlockMatrix
import numpy as np

In [None]:
WD = 'data/'

# input
GT_PLINK = WD + '1000G.EUR.QC.1.mini'
GT_BED = GT_PLINK + '.bed'
GT_BIM = GT_PLINK + '.bim'
GT_FAM = GT_PLINK + '.fam'

SNP_LIST =     WD + 'hackathon.1.SNPlist'
ANNOT =        WD + 'hackathon.1.annot.tsv'

# output
GT_ROW_HT =    WD + 'hackathon.1.row.ht'
LD_SCORES_BM = WD + 'hackathon.1.ldscores.bm'
LD_SCORES =    WD + 'hackathon.1.l2.ldscores'

def compute_row_intervals(a, b, radius):
    """
    a is ndarray of non-decreasing floats
    b is ndarray of non-decreasing floats
    radius is non-negative float
    
    starts and stops are ndarrays of int of length a.size
    starts[i] is minimum index j in b such that b[j] >= a[i] - radius
    stops[i] is one greater than the maximum index j in b such that b[j] <= a[i] + radius
    """ 
    assert radius > 0
    
    b_size = b.size

    starts = np.zeros(a.size, dtype=int)
    stops = np.zeros(a.size, dtype=int)
    j = 0
    k = 0
    for i in range(a.size):
        min_val = a[i] - radius
        max_val = a[i] + radius
        while (j < b_size) and (b[j] < min_val):
            j += 1
        starts[i] = j
        while (k < b_size) and (b[k] <= max_val):
            k += 1
        stops[i] = k

    return starts, stops

In [None]:
# import
h_rsid_ht = hl.import_table(SNP_LIST, no_header=True, key='f1')

bim_ht = hl.import_table(GT_BIM, no_header=True, impute=True).key_by('f1')

gt_mt = hl.import_plink(GT_BED, GT_BIM, GT_FAM)

In [None]:
gt_mt = gt_mt.annotate_rows(cm = bim_ht[gt_mt.rsid].f2,
                            stats = hl.agg.stats(gt_mt.GT.n_alt_alleles()))
gt_mt = gt_mt.annotate_rows(keep = gt_mt.stats.stdev > 0.0)
gt_mt.rows().write(GT_ROW_HT, overwrite=True)

annot_mt = hl.import_matrix_table(ANNOT,
                                  row_fields={
                                      'CHR': hl.tstr,
                                      'RSID': hl.tstr,
                                      'a': hl.tstr,
                                      'r': hl.tstr,
                                      'CM': hl.tfloat,
                                      'BP': hl.tint,},
                                  entry_type=hl.tfloat64)

annot_mt = (annot_mt.annotate_rows(
    locus = hl.locus(annot_mt.CHR, annot_mt.BP),                              
    alleles = [annot_mt.r, annot_mt.a]))

annot_mt = annot_mt.partition_rows_by('locus', 'locus', 'alleles')

gt_ht = hl.read_table(GT_ROW_HT)

# very soon, once tables are ordered, sort won't be necessary here
g_pos = np.sort(np.array(gt_ht.filter(gt_ht.keep).cm.collect()))
h_pos = np.sort(np.array(gt_ht.filter(hl.is_defined(h_rsid_ht[gt_ht.rsid]) & gt_ht.keep).cm.collect()))

# A
annot_mt = annot_mt.filter_rows(gt_ht[annot_mt.row_key].keep)
A = BlockMatrix.from_entry_expr(annot_mt.x)

# G
g_mt = hl.import_plink(GT_BED, GT_BIM, GT_FAM)
g_mt = g_mt.annotate_rows(aux = gt_ht[g_mt.row_key])
g_mt = g_mt.filter_rows(g_mt.aux.keep)
G = BlockMatrix.from_entry_expr(
    hl.or_else((g_mt.GT.n_alt_alleles() - g_mt.aux.stats.mean) / g_mt.aux.stats.stdev, 0.0))

# H
h_mt = hl.import_plink(GT_BED, GT_BIM, GT_FAM)
h_mt = h_mt.filter_rows(hl.is_defined(h_rsid_ht[h_mt.rsid]))
h_mt = h_mt.annotate_rows(aux = gt_ht[h_mt.row_key])
h_mt = h_mt.filter_rows(h_mt.aux.keep)
H = BlockMatrix.from_entry_expr(
    hl.or_else((h_mt.GT.n_alt_alleles() - h_mt.aux.stats.mean) / h_mt.aux.stats.stdev, 0.0))

In [None]:
radius = 5

starts, stops = compute_row_intervals(h_pos, g_pos, radius)

n = gt_mt.count_cols()
c = 1 / (n - 2)
R2 = ((H @ G.T) ** 2) * ((1 + c) / n ** 2) - c

R2_sparse = R2.sparsify_row_intervals([int(s) for s in starts],
                                      [int(s) for s in stops])

L2 = R2_sparse @ A

L2.write(LD_SCORES_BM, force_row_major=True)

BlockMatrix.export(LD_SCORES_BM, LD_SCORES)

In [None]:
import matplotlib.pyplot as plt
plt.matshow(R2_sparse.to_numpy() != 0)
plt.show()

In [None]:
import matplotlib.gridspec as gridspec
fig = plt.figure(figsize=(10,12))
gs = gridspec.GridSpec(1, 1)
ax1 = plt.subplot(gs[0,0])
im = ax1.matshow(np.log10(np.square(R2_sparse.to_numpy())), vmin=-10, vmax=0, cmap=plt.get_cmap('hot_r'))

ax1.set_title('Hailathon'); ax1.set_ylabel('LD SNP')
ax1.xaxis.set_ticks_position('bottom'); ax1.set_xlabel('SNP')
ax1.set_aspect('auto')
cbar1 = fig.colorbar(im, ax=ax1); cbar1.set_label('$\log_{10}(r2)$')
fig.tight_layout()