In [None]:
import torch
import torch.nn.functional as F
import os
from tqdm import tqdm
import numpy as np
import seaborn as sns
import kipoiseq
import matplotlib.pyplot as plt
import pandas as pd
from dataloaders.h5dataset import GEPBedDataset, MultiSpeciesDataset

human_test_data_path: str = "/home/jiwei_zhu/disk/Enformer/Data/human_test.h5"
human_test_bed_path: str = "/home/jiwei_zhu/disk/Enformer/Data/human_test.bed"
human_genome_path: str = "/home/jiwei_zhu/disk/Enformer/Data/hg38.ml.fa"

mouse_test_data_path: str = "/home/jiwei_zhu/disk/Enformer/Data/mouse_test.h5"
mouse_test_bed_path: str = "/home/jiwei_zhu/disk/Enformer/Data/mouse_test.bed"
mouse_genome_path: str = "/home/jiwei_zhu/disk/Enformer/Data/mm10.fa"

df_human = pd.read_csv("/home/jiwei_zhu/disk/Enformer/enformer_MoE/targets_human_sorted.txt", sep="\t")
df_mouse = pd.read_csv("/home/jiwei_zhu/disk/Enformer/enformer_MoE/targets_mouse_sorted.txt", sep="\t")

index_human = list(df_human["index"])
track_types_human = {
    "DNASE/ATAC": (0, 684),
    "TF ChIP-seq": (684, 2573),
    "Histone ChIP-seq": (2573, 4675),
    "CAGE": (4675, 5313),
}
index_mouse = list(df_mouse["index"])
track_types_mouse = {
    "DNASE/ATAC": (0, 228),
    "TF ChIP-seq": (228, 519),
    "Histone ChIP-seq": (519, 1286),
    "CAGE": (1286, 1643),
}

test_dataset = MultiSpeciesDataset(
    file_paths = [human_test_data_path, mouse_test_data_path],
    bed_paths = [human_test_bed_path, mouse_test_bed_path],
    seqlen = 131072,
    genome_paths = [human_genome_path, mouse_genome_path],
    shift_aug = False,
    rc_aug = False,
)

device = torch.device('cuda:0')

load preprocess data: /home/jiwei_zhu/disk/Enformer/Data/data//human_test_196608_False_False.bin
load preprocess data: /home/jiwei_zhu/disk/Enformer/Data/data//mouse_test_196608_False_False.bin


In [None]:
from model.modeling_space import Space, SpaceConfig, TrainingSpace

model_path = "./results/Space_species_tracks/checkpoint-28000"
config = SpaceConfig.from_pretrained(os.path.join(model_path, "config.json"))
model = TrainingSpace(config)
state_dict = torch.load(os.path.join(model_path, "pytorch_model.bin"))

new_state_dict = {}
for key, value in state_dict.items():
    key = key.replace("enformer", "model")
    if key == "model.tracks.output.human.weight":
        key = "model.heads.human.0.weight"
    if key == "model.tracks.output.human.bias":
        key = "model.heads.human.0.bias"
    if key == "model.tracks.output.mouse.weight":
        key = "model.heads.mouse.0.weight"
    if key == "model.tracks.output.mouse.bias":
        key = "model.heads.mouse.0.bias"
    new_state_dict[key] = value
state_dict = new_state_dict

model.load_state_dict(state_dict)
model = model.to(device)
is_species = True
is_tracks = True

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def pearson_corr_coef_per_track(x, y, dim=0):
    x_centered = x - x.mean(dim=dim, keepdim=True)
    y_centered = y - y.mean(dim=dim, keepdim=True)
    corr = F.cosine_similarity(x_centered, y_centered, dim=dim)
    return corr


def pearson_corr_coef(x, y, dim=1, reduce_dims=(-1,)):
    x_centered = x - x.mean(dim=dim, keepdim=True)
    y_centered = y - y.mean(dim=dim, keepdim=True)
    return F.cosine_similarity(x_centered, y_centered, dim=dim).mean(dim=reduce_dims)

In [None]:
depth = 11
species_num_experts = 4
tracks_num_experts = 8

pres_human, pres_mouse = [], []
target_human, target_mouse = [], []
p_human, p_mouse = 0.0, 0.0
gates_human, gates_mouse = np.zeros((depth, species_num_experts)), np.zeros((depth, species_num_experts))
weights_human, weights_mouse = np.zeros((1, tracks_num_experts)), np.zeros((1, tracks_num_experts))
total_gates = np.zeros((8, tracks_num_experts))

