In [1]:
import os
import pickle
import torch
import math
import time
import itertools
from torch_geometric.loader import DataListLoader as GraphLoader
from torch_geometric.data import Batch
from src.utils.datasets import GDSet, split_dataset
from src.models.transformer import BertForNDP
from src.models.graphtransformer import BertConfig
from sklearn.model_selection import ShuffleSplit
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [2]:
path = './data/'

In [3]:
with open(os.path.join(path, 'pretrained_data_pad_v2_1000.pkl'), 'rb') as handle:
    dataset = pickle.load(handle)

In [39]:
file_config = {
    'model_path': './results/finetune/', # where to save model
    'file_name': 'log.txt',  # log path
}

global_params = {
    'max_seq_len': 50,
    'month': 1,
    'gradient_accumulation_steps': 1
}

optim_param = {
    'lr': 3e-5,
    'warmup_proportion': 0.1,
    'weight_decay': 0.01
}

train_params = {
    'batch_size': 64,
    'use_cuda': False,
    'max_len_seq': global_params['max_seq_len'],
    # 'device': "cuda" if torch.cuda.is_available() else "cpu",
    'device': "cpu",
    'data_len' : len(dataset),
    'val_split' : 0.1,
    'test_split' : 0.2,
    'epochs' : 30,
    'action' : 'train'
}

model_config = {
    'vocab_size': 15322, # number of disease + symbols for word embedding
    'edge_relationship_size': 8, # number of vocab for edge_attr
    'hidden_size': 108*5, # word embedding and seg embedding hidden size
    'seg_vocab_size': 2, # number of vocab for seg embedding
    'age_vocab_size': 151, # number of vocab for age embedding
    'feature_dict':15322,
    'type_vocab_size': 11+1, # number of vocab for type embedding + 1 for mask
    'max_position_embedding': train_params['max_len_seq'], # maximum number of tokens
    'hidden_dropout_prob': 0.2, # dropout rate
    'graph_dropout_prob': 0.2, # dropout rate
    'num_hidden_layers': 6, # number of multi-head attention layers required
    'num_attention_heads': 2, # number of attention heads
    'attention_probs_dropout_prob': 0.2, # multi-head attention dropout rate
    'intermediate_size': 512, # the size of the "intermediate" layer in the transformer encoder
    'hidden_act': 'gelu', # The non-linear activation function in the encoder and the pooler "gelu", 'relu', 'swish' are supported
    'initializer_range': 0.02, # parameter weight initializer range
    'number_output' : 100,
    'n_layers' : 3,
    'alpha' : 0.1
}

In [40]:
total_p = train_params | model_config

In [41]:
total_p

{'batch_size': 64,
 'use_cuda': False,
 'max_len_seq': 50,
 'device': 'cpu',
 'data_len': 1001,
 'val_split': 0.1,
 'test_split': 0.2,
 'epochs': 30,
 'action': 'train',
 'vocab_size': 15322,
 'edge_relationship_size': 8,
 'hidden_size': 540,
 'seg_vocab_size': 2,
 'age_vocab_size': 151,
 'feature_dict': 15322,
 'type_vocab_size': 12,
 'max_position_embedding': 50,
 'hidden_dropout_prob': 0.2,
 'graph_dropout_prob': 0.2,
 'num_hidden_layers': 6,
 'num_attention_heads': 2,
 'attention_probs_dropout_prob': 0.2,
 'intermediate_size': 512,
 'hidden_act': 'gelu',
 'initializer_range': 0.02,
 'number_output': 100,
 'n_layers': 3,
 'alpha': 0.1}

In [33]:
import pickle

with open('./data/dict/dic_global_v2.pkl', 'rb') as f:
    dic_global = pickle.load(f)

In [34]:
len(dic_global)

14443

In [35]:
conf = BertConfig(model_config)
behrt = BertForNDP(conf)

behrt = behrt.to(train_params['device'])

#models parameters
transformer_vars = [i for i in behrt.parameters()]
optim_behrt = torch.optim.AdamW(transformer_vars, lr=3e-5)

In [36]:
pretrained_dict = torch.load("./results/weights/GraphTransformer_pretrain_num_0_v2.pch", map_location=train_params['device'])
model_dict = behrt.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {f'bert.embeddings.word_embeddings.{k}': v for k, v in pretrained_dict.items()}
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
behrt.load_state_dict(model_dict)

RuntimeError: Error(s) in loading state_dict for BertForNDP:
	size mismatch for bert.embeddings.word_embeddings.embed.weight: copying a param with shape torch.Size([15322, 108]) from checkpoint, the shape in current model is torch.Size([14443, 108]).

