In [None]:
# important libraries
!pip install --quiet transformers
# !pip install --quiet pytorch-lightning
!pip install --quiet git+https://github.com/PyTorchLightning/pytorch-lightning
!pip install --quiet torchmetrics

In [None]:
# setup g-drive
from google.colab import drive
drive.mount('/content/drive')

In [3]:
import warnings
warnings.filterwarnings('ignore')

In [4]:
# # check for gpu
# !nvidia-smi

In [5]:
import pandas as pd
import numpy as np

from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup
import torchmetrics as torchmetrics
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
from sklearn.metrics import classification_report, multilabel_confusion_matrix, f1_score

import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc

In [None]:
# configuration
%matplotlib inline  
%config InlineBackend.figure_format='retina'

RANDOM_SEED = 42

sns.set(style='whitegrid', palette='muted', font_scale=1.2)
HAPPY_COLORS_PALETTE = ["#01BEFE", "#FFDD00", "#FF7D00", "#FF006D", "#ADFF02", "#8F00FF"]
sns.set_palette(sns.color_palette(HAPPY_COLORS_PALETTE))
rcParams['figure.figsize'] = 12, 8

pl.seed_everything(RANDOM_SEED)

In [None]:
# install data  
!wget https://raw.githubusercontent.com/chanzuckerberg/DRSM-corpus/main/labeled_data_2022_01_03.tsv

In [9]:
# retrieve data
def get_data(path_to_file: str, test_size=0.15):
    # read tsv file
    df = pd.read_csv(path_to_file, sep="\t")
    df = df[['TRIMMED_TEXT', 'Correct_Label']]
    df.dropna(inplace=True) 

    # label columns
    LABEL_COLUMNS = df['Correct_Label'].unique().tolist()
    # one-hot encoding
    df = pd.get_dummies(df, columns=['Correct_Label'])
    # rename columns
    rename_dict = dict((item1, item2) for item1, item2 in zip(df.columns[1:].tolist(), LABEL_COLUMNS))
    df.rename(columns=rename_dict, inplace=True, errors="raise")
    # split dataset
    train_df, test_df = train_test_split(df, test_size=test_size)

    return train_df, test_df, LABEL_COLUMNS

In [10]:
train_df, test_df, LABEL_COLUMNS = get_data(path_to_file="/content/labeled_data_2022_01_03.tsv")

In [11]:
# initialize PubMedBERT
BERT_MODEL_NAME = "allenai/specter"
tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_NAME)

Downloading:   0%|          | 0.00/321 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/612 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/217k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [12]:
MAX_TOKEN_COUNT = 512
N_EPOCHS = 10
BATCH_SIZE = 4

In [13]:
class TopicAnnotationDataset(Dataset):

  def __init__(
    self,
    data: pd.DataFrame,
    tokenizer: AutoTokenizer,
    max_token_len: int = 512,
  ):
    self.tokenizer = tokenizer
    self.data = data
    self.max_token_len = max_token_len
    
  def __len__(self):
    return len(self.data)

  def __getitem__(self, index=int):
    
    # extract row using index
    data_row = self.data.iloc[index]

    # get data (text, labels)
    text = data_row['TRIMMED_TEXT']
    labels = data_row[LABEL_COLUMNS]

    # apply tokenization
    inputs = self.tokenizer.encode_plus(
        text, 
        max_length=self.max_token_len,
        padding="max_length", 
        truncation=True, 
        return_tensors="pt", 
    )

    return dict(
        text=text,
        input_ids=inputs["input_ids"].flatten(),
        attention_mask=inputs["attention_mask"].flatten(),
        labels=torch.FloatTensor(labels) 
    )  

In [14]:
train_dataset = TopicAnnotationDataset(
    train_df,
    tokenizer,
)

In [15]:
class TopicAnnotationDataModule(pl.LightningDataModule):

  def __init__(self, train_df, test_df, tokenizer, batch_size=8, max_token_len=128):
    
    super().__init__()
    self.batch_size = batch_size
    self.train_df = train_df
    self.test_df = test_df
    self.tokenizer = tokenizer
    self.max_token_len = max_token_len

  def setup(self, stage=None):
    
    # setup train data
    self.train_dataset =  TopicAnnotationDataset(
        data=self.train_df,
        tokenizer=self.tokenizer,
        max_token_len=self.max_token_len
    )

    # setup test data
    self.test_dataset = TopicAnnotationDataset(
        data=self.test_df,
        tokenizer=self.tokenizer,
        max_token_len=self.max_token_len
    )

  def train_dataloader(self):

    return DataLoader(
        self.train_dataset,
        batch_size=self.batch_size,
        shuffle=True,
        num_workers=8
    )

  def val_dataloader(self):

    return DataLoader(
        self.test_dataset,
        batch_size=self.batch_size,
        num_workers=8
    )

  def test_dataloader(self):
    
    return DataLoader(
        self.test_dataset,
        batch_size=self.batch_size,
        num_workers=8
    )