model.eval()
len_human, len_mouse = len(test_dataset.human_dataset), len(test_dataset.mouse_dataset)
print(f"Human: {len_human}, Mouse: {len_mouse}")
for idx in tqdm(range(len(test_dataset))):
    data = test_dataset[idx]
    human_x = torch.tensor(data["human_x"], dtype=torch.float32).to(device)
    mouse_x = torch.tensor(data["mouse_x"], dtype=torch.float32).to(device)
    human_labels = torch.tensor(data["human_labels"][:, index_human], dtype=torch.float32).to(device)
    mouse_labels = torch.tensor(data["mouse_labels"][:, index_mouse], dtype=torch.float32).to(device)

    with torch.no_grad():
        predictions = model(human_x, mouse_x)
        pre_human = predictions["human"]["out"][:, index_human]
        pre_mouse = predictions["mouse"]["out"][:, index_mouse]
        
        if idx < len_human:
            p_human += pearson_corr_coef(pre_human, human_labels).cpu().item()
            pres_human.append(pre_human.cpu())
            target_human.append(human_labels.cpu())
        if idx < len_mouse:
            p_mouse += pearson_corr_coef(pre_mouse, mouse_labels).cpu().item()
            pres_mouse.append(pre_mouse.cpu())
            target_mouse.append(mouse_labels.cpu())
        if is_species:
            gates_human += np.array([tensor.cpu().numpy() for tensor in predictions["human"]["species"]["gates"]])
            gates_mouse += np.array([tensor.cpu().numpy() for tensor in predictions["mouse"]["species"]["gates"]])
        if is_tracks:
            tracks_gates = {f"{species} {key}":  value for species in ["human", "mouse"] for key, value in predictions[species]["tracks"]["gates"].items()}
            tracks_gates = {key: value.cpu().numpy() for key, value in tracks_gates.items()}
            tracks_gates = np.stack(list(tracks_gates.values()))
            total_gates += tracks_gates
            weights_human += predictions["human"]["tracks"]["weights"].cpu().numpy()
            weights_mouse += predictions["mouse"]["tracks"]["weights"].cpu().numpy()

p_human /= len_human
p_mouse /= len_mouse
print(f"Human: {p_human}")
print(f"Mouse: {p_mouse}")
if is_species:
    gates_human /= gates_human.sum(axis=1, keepdims=True)
    gates_mouse /= gates_mouse.sum(axis=1, keepdims=True)
    species_gates = np.stack((gates_human, gates_mouse), axis=1)
    np.save("./temp/species_gates.npy", species_gates)
if is_tracks:
    np.save("./temp/tracks_gates.npy", total_gates)

pres_human = torch.stack(pres_human)
target_human = torch.stack(target_human)
pres_mouse = torch.stack(pres_mouse)
target_mouse = torch.stack(target_mouse)
torch.save(pres_human, "./temp/pres_human.pt")
torch.save(target_human, "./temp/target_human.pt")
torch.save(pres_mouse, "./temp/pres_mouse.pt")
torch.save(target_mouse, "./temp/target_mouse.pt")
print(pres_human.shape, target_human.shape, pres_mouse.shape, target_mouse.shape)


Human: 1937, Mouse: 2017


100%|██████████| 2017/2017 [12:08<00:00,  2.77it/s]


Human: 0.6062219410165414
Mouse: 0.7155115181523855
torch.Size([1937, 896, 5313]) torch.Size([1937, 896, 5313]) torch.Size([2017, 896, 1643]) torch.Size([2017, 896, 1643])


In [None]:
import torch

pres_human = torch.load("./temp/pres_human.pt")
pres_mouse = torch.load("./temp/pres_mouse.pt")
target_human = torch.load("./temp/target_human.pt")
target_mouse = torch.load("./temp/target_mouse.pt")
print(pres_human.shape, pres_mouse.shape)

In [6]:
def pearsonr(x, y):
    # 计算均值
    x_mean = x.mean(dim=0, keepdim=True)
    y_mean = y.mean(dim=0, keepdim=True)

    # 中心化
    x_centered = x - x_mean
    y_centered = y - y_mean

    # 计算协方差
    covariance = (x_centered * y_centered).sum(dim=0)

    # 计算标准差
    x_std = torch.sqrt((x_centered**2).sum(dim=0))
    y_std = torch.sqrt((y_centered**2).sum(dim=0))

    # 计算 Pearson 相关性
    corr = covariance / (x_std * y_std)

    return corr

