In [1]:
import sys
import pandas as pd
import os
import re
import glob
import shutil
import numpy as np
import matplotlib.pyplot as plt
import skimage
from pathlib import Path
import shutil
from base64 import b64encode
from base64 import b64decode

plt.rcParams["figure.figsize"] = (60,60)

from IPython.core.display import display, HTML
# make the Jupyter notebook use the full screen width
display(HTML("<style>.container { width:99% !important; }</style>"))

In [2]:
from pkg_resources import resource_exists, resource_listdir, resource_string, resource_stream,resource_filename
import xml.etree.ElementTree as ET
import numpy

#from pytorch_transformers.tokenization_bert import BertTokenizer
#from pytorch_transformers.modeling_utils import  CONFIG_NAME
#from pytorch_transformers.modeling_bert import BertPreTrainedModel, BertConfig, BertModel

from transformers.tokenization_bert import BertTokenizer
from transformers import CONFIG_NAME, WEIGHTS_NAME
from transformers.modeling_bert import BertPreTrainedModel, BertConfig, BertModel


from torch import nn
import torch
from torch.nn import LSTM
import torch, math, logging, os
import sys, os, time, socket
from sklearn.metrics import f1_score, precision_score, recall_score


### Parameters:

In [3]:
labels='PAST SMOKER, CURRENT SMOKER, NON-SMOKER, UNKNOWN'
device='cuda:0'
batch_size=5  #batch size of lstm


bert_model_path= \
'/notebook/nas-trainings/arne/OCCAM/text_classification_BERT/code_BERT/bert_document_classification/examples/ml4health_2019_replication/clinicalBERT/pretrained_bert_tf/biobert_pretrain_output_all_notes_150000/'
#bert_model_path='bert-base-uncased'
bert_batch_size=7

#parameters for optimizer
weight_decay=0
learning_rate=6e-5

model_storage_directory='results_test'


labels = [x for x in labels.split(', ')]

### Model directory

In [4]:
#Set run specific envirorment configurations
timestamp = time.strftime("run_%Y_%m_%d_%H_%M_%S") + "_{machine}".format(machine=socket.gethostname())
model_directory = os.path.join(model_storage_directory, timestamp) #directory
os.makedirs(model_directory, exist_ok=True)

### Handling logging configurations:

In [5]:
log = logging.getLogger()
log.handlers.clear()
formatter = logging.Formatter('%(message)s')
fh = logging.FileHandler(os.path.join(model_directory, "log.txt"))
fh.setLevel(logging.INFO)
fh.setFormatter(formatter)
log.addHandler(fh)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
ch.setFormatter(formatter)
log.setLevel(logging.INFO)
log.addHandler(ch)
#log.info(p.format_values())

### Data loader helper function:

In [6]:
def load_n2c2_2006(partition='train'):
    """
    Yields a generator of id, doc, label tuples.
    :param partition:
    :return:
    """
    assert partition in ['train', 'test']

    if partition == 'train':
        with open("data/smokers_surrogate_%s_all_version2.xml" % partition) as raw:
            file = raw.read().strip()
        
    elif partition == 'test':
        with open("data/smokers_surrogate_%s_all_groundtruth_version2.xml" % partition) as raw:
            file = raw.read().strip()   
        
    # file = resource_string('clinical_data', 'phenotyping/n2c2_2006/smokers_surrogate_%s_all_version2.xml' % partition).decode('utf-8').strip()
    root = ET.fromstring(file)
    ids = []
    notes = []
    labels = []
    documents = root.findall("./RECORD")
    for document in documents:
        ids.append(document.attrib['ID'])
        notes.append(document.findall('./TEXT')[0].text)
        labels.append(document.findall('./SMOKING')[0].attrib['STATUS'])

    for id, note, label in zip(ids,notes,labels):
        yield (id,note,label)

In [7]:
#load the data:
train, dev = load_n2c2_2006(partition='train'), load_n2c2_2006(partition='test')

train_documents, train_labels = [],[]
for _, text, status in train:
    train_documents.append(text)
    label = [0]*len(labels)
    for idx, name in enumerate(labels):
        if name == status:
            label[idx] = 1
    train_labels.append(label)

