# **BERT document classification**
BERT based models for multi-class and multi-label long document classification.

## **Setting: change runtime**
Change **Runtime type** to **GPU**

## **Clone repository**

In [None]:
! pip install bert_document_classification



In [None]:
! pip install torch==1.4.0



In [None]:
ls -l

total 496
-rw-r--r-- 1 root root   7629 Mar 23 15:50 data.py
-rw-r--r-- 1 root root      0 Mar 23 14:24 __init__.py
-rw-r--r-- 1 root root    648 Mar 23 14:24 n2c2_2006_train_config.ini
-rw-r--r-- 1 root root    698 Mar 23 14:24 n2c2_2008_train_config.ini
-rw-r--r-- 1 root root    708 Mar 23 14:40 newstest_train_config.ini
-rw-r--r-- 1 root root   2786 Mar 23 14:24 predict_n2c2_2006.py
-rw-r--r-- 1 root root   2501 Mar 23 14:24 predict_n2c2_2008.py
-rw-r--r-- 1 root root   3246 Mar 23 14:24 predict_newstest_bert.py
-rw-r--r-- 1 root root   3258 Mar 23 14:24 predict_newstest_distilbert.py
-rw-r--r-- 1 root root 294057 Mar 23 16:01 preprint_classification_data.tsv
drwxr-xr-x 2 root root   4096 Mar 23 15:54 [0m[01;34m__pycache__[0m/
drwxr-xr-x 3 root root   4096 Mar 23 14:24 [01;34mresults_2008_BertLinear_unfreeze_last_layers[0m/
drwxr-xr-x 3 root root   4096 Mar 23 14:24 [01;34mresults_2008_BertLSTM_freeze_bert[0m/
drwxr-xr-x 3 root root   4096 Mar 23 14:24 [01;34mresults_2008_fi

In [None]:
! git clone https://github.com/ArneDefauw/BERT_doc_classification.git

Cloning into 'BERT_doc_classification'...
remote: Enumerating objects: 1260, done.[K
remote: Counting objects: 100% (1260/1260), done.[K
remote: Compressing objects: 100% (1189/1189), done.[K
remote: Total 1260 (delta 115), reused 1160 (delta 68), pack-reused 0
Receiving objects: 100% (1260/1260), 515.04 KiB | 3.28 MiB/s, done.
Resolving deltas: 100% (115/115), done.


In [None]:
ls -l

total 8
drwxr-xr-x 5 root root 4096 Mar 23 01:04 [0m[01;34mbert_document_classification[0m/
drwxr-xr-x 1 root root 4096 Mar 18 13:36 [01;34msample_data[0m/


In [None]:
cd BERT_doc_classification/bert_document_classification/examples/ml4health_2019_replication

/content/BERT_doc_classification/bert_document_classification/examples/ml4health_2019_replication


## **Modify classification interface**

Change hyperparameters in `config.ini`

In [None]:
%%writefile newstest_train_config.ini
model_storage_directory: ./test_out_class
batch_size: 10
epochs:1000
evaluation_interval: 10
checkpoint_interval: 250
;use_tensorboard

bert_model_path: bert-base-uncased
#distilbert-base-uncased
#bert-base-uncased
#distilbert-base-uncased
#clinicalBERT/pretrained_bert_tf/biobert_pretrain_output_all_notes_150000

#You need to replace this with a path to clinicalBert weights
#Find it here: https://github.com/EmilyAlsentzer/clinicalBERT
#bert_model_path: /export/b18/elliot/pretrained_bert_tf/biobert_pretrain_output_all_notes_150000

#labels: alt.atheism, talk.religion.misc, comp.graphics, sci.space
labels: RCT, observational study, other
architecture: DocumentBertLSTM
bert_batch_size: 7

device cuda:0
cuda
learning_rate: 6e-5
weight_decay: 0

Overwriting newstest_train_config.ini


Load preprint data by changing `data.py` and `train_newstest.py`


In [None]:
%%writefile data.py
from pkg_resources import resource_exists, resource_listdir, resource_string, resource_stream,resource_filename
import xml.etree.ElementTree as ET
import numpy
from sklearn.datasets import fetch_20newsgroups
from sklearn.model_selection import train_test_split
import pandas as pd


def load_n2c2_2006_train_dev_split():
    train = list(load_n2c2_2006(partition='train'))
    numpy.random.seed(0)
    numpy.random.shuffle(train)

    labels = {}
    for id, doc, label in train:
        if label not in labels:
            labels[label] = []
        labels[label].append(tuple((id,doc,label)))

    dev = []
    train = []
    for label in labels.keys():
        dev += labels[label][:int(len(labels[label])*.2)]
        train += labels[label][int(len(labels[label])*.2):]

    return train,dev

def load_n2c2_2006(partition='train'):
    """
    Yields a generator of id, doc, label tuples.
    :param partition:
    :return:
    """
    assert partition in ['train', 'test']

    if partition == 'train':
        with open("data/smokers_surrogate_%s_all_version2.xml" % partition) as raw:
            file = raw.read().strip()
        
    elif partition == 'test':
        with open("data/smokers_surrogate_%s_all_groundtruth_version2.xml" % partition) as raw:
            file = raw.read().strip()   
        
    # file = resource_string('clinical_data', 'phenotyping/n2c2_2006/smokers_surrogate_%s_all_version2.xml' % partition).decode('utf-8').strip()
    root = ET.fromstring(file)
    ids = []
    notes = []
    labels = []
    documents = root.findall("./RECORD")
    for document in documents:
        ids.append(document.attrib['ID'])
        notes.append(document.findall('./TEXT')[0].text)
        labels.append(document.findall('./SMOKING')[0].attrib['STATUS'])

    for id, note, label in zip(ids,notes,labels):
        yield (id,note,label)


def load_n2c2_2008_train_dev_split():
    train = list(load_n2c2_2008(partition='train'))

    return train[:int(len(train)*.8)], train[int(len(train)*.8):]


def load_n2c2_2008(partition='train'):
    assert partition in ['train', 'test']
    documents = {} #id : text
    all_diseases = set()
    notes = tuple()
    if partition == 'train':
        with open('data/obesity_patient_records_training.xml') as t1, \
                open('data/obesity_patient_records_training2.xml') as t2:
            notes1 = t1.read().strip()
            notes2 = t2.read().strip()
        notes = (notes1,notes2)
    elif partition == 'test':
        with open('data/obesity_patient_records_test.xml') as t1:
            notes1 = t1.read().strip()
        notes = (notes1,)

    for file in notes:
        root = ET.fromstring(file)
        root = root.findall("./docs")[0]
        for document in root.findall("./doc"):
            assert document.attrib['id'] not in documents
            documents[document.attrib['id']] = {}
            documents[document.attrib['id']]['text'] = document.findall("./text")[0].text

    annotation_files = tuple()
    if partition == 'train':
        with open('data/obesity_standoff_annotations_training.xml') as t1, \
                open('data/obesity_standoff_annotations_training_addendum.xml') as t2, \
                open('data/obesity_standoff_annotations_training_addendum2.xml') as t3, \
                open('data/obesity_standoff_annotations_training_addendum3.xml') as t4:
            train1 = t1.read().strip()
            train2 = t2.read().strip()
            train3 = t3.read().strip()
            train4 = t4.read().strip()
        # train1 = resource_string('clinical_data', 'phenotyping/n2c2_2008/train/obesity_standoff_annotations_training.xml').decode('utf-8').strip()
        # train2 = resource_string('clinical_data', 'phenotyping/n2c2_2008/train/obesity_standoff_annotations_training_addendum.xml').decode('utf-8').strip()
        # train3 = resource_string('clinical_data', 'phenotyping/n2c2_2008/train/obesity_standoff_annotations_training_addendum2.xml').decode('utf-8').strip()
        # train4 = resource_string('clinical_data', 'phenotyping/n2c2_2008/train/obesity_standoff_annotations_training_addendum3.xml').decode('utf-8').strip()
        annotation_files = (train1,train2,train3,train4)
    elif partition == 'test':
        with open('data/obesity_standoff_annotations_test.xml') as t1:
            test1 = t1.read().strip()
        # test1 = resource_string('clinical_data','phenotyping/n2c2_2008/test/obesity_standoff_annotations_test.xml').decode('utf-8').strip()
        annotation_files = (test1,)

    for file in annotation_files:
        root = ET.fromstring(file)
        for diseases_annotation in root.findall("./diseases"):

            annotation_source = diseases_annotation.attrib['source']
            assert isinstance(annotation_source, str)
            for disease in diseases_annotation.findall("./disease"):
                disease_name = disease.attrib['name']
                all_diseases.add(disease_name)
                for annotation in disease.findall("./doc"):
                    doc_id = annotation.attrib['id']
                    if not annotation_source in documents[doc_id]:
                        documents[doc_id][annotation_source] = {}
                    assert doc_id in documents
                    judgment = annotation.attrib['judgment']
                    documents[doc_id][annotation_source][disease_name] = judgment

    all_diseases = list(all_diseases)
    #print(all_diseases)

    for id in documents: #set unlabeled instances to None
        for annotation_type in ('textual', 'intuitive'):
            for disease in all_diseases:
                if not annotation_type in documents[id]:
                    documents[id][annotation_type] = {}
                if not disease in documents[id][annotation_type]:
                    #print(id, annotation_type, disease)
                    documents[id][annotation_type][disease] = None

    for id in documents:
        yield id, documents[id]['text'], documents[id]['textual'], documents[id]['intuitive']
    from pprint import pprint
    #pprint(documents[list(documents.keys())[1]])
    

def load_newstest( random_state=42 , categories = ['alt.atheism', 'talk.religion.misc', 'comp.graphics', 'sci.space']):

    remove=()

    data_train = fetch_20newsgroups(subset='train', categories=categories,
                                    shuffle=True, random_state=42,
                                    remove=remove)

    data_test = fetch_20newsgroups(subset='test', categories=categories,
                                   shuffle=True, random_state=42,
                                   remove=remove)

    target_names = data_train.target_names
    print(  'labels for newstest dataset:', target_names )

    return data_train, data_test, target_names

def generator_newstest( data, target_names   ):
    """
    Yields a generator of id, doc, label tuples.
    :param dict of newstest data , target_names/labels:
    :return:
    """
    ids=[]
    documents=[]
    labels=[]
    for index, (text, nr_label) in enumerate(zip( data.data, data.target )):
        ids.append( index )
        documents.append( text )
        labels.append( target_names[ nr_label  ]   )

    for id, text, label in zip(ids,documents,labels):
        yield (id,text,label)


def load_preprints(data_path):
    df = pd.read_csv(data_path, sep='\t')
    data = df[['text', 'study type']]
    data_train, data_test = train_test_split(data, test_size=0.2, random_state=42, shuffle=True)
    return data_train, data_test


def generator_preprints(data):
    for row in data.itertuples(name=None):
        yield row



Overwriting data.py


In [None]:
%%writefile train_newstest.py
import sys, os, logging, torch, time, configargparse, socket

#appends current directory to sys path allowing data imports.
sys.path.append(os.path.dirname(os.path.realpath(__file__)))

sys.path.append(  "/notebook/nas-trainings/arne/OCCAM/text_classification_BERT/code_BERT/bert_document_classification"  )


from data import load_newstest, generator_newstest, load_preprints, generator_preprints
from bert_document_classification.document_bert import BertForDocumentClassification

log = logging.getLogger()

def _initialize_arguments(p: configargparse.ArgParser):
    p.add('--model_storage_directory', help='The directory caching all model runs')
    p.add('--bert_model_path', help='Model path to BERT')
    p.add('--labels', help='Numbers of labels to predict over', type=str)
    p.add('--architecture', help='Training architecture', type=str)
    p.add('--freeze_bert', help='Whether to freeze bert', type=bool)

    p.add('--batch_size', help='Batch size for training multi-label document classifier', type=int)
    p.add('--bert_batch_size', help='Batch size for feeding 510 token subsets of documents through BERT', type=int)
    p.add('--epochs', help='Epochs to train', type=int)
    #Optimizer arguments
    p.add('--learning_rate', help='Optimizer step size', type=float)
    p.add('--weight_decay', help='Adam regularization', type=float)

    p.add('--evaluation_interval', help='Evaluate model on test set every evaluation_interval epochs', type=int)
    p.add('--checkpoint_interval', help='Save a model checkpoint to disk every checkpoint_interval epochs', type=int)

    #Non-config arguments
    p.add('--cuda', action='store_true', help='Utilize GPU for training or prediction')
    p.add('--device')
    p.add('--timestamp', help='Run specific signature')
    p.add('--model_directory', help='The directory storing this model run, a sub-directory of model_storage_directory')
    p.add('--use_tensorboard', help='Use tensorboard logging', type=bool)
    args = p.parse_args()

    args.labels = [x for x in args.labels.split(', ')]

    #Set run specific envirorment configurations
    args.timestamp = time.strftime("run_%Y_%m_%d_%H_%M_%S") + "_{machine}".format(machine=socket.gethostname())
    args.model_directory = os.path.join(args.model_storage_directory, args.timestamp) #directory
    os.makedirs(args.model_directory, exist_ok=True)

    #Handle logging configurations
    log.handlers.clear()
    formatter = logging.Formatter('%(message)s')
    fh = logging.FileHandler(os.path.join(args.model_directory, "log.txt"))
    fh.setLevel(logging.INFO)
    fh.setFormatter(formatter)
    log.addHandler(fh)
    ch = logging.StreamHandler()
    ch.setLevel(logging.INFO)
    ch.setFormatter(formatter)
    log.setLevel(logging.INFO)
    log.addHandler(ch)
    log.info(p.format_values())


    #Set global GPU state
    #if torch.cuda.is_available() and args.cuda:
    #    if torch.cuda.device_count() > 1:
    #        log.info("Using %i CUDA devices" % torch.cuda.device_count() )
     #   else:
     #       log.info("Using CUDA device:{0}".format(torch.cuda.current_device()))
        #args.device =  args.device  #'cuda:1'
    #else:
    #   log.info("Not using CUDA :(")
    #    args.dev = 'cpu'

    return args


if __name__ == "__main__":
    p = configargparse.ArgParser(default_config_files=["newstest_train_config.ini"])
    args = _initialize_arguments(p)

    torch.cuda.empty_cache()

    # data_train, data_test, target_names =load_newstest( categories=args.labels )
    # train=generator_newstest( data_train, target_names )
    # dev=generator_newstest( data_test, target_names )

    dataset = 'preprint_classification_data.tsv'
    data_train, data_test = load_preprints(dataset)
    train = generator_preprints(data_train)
    dev = generator_preprints(data_test)
    # labels = ['RCT', 'observational study', 'other']
    
    train_documents, train_labels = [],[]
    for _, text, status in train:
        train_documents.append(text)
        label = [0]*len(args.labels)
        for idx, name in enumerate(args.labels):
            if name == status:
                label[idx] = 1
        train_labels.append(label)

    dev_documents, dev_labels = [],[]
    for _, text, status in dev:
        dev_documents.append(text)
        label = [0]*len(args.labels)
        for idx, name in enumerate(args.labels):
            if name == status:
                label[idx] = 1
        dev_labels.append(label)

    model = BertForDocumentClassification(args=args)
    model.fit((train_documents, train_labels), (dev_documents,dev_labels))
  

Overwriting train_newstest.py


## **Fine-tuning with preprint data**
Fine tuning BERT base cased

In [None]:
! python train_newstest.py

Config File (newstest_train_config.ini):
  model_storage_directory:./test_out_class
  batch_size:        10
  epochs:            1000
  evaluation_interval:10
  checkpoint_interval:250
  bert_model_path:   bert-base-uncased
  labels:            RCT, observational study, other
  architecture:      DocumentBertLSTM
  bert_batch_size:   7
  device:            cuda:0
  cuda:              true
  learning_rate:     6e-5
  weight_decay:      0

loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /root/.cache/torch/pytorch_transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json from cache at /root/.cache/torch/pytorch_transformers/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.7156163d5fdc189c3016baca0775ffce230789d7fa2a42ef51648

## **Prediction**
Update evaluation metrics