In [None]:
# Standard library imports
import os
import pickle

# Third-party imports
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm

# scReGAT package imports
from scregat import get_edge_info

In [None]:
# load graph
with open("../data/dataset_atac_core_MFG.pkl", "rb") as f:
    dataset_atac = pickle.load(f)

In [None]:
peaks_list = [t for t in dataset_atac.array_peak if t.startswith("chr")]
data_mt = []
for t in tqdm(dataset_atac.list_graph, desc="Processing Graphs"):
    data_mt.append(torch.flatten(t.x).numpy())
data_mt = np.stack(data_mt)

In [None]:
def cosine_similarity_matrix(matrix):
    norm_matrix = np.linalg.norm(matrix, axis=0)
    normalized_matrix = matrix / norm_matrix
    similarity_matrix = np.dot(normalized_matrix.T, normalized_matrix)
    return similarity_matrix

In [None]:
similarity_matrix = cosine_similarity_matrix(data_mt)

In [None]:
top_samples = dataset_atac.list_graph[0].y_exp.shape[0]
top_samples

In [None]:
edge_info = get_edge_info(dataset_atac)

In [None]:
peak_count = edge_info.peak.value_counts()

In [None]:
edge_index = pd.read_csv("../data/Edge_index_with_Specifici_Hi-C.txt", index_col=0)
edge_index.shape

In [None]:
raw_edge = dataset_atac.list_graph[0].edge_index

In [None]:
edge_peak_count = edge_info.loc[
    ~edge_info.index.isin(edge_index.iloc[:, 0].values),
].peak.value_counts()

In [None]:
df_index = pd.DataFrame(dataset_atac.array_peak)
peak_index = df_index.loc[
    df_index.iloc[:, 0].isin(edge_peak_count.index.values)
].index.values

In [None]:
pairs = []

# Pre-fetch the peak names corresponding to the peak indices
peak_names = dataset_atac.array_peak[peak_index].tolist()

# Iterate over peak_index with tqdm for progress tracking
for idx, i in tqdm(
    enumerate(peak_index), total=len(peak_index), desc="Building similarity pairs"
):
    column_similarities = similarity_matrix[i, :top_samples]

    # Get edgeN: the number of connections for the current peak
    edgeN = edge_peak_count.loc[peak_names[idx]]

    # Get the indices of the top edgeN similarities
    top_sim_indices = np.argsort(-column_similarities)[:edgeN]

    # Construct pairs (j, i), where i is the peak index, j is a top similar sample index
    new_pairs = [(j, i) for j in top_sim_indices]
    pairs.extend(new_pairs)

In [None]:
index1 = [t[1] for t in pairs]
index2 = [t[0] for t in pairs]
new_edge = torch.tensor([index1, index2])
new_edge = torch.cat([raw_edge[:, edge_index.iloc[:, 0].values], new_edge], dim=1)

In [None]:
for t in dataset_atac.list_graph:
    t.edge_index = new_edge

In [None]:
file_atac_test = os.path.join("../data/", "dataset_atac_cosine_MFG.pkl")
with open(file_atac_test, "wb") as w_pkl:
    str_pkl = pickle.dumps(dataset_atac)
    w_pkl.write(str_pkl)