# Create model ready dataset from cisTopic output on `pbmc-granulocyte-sorted-3k_10x-Multiome`
Adam Klie (last updated: *09/20/2023*)
***
This notebook shows how to convert a pycisTopic run into model ready inputs

In [None]:
import os
import pickle
import numpy as np
import pandas as pd
import pyranges as pr
from pycisTopic.topic_binarization import binarize_topics, smooth_topics_f
from pycisTopic.topic_qc import evaluate_models, compute_topic_metrics

In [None]:
# Define arguments, will be command line
data_dir = "/cellar/users/aklie/data/ml4gland/collabs/er_stress_regulation/test"
dataset_name = "test"
output_dir = "/cellar/users/aklie/data/ml4gland/collabs/er_stress_regulation/test"

In [None]:
# Load cisTopic_obj
def load_cisTopic_obj(file_name):
    """Load a cisTopic object from a pickle file"""
    with open(file_name, "rb") as f:
        cisTopic_obj = pickle.load(f)
    cisTopic_obj.selected_model.topic_ass = {}
    return cisTopic_obj
cistopic_obj = load_cisTopic_obj(os.path.join(data_dir, dataset_name + ".pycisTopic_obj.pkl"))

In [None]:
# Binarize topics
region_bin_topics_otsu = binarize_topics(cistopic_obj, method='otsu', plot=True, num_columns=5, save=os.path.join(output_dir, "region_topic_binarization.pdf"))
all_regions = cistopic_obj.selected_model.topic_region.index
all_regions = all_regions[all_regions.str.contains("chr")]

In [None]:
def get_per_regions_topic_membership(region_dict):
    topic_regions_pd = pd.Index([])
    topic_regions_lst = []
    topic_region_mp = {}
    for topic, regions in region_dict.items():
        topic_regions_lst += list(regions.index)
        topic_regions_pd = pd.Index.union(topic_regions_pd, regions.index)
        for region in regions.index:
            topic_region_mp.setdefault(region, []).append(topic)
    return topic_regions_pd, topic_region_mp, topic_regions_lst
topic_regions_pd, topic_region_mp, topic_regions_lst = get_per_regions_topic_membership(region_bin_topics_otsu)

In [None]:
def get_nontopic_regions(all_regions, topic_regions_pd):
    non_topic_regions = all_regions[~all_regions.isin(topic_regions_pd)]
    return non_topic_regions
non_topic_regions = get_nontopic_regions(all_regions, topic_regions_pd)

In [None]:
def create_binarized_matrix(all_regions, topic_region_mp, n_topics):
    arr = np.zeros((len(all_regions), n_topics))
    for i, row in enumerate(all_regions):
        if row in topic_region_mp:
            topic_nums = []
            for topic in topic_region_mp[row]:
                topic_nums.append(int(topic.split("Topic")[-1])-1)
            arr[i, topic_nums] = 1
    return arr
arr = create_binarized_matrix(all_regions, topic_region_mp, 42)
arr.shape

In [None]:
def check_topic_binarization(region_dict, arr, all_regions, non_topic_regions, topic_regions_pd, topic_regions_lst):
    for topic, regions in region_dict.items():
        assert non_topic_regions.isin(regions.index).sum() == 0, f"Topic {topic} contains a non-topic regions"
    assert np.all(np.array([len(regions) for _, regions in region_dict.items()]) == arr.sum(axis=0)), "Number of regions per topic does not match the number of 1s in the matrix"
    assert np.all(all_regions[arr.sum(axis=1) == 0].isin(non_topic_regions)), "Number of regions that are 0 across all topics does not match the number of non-topic regions"
    assert np.all(~all_regions[arr.sum(axis=1) == 0].isin(topic_regions_pd)), "Number of regions that are not 0 across all topics does not match the number of topic regions"
    assert arr.sum() == len(topic_regions_lst), "Number of 1s in the matrix does not match the number of topic regions"
check_topic_binarization(region_bin_topics_otsu, arr, all_regions, non_topic_regions, topic_regions_pd, topic_regions_lst)

In [None]:
def save_seqdata_files(
    bin_mtx, 
    regions, 
    output_dir, 
    dataset_name
):
    region_split = [region.split("-") for region in regions.str.replace(":", "-")]
    region_df = pd.DataFrame(region_split, columns=["Chromosome", "Start", "End"])
    pr_obj = pr.PyRanges(region_df)
    seqs = pr.get_fasta(pr_obj, "/cellar/users/aklie/data/ml4gland/genomes/hg38/hg38.fa")
    np.save(os.path.join(output_dir, dataset_name + "_labels.npy"), bin_mtx)
    np.save(os.path.join(output_dir, dataset_name + "_regions.npy"), regions)
    np.save(os.path.join(output_dir, dataset_name + "_seqs.npy"), seqs)
save_seqdata_files(arr, all_regions, output_dir, dataset_name)

# DONE!

---