dev_documents, dev_labels = [],[]
for _, text, status in dev:
    dev_documents.append(text)
    label = [0]*len(labels)
    for idx, name in enumerate(labels):
        if name == status:
            label[idx] = 1
    dev_labels.append(label)
    

### The model

In [8]:
class DocumentBertLSTM(BertPreTrainedModel):
    """
    BERT output over document in LSTM
    """

    def __init__(self, bert_model_config: BertConfig):
        #super(DocumentBertLSTM, self).__init__(bert_model_config)
        super().__init__(bert_model_config)
        self.bert = BertModel(bert_model_config)
        self.bert_batch_size= self.bert.config.bert_batch_size
        self.dropout = nn.Dropout(p=bert_model_config.hidden_dropout_prob)
        self.lstm = LSTM(bert_model_config.hidden_size,bert_model_config.hidden_size, )
        self.classifier = nn.Sequential(
            nn.Dropout(p=bert_model_config.hidden_dropout_prob),
            nn.Linear(bert_model_config.hidden_size, bert_model_config.num_labels),
            nn.Tanh()
        )
        self.init_weights()        

    #input_ids, token_type_ids, attention_masks
    def forward(self, document_batch: torch.Tensor, document_sequence_lengths: list, device='cuda'):

        #contains all BERT sequences
        #bert should output a (batch_size, num_sequences, bert_hidden_size)
        bert_output = torch.zeros(size=(document_batch.shape[0],
                                              min(document_batch.shape[1],self.bert_batch_size),
                                              self.bert.config.hidden_size), dtype=torch.float, device=device)

        #only pass through bert_batch_size numbers of inputs into bert.
        #this means that we are possibly cutting off the last part of documents.
        #use_grad = not freeze_bert
    
        
        
        for doc_id in range(document_batch.shape[0]):
            bert_output[doc_id][:self.bert_batch_size] = self.dropout(self.bert(document_batch[doc_id][:self.bert_batch_size,0],
                                            token_type_ids=document_batch[doc_id][:self.bert_batch_size,1],
                                            attention_mask=document_batch[doc_id][:self.bert_batch_size,2])[1])

        #lstm expects a ( num_sequences, batch_size (i.e. number of documents) , bert_hidden_size )
        self.lstm.flatten_parameters()
        output, (_, _) = self.lstm(bert_output.permute(1,0,2))
        
        #print(bert_output.requires_grad)
        #print(output.requires_grad)

        last_layer = output[-1]
        #print("Last LSTM layer shape:",last_layer.shape)

        prediction = self.classifier(last_layer)
        #print("Prediction Shape", prediction.shape)
        assert prediction.shape[0] == document_batch.shape[0]
        return prediction
    
    def freeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = False
    
    def unfreeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = True

### Set up the model

#### load the config file

In [9]:
if os.path.exists( bert_model_path  ):
    if os.path.exists(os.path.join(bert_model_path, CONFIG_NAME)):
        print( f"loading {bert_model_path}" )
        config = BertConfig.from_json_file(os.path.join(bert_model_path, CONFIG_NAME))
    elif os.path.exists(os.path.join(bert_model_path, 'bert_config.json')):
        print( f"loading {bert_model_path}" )
        config = BertConfig.from_json_file(os.path.join(bert_model_path, 'bert_config.json'))
    else:
        raise ValueError("Cannot find a configuration for the BERT based model you are attempting to load.")
else:
    config = BertConfig.from_pretrained(bert_model_path )


config.__setattr__( 'num_labels', len( labels ) )
config.__setattr__( 'bert_batch_size', bert_batch_size )

print(config)

loading /notebook/nas-trainings/arne/OCCAM/text_classification_BERT/code_BERT/bert_document_classification/examples/ml4health_2019_replication/clinicalBERT/pretrained_bert_tf/biobert_pretrain_output_all_notes_150000/
BertConfig {
  "architectures": null,
  "attention_probs_dropout_prob": 0.1,
  "bert_batch_size": 7,
  "bos_token_id": 0,
  "do_sample": false,
  "eos_token_ids": 0,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "is_decoder": false,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1
  },
  "layer_norm_eps": 1e-12,
  "length_penalty": 1.0,
  "max_length": 20,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_beams": 1,
  "num_hidden_layers": 12,
  "num_labels": 4,
  "num_return_sequences": 1,
  "output_attentions": false,
  "output_hidden_states": false,
  

