In [1]:
from datasets import UMLSKGDataset
dataset = UMLSKGDataset("/share/project/biomed/hcd/UMLS/processed_data/eng_rel_subset.txt")


In [3]:
print(len(dataset.entities))
print(len(dataset.relations))

81495
193


In [50]:
import pickle
with open('./data/kge/ent2idx.pkl','wb') as handler:
    pickle.dump(dataset.entity_to_id, handler)
    
with open('./data/kge/rel2idx.pkl','wb') as handler:
    pickle.dump(dataset.relation_to_id, handler)   

In [19]:
import torch
from transe import TransE
from torch.utils.data import DataLoader
from tqdm import tqdm

In [35]:
import logging
def setup_logger(name, log_file, level=logging.INFO):
    """To setup as many loggers as you want"""
    formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')

    handler = logging.FileHandler(log_file, mode='a')        
    handler.setFormatter(formatter)

    logger = logging.getLogger(name)
    logger.setLevel(level)
    logger.addHandler(handler)

    consoleHandler = logging.StreamHandler()
    consoleHandler.setFormatter(formatter)

    logger.addHandler(consoleHandler)
    return logger

In [45]:
def train_and_evaluate(model, train_dataloader, valid_dataloader, logger, num_epochs, learning_rate):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)
    device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )
    model.to(device)
    train_losses = []
    val_losses = []
    best_model = None
    best_val = 0
    #print("I am here")
    for epoch in range(num_epochs):
        train_loss = 0.0
        model.train()
        for idx, batch in enumerate(tqdm(train_dataloader, desc="Training")):
            positive_samples, negative_samples = batch
            optimizer.zero_grad()
            loss = model.compute_loss(positive_samples.to(device), negative_samples.to(device))
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        train_loss /= len(train_dataloader)
        valid_loss = evaluate(model, valid_dataloader, device)
        #print("printing log")
        logger.info("Epoch {}, Train Loss: {}, Valid Loss: {}".format(epoch,train_loss,valid_loss))
        # Update learning rate scheduler
        scheduler.step(valid_loss)
        #best_val = valid_loss
        if best_val == 0:
            best_val = valid_loss
        if valid_loss < best_val:
            best_model = {'model': model.state_dict(),
              'optimizer': optimizer.state_dict()}
            torch.save(best_model, './model_ckpts/transE/best_model.pt')
        if (epoch+1)%10 == 0:
            checkpoint = {'model': model.state_dict(),
                'optimizer': optimizer.state_dict()}
            torch.save(checkpoint, './model_ckpts/transE/model_ckpt_'+str(epoch)+'.pt')
    return train_losses, val_losses

def evaluate(model, dataloader, device):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for idx, batch in enumerate(tqdm(dataloader,desc = "Validation")):
            positive_samples, negative_samples = batch
            loss = model.compute_loss(positive_samples.to(device), negative_samples.to(device))
            total_loss += loss.item()
    total_loss /= len(dataloader)
    return total_loss

In [46]:
n_entities = len(dataset.entities)
n_rels = len(dataset.relations)
n_embs = 256
margin = 1
model = TransE(n_entities, n_rels, n_embs, margin)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
logger = setup_logger('TransE_logger', './logs/transE.log')
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=False)
train_loss, val_loss = train_and_evaluate(model, train_dataloader, test_dataloader, logger, num_epochs=50, learning_rate=1e-3)

Training: 100%|██████████| 837/837 [02:39<00:00,  5.25it/s]
Validation: 100%|██████████| 210/210 [00:39<00:00,  5.36it/s]
2023-05-15 21:33:32,217 INFO Epoch 0, Train Loss: 0.45198343822392084, Valid Loss: 0.3522786039681662
2023-05-15 21:33:32,217 INFO Epoch 0, Train Loss: 0.45198343822392084, Valid Loss: 0.3522786039681662
2023-05-15 21:33:32,217 INFO Epoch 0, Train Loss: 0.45198343822392084, Valid Loss: 0.3522786039681662
2023-05-15 21:33:32,217 INFO Epoch 0, Train Loss: 0.45198343822392084, Valid Loss: 0.3522786039681662
Training: 100%|██████████| 837/837 [02:39<00:00,  5.26it/s]
Validation: 100%|██████████| 210/210 [00:39<00:00,  5.36it/s]
2023-05-15 21:36:50,554 INFO Epoch 1, Train Loss: 0.29715290546346024, Valid Loss: 0.3008225204689162
2023-05-15 21:36:50,554 INFO Epoch 1, Train Loss: 0.29715290546346024, Valid Loss: 0.3008225204689162
2023-05-15 21:36:50,554 INFO Epoch 1, Train Loss: 0.29715290546346024, Valid Loss: 0.3008225204689162
2023-05-15 21:36:50,554 INFO Epoch 1, Trai

