In [1]:
import logging

from transformers import BioGptTokenizer, TrainingArguments

import settings
from dataset import MimicDataset, Collator
from settings.utils import Splits

logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.debug("test")

Device: cpu


In [2]:
model_checkpoint = settings.BIOGPT_CHECKPOINT
cache_dir = './cache-temp'

In [3]:
tokenizer = BioGptTokenizer.from_pretrained(model_checkpoint, use_fast=True)


def preprocess_function(examples):
    return tokenizer(examples["text"], max_length=5, truncation=True, padding='max_length')

In [4]:
train_dataset = MimicDataset(tokenizer=tokenizer, split=Splits.train.value, cache_dir=cache_dir)

INFO:dataset:Loading 8066 examples from cached directory ./cache-temp


In [5]:
dev_dataset = MimicDataset(tokenizer=tokenizer, split=Splits.dev.value, label2idx=train_dataset.label2idx,
                           cache_dir=cache_dir)

INFO:dataset:Loading 1573 examples from cached directory ./cache-temp


In [6]:
data_collator = Collator(tokenizer=tokenizer, max_seq_length=settings.MAX_SEQ_LENGTH)

In [7]:
# import numpy as np
# import evaluate
#
# accuracy = evaluate.load("accuracy")
#
# def compute_metrics(eval_pred):
#     predictions, labels = eval_pred
#     predictions = np.argmax(predictions, axis=1)
#     return accuracy.compute(predictions=predictions, references=labels)

In [8]:
from model import BioGptTestModel

model = BioGptTestModel()

Some weights of the model checkpoint at microsoft/biogpt were not used when initializing BioGptForSequenceClassification: ['output_projection.weight']
- This IS expected if you are initializing BioGptForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BioGptForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BioGptForSequenceClassification were not initialized from the model checkpoint at microsoft/biogpt and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
from dataset import build_dataloader

train_loader = build_dataloader(train_dataset, data_collator)
dev_loader = build_dataloader(dev_dataset, data_collator)

In [10]:
from pipeline import Trainer


trainer = Trainer(
    model=model,
    train_loader=train_loader,
    dev_loader=dev_loader,
    label2idx=train_dataset.label2idx
)

  0%|          | 0/1009 [00:00<?, ?it/s]

In [11]:
trainer.train()

  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
  batch["labels"] = torch.tensor([f["labels"].numpy() for f in features], dtype=torch.float)
  3%|â–Ž         | 29/1009 [01:42<45:59,  2.82s/it] 

KeyboardInterrupt: 

In [None]:
trainer.evaluate()


Evaluating:   0%|          | 0/197 [00:00<?, ?it/s][A



tensor([[  0,   1,   3,   4,   7,   8,   9,  10,  11,  15,  16,  18,  20,  24,
          28,  30,  31,  32,  34,  36,  37,  38,  39,  41,  43,  44,  45,  46,
          47,  48],
        [ 50,  51,  53,  54,  57,  58,  59,  60,  61,  65,  66,  68,  70,  74,
          78,  80,  81,  82,  84,  86,  87,  88,  89,  91,  93,  94,  95,  96,
          97,  98],
        [100, 101, 103, 104, 107, 108, 109, 110, 111, 115, 116, 118, 120, 124,
         128, 130, 131, 132, 134, 136, 137, 138, 139, 141, 143, 144, 145, 146,
         147, 148],
        [150, 151, 153, 154, 157, 158, 159, 160, 161, 165, 166, 168, 170, 174,
         178, 180, 181, 182, 184, 186, 187, 188, 189, 191, 193, 194, 195, 196,
         197, 198],
        [200, 201, 203, 204, 207, 208, 209, 210, 211, 215, 216, 218, 220, 224,
         228, 230, 231, 232, 234, 236, 237, 238, 239, 241, 243, 244, 245, 246,
         247, 248],
        [250, 251, 253, 254, 257, 258, 259, 260, 261, 265, 266, 268, 270, 274,
         278, 280, 281, 282, 