In [None]:
import pysam
import collections as c
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import numpy as np
import cv2 as cv

%matplotlib inline

import time
import os
import time
import datetime
from multiprocessing import Pool
import PIL

In [None]:
from scipy.sparse import coo_array

In [None]:
TILE_DIR = 'HDMI_Tiles_Data'

In [None]:
tiles = os.listdir(TILE_DIR)

In [None]:
tiles = ['_'.join(x.split('_')[:2]) for x in tiles if not 'subset' in x]

In [None]:
# Measured based on a single circle in both H&E and in tile coord space
measured_size_px = 2323
measured_size_nm = 79763
nm_per_coord = measured_size_nm / measured_size_px


In [None]:
circle_resize_scale = 25

In [None]:
with open(f'{TILE_DIR}/circle_info.pickle', 'rb') as fh:
    ci_info = pickle.load(fh)

In [None]:
def get_final_centroid(data, circle_resize_scale = 25):
    final_coords_x = []
    final_coords_y = []
    if data['circles1'].shape[1] == 8:
        final_coords_x.append(int(data['centroid1'][0]) * circle_resize_scale)
        final_coords_y.append(int(data['centroid1'][1]) * circle_resize_scale)
    if data['circles2'].shape[1] == 8:
        final_coords_x.append(int(data['centroid2'][0]) * circle_resize_scale)
        final_coords_y.append(int(data['centroid2'][1]) * circle_resize_scale)
    if len(final_coords_x) == 0:
        raise ValueError('No centroids available for use')

    return (np.average(final_coords_y).astype(int), np.average(final_coords_x).astype(int))
        

In [None]:
# These numbers are based on read measurements from the H&E image and EM images
nm_dist_between_circles_y = 310781
coords_dist_between_circles_y = int(nm_dist_between_circles_y / nm_per_coord)

nm_dist_between_centroids_y = nm_dist_between_circles_y * 4
coords_dist_between_centroids_y = int(nm_dist_between_centroids_y / nm_per_coord)

nm_dist_between_centroids_x = 1155750
coords_dist_between_centroids_x = int(nm_dist_between_centroids_x / nm_per_coord)

In [None]:
data_dir = "output"

In [None]:
samples = ! ls {data_dir}/whitelist*

In [None]:
samples = [x.split('/')[-1].split('.')[0] for x in samples]

In [None]:
samples = [x[10:-5] for x in samples]

In [None]:
import scipy.io as sio

In [None]:
import gzip
from tqdm import tqdm, trange

In [None]:
samples

In [None]:
data_folder = 'output'

In [None]:
samples

