In [None]:
# important libraries
!pip install --quiet transformers
!pip install --quiet pytorch-lightning
!pip install --quiet torchmetrics

In [1]:
# # setup g-drive
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
# import warnings
# warnings.filterwarnings('ignore')

In [1]:
# check for gpu
!nvidia-smi

Mon Jan 17 08:43:33 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.46       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   35C    P0    23W / 300W |      0MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [4]:
import pandas as pd
import numpy as np

from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, AdamW, get_linear_schedule_with_warmup
import pytorch_lightning as pl
import torchmetrics as torchmetrics
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
from sklearn.metrics import classification_report, multilabel_confusion_matrix, f1_score

import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc

In [None]:
# configuration
%matplotlib inline  
%config InlineBackend.figure_format='retina'

RANDOM_SEED = 42

sns.set(style='whitegrid', palette='muted', font_scale=1.2)
HAPPY_COLORS_PALETTE = ["#01BEFE", "#FFDD00", "#FF7D00", "#FF006D", "#ADFF02", "#8F00FF"]
sns.set_palette(sns.color_palette(HAPPY_COLORS_PALETTE))
rcParams['figure.figsize'] = 12, 8

pl.seed_everything(RANDOM_SEED)

In [None]:
# install data
!wget https://raw.githubusercontent.com/chanzuckerberg/DRSM-corpus/main/labeled_data_2022_01_03.tsv

In [9]:
# retrieve data
def get_data(path_to_file: str, test_size=0.15):
    # read tsv file
    df = pd.read_csv(path_to_file, sep="\t")
    df = df[['TRIMMED_TEXT', 'Correct_Label']]
    df.dropna(inplace=True) 

    # label columns
    LABEL_COLUMNS = df['Correct_Label'].unique().tolist()
    # one-hot encoding
    df = pd.get_dummies(df, columns=['Correct_Label'])
    # rename columns
    rename_dict = dict((item1, item2) for item1, item2 in zip(df.columns[1:].tolist(), LABEL_COLUMNS))
    df.rename(columns=rename_dict, inplace=True, errors="raise")
    # split dataset
    train_df, test_df = train_test_split(df, test_size=test_size)

    return train_df, test_df, LABEL_COLUMNS

In [10]:
train_df, test_df, LABEL_COLUMNS = get_data(path_to_file="/content/labeled_data_2022_01_03.tsv")

In [None]:
# initialize SPECTER
model_name = "allenai/specter"
tokenizer = AutoTokenizer.from_pretrained(model_name)

Downloading:   0%|          | 0.00/321 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/612 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/217k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [None]:
print("Label wise data samples: ")
train_df[LABEL_COLUMNS].sum().sort_values()

Label wise data samples: 


disease mechanism                                  91
therapeutics in the clinic                        289
irrelevant                                        291
patient-based therapeutics                        968
other                                            2354
clinical characteristics or disease pathology    3588
dtype: int64

In [None]:
MAX_TOKEN_COUNT = 512
N_EPOCHS = 10
BATCH_SIZE = 4

In [None]:
class TopicAnnotationDataset(Dataset):

  def __init__(
    self,
    data: pd.DataFrame,
    tokenizer: AutoTokenizer,
    max_token_len: int = 512,
  ):
    self.tokenizer = tokenizer
    self.data = data
    self.max_token_len = max_token_len
    
  def __len__(self):
    return len(self.data)

  def __getitem__(self, index=int):
    
    # extract row using index
    data_row = self.data.iloc[index]

    # get data (text, labels)
    text = data_row['TRIMMED_TEXT']
    labels = data_row[LABEL_COLUMNS]

    # apply tokenization
    inputs = self.tokenizer.encode_plus(
        text, 
        max_length=self.max_token_len,
        padding="max_length", 
        truncation=True, 
        return_tensors="pt", 
    )

    return dict(
        text=text,
        input_ids=inputs["input_ids"].flatten(),
        attention_mask=inputs["attention_mask"].flatten(),
        labels=torch.FloatTensor(labels) 
    )  

In [None]:
train_dataset = TopicAnnotationDataset(
    train_df,
    tokenizer=tokenizer,
)