### Train

#### 1) Helper function for encoding documents (tokenization of sentences, word2id), and saving checkpoints

In [10]:
def encode_documents(documents: list, tokenizer: BertTokenizer, max_input_length=512):
    """
    Returns a len(documents) * max_sequences_per_document * 3 * 512 tensor where len(documents) is the batch
    dimension and the others encode bert input.

    This is the input to any of the document bert architectures.

    :param documents: a list of text documents
    :param tokenizer: the sentence piece bert tokenizer
    :return:
    """
    tokenized_documents = [tokenizer.tokenize(document)[:10200] for document in documents]  #added by AD (only take first 10200 tokens of each documents as input)
    max_sequences_per_document = math.ceil(max(len(x)/(max_input_length-2) for x in tokenized_documents))
    assert max_sequences_per_document <= 20, "Your document is to large, arbitrary size when writing"

    output = torch.zeros(size=(len(documents), max_sequences_per_document, 3, 512), dtype=torch.long)
    document_seq_lengths = [] #number of sequence generated per document
    #Need to use 510 to account for 2 padding tokens
    for doc_index, tokenized_document in enumerate(tokenized_documents):
        max_seq_index = 0
        for seq_index, i in enumerate(range(0, len(tokenized_document), (max_input_length-2))):
            raw_tokens = tokenized_document[i:i+(max_input_length-2)]
            tokens = []
            input_type_ids = []

            tokens.append("[CLS]")
            input_type_ids.append(0)
            for token in raw_tokens:
                tokens.append(token)
                input_type_ids.append(0)
            tokens.append("[SEP]")
            input_type_ids.append(0)

            input_ids = tokenizer.convert_tokens_to_ids(tokens)
            attention_masks = [1] * len(input_ids)

            while len(input_ids) < max_input_length:
                input_ids.append(0)
                input_type_ids.append(0)
                attention_masks.append(0)

            assert len(input_ids) == 512 and len(attention_masks) == 512 and len(input_type_ids) == 512

            #we are ready to rumble
            output[doc_index][seq_index] = torch.cat((torch.LongTensor(input_ids).unsqueeze(0),
                                                           torch.LongTensor(input_type_ids).unsqueeze(0),
                                                           torch.LongTensor(attention_masks).unsqueeze(0)),
                                                          dim=0)
            max_seq_index = seq_index
        document_seq_lengths.append(max_seq_index+1)
    return output, torch.LongTensor(document_seq_lengths)

#helper function to save checkpoints

def save_checkpoint( model: DocumentBertLSTM, tokenizer: BertTokenizer , checkpoint_path: str):
    """
    Saves an instance of the current model to the specified path.
    :return:
    """
    if not os.path.exists(checkpoint_path):
        os.mkdir(checkpoint_path)
    else:
        raise ValueError("Attempting to save checkpoint to an existing directory")
    log.info("Saving checkpoint: %s" % checkpoint_path )

    #save finetune parameters
    net = model
    if isinstance(model, nn.DataParallel):
        net = model.module
    torch.save(net.state_dict(), os.path.join(checkpoint_path, WEIGHTS_NAME))
    #save configurations
    net.config.to_json_file(os.path.join(checkpoint_path, CONFIG_NAME))
    #save exact vocabulary utilized
    tokenizer.save_vocabulary(checkpoint_path)


#### 2) Train and predict function

