In [None]:
import torch
import torch.nn.functional as F
import os
from tqdm import tqdm
import numpy as np
import seaborn as sns
import kipoiseq
import matplotlib.pyplot as plt
import pandas as pd


from dataloaders.h5dataset import GEPBedDataset, MultiSpeciesDataset
from model.modeling_enformer import Enformer, EnformerConfig, from_pretrained

device = torch.device("cuda:0")

human_test_data_path: str = "/home/jiwei_zhu/disk/Enformer/Data/human_test.h5"
human_test_bed_path: str = "/home/jiwei_zhu/disk/Enformer/Data/human_test.bed"
human_genome_path: str = "/home/jiwei_zhu/disk/Enformer/Data/hg38.ml.fa"

mouse_test_data_path: str = "/home/jiwei_zhu/disk/Enformer/Data/mouse_test.h5"
mouse_test_bed_path: str = "/home/jiwei_zhu/disk/Enformer/Data/mouse_test.bed"
mouse_genome_path: str = "/home/jiwei_zhu/disk/Enformer/Data/mm10.fa"

df_human = pd.read_csv(
    "/home/jiwei_zhu/disk/Enformer/enformer_MoE/targets_human_sorted.txt", sep="\t"
)
df_mouse = pd.read_csv(
    "/home/jiwei_zhu/disk/Enformer/enformer_MoE/targets_mouse_sorted.txt", sep="\t"
)

index_human = list(df_human["index"])
track_types_human = {
    "DNASE/ATAC": (0, 684),
    "TF ChIP-seq": (684, 2573),
    "Histone ChIP-seq": (2573, 4675),
    "CAGE": (4675, 5313),
}

index_mouse = list(df_mouse["index"])
track_types_mouse = {
    "DNASE/ATAC": (0, 228),
    "TF ChIP-seq": (228, 519),
    "Histone ChIP-seq": (519, 1286),
    "CAGE": (1286, 1643),
}


human_genome_dict, mouse_genome_dict = {}, {}
test_dataset = MultiSpeciesDataset(
    file_paths=[human_test_data_path, mouse_test_data_path],
    bed_paths=[human_test_bed_path, mouse_test_bed_path],
    seqlen=196608,
    genome_dicts=[human_genome_dict, mouse_genome_dict],
    shift_aug=False,
    rc_aug=False,
)

load preprocess data
load preprocess data


In [3]:
model_path = "/home/jiwei_zhu/disk/Enformer/enformer_ckpt"
model = from_pretrained(model_path, use_tf_gamma=False)
model.to(device)


def pearson_corr_coef(x, y, dim=1, reduce_dims=(-1,)):
    x_centered = x - x.mean(dim=dim, keepdim=True)
    y_centered = y - y.mean(dim=dim, keepdim=True)
    return F.cosine_similarity(x_centered, y_centered, dim=dim).mean(dim=reduce_dims)

In [None]:
pre_humans, pre_mouses = [], []
targets_human, targets_mouse = [], []
p_human_total, p_mouse_total = 0.0, 0.0

len_human, len_mouse = len(test_dataset.human_dataset), len(test_dataset.mouse_dataset)
model.eval()
for idx in tqdm(range(len(test_dataset))):
    data = test_dataset[idx]
    human_x = torch.tensor(data["human_x"], dtype=torch.float32).to(device)
    mouse_x = torch.tensor(data["mouse_x"], dtype=torch.float32).to(device)
    human_labels = torch.tensor(data["human_labels"][:, index_human], dtype=torch.float32).to(device)
    mouse_labels = torch.tensor(data["mouse_labels"][:, index_mouse], dtype=torch.float32).to(device)

    with torch.no_grad():
        if idx < len_human:
            pre_human = model(human_x, head="human")[:, index_human]
            pre_humans.append(pre_human.cpu())
            targets_human.append(human_labels.cpu())
            p_human_total += pearson_corr_coef(pre_human, human_labels).cpu().item()
        if idx < len_mouse:
            pre_mouse = model(mouse_x, head="mouse")[:, index_mouse]
            pre_mouses.append(pre_mouse.cpu())
            targets_mouse.append(mouse_labels.cpu())
            p_mouse_total += pearson_corr_coef(pre_mouse, mouse_labels).cpu().item()

pre_humans = torch.stack(pre_humans)
pre_mouses = torch.stack(pre_mouses)
print(pre_humans.shape, pre_mouses.shape)
torch.save(pre_humans, "./temp/pre_human_baseline.pt")
torch.save(pre_mouses, "./temp/pre_mouse_baseline.pt")

