In [None]:
# Standard library imports
import os
import random
import pickle

# Third-party imports
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import scanpy as sc
from scipy.stats import hypergeom
import torch

# scReGAT package imports
from scregat import prepare_model_input, sum_counts, plot_edge, ATACGraphDataset

In [None]:
# load graph

In [None]:
ATAC_h5ad_file = "../data/scATAC_MFG.h5ad"
RNA_h5ad_file = "../data/scRNA_MFG.h5ad"

In [None]:
adata_atac = sc.read_h5ad(ATAC_h5ad_file)
adata_atac

In [None]:
adata_rna = sc.read_h5ad(RNA_h5ad_file)
adata_rna

In [None]:
adata_atac.obs.celltype.unique()

In [None]:
adata_rna.obs.celltype.unique()

In [None]:
adata_rna.obs['celltype'] = adata_rna.obs['celltype'].astype('object')
df_rna = sum_counts(adata_rna,by = 'celltype',marker_gene_num=300)

In [None]:
df_rna

In [None]:
dataset_atac = prepare_model_input(
    adata_atac=adata_atac,
    path_data_root = './' ,
    file_atac = ATAC_h5ad_file, 
    df_rna_celltype = df_rna,
    path_eqtl = '../data/all_tissue_SNP_Gene.txt',
    Hi_C_file_suffix = "_" + "brain",
    hg19tohg38 = False, min_percent = 0.01, use_additional_tf=True, tissue_cuttof=10)

In [None]:
dataset_atac.list_graph[0]

In [None]:
file_atac_test = os.path.join('../data/', 'dataset_atac_core_MFG.pkl')
with open(file_atac_test, 'wb') as w_pkl:
    str_pkl = pickle.dumps(dataset_atac)
    w_pkl.write(str_pkl)

### Add Tissue-specific TF-gene 

In [None]:
df = pd.read_csv("../data/TF_Gene_tissue_Brain.csv", index_col=0)
df.columns = ['TF', 'TargetGene', 'tissue_count']

In [None]:
df_tf = df

In [None]:
gene_list = list(dataset_atac.df_rna.columns)

In [None]:
import itertools

In [None]:
set_gene = set(gene_list)
tf_base_filtered = df_tf[df_tf['TF'].isin(set_gene) & df_tf['TargetGene'].isin(set_gene)]
connections = [pair for pair in itertools.product(set_gene, set_gene) ]
gene_pair_base = connections
tf_map_gene = set(tf_base_filtered['TF'].unique())
target_map_gene = set(tf_base_filtered['TargetGene'].unique())
tf_base_tuples = set(zip(tf_base_filtered['TF'], tf_base_filtered['TargetGene']))
map_pair = tf_base_tuples.intersection(gene_pair_base)
map_pair_list = list(map_pair)
df_tf_new = pd.DataFrame(map_pair_list, columns=['TF', 'TargetGene'])
df_tf_all = pd.concat([df_tf_new, dataset_atac.df_tf])
df_tf_all = df_tf_all.drop_duplicates()

In [None]:
dataset_atac.df_tf = df_tf_all

In [None]:
# Create a dictionary to store the index of each element in dataset_atac.array_peak
peak_index_dict = {peak: idx for idx, peak in enumerate(dataset_atac.array_peak)}

# Initialize lists to store indices
index_1 = []
index_2 = []

# Iterate over 'TF' and 'TargetGene' columns in dataset_atac.df_tf
for k1, k2 in zip(dataset_atac.df_tf['TF'].values, dataset_atac.df_tf['TargetGene'].values):
    # Use the dictionary to quickly retrieve indices
    index_1.append(peak_index_dict[k1])
    index_2.append(peak_index_dict[k2])

# Stack the two index lists column-wise and convert to a PyTorch tensor
tf_edge_vec = torch.tensor(np.vstack([index_1, index_2]).T)

# Assign the TF edge tensor to the edge_tf attribute of each graph in the list
for t in dataset_atac.list_graph:
    t.edge_tf = tf_edge_vec


In [None]:
dataset_atac.list_graph[0]

In [None]:
file_atac_test = os.path.join('../data/', 'dataset_atac_core_MFG.pkl')
with open(file_atac_test, 'wb') as w_pkl:
    str_pkl = pickle.dumps(dataset_atac)
    w_pkl.write(str_pkl)