In [1]:
import os
import time
import torch
import argparse
import numpy as np
import pandas as pd
from Bio import SeqIO
from torch import Tensor, nn
from datetime import datetime
import matplotlib.pyplot as plt
from torchinfo import summary  # 需要安装 torchinfo 包
from torch.utils.data import random_split, DataLoader, Dataset
from scipy.stats import spearmanr
import warnings

import torch.optim

os.environ["CUDA_VISIBLE_DEVICES"] = "0" 

warnings.filterwarnings("ignore")

import grelu.lightning
import grelu.data
import grelu.model.models
from grelu.sequence.format import (
    INDEX_TO_BASE_HASH,
    indices_to_one_hot,
    strings_to_indices,
)

In [2]:
class GeneExpressionDataset_self(Dataset):
    def __init__(self, sequences, labels):
        self.sequences = sequences
        self.labels = labels

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):

        # One-hot encode
        seq = self.sequences[idx]
        seq_strings = strings_to_indices(seq)
        seq_one_hot = indices_to_one_hot(seq_strings)
        
        label = self.labels[idx]

        return seq_one_hot, label

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
expression_type = "power"
folder_path = f"code/model/CNN_Model"
save_dir = "2025_2_4_1_35_power" 
model_file = "checkpoint.ckpt"    

In [4]:
model = grelu.model.models.DNA_Conv1D_Model(expression_type).to(device)

from torchinfo import summary  
summary(model, input_size=(1, 4, 196608))

Layer (type:depth-idx)                   Output Shape              Param #
DNA_Conv1D_Model                         [1, 78]                   --
├─Conv1d: 1-1                            [1, 64, 196608]           3,904
├─BatchNorm1d: 1-2                       [1, 64, 196608]           128
├─MaxPool1d: 1-3                         [1, 64, 39321]            --
├─Dropout: 1-4                           [1, 64, 39321]            --
├─Conv1d: 1-5                            [1, 128, 39321]           123,008
├─BatchNorm1d: 1-6                       [1, 128, 39321]           256
├─MaxPool1d: 1-7                         [1, 128, 7864]            --
├─Dropout: 1-8                           [1, 128, 7864]            --
├─Conv1d: 1-9                            [1, 128, 7864]            245,888
├─BatchNorm1d: 1-10                      [1, 128, 7864]            256
├─MaxPool1d: 1-11                        [1, 128, 1572]            --
├─Dropout: 1-12                          [1, 128, 1572]            --

In [3]:

if expression_type == "power": 
    mmc2_gene_expression_file = "data/mmc2_78_gene_expression_power_transformed.npy"
elif expression_type == "identity":
    mmc2_gene_expression_file = "data/mmc2_78_gene_expression_identity_transformed.npy"

mmc2_gene_sequence_file   = "data/mmc2_78_gene_sequence.fasta"
mmc2_gene_location_file   = "data/mmc2_78_gene_location.csv"

# select gene sequence in length less than 100000 
threshold = 100000 
df_mmc2_gene_location = pd.read_csv(mmc2_gene_location_file)
cell_type_list = df_mmc2_gene_location.iloc[:, 0].values
df_mmc2_gene_location['len'] = pd.to_numeric(df_mmc2_gene_location['len'], errors='coerce')  # 如果有非数值的值，设置为NaN
df_mmc2_gene_location = df_mmc2_gene_location[df_mmc2_gene_location['len'] < threshold]
selected_gene = df_mmc2_gene_location["gene_id"].to_list()

# load gene expression 
mmc2_gene_expression = np.load(mmc2_gene_expression_file, allow_pickle=True)
mmc2_gene_expression_dict = dict(zip(mmc2_gene_expression.item()["genes"], mmc2_gene_expression.item()["expression"]))

# load gene sequence 
mmc2_gene_sequence_dict = {}
for record in SeqIO.parse(mmc2_gene_sequence_file, "fasta"):
    gene_id = record.id  
    gene_sequence = str(record.seq) 
    mmc2_gene_sequence_dict[gene_id] = gene_sequence

dataset_dict = {
                'gene': [],
                'sequence': [],
                'expression': []
            }
dataset_df = pd.DataFrame(dataset_dict)

for gene in selected_gene:
    sequence = mmc2_gene_sequence_dict[gene]
    expression = mmc2_gene_expression_dict[gene]

    dataset_dict['gene'].append(gene)
    dataset_dict['sequence'].append(sequence)
    dataset_dict['expression'].append(expression)

# tarin_step
total_dataset = GeneExpressionDataset_self(dataset_dict['sequence'], dataset_dict['expression'])
total_loader = DataLoader(total_dataset, batch_size=32, shuffle=True)

def calculate_spearman(real_activate, fake_activate):

    spearman_coefficients = []
    real_tensor = torch.stack(real_activate)  
    fake_tensor = torch.stack(fake_activate)

    for col in range(real_tensor.shape[1]):
        real_col = real_tensor[:, col].numpy()  
        fake_col = fake_tensor[:, col].numpy()  
        coefficient, _ = spearmanr(real_col, fake_col)
        spearman_coefficients.append(coefficient)

    return spearman_coefficients


def Measure_Metric(model, total_loader):
    model.eval()  
    real_activate = []
    fake_activate = []
    with torch.no_grad():  
        for inputs, targets in total_loader:
            
            inputs = torch.tensor(inputs, dtype=torch.float32).to(device)
            outputs = model(inputs)
            real_activate.extend(targets)
            fake_activate.extend(outputs.to('cpu'))

    spearman_coefficients = calculate_spearman(real_activate, fake_activate)

    return spearman_coefficients, real_activate, fake_activate

In [4]:
model = grelu.model.models.DNA_Conv1D_Model(expression_type).to(device)

checkpoint = torch.load(os.path.join(folder_path, save_dir, model_file))
model.load_state_dict(checkpoint['model_state_dict'])

spearmanr_coefficients, real_activate, fake_activate = Measure_Metric(model, total_loader)

def plt_figure(save_figure_path, cell_type, real_col, fake_col, spearmanr):

    os.makedirs(save_figure_path, exist_ok=True)

    plt.figure(figsize=(8, 6))
    plt.scatter(real_col, fake_col, color="blue", alpha=0.5)

    plt.title(f"{cell_type}_{spearmanr:.5f}", fontsize=16)
    plt.xlabel("real", fontsize=14)
    plt.ylabel("fake", fontsize=14)
    plt.xlim(0, 200)  
    plt.ylim(0, 200)  

    plt.savefig(f'{save_figure_path}/{cell_type}_{spearmanr:.5f}.png')
    plt.close()


real_tensor = torch.stack(real_activate)  
fake_tensor = torch.stack(fake_activate)

for index in range(real_tensor.shape[1]):
    real_col = real_tensor[:, index].numpy()  
    fake_col = fake_tensor[:, index].numpy() 

    spearmanr = spearmanr_coefficients[index]
    cell_type = cell_type_list[index]
    plt_figure(os.path.join(folder_path, save_dir, "figure"), cell_type, real_col, fake_col, spearmanr)