In [None]:
class TopicAnnotationDataModule(pl.LightningDataModule):

  def __init__(self, train_df, test_df, tokenizer, batch_size=8, max_token_len=128):
    
    super().__init__()
    self.batch_size = batch_size
    self.train_df = train_df
    self.test_df = test_df
    self.tokenizer = tokenizer
    self.max_token_len = max_token_len

  def setup(self, stage=None):
    
    # setup train data
    self.train_dataset =  TopicAnnotationDataset(
        data=self.train_df,
        tokenizer=self.tokenizer,
        max_token_len=self.max_token_len
    )

    # setup test data
    self.test_dataset = TopicAnnotationDataset(
        data=self.test_df,
        tokenizer=self.tokenizer,
        max_token_len=self.max_token_len
    )

  def train_dataloader(self):

    return DataLoader(
        self.train_dataset,
        batch_size=self.batch_size,
        shuffle=True,
        num_workers=2
    )

  def val_dataloader(self):

    return DataLoader(
        self.test_dataset,
        batch_size=self.batch_size,
        num_workers=2
    )

  def test_dataloader(self):
    
    return DataLoader(
        self.test_dataset,
        batch_size=self.batch_size,
        num_workers=2
    )

In [None]:
data_module = TopicAnnotationDataModule(
    train_df,
    test_df,
    tokenizer,
    batch_size=BATCH_SIZE,
    max_token_len=MAX_TOKEN_COUNT
)

In [None]:
# calculate class-wise weights
label_count = train_df[LABEL_COLUMNS].sum().to_dict()
count = list(label_count.values())
max_val = max(count)
class_weight = [max_val/val for val in count]
# transfer to gpu
class_weight = torch.tensor(class_weight, device="cuda")

In [None]:
# fully-functional networks module
class FFN(nn.Module):
  def __init__(self, in_feat, out_feat, dropout):
      super(FFN, self).__init__()
      self.in2hid = nn.Linear(in_feat, in_feat)
      self.hid2out = nn.Linear(in_feat, out_feat)

      self.activation = nn.ReLU()
      self.dropout = nn.Dropout(dropout)

  def forward(self, input):
      hid = self.activation(self.dropout(self.in2hid(input)))
      return self.hid2out(hid)

In [None]:
# self-attention module
class Attention(nn.Module):
  def __init__(self, dimensions, attention_type="general"):
    super(Attention, self).__init__()

    if attention_type not in ['dot', 'general']:
      raise ValueError('Invalid attention type selected.')
    
    self.attention_type = attention_type
    if self.attention_type == "general":
      self.linear_in = nn.Linear(dimensions, dimensions, bias=False)

    self.linear_out = nn.Linear(dimensions * 2, dimensions, bias=False)
    self.softmax = nn.Softmax(dim=-1)
    self.tanh = nn.Tanh()

  def forward(self, query, context):
    batch_size, output_len, dimensions = query.size()
    query_len = context.size(1)

    if self.attention_type == "general":
        query = query.reshape(batch_size * output_len, dimensions)
        query = self.linear_in(query)
        query = query.reshape(batch_size, output_len, dimensions)

    # (batch_size, output_len, dimensions) * (batch_size, query_len, dimensions) ->
    # (batch_size, output_len, query_len)
    attention_scores = torch.bmm(query, context.transpose(1, 2).contiguous())

    # Compute weights across every context sequence
    attention_scores = attention_scores.view(batch_size * output_len, query_len)
    attention_weights = self.softmax(attention_scores)
    attention_weights = attention_weights.view(batch_size, output_len, query_len)

    # (batch_size, output_len, query_len) * (batch_size, query_len, dimensions) ->
    # (batch_size, output_len, dimensions)
    mix = torch.bmm(attention_weights, context)

    # concat -> (batch_size * output_len, 2*dimensions)
    combined = torch.cat((mix, query), dim=2)
    combined = combined.view(batch_size * output_len, 2 * dimensions)

    # Apply linear_out on every 2nd dimension of concat
    # output -> (batch_size, output_len, dimensions)
    output = self.linear_out(combined).view(batch_size, output_len, dimensions)
    output = self.tanh(output)

    return output, attention_weights

