In [1]:
import torch
from ipynb.fs.full.Bert import BertModel, LayerNorm
from torch import nn
import torch.nn.functional as F
from torch.optim import Adam
from sklearn.metrics import roc_auc_score, f1_score, average_precision_score
import numpy as np
import os

## Pre-Training
The pre-training is a modified version of the usual pretraining tasks. Instead of the masked language model task and next sentence prediction task, G-BERT uses a self-prediction task and a dual-prediction task. This code is primarily taken straight from the G-Bert Github repository.

In [2]:
def t2n(x):
    return x.detach().cpu().numpy()

def multi_label_metric(y_gt, y_pred, y_prob):

    def jaccard(y_gt, y_pred):
        score = []
        for b in range(y_gt.shape[0]):
            target = np.where(y_gt[b] == 1)[0]
            out_list = np.where(y_pred[b] == 1)[0]
            inter = set(out_list) & set(target)
            union = set(out_list) | set(target)
            jaccard_score = 0 if union == 0 else len(inter) / len(union)
            score.append(jaccard_score)
        return np.mean(score)

    def average_prc(y_gt, y_pred):
        score = []
        for b in range(y_gt.shape[0]):
            target = np.where(y_gt[b] == 1)[0]
            out_list = np.where(y_pred[b] == 1)[0]
            inter = set(out_list) & set(target)
            prc_score = 0 if len(out_list) == 0 else len(inter) / len(out_list)
            score.append(prc_score)
        return score

    def average_recall(y_gt, y_pred):
        score = []
        for b in range(y_gt.shape[0]):
            target = np.where(y_gt[b] == 1)[0]
            out_list = np.where(y_pred[b] == 1)[0]
            inter = set(out_list) & set(target)
            recall_score = 0 if len(target) == 0 else len(inter) / len(target)
            score.append(recall_score)
        return score

    def average_f1(average_prc, average_recall):
        score = []
        for idx in range(len(average_prc)):
            if average_prc[idx] + average_recall[idx] == 0:
                score.append(0)
            else:
                score.append(
                    2*average_prc[idx]*average_recall[idx] / (average_prc[idx] + average_recall[idx]))
        return score

    def f1(y_gt, y_pred):
        all_micro = []
        for b in range(y_gt.shape[0]):
            all_micro.append(f1_score(y_gt[b], y_pred[b], average='macro'))
        return np.mean(all_micro)

    def roc_auc(y_gt, y_prob):
        all_micro = []
        for b in range(len(y_gt)):
            all_micro.append(roc_auc_score(
                y_gt[b], y_prob[b], average='macro'))
        return np.mean(all_micro)

    def precision_auc(y_gt, y_prob):
        all_micro = []
        for b in range(len(y_gt)):
            all_micro.append(average_precision_score(
                y_gt[b], y_prob[b], average='macro'))
        return np.mean(all_micro)

    def precision_at_k(y_gt, y_prob, k=3):
        precision = 0
        sort_index = np.argsort(y_prob, axis=-1)[:, ::-1][:, :k]
        for i in range(len(y_gt)):
            TP = 0
            for j in range(len(sort_index[i])):
                if y_gt[i, sort_index[i, j]] == 1:
                    TP += 1
            precision += TP / len(sort_index[i])
        return precision / len(y_gt)

    auc = roc_auc(y_gt, y_prob)
    p_1 = precision_at_k(y_gt, y_prob, k=1)
    p_3 = precision_at_k(y_gt, y_prob, k=3)
    p_5 = precision_at_k(y_gt, y_prob, k=5)
    f1 = f1(y_gt, y_pred)
    prauc = precision_auc(y_gt, y_prob)
    ja = jaccard(y_gt, y_pred)
    avg_prc = average_prc(y_gt, y_pred)
    avg_recall = average_recall(y_gt, y_pred)
    avg_f1 = average_f1(avg_prc, avg_recall)

    return ja, prauc, np.mean(avg_prc), np.mean(avg_recall), np.mean(avg_f1)


