In [None]:
# https://curiousily.com/posts/multi-label-text-classification-with-bert-and-pytorch-lightning/

In [1]:
from google.colab import drive
drive.mount('/content/drive')

In [1]:
import warnings
warnings.filterwarnings('ignore')

In [1]:
!nvidia-smi

In [None]:
!pip install transformers

Collecting transformers
  Downloading transformers-4.9.2-py3-none-any.whl (2.6 MB)
[K     |████████████████████████████████| 2.6 MB 5.4 MB/s 
Collecting tokenizers<0.11,>=0.10.1
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 40.7 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-5.4.1-cp37-cp37m-manylinux1_x86_64.whl (636 kB)
[K     |████████████████████████████████| 636 kB 50.1 MB/s 
Collecting huggingface-hub==0.0.12
  Downloading huggingface_hub-0.0.12-py3-none-any.whl (37 kB)
Collecting sacremoses
  Downloading sacremoses-0.0.45-py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 60.2 MB/s 
Installing collected packages: tokenizers, sacremoses, pyyaml, huggingface-hub, transformers
  Attempting uninstall: pyyaml
    Found existing installation: PyYAML 3.13
    Uninstalling PyYAML-3.13:
      Successfully uninstalled P

In [None]:
!pip install pytorch-lightning

Collecting pytorch-lightning
  Downloading pytorch_lightning-1.4.2-py3-none-any.whl (916 kB)
[K     |████████████████████████████████| 916 kB 5.3 MB/s 
Collecting torchmetrics>=0.4.0
  Downloading torchmetrics-0.5.0-py3-none-any.whl (272 kB)
[K     |████████████████████████████████| 272 kB 16.3 MB/s 
Collecting fsspec[http]!=2021.06.0,>=2021.05.0
  Downloading fsspec-2021.7.0-py3-none-any.whl (118 kB)
[K     |████████████████████████████████| 118 kB 18.8 MB/s 
Collecting future>=0.17.1
  Downloading future-0.18.2.tar.gz (829 kB)
[K     |████████████████████████████████| 829 kB 18.3 MB/s 
[?25hCollecting pyDeprecate==0.3.1
  Downloading pyDeprecate-0.3.1-py3-none-any.whl (10 kB)
Collecting aiohttp
  Downloading aiohttp-3.7.4.post0-cp37-cp37m-manylinux2014_x86_64.whl (1.3 MB)
[K     |████████████████████████████████| 1.3 MB 30.6 MB/s 
Collecting yarl<2.0,>=1.0
  Downloading yarl-1.6.3-cp37-cp37m-manylinux2014_x86_64.whl (294 kB)
[K     |████████████████████████████████| 294 kB 40.

In [None]:
import pandas as pd
import numpy as np

from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

from transformers import BertTokenizerFast as BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup

import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy, f1, auroc
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, multilabel_confusion_matrix

import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc

%matplotlib inline  
%config InlineBackend.figure_format='retina'

RANDOM_SEED = 42

sns.set(style='whitegrid', palette='muted', font_scale=1.2)
HAPPY_COLORS_PALETTE = ["#01BEFE", "#FFDD00", "#FF7D00", "#FF006D", "#ADFF02", "#8F00FF"]
sns.set_palette(sns.color_palette(HAPPY_COLORS_PALETTE))
rcParams['figure.figsize'] = 12, 8

pl.seed_everything(RANDOM_SEED)

Global seed set to 42


42

In [None]:
# import dataset
df = pd.read_csv('/content/drive/MyDrive/Biocreative/Biocreative/preprocessed_version1.csv')
df = df.dropna()
df.head()

Unnamed: 0,abstract,Case Report,Diagnosis,Epidemic Forecasting,Mechanism,Prevention,Transmission,Treatment
0,December 2019 new highly contagious infectious...,0,0,0,1,0,0,1
1,"novel coronavirus disease COVID-19 , transmitt...",0,0,0,0,1,0,1
2,BACKGROUND December 2019 novel coronavirus SAR...,1,0,0,0,0,0,0
3,coronavirus disease 2019 COVID-19 pandemic imp...,0,0,0,0,1,0,0
4,OBJECTIVES Sofosbuvir daclatasvir direct-actin...,0,0,0,0,0,0,1


In [None]:
train_df, val_df = train_test_split(df, test_size=0.1)
train_df.shape, val_df.shape

((22461, 8), (2496, 8))

In [None]:
# preprocess dataset
# df = df.drop(['pmid',	'journal',	'title', 'doi',	'label', 'keywords', 'pub_type', 'authors'], axis=1)
# df.head()

In [None]:
LABEL_COLUMNS = df.columns.tolist()[1:]

In [None]:
print(LABEL_COLUMNS)

['Case Report', 'Diagnosis', 'Epidemic Forecasting', 'Mechanism', 'Prevention', 'Transmission', 'Treatment']


In [None]:
BERT_MODEL_NAME = 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext'
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/337 [00:00<?, ?B/s]

In [None]:
MAX_TOKEN_COUNT = 512

In [None]:
class TopicAnnotationDataset(Dataset):

  def __init__(
    self,
    data: pd.DataFrame,
    tokenizer: BertTokenizer,
    max_token_len: int = 128
  ):
    self.tokenizer = tokenizer
    self.data = data
    self.max_token_len = max_token_len

  def __len__(self):
    return len(self.data)

  def __getitem__(self, index=int):
    
    data_row = self.data.iloc[index]

    abstract_text = data_row.abstract
    labels = data_row[LABEL_COLUMNS]

    encoding = self.tokenizer.encode_plus(
        abstract_text,
        add_special_tokens=True,
        max_length=self.max_token_len,
        return_token_type_ids=False,
        padding="max_length",
        truncation=True,
        return_attention_mask=True,
        return_tensors='pt',
    )

    return dict(
        abstract_text=abstract_text,
        input_ids=encoding["input_ids"].flatten(),
        attention_mask=encoding["attention_mask"].flatten(),
        labels=torch.FloatTensor(labels) 
    )  

In [None]:
train_dataset = TopicAnnotationDataset(
    train_df,
    tokenizer,
    max_token_len=MAX_TOKEN_COUNT
)

In [None]:
class TopicAnnotationDataModule(pl.LightningDataModule):

  def __init__(self, train_df, test_df, tokenizer, batch_size=8, max_token_len=128):
    
    super().__init__()
    self.batch_size = batch_size
    self.train_df = train_df
    self.test_df = test_df
    self.tokenizer = tokenizer
    self.max_token_len = max_token_len

  def setup(self, stage=None):
    
    self.train_dataset =  TopicAnnotationDataset(
        self.train_df,
        self.tokenizer,
        self.max_token_len
    )

    self.test_dataset = TopicAnnotationDataset(
        self.test_df,
        self.tokenizer,
        self.max_token_len
    )

  def train_dataloader(self):

    return DataLoader(
      self.train_dataset,
      batch_size=self.batch_size,
      shuffle=True,
      num_workers=2
    )

  def val_dataloader(self):

    return DataLoader(
      self.test_dataset,
      batch_size=self.batch_size,
      num_workers=2
    )

  def test_dataloader(self):
    
    return DataLoader(
      self.test_dataset,
      batch_size=self.batch_size,
      num_workers=2
    )

In [None]:
N_EPOCHS = 10
BATCH_SIZE = 16

data_module = TopicAnnotationDataModule(
    train_df,
    val_df,
    tokenizer,
    batch_size=BATCH_SIZE,
    max_token_len=MAX_TOKEN_COUNT
)

In [None]:
class TopicAnnotationTagger(pl.LightningModule):

  def __init__(self, n_classes: int, label_embed: torch.Tensor, n_training_steps=None, n_warmup_steps=None):
    
    super().__init__()
    self.bert = BertModel.from_pretrained(BERT_MODEL_NAME, return_dict=True)
    # self.classifier = nn.Linear(self.bert.config.hidden_size, n_classes)
    self.n_training_steps = n_training_steps
    self.n_warmup_steps = n_warmup_steps
    self.criterion = nn.BCELoss()
    self.hidden_size = self.bert.config.hidden_size
    self.d_a = 200
    self.n_classes = n_classes

    self.linear_first = torch.nn.Linear(self.hidden_size, self.d_a)
    self.linear_second = torch.nn.Linear(self.d_a, n_classes)

    self.weight1 = torch.nn.Linear(self.hidden_size, 1)
    self.weight2 = torch.nn.Linear(self.hidden_size, 1)

    self.output_layer = torch.nn.Linear(self.hidden_size, n_classes)
    self.embedding_dropout = torch.nn.Dropout(p=0.3)

    self.label_embed = label_embed

  def forward(self, input_ids, attention_mask, labels=None):

    # step1 bert_model
    output = self.bert(input_ids, attention_mask=attention_mask)

    encoded_output = output.last_hidden_state
    batch_size = encoded_output.shape[0]

    # step2 self-attention
    selfatt = torch.tanh(self.linear_first(encoded_output))
    selfatt = self.linear_second(selfatt)
    selfatt = F.softmax(selfatt, dim=1)
    selfatt = selfatt.transpose(1,2)
    self_att = torch.bmm(selfatt, encoded_output) 

    # step3 label-attention
    M = torch.bmm(self.label_embed.expand(batch_size, self.n_classes, self.hidden_size), encoded_output.transpose(1, 2))
    label_att = torch.bmm(M, encoded_output)

    # step4 Adaptive fusion
    alpha = torch.sigmoid(self.weight1(label_att))
    beta = torch.sigmoid(self.weight1(self_att))
    ## alpha + beta = 1
    alpha = alpha/(alpha+beta)
    beta = 1-alpha
    
    doc = alpha*label_att + beta*self_att

    # step5 pooled version
    avg_sentence_embeddings = torch.sum(doc, 1)/self.n_classes

    output = torch.sigmoid(self.output_layer(avg_sentence_embeddings))
    loss = 0
    if labels is not None:
      loss = self.criterion(output, labels)
    return loss, output

  def training_step(self, batch, batch_idx):

    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["labels"]

    loss, outputs = self(input_ids, attention_mask, labels)
    self.log("train_loss", loss, prog_bar=True, logger=True)
    return {"loss": loss, "predictions": outputs, "labels": labels}

  def validation_step(self, batch, batch_idx):

    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["labels"]
    loss, outputs = self(input_ids, attention_mask, labels)
    self.log("val_loss", loss, prog_bar=True, logger=True)
    return loss

  def test_step(self, batch, batch_idx):
    
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["labels"]
    loss, outputs = self(input_ids, attention_mask, labels)
    self.log("test_loss", loss, prog_bar=True, logger=True)
    return loss

  def training_epoch_end(self, outputs):

    labels = []
    predictions = []
    for output in outputs:
      for out_labels in output["labels"].detach().cpu():
        labels.append(out_labels)
      for out_predictions in output["predictions"].detach().cpu():
        predictions.append(out_predictions)

    labels = torch.stack(labels).int()
    predictions = torch.stack(predictions)

    for i, name in enumerate(LABEL_COLUMNS):
      class_roc_auc = auroc(predictions[:, i], labels[:, i])
      self.logger.experiment.add_scalar(f"{name}_roc_auc/Train", class_roc_auc, self.current_epoch)

  def configure_optimizers(self):

    optimizer = AdamW(self.parameters(), lr=2e-5)

    scheduler = get_linear_schedule_with_warmup(
      optimizer,
      num_warmup_steps=self.n_warmup_steps,
      num_training_steps=self.n_training_steps
    )

    return dict(
        optimizer=optimizer,
      lr_scheduler=dict(
        scheduler=scheduler,
        interval='step'
      )
    )

In [None]:
steps_per_epoch=len(train_df) // BATCH_SIZE
total_training_steps = steps_per_epoch * N_EPOCHS

In [None]:
warmup_steps = total_training_steps // 5
warmup_steps, total_training_steps

(2806, 14030)

In [None]:
label_embed = torch.load('/content/drive/MyDrive/Biocreative/Biocreative/Label_embed.pt',  map_location=torch.device('cuda'))
label_embed.shape

torch.Size([7, 768])

In [None]:
model = TopicAnnotationTagger(
    n_classes=len(LABEL_COLUMNS),
    label_embed=label_embed,
    n_warmup_steps=warmup_steps,
    n_training_steps=total_training_steps
)

Downloading:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
%pwd

'/content'

In [None]:
cd '/content/drive/MyDrive/Biocreative/Biocreative'

/content/drive/MyDrive/Biocreative/Biocreative


In [None]:
checkpoint_callback = ModelCheckpoint(
    dirpath="lsan-checkpoints",
    filename="best-checkpoint",
    save_top_k=1,
    verbose=True,
    monitor="val_loss",
    mode="min"
)

In [None]:
logger = TensorBoardLogger("lsan-lightning_logs", name="topic-annotations")

In [None]:
early_stopping_callback = EarlyStopping(monitor='val_loss', patience=2)

In [None]:
trainer = pl.Trainer(
    logger=logger,
    checkpoint_callback=True,
    callbacks=[checkpoint_callback, early_stopping_callback],
    max_epochs=N_EPOCHS,
    gpus=1,
    progress_bar_refresh_rate=30
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [None]:
trainer.fit(model, data_module)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type      | Params
------------------------------------------------
0 | bert              | BertModel | 109 M 
1 | criterion         | BCELoss   | 0     
2 | linear_first      | Linear    | 153 K 
3 | linear_second     | Linear    | 1.4 K 
4 | weight1           | Linear    | 769   
5 | weight2           | Linear    | 769   
6 | output_layer      | Linear    | 5.4 K 
7 | embedding_dropout | Dropout   | 0     
------------------------------------------------
109 M     Trainable params
0         Non-trainable params
109 M     Total params
438.577   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Global seed set to 42


Training: -1it [00:00, ?it/s]

  f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically"


Validating: 0it [00:00, ?it/s]

  stream(template_mgs % msg_args)
Epoch 0, global step 1403: val_loss reached 0.66797 (best 0.66797), saving model to "/content/drive/MyDrive/Biocreative/Biocreative/lsan-checkpoints/best-checkpoint.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 1, global step 2807: val_loss reached 0.41326 (best 0.41326), saving model to "/content/drive/MyDrive/Biocreative/Biocreative/lsan-checkpoints/best-checkpoint.ckpt" as top 1


Validating: 0it [00:00, ?it/s]

Epoch 2, global step 4211: val_loss reached 0.26633 (best 0.26633), saving model to "/content/drive/MyDrive/Biocreative/Biocreative/lsan-checkpoints/best-checkpoint.ckpt" as top 1


### Test

In [None]:
trained_model = TopicAnnotationTagger.load_from_checkpoint('/content/drive/MyDrive/Biocreative/Biocreative/checkpoints/best-checkpoint-v1.ckpt', n_classes=7)

In [None]:
trained_model.freeze()

### Evaluation

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
trained_model = trained_model.to(device)

val_dataset = TopicAnnotationDataset(
    val_df,
    tokenizer,
    max_token_len=MAX_TOKEN_COUNT
)

In [None]:
predictions = []
labels = []

for item in tqdm(val_dataset):

  _, prediction = trained_model(
      item["input_ids"].unsqueeze(dim=0).to(device),
      item["attention_mask"].unsqueeze(dim=0).to(device)    
  )
  
  predictions.append(prediction.flatten())
  labels.append(item["labels"].int())

predictions = torch.stack(predictions).detach().cpu()
labels = torch.stack(labels).detach().cpu()

#### Accuracy

In [None]:
THRESHOLD = 0.9

In [None]:
# calculating the accuracy of the model
accuracy(predictions, labels, threshold=THRESHOLD)

#### ROC for each tag

In [None]:
print("AUROC per tag")
for i, name in enumerate(LABEL_COLUMNS):
  tag_auroc = auroc(predictions[:, i], labels[:, i], pos_label=1)
  print(f"{name}: {tag_auroc}")

#### classification report for each class

In [None]:
y_pred = predictions.numpy()
y_true = labels.numpy()

upper, lower = 1, 0

y_pred = np.where(y_pred > THRESHOLD, upper, lower)

print(
    classification_report(
        y_true,
        y_pred,
        target_names=LABEL_COLUMNS,
        zero_division=0
    )
)