In [1]:
import pandas as pd
import numpy as np
import pickle
from pyhpo import Ontology
from PCL_HPOEncoder import *
import torch
import torch.nn.functional as F
from torch import optim
import random


  from .autonotebook import tqdm as notebook_tqdm


In [2]:

Ontology('./HPO_2025_3_3/')

<pyhpo.ontology.OntologyClass at 0x7f5189122518>

In [3]:

with open('./node_embedding_dict_test.plk', 'rb') as f:
    node_embedding = pickle.load(f)


In [4]:

disease_dict = dict()
disease_list = list(Ontology.omim_diseases)
hps_list = (node_embedding.keys())
for d in disease_list:
    disease_dict[d.id] = [Ontology.get_hpo_object(t).id for t in list(d.hpo) if Ontology.get_hpo_object(t).id in hps_list]

d_count = []
disease_db = []
for i in list(disease_dict.keys()):
    if len(disease_dict[i]) >= 5:
        disease_db.append(i)


In [5]:

input_dim = 256
num_heads = 8
num_layers = 3
hidden_dim = 512
output_dim = 1
max_seq_length = 128

model = PCL_HPOEncoder(input_dim, num_heads, num_layers, hidden_dim, output_dim, max_seq_length)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [6]:
n_s = 2000
max_seq_length = 128
# num_epochs = 10
num_epochs = 200
batch_size = 100
device = 'cuda:3'
model.to(device)
model.train()
inputs_list, mask_list = get_training_sample(disease_db, disease_dict, node_embedding, n_s)
inputs1 = inputs_list[0].to(device)
inputs2 = inputs_list[1].to(device)
masks1 = mask_list[0].to(device)
masks2 = mask_list[1].to(device)


num_batches = n_s // batch_size + (1 if n_s % batch_size != 0 else 0)

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    total_steps = 0
    
    for batch_idx in range(num_batches):
        start_idx = batch_idx * batch_size
        end_idx = min(start_idx + batch_size, n_s)
        
        inputs1_batch = inputs1[start_idx:end_idx]
        inputs2_batch = inputs2[start_idx:end_idx]
        mask1_batch = masks1[start_idx:end_idx]
        mask1_batch = mask1_batch.float()
        mask2_batch = masks2[start_idx:end_idx]
        mask2_batch = mask2_batch.float()
        cls_embedding1, emb1 = model(inputs1_batch, mask1_batch)
        cls_embedding2, emb2 = model(inputs2_batch, mask2_batch)
        
        labels = torch.tensor([1.0 if i == j else 0.0 for i in range(len(inputs1_batch)) for j in range(len(inputs2_batch))]).to(device).view(inputs1_batch.size(0), inputs2_batch.size(0))
        
        loss = info_nce_loss(cls_embedding1, cls_embedding2)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        total_steps += 1
    
    print(f'Epoch {epoch + 1}, Loss: {total_loss / total_steps}')

Epoch 1, Loss: 4.627202296257019
Epoch 2, Loss: 4.240150046348572
Epoch 3, Loss: 3.989004373550415
Epoch 4, Loss: 3.7461729288101195
Epoch 5, Loss: 3.5224256038665773
Epoch 6, Loss: 3.333056557178497
Epoch 7, Loss: 3.1676364183425902
Epoch 8, Loss: 2.9961713194847106
Epoch 9, Loss: 2.8208709239959715
Epoch 10, Loss: 2.6738101482391357
Epoch 11, Loss: 2.51875821352005
Epoch 12, Loss: 2.390128219127655
Epoch 13, Loss: 2.2942910194396973
Epoch 14, Loss: 2.1770077109336854
Epoch 15, Loss: 2.0273988127708433
Epoch 16, Loss: 1.9442080080509185
Epoch 17, Loss: 1.8290502667427062
Epoch 18, Loss: 1.7930613100528716
Epoch 19, Loss: 1.7642766654491424
Epoch 20, Loss: 1.6556639671325684
Epoch 21, Loss: 1.588652390241623
Epoch 22, Loss: 1.518041294813156
Epoch 23, Loss: 1.476681661605835
Epoch 24, Loss: 1.4305603206157684
Epoch 25, Loss: 1.3992465555667877
Epoch 26, Loss: 1.4107238173484802
Epoch 27, Loss: 1.3659060537815093
Epoch 28, Loss: 1.3196948409080504
Epoch 29, Loss: 1.2966501116752625
Epoc

In [7]:

model.to('cpu')
torch.save(model.state_dict(), './transformer_encoder_infoNCE_test.pth')