In [26]:
trainDSet, valDSet, testDSet = split_dataset(dataset, train_params, random_seed=0)
trainload = GraphLoader(GDSet(trainDSet), batch_size=train_params['batch_size'], shuffle=False)
valload = GraphLoader(GDSet(valDSet), batch_size=train_params['batch_size'], shuffle=False)
valload = GraphLoader(GDSet(valDSet), batch_size=train_params['batch_size'], shuffle=False)

In [27]:
def save_model(_model_dict, file_name):
    torch.save(_model_dict, file_name)

In [28]:
def run_epoch(e, trainload, device):
    tr_loss = 0
    start = time.time()
    behrt.train()
    for step, data in enumerate(trainload):
        optim_behrt.zero_grad()
        batched_data = Batch()
        data_x, data_y = zip(*data)
        labels = torch.stack(data_y)
        graph_batch = batched_data.from_data_list(list(itertools.chain.from_iterable(data_x)))
        graph_batch = graph_batch.to(device)
        nodes = graph_batch.x
        edge_index = graph_batch.edge_index
        edge_attr = graph_batch.edge_attr
        batch = graph_batch.batch
        age_ids = torch.reshape(graph_batch.age, [graph_batch.age.shape[0] // 50, 50])
        time_ids = torch.reshape(graph_batch.time, [graph_batch.time.shape[0] // 50, 50])
        type_ids = torch.reshape(graph_batch.type, [graph_batch.type.shape[0] // 50, 50])
        posi_ids = torch.reshape(graph_batch.posi_ids, [graph_batch.posi_ids.shape[0] // 50, 50])
                
        loss, logits = behrt(nodes, edge_index, edge_attr, batch, \
                             age_ids, time_ids, type_ids, posi_ids, labels)
        loss.backward()
        tr_loss += loss.item()
        if step % 500 == 0:
            print(loss.item())
        optim_behrt.step()
        del loss
    cost = time.time() - start
    return tr_loss, cost

In [29]:
def train(trainload, valload, device):
    best_val = math.inf
    for e in range(train_params["epochs"]):
        print("Epoch n" + str(e))
        train_loss, train_time_cost = run_epoch(e, trainload, device)
        val_loss, val_time_cost,pred, label, mask = eval(valload, False, device)
        train_loss = (train_loss * train_params['batch_size']) / len(trainload)
        val_loss = (val_loss * train_params['batch_size']) / len(valload)
        print('TRAIN {}\t{} secs\n'.format(train_loss, train_time_cost))
        print("Epoch n" + str(e) + '\n TRAIN {}\t{} secs\n'.format(train_loss, train_time_cost))
        print('EVAL {}\t{} secs\n'.format(val_loss, val_time_cost) + '\n\n\n')
        print('EVAL {}\t{} secs\n'.format(val_loss, val_time_cost))
        if val_loss < best_val:
            print("** ** * Saving fine - tuned model ** ** * ")
            model_to_save = behrt.module if hasattr(behrt, 'module') else behrt
            save_model(model_to_save.state_dict(), './finetune/GraphTransformer_finetune_num_test.pch')
            best_val = val_loss
    return train_loss, val_loss


In [30]:
def eval(_valload, device):
    val_loss = 0
    start = time.time()
    behrt.eval()
    for step, data in enumerate(_valload):
        optim_behrt.zero_grad()
        batched_data = Batch()
        data_x, data_y = zip(*data)
        labels = torch.stack(data_y).to(device)
        graph_batch = batched_data.from_data_list(list(itertools.chain.from_iterable(data_x)))
        graph_batch = graph_batch.to(device)
        nodes = graph_batch.x
        edge_index = graph_batch.edge_index
        edge_attr = graph_batch.edge_attr
        batch = graph_batch.batch
        age_ids = torch.reshape(graph_batch.age, [graph_batch.age.shape[0] // 50, 50])
        time_ids = torch.reshape(graph_batch.time, [graph_batch.time.shape[0] // 50, 50])
        type_ids = torch.reshape(graph_batch.type, [graph_batch.type.shape[0] // 50, 50])
        posi_ids = torch.reshape(graph_batch.posi_ids, [graph_batch.posi_ids.shape[0] // 50, 50])
        loss, logits = behrt(nodes, edge_index, edge_attr, batch, \
                             age_ids, time_ids, type_ids,posi_ids, data_y)

        val_loss += loss.item()
        del loss

    print("TOTAL LOSS", (val_loss * train_params['batch_size']) / len(_valload))

    cost = time.time() - start
    return val_loss, cost, logits, labels

In [31]:
train_loss, val_loss = train(trainload, valload, train_params['device'])

Epoch n0


IndexError: index out of range in self