In [None]:
import torch
import pandas as pd
from transformers import AutoTokenizer, AutoModel
from transformers.models.bert.configuration_bert import BertConfig
import numpy as np
import os
import logging
import json


with open('config.json', 'r') as f:
    config = json.load(f)

path = config['working_dir']

output_dir = os.path.join(path,'output_bert_no_dim_new_new3') 
print('output_dir: ',output_dir)
os.makedirs(output_dir, exist_ok=True)  

logging.basicConfig(filename=os.path.join(output_dir, 'txt_embedding_generation.log'), 
                    level=logging.INFO, 
                    format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger()

In [None]:
file_path = config['nodes_file_path']  
print('nodes_file_path:',file_path)

df = pd.read_csv(file_path, sep="\t", usecols=["name", "type", "Description", "Sequence"])

biobert_tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-base-cased-v1.2", trust_remote_code=True)
biobert_model = AutoModel.from_pretrained("dmis-lab/biobert-base-cased-v1.2", trust_remote_code=True)

df["embedding"] = None

In [None]:

def get_text_embedding(text):
    if pd.isna(text):  
        return None
    try:
        inputs = biobert_tokenizer(text, return_tensors='pt')['input_ids']
        hidden_states = biobert_model(inputs)[0]  # [1, sequence_length, 768]

        embedding_mean = torch.mean(hidden_states[0], dim=0)  # Mean pooling

        return embedding_mean.detach().cpu().numpy().tolist() 
    except Exception as e:
        logger.error(f"Error: len {len(text)}  {text}. \nErrore: {e}")
        return None

def save_batch(df, file_path):
    if not os.path.exists(os.path.dirname(file_path)):
        os.makedirs(os.path.dirname(file_path))
    
    if os.path.exists(file_path):
        existing_df = pd.read_csv(file_path, sep="\t")
        df = pd.concat([existing_df, df], ignore_index=True)
    
    df[["name", 'type', "embedding"]].to_csv(file_path, sep="\t", index=False)
    os.sync()  

def load_processed_indices(file_path):
    if os.path.exists(file_path):
        return set(pd.read_csv(file_path, sep="\t").index.tolist())
    return set()

In [None]:

text_output_path = os.path.join(output_dir, "text_embeddings.tsv")
text_output_filled_path = os.path.join(output_dir, "text_embeddings_filled.tsv")
processed_indices_path = os.path.join(output_dir, "processed_text_indices.txt")

In [None]:
df_text = df[df["type"].isin(["Phenotype", "Disease", "Genomic feature"])].copy()
df_text

processed_indices = load_processed_indices(processed_indices_path)

if os.path.exists(text_output_path):
    existing_df = pd.read_csv(text_output_path, sep="\t")
    if len(existing_df) == len(df_text):  
        df_text = existing_df
    else:
        for i, row in df_text.iterrows():
            if i in processed_indices:  
                continue
            print(row["Description"])
            embedding = get_text_embedding(row["Description"])
            df_text.at[i, "embedding"] = embedding  
            
            if (i + 1) % 10 == 0:
                save_batch(df_text.iloc[i-9:i+1], text_output_path)
                with open(processed_indices_path, "a") as f:
                    f.write("\n".join(map(str, range(i-9, i+1))) + "\n")

        remaining_indices = range(len(df_text) - (len(df_text) % 10), len(df_text))
        if len(remaining_indices) > 0:
            save_batch(df_text.iloc[remaining_indices], text_output_path)
            with open(processed_indices_path, "a") as f:
                f.write("\n".join(map(str, remaining_indices)) + "\n")
else:
    for i, row in df_text.iterrows():
        embedding = get_text_embedding(row["Description"])
        df_text.at[i, "embedding"] = embedding  
        
        if (i + 1) % 10 == 0:
            save_batch(df_text.iloc[i-9:i+1], text_output_path)
            with open(processed_indices_path, "a") as f:
                f.write("\n".join(map(str, range(i-9, i+1))) + "\n")

    remaining_indices = range(len(df_text) - (len(df_text) % 10), len(df_text))
    if len(remaining_indices) > 0:
        save_batch(df_text.iloc[remaining_indices], text_output_path)
        with open(processed_indices_path, "a") as f:
            f.write("\n".join(map(str, remaining_indices)) + "\n")


In [None]:
def replace_null_embeddings_with_type_mean(df):

    if isinstance(df['embedding'].iloc[0], str):
        df['embedding'] = df['embedding'].apply(lambda x: eval(x) if pd.notna(x) else np.nan)
    
    df['embedding'] = df['embedding'].apply(lambda x: np.array(x) if isinstance(x, list) else x)
    
    non_null_mask = df['embedding'].apply(lambda x: x is not np.nan if isinstance(x, np.ndarray) else pd.notna(x))
    non_null_embeddings = df[non_null_mask]
    
    type_mean_embeddings = non_null_embeddings.groupby('type')['embedding'].apply(
        lambda x: np.mean(np.stack(x.values), axis=0)
    ).to_dict()
    
    def fill_na_embedding(row):
        if isinstance(row['embedding'], np.ndarray):
            return row['embedding']
        elif pd.isna(row['embedding']):
            return type_mean_embeddings.get(row['type'], np.nan)
        return row['embedding']
    
    df['embedding'] = df.apply(fill_na_embedding, axis=1)

    def to_list(embedding):
        if isinstance(embedding, str):
            embedding = np.array(eval(embedding))
        return embedding.tolist()
    
    df['embedding'] = df['embedding'].apply(to_list)
    
    return df

df_text = pd.read_csv(text_output_path, sep="\t")

df_text_filled = replace_null_embeddings_with_type_mean(df_text)
print(df_text_filled[df_text_filled['embedding'].isna()]) 

if os.path.exists(text_output_filled_path):
    pass 
else:
    try:
        save_batch(df_text_filled, text_output_filled_path)
    except Exception as e:
        logger.error(f"Error: {e}")