KeyboardInterrupt: 

In [18]:
import pandas as pd
df = pd.read_csv('./data/UMLS/kg.txt',sep='\t', header= None, names=['h','r','t'])
# Group by 'Column1' and count the occurrences


In [24]:
r_counts = df.groupby('r').transform('size')
# Boolean indexing to filter rows
filtered_df = df[r_counts >= 10000]
# Display the filtered DataFrame
print(len(filtered_df),len(df))

778355 1090990


In [1]:
import json
with open('./data/UMLS/full_cui2def.json','r') as f:
    cui2def = json.load(f)



In [2]:
updated_cui2def = {}
total_defs = 0
total_updated_defs = 0
for cui, def_list in cui2def.items():
    total_defs += len(def_list)
    if cui not in updated_cui2def:
        updated_cui2def[cui] = []
        unique_defs = set() 
        for definition in def_list:
            unique_defs.add(definition.lower())
        for definition in unique_defs:
            updated_cui2def[cui].append(definition)
        total_updated_defs += len(unique_defs)

print(len(updated_cui2def),len(cui2def))
print(total_defs, total_updated_defs)

227408 227408
284363 284106


In [3]:
with open('./data/UMLS/cui2def.json','w') as f:
    json.dump(updated_cui2def, f)



In [4]:
import json
with open('./data/UMLS/cui2def.json','r') as f:
    cui2def = json.load(f)

for idx,(k,v) in enumerate(cui2def.items()):
    print(k,v)
    if idx == 10:
        break


C0000039 ['synthetic phospholipid used in liposomes and lipid bilayers to study biological membranes. it is also a major constituent of pulmonary surfactants.']
C0000052 ['in glycogen or amylopectin synthesis, the enzyme that catalyzes the transfer of a segment of a 1,4-alpha-glucan chain to a primary hydroxy group in a similar glucan chain. ec 2.4.1.18.']
C0000084 ['found in various tissues, particularly in four blood-clotting proteins including prothrombin, in kidney protein, in bone protein, and in the protein present in various ectopic calcifications.']
C0000096 ['a potent cyclic nucleotide phosphodiesterase inhibitor; due to this action, the compound increases cyclic amp and cyclic gmp in tissue and thereby activates cyclic nucleotide-regulated protein kinases']
C0000097 ['a dopaminergic neurotoxic compound which produces irreversible clinical, chemical, and pathological alterations that mimic those found in parkinson disease.', '1-methyl-4-phenyl-1,2,5,6-tetrahydropyridine, a tox

In [8]:
with open('./data/UMLS/cui2syn.json','r') as f:
    cui2syn = json.load(f)


In [9]:
updated_cui2syn = {}
total_syn = 0
total_updated_syn = 0
for cui, syn_list in cui2syn.items():
    total_syn += len(syn_list)
    if cui not in updated_cui2syn:
        updated_cui2syn[cui] = []
        unique_syns = set()
        for syn in syn_list:
            unique_syns.add(syn.lower())
        for syn in unique_syns:
            updated_cui2syn[cui].append(syn)
        total_updated_syn += len(unique_syns)
print(total_syn, total_updated_syn)
print(len(cui2syn),len(updated_cui2syn))
        

9508795 9508795
4548855 4548855


In [10]:
with open('./data/UMLS/cui2syn.json','w') as f:
    json.dump(updated_cui2syn,f)

In [14]:
with open('./data/UMLS/dictionary.txt','w') as f:
    for idx, (cui, defs) in enumerate(updated_cui2def.items()):
        print(idx, end='\r')
        if cui in updated_cui2syn:
            for syn in updated_cui2syn[cui]:
                for definition in defs:
                    
                    f.write(syn + '\t' + definition + '\n')

57

UnicodeEncodeError: 'latin-1' codec can't encode character '\u03b3' in position 0: ordinal not in range(256)