In [16]:
data_module = TopicAnnotationDataModule(
    train_df,
    test_df,
    tokenizer,
    batch_size=BATCH_SIZE,
    max_token_len=MAX_TOKEN_COUNT
)

In [17]:
# calculate class-wise weights
label_count = train_df[LABEL_COLUMNS].sum().to_dict()
count = list(label_count.values())
max_val = max(count)
class_weight = [max_val/val for val in count]
# transfer to gpu
class_weight = torch.tensor(class_weight, device="cuda")

In [18]:
class TopicAnnotationTagger(pl.LightningModule):

  def __init__(self, n_classes: int, n_training_steps=None, n_warmup_steps=None):
    
    super().__init__()
    self.bert = AutoModel.from_pretrained(BERT_MODEL_NAME, return_dict=True)
    
    self.n_training_steps = n_training_steps
    self.n_warmup_steps = n_warmup_steps
    self.criterion = nn.BCELoss(weight=class_weight)
    self.fc = nn.Linear(self.bert.config.hidden_size, n_classes)
    self.relu = nn.ReLU()
    self.softmax = nn.Softmax(dim=1)
    self.classifier = nn.Linear(self.bert.config.hidden_size, 1)
    self.sigmoid = nn.Sigmoid()

  def forward(self, input_ids, attention_mask, labels=None):

    bert_outputs = self.bert(input_ids, attention_mask=attention_mask)
    # last_hidden_state (batch_size x sequence_length x hidden_size)
    encoded_output = bert_outputs.last_hidden_state

    # linear transformation and apply relu 
    # (batch_size x sequence_length x hidden_size) => (batch_size, sequence_length x n_classes)
    energy = self.relu(self.fc(encoded_output))
    
    # calculate attention weights 
    # (batch_size x sequence_length x n_classes)
    attention = self.softmax(energy)

    # apply attention weigths to encoded_output 
    # (batch_size x sequence_length x n_classes) * (batch_size x sequence_length x hidden_size) => (batch_size x n_classes x hidden_size)
    context_vector = torch.einsum("nsk,nsl->nkl", attention, encoded_output)
    # apply sigmoid function to context vector
    output = self.sigmoid(self.classifier(context_vector).squeeze(-1))
    # output = self.classifier(output.pooler_output)
    # output = torch.sigmoid(output)
    loss = 0
    if labels is not None:
      loss = self.criterion(output, labels)
    return loss, output

  def training_step(self, batch, batch_idx):

    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["labels"]

    loss, outputs = self(input_ids, attention_mask, labels)
    self.log("train_loss", loss, prog_bar=True, logger=True)
    return {"loss": loss, "predictions": outputs, "labels": labels}

  def validation_step(self, batch, batch_idx):

    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["labels"]
    loss, outputs = self(input_ids, attention_mask, labels)
    self.log("val_loss", loss, prog_bar=True, logger=True)
    return loss

  def test_step(self, batch, batch_idx):
    
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["labels"]
    loss, outputs = self(input_ids, attention_mask, labels)
    self.log("test_loss", loss, prog_bar=True, logger=True)
    return loss

  def training_epoch_end(self, outputs):

    labels = []
    predictions = []
    for output in outputs:
      for out_labels in output["labels"].detach().cpu():
        labels.append(out_labels)
      for out_predictions in output["predictions"].detach().cpu():
        predictions.append(out_predictions)

    labels = torch.stack(labels).int()
    predictions = torch.stack(predictions)

    for i, name in enumerate(LABEL_COLUMNS):
      auroc = torchmetrics.AUROC(num_classes=len(LABEL_COLUMNS))
      class_roc_auc = auroc(predictions[:, i], labels[:, i])
      self.logger.experiment.add_scalar(f"{name}_roc_auc/Train", class_roc_auc, self.current_epoch)

  def configure_optimizers(self):

    optimizer = AdamW(self.parameters(), lr=2e-5)

    scheduler = get_linear_schedule_with_warmup(
      optimizer,
      num_warmup_steps=self.n_warmup_steps,
      num_training_steps=self.n_training_steps
    )

    return dict(
        optimizer=optimizer,
        lr_scheduler=dict(
            scheduler=scheduler,
            interval='step'
            )
        )

