In [None]:
import os
import logging
import pandas as pd
import torch
from transformers import BertConfig, AutoModel, AutoTokenizer
import numpy as np
import json

with open('config.json', 'r') as f:
    config = json.load(f)

path = config['working_dir']

output_dir = os.path.join(path,'output_bert_no_dim_new_new2') 
print('output_dir:',output_dir)
os.makedirs(output_dir, exist_ok=True)  

output_file_name = "dna_embeddings"
print('output_file_name:',output_file_name)

logging.basicConfig(filename=os.path.join(output_dir, 'dna_embedding_generation.log'), 
                    level=logging.INFO, 
                    format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger()

nodes_file_path = config['nodes_file_path']  
print('nodes_file_path:',nodes_file_path)

## Generate DNABERT embeddings

In [None]:
config = BertConfig.from_pretrained("zhihan1996/DNABERT-2-117M")
dnabert_model = AutoModel.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True, config=config)
dnabert_tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)

In [None]:
def calculate_embedding(seq):
    if pd.isna(seq): 
        return None
    try:
        inputs = dnabert_tokenizer(seq, return_tensors='pt')["input_ids"]
        hidden_states = dnabert_model(inputs)[0]  # [1, sequence_length, 768]

        # embedding with mean pooling
        embedding_mean = torch.mean(hidden_states[0], dim=0)
        
        return embedding_mean.detach().cpu().numpy().tolist()
    except Exception as e:
        logger.error(f"Error: {e}")
        return None

In [None]:
def calculate_final_embedding(sequences):
    embeddings_list = []

    for seq in sequences:
        if not seq: 
            continue

        try:
            embedding = calculate_embedding(seq)
            if embedding is not None:
                embeddings_list.append(embedding) 
        except Exception as e:
            logger.error(f"Error: {seq}: {e}")

    if not embeddings_list: 
        return None

    final_embedding = np.mean(embeddings_list, axis=0).tolist()  
    return final_embedding

In [None]:
df = pd.read_csv(nodes_file_path, sep='\t')
len0 = df.groupby('type')['name'].nunique()
print("\nCount of names by type :",len0)

df_dna = df[df["type"].isin(["Gene", "miRNA"])]

df_dna.loc[:, 'Sequence'] = df_dna['Sequence'].str.replace('U', 'T')

df_dna.loc[:, 'len_seq'] = df_dna['Sequence'].apply(lambda x: len(x) if isinstance(x, str) else 0)

chunk_size = 512  
df_dna.loc[:, 'seq_list'] = df_dna['Sequence'].apply(lambda seq: [seq[i:i + chunk_size] for i in range(0, len(seq), chunk_size)] if isinstance(seq, str) else [])
df_dna.loc[:, 'seq_list_len'] = df_dna['seq_list'].apply(len)

print(df_dna.shape)
df_dna.head()

In [None]:
len0 = df_dna[df_dna['len_seq']==0].groupby('type')['name'].nunique()
print("\nCount of names with seq_len=0 by type :",len0)

In [None]:
output_file_path = os.path.join(output_dir, f'{output_file_name}.tsv')
output_file_path_discarded = os.path.join(output_dir, f"{output_file_name}_discarded.tsv")

if os.path.exists(output_file_path):
    df_existing = pd.read_csv(output_file_path, sep='\t')
    processed_count = df_existing.shape[0]
else:
    with open(output_file_path, 'w') as f:
        f.write("name\ttype\tlen_seq\tembedding\n")
    processed_count = 0

    
batch_size = 5  
embeddings_batch = []

for index, row in df_dna.iloc[processed_count:].iterrows():
    sequence_name = row['name']
    sequence_type = row['type']
    sequence_list = row['seq_list']
    sequence_len = row['len_seq']
        
    embedding = calculate_final_embedding(sequence_list)

    if embedding is None:
        if not os.path.exists(output_file_path_discarded):
            with open(output_file_path_discarded, 'w') as f:
                f.write("name\ttype\tlen_seq\tembedding\n")
        
        with open(output_file_path_discarded, 'a') as f:
            f.write(f"{sequence_name}\t{sequence_type}\t{sequence_len}\n")
            
    
    embeddings_batch.append((sequence_name, sequence_type, sequence_len, embedding))
    processed_count += 1

    print(str(processed_count)+"/"+str(df_dna.shape[0])+" name="+sequence_name)
    logger.info(str(processed_count)+"/"+str(df_dna.shape[0])+" name="+sequence_name)

    if len(embeddings_batch) >= batch_size:
        with open(output_file_path, 'a') as f:
            for sequence_name, sequence_type, sequence_len, embedding in embeddings_batch:
                f.write(f"{sequence_name}\t{sequence_type}\t{sequence_len}\t{embedding}\n")
        embeddings_batch = []
        # torch.cuda.empty_cache()

if embeddings_batch:
    with open(output_file_path, 'a') as f:
        for sequence_name, sequence_type, sequence_len, embedding in embeddings_batch:
            f.write(f"{sequence_name}\t{sequence_type}\t{sequence_len}\t{embedding}\n")

## Fill missing embeddings

In [None]:
import pandas as pd
import numpy as np

def replace_null_embeddings_with_type_mean(df):
    if isinstance(df['embedding'].iloc[0], str):
        df['embedding'] = df['embedding'].apply(lambda x: eval(x) if pd.notna(x) else np.nan)
    
    df['embedding'] = df['embedding'].apply(lambda x: np.array(x) if isinstance(x, list) else x)
    
    non_null_mask = df['embedding'].apply(lambda x: x is not np.nan if isinstance(x, np.ndarray) else pd.notna(x))
    non_null_embeddings = df[non_null_mask]
    
    type_mean_embeddings = non_null_embeddings.groupby('type')['embedding'].apply(
        lambda x: np.mean(np.stack(x.values), axis=0)
    ).to_dict()
    
    def fill_na_embedding(row):
        if isinstance(row['embedding'], np.ndarray):
            return row['embedding']
        elif pd.isna(row['embedding']):
            return type_mean_embeddings.get(row['type'], np.nan)
        return row['embedding']
    
    df['embedding'] = df.apply(fill_na_embedding, axis=1)

    def to_list(embedding):
        if isinstance(embedding, str):
            embedding = np.array(eval(embedding))
        return embedding.tolist()
    
    df['embedding'] = df['embedding'].apply(to_list)
    
    return df

In [None]:
df_output = pd.read_csv(os.path.join(output_dir, f"{output_file_name}.tsv"), sep='\t')
df_output = replace_null_embeddings_with_type_mean(df_output)
print(df_output[df_output['embedding'].isna()])  

df_output.to_csv(os.path.join(output_dir, f"{output_file_name}_filled.tsv"), sep='\t', index=False)