def metric_report(y_pred, y_true, therhold=0.5):
    y_prob = y_pred.copy()
    y_pred[y_pred > therhold] = 1
    y_pred[y_pred <= therhold] = 0

    acc_container = {}
    ja, prauc, avg_p, avg_r, avg_f1 = multi_label_metric(
        y_true, y_pred, y_prob)
    acc_container['jaccard'] = ja
    acc_container['f1'] = avg_f1
    acc_container['prauc'] = prauc

    return acc_container

In [3]:
class BERT_Pretrain(nn.Module):
    def __init__(self, data, hidden_size, dropout_prob, useGraph):
        super(BERT_Pretrain, self).__init__()
        self.all_conditions_size = len(data["all_conditions"])
        self.all_drugs_size = len(data["all_drugs"])

        self.bert = BertModel(len(data["vocab"]), hidden_size, dropout_prob, useGraph, data["all_conditions"], data["all_drugs"])
        
        self.cls_conditions_1 = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, self.all_conditions_size))
        
        self.cls_drugs_1 = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, self.all_conditions_size))
        
        self.cls_conditions_2 = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, self.all_drugs_size))
        
        self.cls_drugs_2 = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, self.all_drugs_size))
        
        self.apply(self.init_bert_weights)
        
    def init_bert_weights(self, module):
        '''
        Taken from https://github.com/huggingface/transformers/blob/78b7debf56efb907c6af767882162050d4fbb294/src/transformers/modeling_utils.py#L1596
        '''
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=0.02)
        elif isinstance(module, LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()
            
    def forward(self, inputs, conditions_labels=None, drugs_labels=None):
        # inputs (B, 2, max_len)
        # bert_pool (B, hidden)
        _, conditions_bert_pool = self.bert(inputs[:, 0, :], torch.zeros(
            (inputs.size(0), inputs.size(2))).long().to(inputs.device))
        _, drugs_bert_pool = self.bert(inputs[:, 1, :], torch.zeros(
            (inputs.size(0), inputs.size(2))).long().to(inputs.device))

        conditions2condition = self.cls_conditions_1(conditions_bert_pool)
        drug2condition = self.cls_drugs_1(drugs_bert_pool)
        condition2drug = self.cls_conditions_2(conditions_bert_pool)
        drug2drug = self.cls_drugs_2(drugs_bert_pool)
        
        # output logits
        if drugs_labels is None or conditions_labels is None:
            return torch.sigmoid(conditions2condition), torch.sigmoid(drug2condition), torch.sigmoid(condition2drug), torch.sigmoid(drug2drug)
        else:
            loss = F.binary_cross_entropy_with_logits(conditions2condition, conditions_labels) + \
                F.binary_cross_entropy_with_logits(drug2condition, conditions_labels) + \
                F.binary_cross_entropy_with_logits(condition2drug, drugs_labels) + \
                F.binary_cross_entropy_with_logits(drug2drug, drugs_labels)
                
            return loss, torch.sigmoid(conditions2condition), torch.sigmoid(drug2condition), torch.sigmoid(condition2drug), torch.sigmoid(drug2drug)
        
        
    def from_pretrained(data, useGraph, outputFileName):
        # Instantiate model.
        model = BERT_Pretrain(data, 300, 0.4, useGraph)
        
        weights_path = os.path.join("", outputFileName)
        state_dict = torch.load(weights_path)

        old_keys = []
        new_keys = []
        for key in state_dict.keys():
            new_key = None
            if 'gamma' in key:
                new_key = key.replace('gamma', 'weight')
            if 'beta' in key:
                new_key = key.replace('beta', 'bias')
            if new_key:
                old_keys.append(key)
                new_keys.append(new_key)
        for old_key, new_key in zip(old_keys, new_keys):
            state_dict[new_key] = state_dict.pop(old_key)

        missing_keys = []
        unexpected_keys = []
        error_msgs = []
        # copy state_dict so _load_from_state_dict can modify it
        metadata = getattr(state_dict, '_metadata', None)
        state_dict = state_dict.copy()
        if metadata is not None:
            state_dict._metadata = metadata

        def load(module, prefix=''):
            local_metadata = {} if metadata is None else metadata.get(
                prefix[:-1], {})
            module._load_from_state_dict(
                state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
            for name, child in module._modules.items():
                if child is not None:
                    load(child, prefix + name + '.')

        load(model, prefix='' if hasattr(model, 'bert') else 'bert.')
        return model
        

In [4]:
def pretrain(data, train_dataloader, eval_dataloader, outputFileName, usePretrainedModel, useGraph):
    print("***** Running Pre-training *****")
    device = torch.device("cuda" if torch.cuda.is_available()
                              else "cpu")

    if usePretrainedModel:
        model = BERT_Pretrain.from_pretrained(data, useGraph, outputFileName)
    else:
        model = BERT_Pretrain(data, 300, 0.4, useGraph)
    
    model.to(device)
    
    model_to_save = model.module if hasattr(
        model, 'module') else model  # Only save the model it-self
    dx_output_model_file = os.path.join(
        '', outputFileName)
    
    optimizer = Adam(model.parameters(), lr=5e-4)
    
    dx_acc_best, rx_acc_best = 0, 0
    acc_name = 'prauc'
        
    global_step = 0
    for _ in range(5):
        print("***** Running training *****")
        tr_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0
        model.train()
        for batch in train_dataloader:
            batch = tuple(t.to(device) for t in batch)
            input_ids, dx_labels, rx_labels = batch
            loss, dx2dx, rx2dx, dx2rx, rx2rx = model(input_ids, dx_labels, rx_labels)
            
            loss.backward()

            tr_loss += loss.item()
            nb_tr_examples += 1
            nb_tr_steps += 1

            optimizer.step()
            optimizer.zero_grad()

        global_step += 1
        print('train/loss:', tr_loss / nb_tr_steps, " epoch: ", global_step)
        
        print("***** Running eval *****")
        model.eval()
        dx2dx_y_preds = []
        rx2dx_y_preds = []
        dx_y_trues = []

        dx2rx_y_preds = []
        rx2rx_y_preds = []
        rx_y_trues = []
        for batch in eval_dataloader:
            batch = tuple(t.to(device) for t in batch)
            input_ids, dx_labels, rx_labels = batch
            with torch.no_grad():
                dx2dx, rx2dx, dx2rx, rx2rx = model(input_ids)
                dx2dx_y_preds.append(t2n(dx2dx))
                rx2dx_y_preds.append(t2n(rx2dx))
                dx2rx_y_preds.append(t2n(dx2rx))
                rx2rx_y_preds.append(t2n(rx2rx))

                dx_y_trues.append(
                    t2n(dx_labels))
                rx_y_trues.append(
                    t2n(rx_labels))

        dx2dx_acc_container = metric_report(
            np.concatenate(dx2dx_y_preds, axis=0), np.concatenate(dx_y_trues, axis=0))
       
        rx2dx_acc_container = metric_report(
            np.concatenate(rx2dx_y_preds, axis=0), np.concatenate(dx_y_trues, axis=0))
       
        dx2rx_acc_container = metric_report(
            np.concatenate(dx2rx_y_preds, axis=0), np.concatenate(rx_y_trues, axis=0))
        
        rx2rx_acc_container = metric_report(
            np.concatenate(rx2rx_y_preds, axis=0), np.concatenate(rx_y_trues, axis=0))

        for k, v in dx2dx_acc_container.items():
            print('eval_dx2dx/', k, ": ", v, " epoch: ", global_step)
        for k, v in rx2dx_acc_container.items():
            print('eval_rx2dx/', k, ": ", v, " epoch: ", global_step)
        for k, v in dx2rx_acc_container.items():
            print('eval_dx2rx/', k, ": ", v, " epoch: ", global_step)
        for k, v in rx2rx_acc_container.items():
            print('eval_rx2rx/', k, ": ", v, " epoch: ", global_step)

        if rx2rx_acc_container[acc_name] > dx_acc_best:
            dx_acc_best = rx2rx_acc_container[acc_name]
            # save model
            torch.save(model_to_save.state_dict(),
                       dx_output_model_file)
        