In [None]:
class TopicAnnotationTagger(pl.LightningModule):

  def __init__(self, n_classes: int, n_training_steps=None, n_warmup_steps=None):
    
    super().__init__()

    # specter embedding model
    self.specter = AutoModel.from_pretrained(model_name, return_dict=True) 

    self.n_training_steps = n_training_steps
    self.n_warmup_steps = n_warmup_steps
    self.criterion = nn.BCELoss(weight=class_weight)
    self.sequence_length = 512
    self.fc = nn.Linear(self.specter.config.hidden_size, n_classes)
    self.relu = nn.ReLU()
    self.softmax = nn.Softmax(dim=1)
    self.classifier = nn.Linear(self.sequence_length, 1)
    self.sigmoid = nn.Sigmoid()

    self.attention = Attention(self.specter.config.hidden_size)
    self.hidden2labels = nn.Linear(self.specter.config.hidden_size, n_classes)

  def forward(self, input_ids, attention_mask, labels=None):

    specter_output = self.specter(input_ids, attention_mask)
    # [batch_size x sequence_length x hidden_size]
    specter_output = specter_output.last_hidden_state
    
    specter_output, x = self.attention(specter_output, specter_output)
    specter_output, x = self.attention(specter_output, specter_output)

    # Transform data [batch_size x sequence_length x hidden_size] => [batch_size x sequence_length x n_classes]
    specter_output = self.hidden2labels(specter_output).transpose(1,2)
    # apply sigmoid function to the same
    output = self.sigmoid(self.classifier(specter_output).squeeze(-1))

    loss = 0
    if labels is not None:
      loss = self.criterion(output, labels)
    return loss, output

  def training_step(self, batch, batch_idx):

    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["labels"]

    loss, outputs = self(input_ids, attention_mask, labels)
    self.log("train_loss", loss, prog_bar=True, logger=True)
    return {"loss": loss, "predictions": outputs, "labels": labels}

  def validation_step(self, batch, batch_idx):

    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["labels"]

    loss, outputs = self(input_ids, attention_mask, labels)
    self.log("val_loss", loss, prog_bar=True, logger=True)
    return loss

  def test_step(self, batch, batch_idx):
    
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["labels"]

    loss, outputs = self(input_ids, attention_mask, labels)
    self.log("test_loss", loss, prog_bar=True, logger=True)
    return loss

  def training_epoch_end(self, outputs):

    labels = []
    predictions = []
    for output in outputs:
      for out_labels in output["labels"].detach().cpu():
        labels.append(out_labels)
      for out_predictions in output["predictions"].detach().cpu():
        predictions.append(out_predictions)

    labels = torch.stack(labels).int()
    predictions = torch.stack(predictions)

    for i, name in enumerate(LABEL_COLUMNS):
      auroc = torchmetrics.AUROC(num_classes=len(LABEL_COLUMNS))
      class_roc_auc = auroc(predictions[:, i], labels[:, i])
      self.logger.experiment.add_scalar(f"{name}_roc_auc/Train", class_roc_auc, self.current_epoch)

  def configure_optimizers(self):

    optimizer = AdamW(self.parameters(), lr=2e-5)

    scheduler = get_linear_schedule_with_warmup(
      optimizer,
      num_warmup_steps=self.n_warmup_steps,
      num_training_steps=self.n_training_steps
    )

    return dict(
        optimizer=optimizer,
        lr_scheduler=dict(
            scheduler=scheduler,
            interval='step'
            )
        )

In [None]:
# calculate warmup steps and total training steps
steps_per_epoch=len(train_df) // BATCH_SIZE
total_training_steps = steps_per_epoch * N_EPOCHS
warmup_steps = total_training_steps // 5
warmup_steps, total_training_steps

(3790, 18950)

----
#### Train

In [None]:
# model = TopicAnnotationTagger(
#     n_classes=len(LABEL_COLUMNS),
#     n_warmup_steps=warmup_steps,
#     n_training_steps=total_training_steps
# )

In [None]:
# cd '/content/drive/MyDrive/Nidhir_Akash_Biocreative/'