pre_human_flatten = pres_human.view(-1, 5313)
targets_human_flatten = target_human.view(-1, 5313)
tracks_pre_human = {key: pre_human_flatten[:, start:end] for key, (start, end) in track_types_human.items()}
tracks_target_human = {key: targets_human_flatten[:, start:end] for key, (start, end) in track_types_human.items()}

pre_mouse_flatten = pres_mouse.view(-1, 1643)
targets_mouse_flatten = target_mouse.view(-1, 1643)
tracks_pre_mouse = {key: pre_mouse_flatten[:, start:end] for key, (start, end) in track_types_mouse.items()}
tracks_target_mouse = {key: targets_mouse_flatten[:, start:end] for key, (start, end) in track_types_mouse.items()}

corr_human = {}
for track in track_types_human.keys():
    p, l = tracks_pre_human[track], tracks_target_human[track]
    if track == "CAGE":
        p, l = torch.log(p + 1), torch.log(l + 1)
    corr_human[track] = pearsonr(p, l)
corr_mouse = {}  
for track in track_types_mouse.keys():
    p, l = tracks_pre_mouse[track], tracks_target_mouse[track]
    if track == "CAGE":
        p, l = torch.log(p + 1), torch.log(l + 1)
    corr_mouse[track] = pearsonr(p, l)

print("Human:")
for key, value in corr_human.items():
    print(f"{key}: {value.mean()}")

print("Mouse:")
for key, value in corr_mouse.items():
    print(f"{key}: {value.mean()}")


Human:
DNASE/ATAC: 0.8182075023651123
TF ChIP-seq: 0.5456007719039917
Histone ChIP-seq: 0.6641464829444885
CAGE: 0.6319071054458618
Mouse:
DNASE/ATAC: 0.7835639119148254
TF ChIP-seq: 0.5928989052772522
Histone ChIP-seq: 0.7853987812995911
CAGE: 0.6393424868583679


In [None]:
import pickle
with open("./temp/corr_human_space.pkl", "wb") as f:
    pickle.dump(corr_human, f)

with open("./temp/corr_mouse_space.pkl", "wb") as f:
    pickle.dump(corr_mouse, f)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
species_gates = np.load("./temp/species_gates.npy")
is_species = True
show_gates = 100 * species_gates
plt.figure(figsize=(4, 2))
sns.heatmap(
    show_gates[-1],
    cmap="Blues",
    annot=True,
    fmt=".2f",
    cbar=False,
    yticklabels=["Human", "Mouse"],
)

plt.savefig("./temp/species_gates.png", dpi=300)

In [None]:
if is_species:
    # 绘制SpeciesMoE门控网络图
    gates_path = "./version/5/4.28000.png"
    show_gates = 100 * species_gates

    fig, axes = plt.subplots(nrows=4, ncols=3, figsize=(10, 8)) 
    axes = axes.flatten()

    for i in range(11):  
        sns.heatmap(show_gates[i], ax=axes[i], cmap="Blues", annot=True, fmt=".2f", cbar=False, yticklabels=["Human", "Mouse"])
        axes[i].set_title(f"Transformer Block {i+1}")
        axes[i].set_xlabel("Experts")
        axes[i].set_ylabel("Species")

    # 隐藏多余的子图
    for j in range(11, 12): 
        axes[j].axis("off")

    plt.tight_layout()
    plt.savefig(gates_path)
    plt.show()

In [None]:
tracks_gates = np.load("./temp/tracks_gates.npy")
is_tracks = True
if is_tracks:
    temp1, temp2 =tracks_gates[0:4], tracks_gates[4:8]
    tracks_gates = temp1 + temp2
    tracks_gates = tracks_gates / tracks_gates.sum(axis=1, keepdims=True)

    # 绘制门控网络图
    gates_path = "./version/5/4.28000_tracks.png"
    show_gates = 100 * tracks_gates

    plt.figure(figsize=(8, 4))
    labels = [
        "DNASE/ATAC",
        "TF ChIP-seq",
        "Histone ChIP-seq",
        "CAGE",
    ]
    sns.heatmap(show_gates, annot=True, fmt=".2f", cmap="Blues", linewidths=0.5, cbar=False, yticklabels=labels)

    plt.savefig("./temp/tracks_gate.png", dpi=500, bbox_inches="tight")
    # plt.savefig(gates_path)
    plt.show()