In [62]:
import os
import pickle
import torch
import random
from torch_geometric.loader import DataListLoader as GraphLoader
from torch_geometric.data import Batch
from src.utils.datasets import GDSet, split_dataset
from src.models.transformer import BertForNDP
from src.models.graphtransformer import BertConfig
from sklearn.model_selection import ShuffleSplit
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [2]:
path = './data/'

In [3]:
with open(os.path.join(path, 'pretrained_data_pad1000.pkl'), 'rb') as handle:
    dataset = pickle.load(handle)

In [4]:
train_l = int(len(dataset)*0.70)
val_l = int(len(dataset)*0.10)
test_l = len(dataset) - val_l - train_l
number_output = 100

In [63]:
file_config = {
    'model_path': './results/finetune/', # where to save model
    'file_name': 'log.txt',  # log path
}

global_params = {
    'max_seq_len': 50,
    'month': 1,
    'gradient_accumulation_steps': 1
}

optim_param = {
    'lr': 3e-5,
    'warmup_proportion': 0.1,
    'weight_decay': 0.01
}

train_params = {
    'batch_size': 64,
    'use_cuda': True,
    'max_len_seq': global_params['max_seq_len'],
    'device': "cuda" if torch.cuda.is_available() else "cpu",
    'data_len' : len(dataset),
    'train_data_len' : train_l,
    'val_data_len' :  val_l,
    'test_data_len' : test_l,
    'epochs' : 30,
    'action' : 'train'
}

model_config = {
    'vocab_size': 15322, # number of disease + symbols for word embedding
    'edge_relationship_size': 12, # number of vocab for edge_attr
    'hidden_size': 50*5, # word embedding and seg embedding hidden size
    'seg_vocab_size': 2, # number of vocab for seg embedding
    'age_vocab_size': 151, # number of vocab for age embedding
    'feature_dict':15322,
    'type_vocab_size': 11+1, # number of vocab for type embedding + 1 for mask
    'max_position_embedding': train_params['max_len_seq'], # maximum number of tokens
    'hidden_dropout_prob': 0.2, # dropout rate
    'graph_dropout_prob': 0.2, # dropout rate
    'num_hidden_layers': 6, # number of multi-head attention layers required
    'num_attention_heads': 2, # number of attention heads
    'attention_probs_dropout_prob': 0.2, # multi-head attention dropout rate
    'intermediate_size': 512, # the size of the "intermediate" layer in the transformer encoder
    'hidden_act': 'gelu', # The non-linear activation function in the encoder and the pooler "gelu", 'relu', 'swish' are supported
    'initializer_range': 0.02, # parameter weight initializer range
    'number_output' : number_output,
    'n_layers' : 3,
    'alpha' : 0.1
}

In [64]:
conf = BertConfig(model_config)
behrt = BertForNDP(conf)

behrt = behrt.to(train_params['device'])

#models parameters
transformer_vars = [i for i in behrt.parameters()]
optim_behrt = torch.optim.AdamW(transformer_vars, lr=3e-5)

In [65]:
pretrained_dict = torch.load("./results/weights/GraphTransformer_pretrain_num_0.pch", map_location=train_params['device'])
model_dict = behrt.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {f'bert.embeddings.word_embeddings.{k}': v for k, v in pretrained_dict.items()}
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
behrt.load_state_dict(model_dict)

<All keys matched successfully>

In [66]:
model_dict.keys()

odict_keys(['bert.embeddings.cls_token', 'bert.embeddings.word_embeddings.conv.module_0.lin_key.weight', 'bert.embeddings.word_embeddings.conv.module_0.lin_key.bias', 'bert.embeddings.word_embeddings.conv.module_0.lin_query.weight', 'bert.embeddings.word_embeddings.conv.module_0.lin_query.bias', 'bert.embeddings.word_embeddings.conv.module_0.lin_value.weight', 'bert.embeddings.word_embeddings.conv.module_0.lin_value.bias', 'bert.embeddings.word_embeddings.conv.module_0.layernorm1.weight', 'bert.embeddings.word_embeddings.conv.module_0.layernorm1.bias', 'bert.embeddings.word_embeddings.conv.module_0.layernorm2.weight', 'bert.embeddings.word_embeddings.conv.module_0.layernorm2.bias', 'bert.embeddings.word_embeddings.conv.module_0.proj.weight', 'bert.embeddings.word_embeddings.conv.module_0.proj.bias', 'bert.embeddings.word_embeddings.conv.module_0.ffn.weight', 'bert.embeddings.word_embeddings.conv.module_0.ffn.bias', 'bert.embeddings.word_embeddings.conv.module_0.ffn2.weight', 'bert.embe

In [None]:
trainload = GraphLoader(GDSet(trainDSet), batch_size=train_params['batch_size'], shuffle=False)
valload = GraphLoader(GDSet(valDSet), batch_size=train_params['batch_size'], shuffle=False)


In [None]:
train_loss, val_loss = train(trainload, valload, train_params['device'])