/content/drive/MyDrive/Nidhir_Akash_Biocreative


In [None]:
# # set-up checkpoints annd logs directory
# checkpoint_callback = ModelCheckpoint(
#     dirpath="model_1_checkpoints",
#     filename="model_1_checkpoints",
#     save_top_k=1,
#     verbose=True,
#     monitor="val_loss",
#     mode="min"
# )
# logger = TensorBoardLogger("model_1_logs", name="topic-annotations")

# # early-stopping criterion
# early_stopping_callback = EarlyStopping(monitor='val_loss', patience=2)

In [None]:
# # setup trainer and add training arguments
# trainer = pl.Trainer(
#     logger=logger,
#     checkpoint_callback=True,
#     callbacks=[checkpoint_callback, early_stopping_callback],
#     max_epochs=N_EPOCHS,
#     gpus=1,
#     progress_bar_refresh_rate=30
# )

In [None]:
# train model
# trainer.fit(model, data_module)

-----
#### Evaluation

In [None]:
path_to_model_checkpoints = '/content/drive/MyDrive/Nidhir_Akash_Biocreative/model_1_checkpoints/model_1_checkpoints.ckpt'
trained_model = TopicAnnotationTagger.load_from_checkpoint(path_to_model_checkpoints, n_classes=6)
trained_model.freeze()

Downloading:   0%|          | 0.00/419M [00:00<?, ?B/s]

In [None]:
# transfer model to gpu if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
trained_model = trained_model.to(device)

# create test data
test_dataset = TopicAnnotationDataset(
    test_df,
    tokenizer,
    max_token_len=MAX_TOKEN_COUNT
)

In [None]:
predictions = []
labels = []

# generate predictions 
for item in tqdm(test_dataset):

  _, prediction = trained_model(
      item["input_ids"].unsqueeze(dim=0).to(device),
      item["attention_mask"].unsqueeze(dim=0).to(device),
  )
  predictions.append(prediction.flatten())
  labels.append(item["labels"].int())

predictions = torch.stack(predictions).detach().cpu()
labels = torch.stack(labels).detach().cpu()

  0%|          | 0/1338 [00:00<?, ?it/s]

In [None]:
# calculate accuracy
threshold = 0.5
f1_score(labels, predictions>threshold, average=None)

array([0.93197279, 0.94831461, 0.45454545, 0.55238095, 0.8173913 ,
       0.87765957])

In [None]:
# AUROC
print("AUROC per tag")
for i, name in enumerate(LABEL_COLUMNS):
  auroc = torchmetrics.AUROC(num_classes=len(LABEL_COLUMNS))
  tag_auroc = auroc(predictions[:, i], labels[:, i])
  print(f"{name}: {tag_auroc}")

AUROC per tag
clinical characteristics or disease pathology: 0.9780116081237793
other: 0.9886510372161865
disease mechanism: 0.9023022055625916
therapeutics in the clinic: 0.9065976142883301
irrelevant: 0.9955818057060242
patient-based therapeutics: 0.9858331680297852




In [None]:
# classification report
y_pred = predictions.numpy()
y_true = labels.numpy()

upper, lower = 1, 0

y_pred = np.where(y_pred > 0.5, upper, lower)

print(
    classification_report(
        y_true,
        y_pred,
        digits=4,
        target_names=LABEL_COLUMNS,
        zero_division=0
    )
)

                                               precision    recall  f1-score   support

clinical characteristics or disease pathology     0.9133    0.9514    0.9320       576
                                        other     0.9505    0.9462    0.9483       446
                            disease mechanism     1.0000    0.2941    0.4545        17
                   therapeutics in the clinic     0.5472    0.5577    0.5524        52
                                   irrelevant     0.7344    0.9216    0.8174        51
                   patient-based therapeutics     0.9167    0.8418    0.8777       196

                                    micro avg     0.9034    0.9088    0.9061      1338
                                    macro avg     0.8437    0.7521    0.7637      1338
                                 weighted avg     0.9062    0.9088    0.9043      1338
                                  samples avg     0.9028    0.9088    0.9048      1338