In [11]:
def train( model: DocumentBertLSTM , optimizer: torch.optim.Adam , tokenizer: BertTokenizer, train: tuple, dev: tuple, batch_size: int , output_path:str, labels:list, epochs=10 , device='cuda:0'  ):

    model.train()
    
    train_documents, train_labels = train  #train is tuple
    dev_documents, dev_labels = dev
    
    document_representations, document_sequence_lengths = encode_documents(train_documents, bert_tokenizer)

    correct_output = torch.FloatTensor(train_labels)
    
    assert document_representations.shape[0] == correct_output.shape[0]
    
    #if torch.cuda.device_count()>1:
        #model=torch.nn.DataParallel( model )
    
    model.to( device=device )
    
    #get the loss function
    
    loss_weight = ((correct_output.shape[0] / torch.sum(correct_output, dim=0))-1).to(device=device)
    loss_function = torch.nn.BCEWithLogitsLoss(pos_weight=loss_weight)
    #loss_function=torch.nn.CrossEntropyLoss(weight=loss_weight)
    
    for epoch in range( 1, epochs+1  ):
        # shuffle
        permutation = torch.randperm(document_representations.shape[0])
        document_representations = document_representations[permutation]
        document_sequence_lengths = document_sequence_lengths[permutation]
        correct_output = correct_output[permutation]
        
        epoch_loss = 0.0
        for i in range(0, document_representations.shape[0], batch_size  ):

            batch_document_tensors = document_representations[i:i + batch_size  ].to(device )
            batch_document_sequence_lengths= document_sequence_lengths[i:i+ batch_size  ]
            #self.log.info(batch_document_tensors.shape)
            batch_predictions=model( batch_document_tensors, batch_document_sequence_lengths, device=device )  #we freeze bert
            batch_correct_output = correct_output[i:i + batch_size ].to( device=device )
            loss = loss_function(batch_predictions, batch_correct_output)
            epoch_loss += float(loss.item())
            #self.log.info(batch_predictions)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        epoch_loss /= int(document_representations.shape[0] / batch_size )  # divide by number of batches per epoch

        log.info('Epoch %i Completed: %f' % (epoch, epoch_loss))
    
        if epoch % 250 == 0:
            save_checkpoint( model, tokenizer, os.path.join( output_path , "checkpoint_%s" % epoch ) )

        # evaluate on development data
        if epoch % 10 == 0:
            predict( model, tokenizer, (dev_documents, dev_labels), batch_size , epoch, output_path, labels, device=device )
                
        #make sure that it is back on train...
        

def predict( model: DocumentBertLSTM, tokenizer: BertTokenizer, data, batch_size:int, epoch:int, output_path:str, labels:list, threshold=0, device='cuda:0' ):
    """
    A tuple containing
    :param data:
    :return:
    """
    document_representations = None
    document_sequence_lengths = None
    correct_output = None
    if isinstance(data, list):
        document_representations, document_sequence_lengths = encode_documents(data, tokenizer)
    if isinstance(data, tuple) and len(data) == 2:
        log.info('Evaluating on Epoch %i' % (epoch))
        document_representations, document_sequence_lengths = encode_documents(data[0], tokenizer)
        correct_output = torch.FloatTensor(data[1]).transpose(0,1)
        assert labels is not None

    model.to(device=device )
    
    #put in eval mode:
    model.eval()
    with torch.no_grad():
        predictions = torch.empty((document_representations.shape[0], len(labels )  ))
        for i in range(0, document_representations.shape[0], batch_size  ):
            batch_document_tensors = document_representations[i:i + batch_size  ].to(device=device)
            batch_document_sequence_lengths= document_sequence_lengths[i:i+batch_size  ]

            prediction = model(batch_document_tensors, batch_document_sequence_lengths,device=device  )
            predictions[i:i + batch_size  ] = prediction

    for r in range(0, predictions.shape[0]):
        for c in range(0, predictions.shape[1]):
            if predictions[r][c] > threshold:
                predictions[r][c] = 1
            else:
                predictions[r][c] = 0
    predictions = predictions.transpose(0, 1)

    if correct_output is None:
        return predictions.cpu()
    else:
        assert correct_output.shape == predictions.shape
        precisions = []
        recalls = []
        fmeasures = []

        for label_idx in range(predictions.shape[0]):
            correct = correct_output[label_idx].cpu().view(-1).numpy()
            predicted = predictions[label_idx].cpu().view(-1).numpy()
            present_f1_score = f1_score(correct, predicted, average='binary', pos_label=1)
            present_precision_score = precision_score(correct, predicted, average='binary', pos_label=1)
            present_recall_score = recall_score(correct, predicted, average='binary', pos_label=1)

            precisions.append(present_precision_score)
            recalls.append(present_recall_score)
            fmeasures.append(present_f1_score)
            logging.info('F1\t%s\t%f' % (labels[label_idx], present_f1_score))

        micro_f1 = f1_score(correct_output.reshape(-1).numpy(), predictions.reshape(-1).numpy(), average='micro')
        macro_f1 = f1_score(correct_output.reshape(-1).numpy(), predictions.reshape(-1).numpy(), average='macro')

        #if 'use_tensorboard' in self.args and self.args['use_tensorboard']:
        #    for label_idx in range(predictions.shape[0]):
        #        self.tensorboard_writer.add_scalar('Precision/%s/Test' % self.args['labels'][label_idx].replace(" ", "_"), precisions[label_idx], self.epoch)
        #        self.tensorboard_writer.add_scalar('Recall/%s/Test' % self.args['labels'][label_idx].replace(" ", "_"), recalls[label_idx], self.epoch)
        #        self.tensorboard_writer.add_scalar('F1/%s/Test' % self.args['labels'][label_idx].replace(" ", "_"), fmeasures[label_idx], self.epoch)
        #    self.tensorboard_writer.add_scalar('Micro-F1/Test', micro_f1, self.epoch)
        #    self.tensorboard_writer.add_scalar('Macro-F1/Test', macro_f1, self.epoch)

        with open(os.path.join( output_path , "eval_%s.csv" % epoch), 'w') as eval_results:
            eval_results.write('Metric\t' + '\t'.join([ labels[label_idx] for label_idx in range(predictions.shape[0])]) +'\n' )
            eval_results.write('Precision\t' + '\t'.join([str(precisions[label_idx]) for label_idx in range(predictions.shape[0])]) + '\n' )
            eval_results.write('Recall\t' + '\t'.join([str(recalls[label_idx]) for label_idx in range(predictions.shape[0])]) + '\n' )
            eval_results.write('F1\t' + '\t'.join([ str(fmeasures[label_idx]) for label_idx in range(predictions.shape[0])]) + '\n' )
            eval_results.write('Micro-F1\t' + str(micro_f1) + '\n' )
            eval_results.write('Macro-F1\t' + str(macro_f1) + '\n' )

    model.train()

