In [2]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import Adam
from tqdm import tqdm, trange
from ipynb.fs.full.Bert import BertModel, LayerNorm, FuseEmbeddings
from ipynb.fs.full.PreTraining import t2n, metric_report
import numpy as np
import os

## Fine Tuning
The fine tuning step is the last step in G-BERT.  This code is primarily taken straight from the G-Bert Github repository.

In [3]:
class MappingHead(nn.Module):
    def __init__(self):
        super(MappingHead, self).__init__()
        self.dense = nn.Sequential(nn.Linear(300, 300), nn.ReLU())

    def forward(self, input):
        return self.dense(input)

In [4]:
class FineTuning(nn.Module):
    def __init__(self, data, useGraph):
        super(FineTuning, self).__init__()
        self.bert = BertModel(len(data["vocab"]), 300, 0.4, useGraph, data["all_conditions"], data["all_drugs"])
        self.dense = nn.ModuleList([MappingHead(), MappingHead()])
        self.cls = nn.Sequential(nn.Linear(900, 600), nn.ReLU(), nn.Linear(600, len(data["multi_visit_drugs"])))

        self.apply(self.init_bert_weights)

    def init_bert_weights(self, module):
        '''
        Taken from https://github.com/huggingface/transformers/blob/78b7debf56efb907c6af767882162050d4fbb294/src/transformers/modeling_utils.py#L1596
        '''
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=0.02)
        elif isinstance(module, LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()
            
    def forward(self, input_ids, rx_labels):
        token_types_ids = torch.cat([torch.zeros((1, input_ids.size(1))), torch.ones(
            (1, input_ids.size(1)))], dim=0).long().to(input_ids.device)
        token_types_ids = token_types_ids.repeat(
            1 if input_ids.size(0)//2 == 0 else input_ids.size(0)//2, 1)
        # bert_pool: (2*adm, H)
        _, bert_pool = self.bert(input_ids, token_types_ids)
        loss = 0
        bert_pool = bert_pool.view(2, -1, bert_pool.size(1))  # (2, adm, H)
        dx_bert_pool = self.dense[0](bert_pool[0])  # (adm, H)
        rx_bert_pool = self.dense[1](bert_pool[1])  # (adm, H)

        # mean and concat for rx prediction task
        rx_logits = []
        for i in range(rx_labels.size(0)):
            # mean
            dx_mean = torch.mean(dx_bert_pool[0:i+1, :], dim=0, keepdim=True)
            rx_mean = torch.mean(rx_bert_pool[0:i+1, :], dim=0, keepdim=True)

            # concat
            concat = torch.cat(
                [dx_mean, rx_mean, dx_bert_pool[i+1, :].unsqueeze(dim=0)], dim=-1)
            rx_logits.append(self.cls(concat))

        rx_logits = torch.cat(rx_logits, dim=0)
    
        loss = F.binary_cross_entropy_with_logits(rx_logits, rx_labels)
        return loss, rx_logits
    
    def from_pretrained(data, useGraph, outputFileName):
        # Instantiate model.
        model = FineTuning(data, useGraph)
        
        weights_path = os.path.join("", outputFileName)
        state_dict = torch.load(weights_path)

        old_keys = []
        new_keys = []
        for key in state_dict.keys():
            new_key = None
            if 'gamma' in key:
                new_key = key.replace('gamma', 'weight')
            if 'beta' in key:
                new_key = key.replace('beta', 'bias')
            if new_key:
                old_keys.append(key)
                new_keys.append(new_key)
        for old_key, new_key in zip(old_keys, new_keys):
            state_dict[new_key] = state_dict.pop(old_key)

        missing_keys = []
        unexpected_keys = []
        error_msgs = []
        # copy state_dict so _load_from_state_dict can modify it
        metadata = getattr(state_dict, '_metadata', None)
        state_dict = state_dict.copy()
        if metadata is not None:
            state_dict._metadata = metadata

        def load(module, prefix=''):
            local_metadata = {} if metadata is None else metadata.get(
                prefix[:-1], {})
            module._load_from_state_dict(
                state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
            for name, child in module._modules.items():
                if child is not None:
                    load(child, prefix + name + '.')

        load(model, prefix='' if hasattr(model, 'bert') else 'bert.')
        
        return model


In [5]:
def fine_tuning(data, train_dataloader, eval_dataloader, test_dataloader, outputFileName, useGraph):
    print("***** Running Fine Tuning *****")
    model = FineTuning.from_pretrained(data, useGraph, outputFileName)
    
    device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
    model.to(device)
    
    model_to_save = model.module if hasattr(
        model, 'module') else model  # Only save the model it-self
    dx_output_model_file = os.path.join(
        '', outputFileName)

    optimizer = Adam(model.parameters(), lr=5e-4)

    dx_acc_best, rx_acc_best = 0, 0
    acc_name = 'prauc'
    
    global_step = 0

    dx_acc_best, rx_acc_best = 0, 0
    acc_name = 'prauc'
    dx_history = {'prauc': []}
    rx_history = {'prauc': []}

    for _ in range(5):
        print("***** Running training *****")
        tr_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0
        model.train()
        for batch in train_dataloader:
            batch = tuple(t.to(device) for t in batch)
            input_ids, dx_labels, rx_labels = batch
            input_ids, dx_labels, rx_labels = input_ids.squeeze(dim=0), dx_labels.squeeze(dim=0), rx_labels.squeeze(dim=0)
            loss, rx_logits = model(input_ids, rx_labels)
            loss.backward()

            tr_loss += loss.item()
            nb_tr_examples += 1
            nb_tr_steps += 1

            optimizer.step()
            optimizer.zero_grad()

        global_step += 1
        print('train/loss:', tr_loss / nb_tr_steps, " epoch: ", global_step)
        
        print("***** Running eval *****")
        model.eval()
        dx_y_preds = []
        dx_y_trues = []
        rx_y_preds = []
        rx_y_trues = []
        for eval_input in eval_dataloader:
            eval_input = tuple(t.to(device) for t in eval_input)
            input_ids, dx_labels, rx_labels = eval_input
            input_ids, dx_labels, rx_labels = input_ids.squeeze(), dx_labels.squeeze(), rx_labels.squeeze(dim=0)
            with torch.no_grad():
                loss, rx_logits = model(input_ids, rx_labels)
                rx_y_preds.append(t2n(torch.sigmoid(rx_logits)))
                rx_y_trues.append(t2n(rx_labels))

        rx_acc_container = metric_report(np.concatenate(rx_y_preds, axis=0), np.concatenate(rx_y_trues, axis=0))
        for k, v in rx_acc_container.items():
            print("eval/", k, ": ", v, " epoch: ", global_step)

        if rx_acc_container[acc_name] > rx_acc_best:
            rx_acc_best = rx_acc_container[acc_name]
            # save model
            torch.save(model_to_save.state_dict(), dx_output_model_file)