In [19]:
# calculate warmup steps and total training steps
steps_per_epoch=len(train_df) // BATCH_SIZE
total_training_steps = steps_per_epoch * N_EPOCHS
warmup_steps = total_training_steps // 5
warmup_steps, total_training_steps

(3790, 18950)

----
#### Train

In [None]:
# model = TopicAnnotationTagger(
#     n_classes=len(LABEL_COLUMNS),
#     n_warmup_steps=warmup_steps,
#     n_training_steps=total_training_steps
# )

In [None]:
# cd '/content/drive/MyDrive/Nidhir_Akash_Biocreative/'

In [23]:
# # set-up checkpoints annd logs directory
# checkpoint_callback = ModelCheckpoint(
#     dirpath="model_4_checkpoints",
#     filename="model_4_checkpoints",
#     save_top_k=1,
#     verbose=True,
#     monitor="val_loss",
#     mode="min"
# )
# logger = TensorBoardLogger("model_4_logs", name="topic-annotations")
# # early-stopping criterion
# early_stopping_callback = EarlyStopping(monitor='val_loss', patience=2)

In [None]:
# # setup trainer and add training arguments
# trainer = pl.Trainer(
#     logger=logger,
#     checkpoint_callback=True,
#     callbacks=[checkpoint_callback, early_stopping_callback],
#     max_epochs=N_EPOCHS,
#     gpus=1,
#     progress_bar_refresh_rate=30
# )

In [None]:
# trainer.fit(model, data_module)

### Evaluation

In [26]:
path_to_model_checkpoints = '/content/drive/MyDrive/Nidhir_Akash_Biocreative/model_4_checkpoints/model_4_checkpoints.ckpt'
trained_model = TopicAnnotationTagger.load_from_checkpoint(path_to_model_checkpoints, n_classes=6)
trained_model.freeze()

In [27]:
# transfer model to gpu if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
trained_model = trained_model.to(device)

# create test data
test_dataset = TopicAnnotationDataset(
    test_df,
    tokenizer,
    max_token_len=MAX_TOKEN_COUNT
)

In [28]:
predictions = []
labels = []

# generate predictions 
for item in tqdm(test_dataset):

  _, prediction = trained_model(
      item["input_ids"].unsqueeze(dim=0).to(device),
      item["attention_mask"].unsqueeze(dim=0).to(device),
  )
  predictions.append(prediction.flatten())
  labels.append(item["labels"].int())

predictions = torch.stack(predictions).detach().cpu()
labels = torch.stack(labels).detach().cpu()

  0%|          | 0/1338 [00:00<?, ?it/s]

In [29]:
# calculate accuracy
threshold = 0.5
f1_score(labels, predictions>threshold, average=None)

array([0.92983939, 0.94209354, 0.26086957, 0.56862745, 0.78350515,
       0.87234043])

In [30]:
# AUROC
print("AUROC per tag")
for i, name in enumerate(LABEL_COLUMNS):
  auroc = torchmetrics.AUROC(num_classes=len(LABEL_COLUMNS))
  tag_auroc = auroc(predictions[:, i], labels[:, i])
  print(f"{name}: {tag_auroc}")

AUROC per tag
clinical characteristics or disease pathology: 0.9777472615242004
other: 0.991519033908844
disease mechanism: 0.9375695586204529
therapeutics in the clinic: 0.9445956945419312
irrelevant: 0.9948046803474426
patient-based therapeutics: 0.9857392311096191


In [31]:
# classification report
y_pred = predictions.numpy()
y_true = labels.numpy()

upper, lower = 1, 0

y_pred = np.where(y_pred > 0.5, upper, lower)

print(
    classification_report(
        y_true,
        y_pred,
        digits=4,
        target_names=LABEL_COLUMNS,
        zero_division=0
    )
)

                                               precision    recall  f1-score   support

clinical characteristics or disease pathology     0.9061    0.9549    0.9298       576
                                        other     0.9358    0.9484    0.9421       446
                            disease mechanism     0.5000    0.1765    0.2609        17
                   therapeutics in the clinic     0.5800    0.5577    0.5686        52
                                   irrelevant     0.8261    0.7451    0.7835        51
                   patient-based therapeutics     0.9111    0.8367    0.8723       196

                                    micro avg     0.9001    0.9021    0.9011      1338
                                    macro avg     0.7765    0.7032    0.7262      1338
                                 weighted avg     0.8959    0.9021    0.8974      1338
                                  samples avg     0.8901    0.9021    0.8941      1338