In [None]:
for sample in samples:
    _split = sample.split("_")
    sample_name = "_".join(_split[:-7])
    lane, surface, _, start_tile_y, _, end_tile_y = _split[-6:]
    lane = int(lane)
    surface = int(surface)
    start_tile_y = int(start_tile_y)
    end_tile_y = int(end_tile_y)

    if os.path.exists(
        f"{data_folder}/bc_coords_{lane}_{surface}_tiles_{start_tile_y}_to_{end_tile_y}.pickle"
    ):
        with open(
            f"{data_folder}/bc_coords_{lane}_{surface}_tiles_{start_tile_y}_to_{end_tile_y}.pickle",
            "rb",
        ) as fh:
            bc_map = pickle.load(fh)
    else:
        start_tile_id = f"{lane}_{surface}{1}{start_tile_y:02}"
        with open(f"{TILE_DIR}/{start_tile_id}_barcodes.pickle", "rb") as tile_fh:
            tile_1_data = pickle.load(tile_fh)
        tile_1_ci = ci_info[start_tile_id]
        tile_1_cx = int(
            ((tile_1_ci["centroid1"][1] * circle_resize_scale) + (tile_1_ci["centroid2"][1]) * circle_resize_scale) / 2
        )
        tile_1_cy = int(
            ((tile_1_ci["centroid1"][0] * circle_resize_scale) + (tile_1_ci["centroid2"][0]) * circle_resize_scale) / 2
        )

        bcs = list(tile_1_data.keys())
        xs = [int(x[0]) for x in tile_1_data.values()]
        ys = [int(x[1]) for x in tile_1_data.values()]

        for _y_tile in range(start_tile_y, end_tile_y + 1):
            for x_tile in range(6):
                if _y_tile == 0 and x_tile == 0:
                    continue
                x_tile = x_tile + 1
                y_tile = _y_tile + 1

                cur_tile_id = f"{lane}_{surface}{x_tile}{y_tile:02}"
                with open(f"{TILE_DIR}/{cur_tile_id}_barcodes.pickle", "rb") as tile_fh:
                    cur_tile_data = pickle.load(tile_fh)
                cur_tile_ci = ci_info[cur_tile_id]
                # TODO: This can have multiple nones
                if cur_tile_ci == None:
                    # try to average the centroids from the previous and next tiles
                    cur_tile_ci = {}
                    next_tile_id = cur_tile_id
                    next_y_tile = y_tile
                    while next_tile_id not in ci_info or ci_info[next_tile_id] == None:
                        next_y_tile += 1
                        next_tile_id = f"{lane}_{surface}{x_tile}{next_y_tile:02}"
                    prev_tile_id = cur_tile_id
                    prev_y_tile = y_tile
                    while prev_tile_id not in ci_info or ci_info[prev_tile_id] == None:
                        prev_y_tile -= 1
                        prev_tile_id = f"{lane}_{surface}{x_tile}{prev_y_tile:02}"

                    cur_tile_ci["centroid1"] = (
                        (ci_info[next_tile_id]["centroid1"][0] + ci_info[prev_tile_id]["centroid1"][0]) / 2,
                        (ci_info[next_tile_id]["centroid1"][1] + ci_info[prev_tile_id]["centroid1"][1]) / 2,
                    )
                    cur_tile_ci["centroid2"] = (
                        (ci_info[next_tile_id]["centroid2"][0] + ci_info[prev_tile_id]["centroid2"][0]) / 2,
                        (ci_info[next_tile_id]["centroid2"][1] + ci_info[prev_tile_id]["centroid2"][1]) / 2,
                    )

                cur_tile_cx = int(
                    (
                        (cur_tile_ci["centroid1"][1] * circle_resize_scale)
                        + (cur_tile_ci["centroid2"][1]) * circle_resize_scale
                    )
                    / 2
                )
                cur_tile_cy = int(
                    (
                        (cur_tile_ci["centroid1"][0] * circle_resize_scale)
                        + (cur_tile_ci["centroid2"][0]) * circle_resize_scale
                    )
                    / 2
                )
                y_offset = cur_tile_cy - tile_1_cy
                x_offset = cur_tile_cx - tile_1_cx
                bcs += list(cur_tile_data.keys())
                if x_tile % 2 == 0 and surface == 2:
                    for x, y in cur_tile_data.values():
                        xs.append(int(x) - x_offset + ((x_tile - 1) * coords_dist_between_centroids_x))
                        ys.append(
                            (int(y) - y_offset + ((y_tile - start_tile_y - 1) * coords_dist_between_centroids_y))
                            + coords_dist_between_circles_y
                        )  # Even swaths need moving 1 circle down on thick surface
                elif x_tile % 2 == 1 and surface == 1:
                    for x, y in cur_tile_data.values():
                        xs.append(int(x) - x_offset + ((x_tile - 1) * coords_dist_between_centroids_x))
                        ys.append(
                            (int(y) - y_offset + ((y_tile - start_tile_y - 1) * coords_dist_between_centroids_y))
                            + coords_dist_between_circles_y
                        )  # Odd swaths need moving 1 circle down on thin surface
                else:
                    for x, y in cur_tile_data.values():
                        xs.append(int(x) - x_offset + ((x_tile - 1) * coords_dist_between_centroids_x))
                        ys.append(int(y) - y_offset + ((y_tile - start_tile_y - 1) * coords_dist_between_centroids_y))
        bc_map = {x[0][:31]: x[1] for x in zip(bcs, zip(xs, ys))}

        with open(
            f"{data_folder}/bc_coords_{lane}_{surface}_tiles_{start_tile_y}_to_{end_tile_y}.pickle",
            "wb",
        ) as fh:
            pickle.dump(bc_map, fh)

    hdmis = pd.read_csv(
        f"{data_folder}/{sample_name}/{sample_name}_Solo.out/GeneFull/raw/barcodes.tsv",
        header=None,
    )
    mtx = sio.mmread(
        f"{data_folder}/{sample_name}/{sample_name}_Solo.out/GeneFull/raw/matrix.mtx"
    )
    with open(
        f"{data_folder}/{sample_name}/{sample_name}_Solo.out/GeneFull/raw/features.tsv",
        "r",
    ) as fh:
        gene_list = [x.split("\t")[1] for x in fh.readlines()]

    mtx_csr = mtx.tocsr()
    hdmi_bcs = list(hdmis[0].values)

    # Gem creation

    header = f"""#FileFormat=GEMv0.1
#SortedBy=None
#BinSize=1
#STOmicsChip={sample}
#OffsetX=0
#OffsetY=0
geneID\tx\ty\tMIDCount\tExonCount\n"""

    with gzip.open(
        f"{data_folder}/{sample_name}/{sample_name}.gem.gz",
        mode="wt",
    ) as fh:
        fh.write(header)
        # Loop over all genes
        for gene_idx in trange(len(gene_list)):
            gene = gene_list[gene_idx]
            gene_data = mtx_csr.getrow(gene_idx)
            for cell_idx, gene_count in zip(gene_data.indices, gene_data.data):
                try:
                    cell_bc = hdmi_bcs[cell_idx]
                    x, y = bc_map[cell_bc]
                    fh.write(f"{gene}\t{x}\t{y}\t{gene_count}\t{gene_count}\n")
                except (KeyError, IndexError):
                    continue
    print("Finished writing GEM file")

    del mtx_csr

    max_x = max([x[0] for x in bc_map.values()])
    max_y = max([x[1] for x in bc_map.values()])

    gene_sums = mtx.sum(axis=1)
    hdmi_sums = mtx.sum(axis=0)
    hdmi_mask = hdmi_sums > 0
    hdmi_stats = hdmis[np.ravel(hdmi_mask)]
    hdmi_stats = hdmi_stats.reindex(hdmi_stats[0])
    hdmi_stats = hdmi_stats.drop(0, axis=1)
    hdmi_stats["Total_Counts"] = np.ravel(hdmi_sums[hdmi_mask])
    genes_detected = (mtx > 0).sum(axis=0)
    hdmi_stats["Total_Genes"] = np.ravel(genes_detected[hdmi_mask])

    hdmi_stats.plot(kind="box")
    hdmi_stats.plot(kind="scatter", x="Total_Counts", y="Total_Genes")

    plt.xlim(0, 200)
    plt.ylim(0, 200)

    mtx_csc = mtx.tocsc()
    filtered_mtx = mtx_csc[:, np.ravel(hdmi_mask)]
    filtered_mtx_csr = filtered_mtx.tocsr()

    nbin_to_500nm = 500 / nm_per_coord
    bin_size = int(200 * nbin_to_500nm)

    gene_sums = pd.DataFrame(index=gene_list, data=np.ravel(filtered_mtx.sum(axis=1)), columns=["Sum"])
    gene_sums.sort_values(by="Sum", ascending=False)

    tot_counts = np.ravel(mtx_csc.sum(axis=0))

    filtered_bcs = [x[0] for x in zip(hdmis[0], np.ravel(hdmi_mask)) if x[1]]

    tile_coo_genes = coo_array((max_x + 1, max_y + 1), dtype=np.int8).toarray()
    hdmi_sums = filtered_mtx.sum(axis=0)

    del mtx
    del mtx_csc
    del filtered_mtx
    del filtered_mtx_csr

    hdmi_sums = hdmi_sums.astype(np.uint16)
    missing_bcs = set()

    for n, bc in enumerate(filtered_bcs):
        if bc in bc_map:
            x, y = bc_map[bc]
            tile_coo_genes[x, y] = hdmi_sums[0, n]
        else:
            missing_bcs.add(bc)

    del bc_map

    print(f"Missing BCs: {len(missing_bcs)}")
    del missing_bcs

    # Bin the data in 25x25 bins
    bin_size = 25

    # function to bin 2d matrix
    def bin2d(a, K):
        m_bins = a.shape[0] // K
        n_bins = a.shape[1] // K
        return a[: m_bins * K, : n_bins * K].reshape(m_bins, K, n_bins, K).sum(3).sum(1)

    # Coo to np
    tile_coo_genes_np = np.array(tile_coo_genes)

    # Bin
    gene_binned = bin2d(tile_coo_genes_np, bin_size)
    # add min and then rescale to have max at 255
    if gene_binned.min() < 0:
        gene_binned = gene_binned + abs(gene_binned.min())

    # gene_binned = (gene_binned / gene_binned.max()) * 255
    gene_binned = (gene_binned / np.percentile(gene_binned, 0.999995)) * 255

    gene_binned = np.clip(gene_binned, 0, 255)

    PIL.Image.fromarray(gene_binned.astype(np.uint8), "L").save(
        f"{data_folder}/{sample_name}/GeneFull_HDMI_Locs_Global_HDMIs.png"
    )