In [1]:
from transformers import AutoConfig, AutoTokenizer, AutoModel, BertForSequenceClassification
from data.acronymDataset import AcronymDataset
from evaluate import load
import numpy as np
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext'
config = AutoConfig.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=config.max_position_embeddings)
pre_trained_model = AutoModel.from_pretrained(model_name).to('mps')

Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Test with the data

In [15]:
# %load_ext autoreload
# %autoreload 2
torch.manual_seed(5)
file_path = 'data/acronym_data.txt'
dataset = AcronymDataset(file_path=file_path, tokenizer=tokenizer)
data = dataset.data

In [16]:
dataset.preprocss_dataset()

                                                                   

In [33]:
train_loader, val_loader = dataset.get_dataloaders(train_size=0.9, batch_size=32)

In [35]:
batch = next(iter(train_loader)).to('mps')

In [36]:
# %load_ext autoreload
# %autoreload 2
from models.multiHeadModel import MultiHeadModel
from models.heads import ClassificationHead

In [37]:
in_features = config.hidden_size
two_labels_head = ClassificationHead(in_features=in_features, out_features=2).to('mps')
four_labels_head = ClassificationHead(in_features=in_features, out_features=4)

classifiers = {
    "two_labels_head": two_labels_head,
    "four_labels_head": four_labels_head
}

In [38]:
multi_head_model = MultiHeadModel(pre_trained_model, classifiers)
multi_head_model.eval()

MultiHeadModel(
  (base_model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementw

In [39]:

with torch.no_grad():
    output = multi_head_model(batch, "two_labels_head")

output

  outputs = self.softmax(outputs)


tensor([[0.5303, 0.4697],
        [0.5329, 0.4671],
        [0.4897, 0.5103],
        [0.4436, 0.5564],
        [0.5175, 0.4825],
        [0.5449, 0.4551],
        [0.4937, 0.5063],
        [0.4788, 0.5212],
        [0.5149, 0.4851],
        [0.5900, 0.4100],
        [0.5228, 0.4772],
        [0.6227, 0.3773],
        [0.5340, 0.4660],
        [0.4660, 0.5340],
        [0.6012, 0.3988],
        [0.4931, 0.5069],
        [0.5604, 0.4396],
        [0.4873, 0.5127],
        [0.4541, 0.5459],
        [0.4920, 0.5080],
        [0.5474, 0.4526],
        [0.4705, 0.5295],
        [0.5807, 0.4193],
        [0.5773, 0.4227],
        [0.5777, 0.4223],
        [0.5863, 0.4137],
        [0.5365, 0.4635],
        [0.5073, 0.4927],
        [0.5101, 0.4899],
        [0.5746, 0.4254],
        [0.5234, 0.4766],
        [0.4840, 0.5160]], device='mps:0')

In [41]:
metric = load("accuracy")
labels = batch['labels']
predictions = np.argmax(output.cpu().numpy(), axis=-1)
res = metric.compute(predictions=predictions, references=labels)

res

{'accuracy': 0.34375}

In [51]:
%load_ext autoreload
%autoreload 2

from utils.train import train
train_loader1, _ = dataset.get_dataloaders(train_size=0.9, batch_size=32)
train_loader2, _ = dataset.get_dataloaders(train_size=0.9, batch_size=16)

train_args = {
    "epochs": 2
}

heads_props = {
    "two_labels_head": {
        "train_loader": train_loader1
    },
    "four_labels_head": {
        "train_loader": train_loader2
    }
}

train(multi_head_model, heads_props, train_args)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


  0%|          | 0/2 [00:00<?, ?it/s]

torch.Size([32, 291]) torch.Size([16, 245])
torch.Size([32, 441]) torch.Size([16, 282])
torch.Size([32, 495]) torch.Size([16, 328])
torch.Size([32, 368]) torch.Size([16, 512])
torch.Size([32, 424]) torch.Size([16, 355])
torch.Size([32, 512]) torch.Size([16, 296])
torch.Size([32, 388]) torch.Size([16, 308])
torch.Size([32, 264]) torch.Size([16, 512])
torch.Size([32, 269]) torch.Size([16, 283])
torch.Size([32, 350]) torch.Size([16, 364])
torch.Size([32, 306]) torch.Size([16, 302])
torch.Size([32, 358]) torch.Size([16, 345])
torch.Size([32, 431]) torch.Size([16, 349])
torch.Size([32, 414]) torch.Size([16, 321])
torch.Size([32, 512]) torch.Size([16, 303])
torch.Size([32, 277]) torch.Size([16, 274])
torch.Size([32, 362]) torch.Size([16, 294])
torch.Size([32, 362]) torch.Size([16, 371])
torch.Size([32, 277]) torch.Size([16, 294])
torch.Size([32, 479]) torch.Size([16, 343])
torch.Size([32, 413]) torch.Size([16, 338])
torch.Size([32, 304]) torch.Size([16, 358])
torch.Size([32, 349]) torch.Size

  0%|          | 0/2 [00:04<?, ?it/s]

torch.Size([32, 295]) torch.Size([16, 236])





KeyboardInterrupt: 