In [None]:
#bert tokenizer:
bert_tokenizer = BertTokenizer.from_pretrained( bert_model_path )

#model:
bert_doc_classifier=DocumentBertLSTM.from_pretrained( bert_model_path  , config=config   )

bert_doc_classifier.freeze_bert_encoder()

#for param in bert_doc_classifier.bert.parameters():
#    print(param.requires_grad)
     #= False

#optimizer:
optimizer = torch.optim.Adam(
            bert_doc_classifier.parameters(),  #you could limit here to the lstm parameters (as opposed to using with torch.no_grad() in the DocumentBertLSTM class ) 
            weight_decay=weight_decay,
            lr=learning_rate
        )

train( bert_doc_classifier , optimizer , bert_tokenizer, ( train_documents, train_labels )  , (dev_documents, dev_labels) , batch_size , model_directory , labels, epochs=250 , device=device )

Model name '/notebook/nas-trainings/arne/OCCAM/text_classification_BERT/code_BERT/bert_document_classification/examples/ml4health_2019_replication/clinicalBERT/pretrained_bert_tf/biobert_pretrain_output_all_notes_150000/' not found in model shortcut name list (bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese, bert-base-german-cased, bert-large-uncased-whole-word-masking, bert-large-cased-whole-word-masking, bert-large-uncased-whole-word-masking-finetuned-squad, bert-large-cased-whole-word-masking-finetuned-squad, bert-base-cased-finetuned-mrpc, bert-base-german-dbmdz-cased, bert-base-german-dbmdz-uncased, bert-base-finnish-cased-v1, bert-base-finnish-uncased-v1, bert-base-dutch-cased). Assuming '/notebook/nas-trainings/arne/OCCAM/text_classification_BERT/code_BERT/bert_document_classification/examples/ml4health_2019_replication/clinicalBERT/pretrained_bert_tf/biobert_pretrain_output

In [None]:
correct_output = torch.FloatTensor(train_labels)
loss_weight = ((correct_output.shape[0] / torch.sum(correct_output, dim=0) )-1).to(device=device)
loss_weight

In [None]:
(correct_output.shape[0] / torch.sum(correct_output, dim=0) )

In [None]:
torch.sum(correct_output, dim=0)