In [35]:
from transformers import AutoConfig, EvalPrediction, AutoTokenizer, AutoModelForSequenceClassification, AutoModel
from datasets import load_dataset, Dataset
from evaluate import load
import numpy as np
# import wandb
import os
import sys
import pandas as pd
import torch
import torch.nn as nn

In [36]:
model_name = 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model_with_head = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
pre_trained_model = AutoModel.from_pretrained(model_name)

Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Ber

Insert a layer in the middle

In [37]:
def create_classifier_layer(layer_to_insert):
    till_layer_to_insert = pre_trained_model.encoder.layer[:layer_to_insert]
    after_layer_to_insert = pre_trained_model.encoder.layer[layer_to_insert:]
    model_with_middle_classifier = till_layer_to_insert + [nn.Linear(768, 2)] + after_layer_to_insert

    return model_with_middle_classifier

In [38]:
model_with_middle_classifier = create_classifier_layer(layer_to_insert=4)
model_with_middle_classifier

ModuleList(
  (0-3): 4 x BertLayer(
    (attention): BertAttention(
      (self): BertSelfAttention(
        (query): Linear(in_features=768, out_features=768, bias=True)
        (key): Linear(in_features=768, out_features=768, bias=True)
        (value): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (output): BertSelfOutput(
        (dense): Linear(in_features=768, out_features=768, bias=True)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (intermediate): BertIntermediate(
      (dense): Linear(in_features=768, out_features=3072, bias=True)
      (intermediate_act_fn): GELUActivation()
    )
    (output): BertOutput(
      (dense): Linear(in_features=3072, out_features=768, bias=True)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (4): Lin

Test with the data

In [39]:
file_path = 'data/acronym_data.txt'
data = []

with open(file_path, "r", errors='ignore') as file:
    for line in file.readlines():
        split = line.strip().split('|')
        
        # build the sentence structure
        source_sentence = split[6]
        compare_sentence = source_sentence[:int(split[3])] + split[1] + source_sentence[int(split[4]):]

        row = {
            'source_sentence': source_sentence,
            'compare_sentence': compare_sentence,
            'label': 1
        }
        data.append(row)

data_dict = {key: [item[key] for item in data] for key in data[0]}
dataset = Dataset.from_dict(data_dict)


In [40]:
tokens = tokenizer(dataset[0]['source_sentence'], dataset[0]['compare_sentence'], return_tensors='pt')
tokens

{'input_ids': tensor([[    2,    41,     9,     7,  9533,     7,     9,    41,    41,     9,
             7,  9533,     7,     9,    41,  1977,    43,  3230,    17,  2476,
            17,  4156, 12413,  9532,    23,    16,  8202,    22,    17,    20,
            17,    20,    17,    22,    16,  2746,  3549,  1942,  1920,  7877,
          4921,  6581,  2465,  1927,  3605, 10120,  7633,  3328,  3993,    26,
            43,    18,    55,    18,  1920,  2774,  3731,  2162,    43,  3171,
          4403,  1927,  4537,  2430,    43, 12749,  2186,    18,  4693,  2430,
          2252,  3489,  2019,  3741,  6248,  1930,  1982, 18313,  1988,  4693,
          2430,    43, 12749,  2186, 17002,  2256, 20898,  3368,  1036,    18,
          1920,  2774,  2019,  1988,  2367,  2430,  2252, 22134,  1919,  1942,
          4978,    43,    46,    10,    45,  1930,  1982,  3385, 10318,  5682,
         21898,  1945,  1942,  4087,  4080,    43, 29670,    16,  2406,    16,
          2774, 11390,  2019,  1988,  

In [91]:
class MultiHeadPubMedBert(nn.Module):
    def __init__(self, base_model: nn.Module, classifier_heads: list):
        super(MultiHeadPubMedBert, self).__init__()
        self.base_model = base_model
        self.dropout = nn.Dropout(0.1, inplace=False)
        self.heads = classifier_heads

    def forward(self, tokens, head_to_use):
        outputs = self.base_model(**tokens)
        outputs = self.heads[head_to_use](outputs)

        return outputs

In [99]:
class BinariHead(nn.Module):
    def __init__(self):
        super(BinariHead, self).__init__()
        self.dropout = nn.Dropout(0.1, inplace=False)
        self.mean = torch.mean
        self.linear = nn.Linear(in_features=768, out_features=2)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.dropout(x[0])
        x = self.mean(x, dim=1)
        x = self.linear(x)
        x = self.sigmoid(x)
        return x
    

class MultiClassHead(nn.Module):
    def __init__(self):
        super(MultiClassHead, self).__init__()
        self.dropout = nn.Dropout(0.1, inplace=False)
        self.mean = torch.mean
        self.linear = nn.Linear(in_features=768, out_features=4)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.dropout(x[0])
        x = self.mean(x, dim=1)
        x = self.linear(x)
        x = self.sigmoid(x)
        return x    

In [101]:
classifiers = [BinariHead(), MultiClassHead()]
multi_head_model = MultiHeadPubMedBert(pre_trained_model, classifiers)

In [102]:
output = multi_head_model(tokens, 0)
output

tensor([[0.5180, 0.5014]], grad_fn=<SigmoidBackward0>)

In [103]:
output = multi_head_model(tokens, 1)
output

tensor([[0.5240, 0.4672, 0.4792, 0.5236]], grad_fn=<SigmoidBackward0>)