targets_human = torch.stack(targets_human)
targets_mouse = torch.stack(targets_mouse)
print(targets_human.shape, targets_mouse.shape)
torch.save(targets_human, "./temp/target_human_baseline.pt")
torch.save(targets_mouse, "./temp/target_mouse_baseline.pt")


p_human_total /= len_human
p_mouse_total /= len_mouse
print(f"human: {p_human_total}")
print(f"mouse: {p_mouse_total}")

100%|██████████| 2017/2017 [53:09<00:00,  1.58s/it]


torch.Size([1937, 896, 5313]) torch.Size([2017, 896, 1643])
torch.Size([1937, 896, 5313]) torch.Size([2017, 896, 1643])
human: 0.6163953073134297
mouse: 0.6912565573820928


In [None]:
import torch

pre_humans = torch.load("./temp/pres_human_baseline.pt")
pre_mouses = torch.load("./temp/pres_mouse_baseline.pt")
targets_human = torch.load("./temp/target_human_baseline.pt")
targets_mouse = torch.load("./temp/target_mouse_baseline.pt")
print(pre_humans.shape)
print(pre_mouses.shape)
print(targets_human.shape)
print(targets_mouse.shape)

torch.Size([1937, 896, 5313])
torch.Size([2017, 896, 1643])
torch.Size([1937, 896, 5313])
torch.Size([2017, 896, 1643])


In [12]:
def pearsonr(x, y):
    # 计算均值
    x_mean = x.mean(dim=0, keepdim=True)
    y_mean = y.mean(dim=0, keepdim=True)

    # 中心化
    x_centered = x - x_mean
    y_centered = y - y_mean

    # 计算协方差
    covariance = (x_centered * y_centered).sum(dim=0)

    # 计算标准差
    x_std = torch.sqrt((x_centered**2).sum(dim=0))
    y_std = torch.sqrt((y_centered**2).sum(dim=0))

    # 计算 Pearson 相关性
    corr = covariance / (x_std * y_std)

    return corr

pre_human_flatten = pre_humans[:1937].view(-1, 5313)
targets_human_flatten = targets_human[:1937].view(-1, 5313)
tracks_pre_human = {key: pre_human_flatten[:, start:end] for key, (start, end) in track_types_human.items()}
tracks_target_human = {key: targets_human_flatten[:, start:end] for key, (start, end) in track_types_human.items()}

pre_mouse_flatten = pre_mouses.view(-1, 1643)
targets_mouse_flatten = targets_mouse.view(-1, 1643)
tracks_pre_mouse = {key: pre_mouse_flatten[:, start:end] for key, (start, end) in track_types_mouse.items()}
tracks_target_mouse = {key: targets_mouse_flatten[:, start:end] for key, (start, end) in track_types_mouse.items()}

corr_human = {}
for track in track_types_human.keys():
    p, l = tracks_pre_human[track], tracks_target_human[track]
    if track == "CAGE":
        p, l = torch.log(p + 1), torch.log(l + 1)
    corr_human[track] = pearsonr(p, l)
corr_mouse = {}  
for track in track_types_mouse.keys():
    p, l = tracks_pre_mouse[track], tracks_target_mouse[track]
    if track == "CAGE":
        p, l = torch.log(p + 1), torch.log(l + 1)
    corr_mouse[track] = pearsonr(p, l)

In [None]:
import pickle
with open('./temp/corr_human_baseline.pkl', 'wb') as f:
    pickle.dump(corr_human, f)

with open('./temp/corr_mouse_baseline.pkl', 'wb') as f:
    pickle.dump(corr_mouse, f)

In [10]:
# corr_human["DNASE/ATAC"] = torch.cat((corr_human["DNASE"], corr_human["ATAC"]), dim=0)
print("Human:")
for key, value in corr_human.items():
    print(f"{key}: {value.shape}")

# corr_mouse["DNASE/ATAC"] = torch.cat((corr_mouse["DNASE"], corr_mouse["ATAC"]), dim=0)
print("Mouse:")
for key, value in corr_mouse.items():
    print(f"{key}: {value.shape}")

Human:
DNASE: torch.Size([674])
ATAC: torch.Size([10])
TF ChIP-seq: torch.Size([1889])
Histone ChIP-seq: torch.Size([2102])
CAGE: torch.Size([638])
Mouse:
DNASE: torch.Size([101])
ATAC: torch.Size([127])
TF ChIP-seq: torch.Size([291])
Histone ChIP-seq: torch.Size([767])
CAGE: torch.Size([357])


In [None]:
print(len(test_dataset.human_dataset))