In [1]:
from transformers import AutoConfig, AutoTokenizer, AutoModel
from data.acronymDataset import AcronymDataset
from evaluate import load
import numpy as np
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext'
config = AutoConfig.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=config.max_position_embeddings)
pre_trained_model = AutoModel.from_pretrained(model_name).to('mps')

Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Test with the data

In [3]:
torch.manual_seed(5)
file_path = 'data/acronym_data.txt'
dataset = AcronymDataset(file_path=file_path, tokenizer=tokenizer)
data = dataset.data

[INFO] Dataset already been loaded, using the cached dataset..


In [6]:
dataset.preprocss_dataset()

Map:   0%|          | 0/37331 [00:00<?, ? examples/s]

                                                                   

In [7]:
# %load_ext autoreload
# %autoreload 2
from models.multiHeadModel import MultiHeadModel
from models.heads import ClassificationHead

in_features = config.hidden_size
binari_head = ClassificationHead(in_features=in_features, out_features=1).to('mps')
four_labels_head = ClassificationHead(in_features=in_features, out_features=4).to('mps')

classifiers = torch.nn.ModuleDict({
    "binari_head": binari_head,
    # "four_labels_head": four_labels_head
})

In [8]:
multi_head_model = MultiHeadModel(pre_trained_model, classifiers)

In [9]:
%load_ext autoreload
%autoreload 2

from train import train
train_loader1, _ = dataset.get_dataloaders(train_size=0.9, batch_size=8)
# train_loader2, _ = dataset.get_dataloaders(train_size=0.9, batch_size=32)

optim = torch.optim.AdamW(multi_head_model.parameters(), lr=0.001)

train_args = {
    "epochs": 1,
    "device": "mps",
    "optim": optim
}

heads_props = {
    "binari_head": {
        "train_loader": train_loader1,
        "loss_weight": 1.0,
        "loss_func": torch.nn.BCEWithLogitsLoss()
    },
    # "four_labels_head": {
    #     "train_loader": train_loader2,
    #     "loss_weight": 0.8

    # }
}

In [12]:
train(multi_head_model, heads_props, train_args)

  0%|          | 0/1 [00:00<?, ?it/s]

tensor(0.6596, device='mps:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor(0.6061, device='mps:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor(0.6616, device='mps:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor(0.7258, device='mps:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor(0.5480, device='mps:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor(0.7695, device='mps:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor(0.7753, device='mps:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor(0.7227, device='mps:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor(0.4851, device='mps:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor(0.5486, device='mps:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor(0.6665, device='mps:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor(0.5475, device='mps:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
tensor(0.4904, device='mps:0', grad_fn=<

  0%|          | 0/1 [03:09<?, ?it/s]


KeyboardInterrupt: 