In [1]:
# Graph Neural Network Libraries
import dgl
import networkx as nx
import obonet

# PhenoDP and Preprocessing
from PhenoDP import *
from PhenoDP_Preprocess import *

# HPO Encoders
from PSD_HPOEncoder import *
from PCL_HPOEncoder import *

# Transformers and PEFT
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

# PyTorch
import torch

# NumPy and Progress Tracking
import numpy as np
from tqdm import tqdm

# HPO Ontology
from pyhpo.ontology import Ontology


In [2]:

# Initialize the HPO Ontology
Ontology()
hp_df = Ontology.to_dataframe()

# Define the device for model loading
device = "cuda:3"  # The device to load the model onto
model_name_or_path = '/remote-home/share/data3/ly/phenoDP/new-checkpoint-finetune-with-4-datasets/'

# Load the pre-trained model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    model_name_or_path,
    torch_dtype=torch.float16,
    device_map="cuda:3"
)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [3]:
def get_average_encoding(text):
    """
    Compute the average encoding of a given text using the pre-trained model.

    Args:
        text (str): The input text to encode.

    Returns:
        torch.Tensor: The average encoding of the text.
    """
    # Convert the text into model inputs
    inputs = tokenizer(text, return_tensors="pt").to(device)
    
    # Obtain the hidden states from the model
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
    
    # Extract the hidden states from the last layer
    last_hidden_states = outputs.hidden_states[-1]
    
    # Compute the average of the hidden states to obtain the text representation
    average_encoding = last_hidden_states.mean(dim=1).squeeze()
    
    return average_encoding

# Example text
text = "This is an example text for computing the average vector encoding."

# Compute the average vector encoding
average_encoding = get_average_encoding(text)
print(average_encoding)

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


tensor([ 0.5605, -1.1172, -0.8311,  ...,  0.9614,  0.3591, -2.2852],
       device='cuda:3', dtype=torch.float16)


In [4]:
def get_hpo_embedding(hpo_id, Ontology, tokenizer, model):
    """
    Retrieve the embedding for a given HPO term.

    Args:
        hpo_id (str): The ID of the HPO term.
        Ontology: The HPO ontology object.
        tokenizer: The tokenizer for the model.
        model: The pre-trained model.

    Returns:
        np.ndarray: The embedding of the HPO term, or None if not found.
    """
    hpo_obj = Ontology.get_hpo_object(hpo_id)
    if hpo_obj:
        hpo_name = hpo_obj.name
        # Check the cache to avoid redundant computations
        if hpo_name in hpo_embedding_cache:
            return hpo_embedding_cache[hpo_name]
        else:
            embedding = get_average_embedding([hpo_name])
            if embedding is not None:
                hpo_embedding_cache[hpo_name] = embedding[0]
                return embedding[0]
    return None

def get_average_embedding(text):
    """
    Compute the average embedding of a given text using the pre-trained model.

    Args:
        text (str): The input text to encode.

    Returns:
        np.ndarray: The average embedding of the text.

    Raises:
        ValueError: If NaN values are detected in the embedding.
    """
    # Convert the text into model inputs
    inputs = tokenizer(text, return_tensors="pt").to(device)
    
    # Obtain the hidden states from the model
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
    
    # Extract the hidden states from the last layer
    last_hidden_states = outputs.hidden_states[-1]
    
    # Compute the average of the hidden states to obtain the text representation
    average_encoding = last_hidden_states.mean(dim=1).squeeze().cpu().numpy()
    
    # Check for NaN values in the embedding
    if np.isnan(average_encoding).any():
        raise ValueError(f"NaN detected in embedding for text: {text}")
    
    return average_encoding


In [5]:
# Load the HPO ontology from the specified URL
url = '../../hp.obo'  
graph = obonet.read_obo(url)
feature_dimension = 2048

# Process each node in the graph to compute and store its embedding
for node in tqdm(graph.nodes(), desc="Processing nodes"):
    try:
        # Compute the embedding for the node
        embedding = get_average_embedding(Ontology.get_hpo_object(node).name)
        # Store the embedding in the node's features
        graph.nodes[node]['feature'] = embedding
    except ValueError as e:
        # If NaN is detected, raise an error and terminate the program
        print(f"Error processing node {node}: {e}")
        raise


Processing nodes: 100%|██████████| 18281/18281 [16:52<00:00, 18.05it/s]


In [None]:
dgl_graph = nx_to_dgl(graph)
feature_dimension = 2048
in_feats = feature_dimension
h_feats = 256
out_feats = feature_dimension

# DGL graph does not implement the API for GPU, so we use CPU for computation
device = torch.device("cpu")
model = GCN(in_feats, h_feats, out_feats).to(device)

# train_model(model, dgl_graph, epochs=50, lr=0.001, node_mask_percentage=0.2, edge_mask_percentage=0.2)

train_model(model, dgl_graph, epochs=500, lr=0.001, node_mask_percentage=0.2, edge_mask_percentage=0.2)

  dgl_graph.ndata['feat'] = torch.tensor(features, dtype=torch.float32)
  assert input.numel() == input.storage().size(), "Cannot convert view " \


Epoch 0, Loss: 4.250472068786621
Epoch 10, Loss: 3.58121395111084
Epoch 20, Loss: 3.309152364730835
Epoch 30, Loss: 3.289189577102661
Epoch 40, Loss: 3.1876633167266846
Epoch 50, Loss: 3.1779816150665283
Epoch 60, Loss: 3.1386053562164307
Epoch 70, Loss: 3.1241912841796875
Epoch 80, Loss: 3.100999355316162
Epoch 90, Loss: 3.0129268169403076
Epoch 100, Loss: 2.9341485500335693
Epoch 110, Loss: 2.9474194049835205
Epoch 120, Loss: 2.858182907104492
Epoch 130, Loss: 2.8204548358917236
Epoch 140, Loss: 2.7541074752807617
Epoch 150, Loss: 2.611929178237915
Epoch 160, Loss: 2.6177871227264404
Epoch 170, Loss: 2.5599496364593506
Epoch 180, Loss: 2.5176162719726562
Epoch 190, Loss: 2.478886365890503
Epoch 200, Loss: 2.4200682640075684
Epoch 210, Loss: 2.3500373363494873
Epoch 220, Loss: 2.31904673576355
Epoch 230, Loss: 2.2613582611083984
Epoch 240, Loss: 2.2600603103637695
Epoch 250, Loss: 2.185352087020874
Epoch 260, Loss: 2.143479347229004


In [None]:
model.eval()
with torch.no_grad():
    outputs, latent = model(dgl_graph, dgl_graph.ndata['feat'])
node_embedding_dict = {node_id: latent[idx].numpy() for idx, node_id in enumerate(list(graph.nodes))}

In [None]:
with open('./node_embedding_dict_test.plk', 'wb') as f:
    pickle.dump(node_embedding_dict, f)