In [1]:
from ipynb.fs.full.ExtractData import extract_data, split_data, get_data_for_pretraining, get_data_for_GBERT
from ipynb.fs.full.PreTraining import pretrain
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from ipynb.fs.full.FineTuning import fine_tuning
import time

## GBERT with out Ontology Embeddings
One of the claims of *Pre-training of Graph Augmented Transformers for Medication Recommendation* was that the Ontology Embeddings increase the effectiveness of the graph. To test this claim, this file is training the model without the ontology embedding and the resulting jaccard similarity score, average F1, and precision recall AUC will be compared to the model trained with the ontology embeddings.   

In [2]:
data = extract_data()
train_ids, test_ids, eval_ids = split_data(data)

pretrain_data, pretrain_eval_data = get_data_for_pretraining(data, train_ids, eval_ids)
pretrain_dataloader = DataLoader(pretrain_data, sampler=RandomSampler(pretrain_data), batch_size=4)
pretrain_eval_dataloader = DataLoader(pretrain_eval_data, sampler=SequentialSampler(pretrain_eval_data), batch_size=4)

train_data, eval_data, test_data = get_data_for_GBERT(data, train_ids, test_ids, eval_ids)
train_dataloader = DataLoader(train_data, sampler=RandomSampler(train_data), batch_size=1)
eval_dataloader = DataLoader(eval_data, sampler=SequentialSampler(eval_data), batch_size=1)
test_dataloader = DataLoader(test_data, sampler=SequentialSampler(test_data), batch_size=1)

Generating samples for single_visit_fn: 100%|█| 1000/1000 [00:00<00:00, 33469.55
Generating samples for multi_visit_fn: 100%|█| 1000/1000 [00:00<00:00, 3343.47it


In [3]:
start_time = time.time()
pretrain(data, pretrain_dataloader, pretrain_eval_dataloader, "gbert-with-out-ontology.bin", usePretrainedModel=False, useGraph=False)
print("Pretrain 1 took ", (time.time() - start_time), " seconds")
print("**********************************************************************************")
start_time = time.time()
fine_tuning(data, train_dataloader, eval_dataloader, test_dataloader, "gbert-with-out-ontology.bin", useGraph=False)
print("Fine Tuning 1 took ", (time.time() - start_time), " seconds")

for i in range(14):
    print("**********************************************************************************")
    start_time = time.time()
    pretrain(data, pretrain_dataloader, pretrain_eval_dataloader, "gbert-with-out-ontology.bin", usePretrainedModel=True, useGraph=False)
    print("Pretrain ", i+2, " took ", (time.time() - start_time), " seconds")
    print("**********************************************************************************")
    start_time = time.time()
    fine_tuning(data, train_dataloader, eval_dataloader, test_dataloader, "gbert-with-out-ontology.bin", useGraph=False)
    print("Fine Tuning ", i+2, " took ", (time.time() - start_time), " seconds")

***** Running Pre-training *****
***** Running training *****
train/loss: 0.6644576866375772  epoch:  1
***** Running eval *****
eval_dx2dx/ jaccard :  0.0  epoch:  1
eval_dx2dx/ f1 :  0.0  epoch:  1
eval_dx2dx/ prauc :  0.10916817263701889  epoch:  1
eval_rx2dx/ jaccard :  0.0  epoch:  1
eval_rx2dx/ f1 :  0.0  epoch:  1
eval_rx2dx/ prauc :  0.10623206074252775  epoch:  1
eval_dx2rx/ jaccard :  0.3359731530715006  epoch:  1
eval_dx2rx/ f1 :  0.49642426493046643  epoch:  1
eval_dx2rx/ prauc :  0.605191267349722  epoch:  1
eval_rx2rx/ jaccard :  0.3514345495866209  epoch:  1
eval_rx2rx/ f1 :  0.5133133964002615  epoch:  1
eval_rx2rx/ prauc :  0.6020259717438728  epoch:  1
***** Running training *****
train/loss: 0.5408142971935455  epoch:  2
***** Running eval *****
eval_dx2dx/ jaccard :  0.014454403403865721  epoch:  2
eval_dx2dx/ f1 :  0.025129426922799247  epoch:  2
eval_dx2dx/ prauc :  0.11076209013476683  epoch:  2
eval_rx2dx/ jaccard :  0.016758635740117222  epoch:  2
eval_rx2dx/ f



train/loss: 0.34007909893989563  epoch:  1
***** Running eval *****
eval/ jaccard :  0.3667756156934033  epoch:  1
eval/ f1 :  0.5315524756968472  epoch:  1
eval/ prauc :  0.6904598732539874  epoch:  1
***** Running training *****
train/loss: 0.2977063413709402  epoch:  2
***** Running eval *****
eval/ jaccard :  0.3901745756769321  epoch:  2
eval/ f1 :  0.555143204597405  epoch:  2
eval/ prauc :  0.6822517410860742  epoch:  2
***** Running training *****
train/loss: 0.3014288434589451  epoch:  3
***** Running eval *****
eval/ jaccard :  0.32038435768393264  epoch:  3
eval/ f1 :  0.47864501716160057  epoch:  3
eval/ prauc :  0.6608648131081851  epoch:  3
***** Running training *****
train/loss: 0.29081322422081773  epoch:  4
***** Running eval *****
eval/ jaccard :  0.35023302697299097  epoch:  4
eval/ f1 :  0.510194363653986  epoch:  4
eval/ prauc :  0.6799148481625755  epoch:  4
***** Running training *****
train/loss: 0.29095925407653506  epoch:  5
***** Running eval *****
eval/ jac

train/loss: 0.4757234454154968  epoch:  4
***** Running eval *****
eval_dx2dx/ jaccard :  0.027853998146678313  epoch:  4
eval_dx2dx/ f1 :  0.04651859488964612  epoch:  4
eval_dx2dx/ prauc :  0.14375184942233984  epoch:  4
eval_rx2dx/ jaccard :  0.018722417742025585  epoch:  4
eval_rx2dx/ f1 :  0.02967862041936116  epoch:  4
eval_rx2dx/ prauc :  0.10788920791557217  epoch:  4
eval_dx2rx/ jaccard :  0.32979859431840636  epoch:  4
eval_dx2rx/ f1 :  0.4807428199387888  epoch:  4
eval_dx2rx/ prauc :  0.6045403552666401  epoch:  4
eval_rx2rx/ jaccard :  0.3774816131662093  epoch:  4
eval_rx2rx/ f1 :  0.5324307425574503  epoch:  4
eval_rx2rx/ prauc :  0.6436989021130791  epoch:  4
***** Running training *****
train/loss: 0.4725539769852561  epoch:  5
***** Running eval *****
eval_dx2dx/ jaccard :  0.028073803148437274  epoch:  5
eval_dx2dx/ f1 :  0.04851358001008666  epoch:  5
eval_dx2dx/ prauc :  0.12382982369965766  epoch:  5
eval_rx2dx/ jaccard :  0.016192535174016656  epoch:  5
eval_rx2d

train/loss: 0.48640580510979065  epoch:  2
***** Running eval *****
eval_dx2dx/ jaccard :  0.03262956683435986  epoch:  2
eval_dx2dx/ f1 :  0.05278713437930612  epoch:  2
eval_dx2dx/ prauc :  0.12244573073878667  epoch:  2
eval_rx2dx/ jaccard :  0.014417052325549058  epoch:  2
eval_rx2dx/ f1 :  0.024115789977303667  epoch:  2
eval_rx2dx/ prauc :  0.11445221534327808  epoch:  2
eval_dx2rx/ jaccard :  0.3578457703963484  epoch:  2
eval_dx2rx/ f1 :  0.5123175259083329  epoch:  2
eval_dx2rx/ prauc :  0.6153415643449849  epoch:  2
eval_rx2rx/ jaccard :  0.3963684837888209  epoch:  2
eval_rx2rx/ f1 :  0.5520926781157778  epoch:  2
eval_rx2rx/ prauc :  0.6411813495343569  epoch:  2
***** Running training *****
train/loss: 0.4768701491601159  epoch:  3
***** Running eval *****
eval_dx2dx/ jaccard :  0.03357944067820611  epoch:  3
eval_dx2dx/ f1 :  0.054269427417575566  epoch:  3
eval_dx2dx/ prauc :  0.12346947821395941  epoch:  3
eval_rx2dx/ jaccard :  0.02521801585417925  epoch:  3
eval_rx2dx

train/loss: 0.2847511551596902  epoch:  2
***** Running eval *****
eval/ jaccard :  0.3867947013982644  epoch:  2
eval/ f1 :  0.5497733239933503  epoch:  2
eval/ prauc :  0.6790388680538598  epoch:  2
***** Running training *****
train/loss: 0.28536796011030674  epoch:  3
***** Running eval *****
eval/ jaccard :  0.37462658749548405  epoch:  3
eval/ f1 :  0.536283739439364  epoch:  3
eval/ prauc :  0.6600875459312431  epoch:  3
***** Running training *****
train/loss: 0.2713024098087441  epoch:  4
***** Running eval *****
eval/ jaccard :  0.3500181416829865  epoch:  4
eval/ f1 :  0.5067416672250825  epoch:  4
eval/ prauc :  0.6347854396482842  epoch:  4
***** Running training *****
train/loss: 0.25973210060461  epoch:  5
***** Running eval *****
eval/ jaccard :  0.3978327686140678  epoch:  5
eval/ f1 :  0.5610864488953142  epoch:  5
eval/ prauc :  0.6709520717028287  epoch:  5
Fine Tuning  6  took  44.45882296562195  seconds
*************************************************************

train/loss: 0.4624091895145663  epoch:  4
***** Running eval *****
eval_dx2dx/ jaccard :  0.03450257638880171  epoch:  4
eval_dx2dx/ f1 :  0.05664783654518577  epoch:  4
eval_dx2dx/ prauc :  0.11896448991982604  epoch:  4
eval_rx2dx/ jaccard :  0.026497796229596998  epoch:  4
eval_rx2dx/ f1 :  0.04243020575159166  epoch:  4
eval_rx2dx/ prauc :  0.1081189968398806  epoch:  4
eval_dx2rx/ jaccard :  0.35902012885740753  epoch:  4
eval_dx2rx/ f1 :  0.5171616585506601  epoch:  4
eval_dx2rx/ prauc :  0.6017846878921324  epoch:  4
eval_rx2rx/ jaccard :  0.4300397474395693  epoch:  4
eval_rx2rx/ f1 :  0.5900955591678193  epoch:  4
eval_rx2rx/ prauc :  0.6567679487189442  epoch:  4
***** Running training *****
train/loss: 0.45432719103457253  epoch:  5
***** Running eval *****
eval_dx2dx/ jaccard :  0.029664043328617343  epoch:  5
eval_dx2dx/ f1 :  0.04835685576426317  epoch:  5
eval_dx2dx/ prauc :  0.12396700981474809  epoch:  5
eval_rx2dx/ jaccard :  0.030156648087404935  epoch:  5
eval_rx2dx

train/loss: 0.4680543690919876  epoch:  2
***** Running eval *****
eval_dx2dx/ jaccard :  0.047541060718105876  epoch:  2
eval_dx2dx/ f1 :  0.07862341627589547  epoch:  2
eval_dx2dx/ prauc :  0.125112846118618  epoch:  2
eval_rx2dx/ jaccard :  0.03058155243594387  epoch:  2
eval_rx2dx/ f1 :  0.04973877104690783  epoch:  2
eval_rx2dx/ prauc :  0.11752707799673358  epoch:  2
eval_dx2rx/ jaccard :  0.35190776474748203  epoch:  2
eval_dx2rx/ f1 :  0.5100986819173151  epoch:  2
eval_dx2rx/ prauc :  0.6156518481929062  epoch:  2
eval_rx2rx/ jaccard :  0.36139910846382745  epoch:  2
eval_rx2rx/ f1 :  0.5142397698701373  epoch:  2
eval_rx2rx/ prauc :  0.6405458949244621  epoch:  2
***** Running training *****
train/loss: 0.46093086762861774  epoch:  3
***** Running eval *****
eval_dx2dx/ jaccard :  0.04518383841587952  epoch:  3
eval_dx2dx/ f1 :  0.07437241757847192  epoch:  3
eval_dx2dx/ prauc :  0.12221616146841988  epoch:  3
eval_rx2dx/ jaccard :  0.035163586312823784  epoch:  3
eval_rx2dx/

train/loss: 0.27525381252846937  epoch:  2
***** Running eval *****
eval/ jaccard :  0.3888167730530365  epoch:  2
eval/ f1 :  0.5506436388086864  epoch:  2
eval/ prauc :  0.6568424852540053  epoch:  2
***** Running training *****
train/loss: 0.2625687499107285  epoch:  3
***** Running eval *****
eval/ jaccard :  0.3968589301393899  epoch:  3
eval/ f1 :  0.5615704853256019  epoch:  3
eval/ prauc :  0.673578205434532  epoch:  3
***** Running training *****
train/loss: 0.25408981511877343  epoch:  4
***** Running eval *****
eval/ jaccard :  0.3507896819137908  epoch:  4
eval/ f1 :  0.5118853955388978  epoch:  4
eval/ prauc :  0.6387373906273625  epoch:  4
***** Running training *****
train/loss: 0.2583286216990514  epoch:  5
***** Running eval *****
eval/ jaccard :  0.38973103989644386  epoch:  5
eval/ f1 :  0.5534900512796176  epoch:  5
eval/ prauc :  0.6915062235298385  epoch:  5
Fine Tuning  11  took  51.82049608230591  seconds
*********************************************************

train/loss: 0.43876589834690094  epoch:  4
***** Running eval *****
eval_dx2dx/ jaccard :  0.03969878781325612  epoch:  4
eval_dx2dx/ f1 :  0.06542913815026492  epoch:  4
eval_dx2dx/ prauc :  0.1395491977630132  epoch:  4
eval_rx2dx/ jaccard :  0.023879053045719714  epoch:  4
eval_rx2dx/ f1 :  0.03918831256865886  epoch:  4
eval_rx2dx/ prauc :  0.10392718946052913  epoch:  4
eval_dx2rx/ jaccard :  0.3711844113563118  epoch:  4
eval_dx2rx/ f1 :  0.5295843593033907  epoch:  4
eval_dx2rx/ prauc :  0.6407991021421511  epoch:  4
eval_rx2rx/ jaccard :  0.3584924897347817  epoch:  4
eval_rx2rx/ f1 :  0.5010775750411679  epoch:  4
eval_rx2rx/ prauc :  0.6647940966713516  epoch:  4
***** Running training *****
train/loss: 0.4351756019455394  epoch:  5
***** Running eval *****
eval_dx2dx/ jaccard :  0.028387533309747557  epoch:  5
eval_dx2dx/ f1 :  0.04786871418278368  epoch:  5
eval_dx2dx/ prauc :  0.11765672380551918  epoch:  5
eval_rx2dx/ jaccard :  0.022805015629089702  epoch:  5
eval_rx2dx/

train/loss: 0.4475806798946344  epoch:  2
***** Running eval *****
eval_dx2dx/ jaccard :  0.05115907748722975  epoch:  2
eval_dx2dx/ f1 :  0.09062488184943687  epoch:  2
eval_dx2dx/ prauc :  0.13788618912859568  epoch:  2
eval_rx2dx/ jaccard :  0.04189474196502097  epoch:  2
eval_rx2dx/ f1 :  0.06932140160840998  epoch:  2
eval_rx2dx/ prauc :  0.1203952076780775  epoch:  2
eval_dx2rx/ jaccard :  0.3509381042432583  epoch:  2
eval_dx2rx/ f1 :  0.5082738259523663  epoch:  2
eval_dx2rx/ prauc :  0.6003237902370371  epoch:  2
eval_rx2rx/ jaccard :  0.4136189762602417  epoch:  2
eval_rx2rx/ f1 :  0.5696973029705219  epoch:  2
eval_rx2rx/ prauc :  0.6432612116454799  epoch:  2
***** Running training *****
train/loss: 0.4346271086680262  epoch:  3
***** Running eval *****
eval_dx2dx/ jaccard :  0.039683941380631166  epoch:  3
eval_dx2dx/ f1 :  0.06702107574835348  epoch:  3
eval_dx2dx/ prauc :  0.13622499858620746  epoch:  3
eval_rx2dx/ jaccard :  0.028150506799744276  epoch:  3
eval_rx2dx/ f