# **Utils & Libraries**

In [5]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [1]:
!pip install wandb -q
!pip install transformers -q

[K     |████████████████████████████████| 1.8MB 7.9MB/s 
[K     |████████████████████████████████| 102kB 11.5MB/s 
[K     |████████████████████████████████| 133kB 51.7MB/s 
[K     |████████████████████████████████| 174kB 53.2MB/s 
[K     |████████████████████████████████| 71kB 8.5MB/s 
[?25h  Building wheel for subprocess32 (setup.py) ... [?25l[?25hdone
  Building wheel for pathtools (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 2.3MB 8.4MB/s 
[K     |████████████████████████████████| 3.3MB 51.3MB/s 
[K     |████████████████████████████████| 901kB 29.8MB/s 
[?25h

# **Libraries**

In [2]:
import wandb
import itertools
import numpy as np
import pandas as pd
from tokenizers import Tokenizer
from tokenizers import models as M
from tokenizers import decoders as D
from tokenizers import trainers as T
from tokenizers import normalizers as N
from tokenizers import pre_tokenizers as PRE
from tokenizers.processors import TemplateProcessing

import torch
from tqdm.notebook import tqdm
from transformers import AdamW
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from transformers import BertForPreTraining, BertConfig
from transformers import get_linear_schedule_with_warmup

# **Modules**

In [3]:
class BERTBPETokenizer:
    def __init__(self, tokenizer, max_length=512):
        self.tokenizer = tokenizer
        self.tokenizer.enable_padding(
            pad_id=tokenizer.token_to_id("[PAD]"), pad_type_id=0, pad_token="[PAD]"
        )
        self.tokenizer.enable_truncation(max_length)

    @classmethod
    def from_pretrained(cls, path, max_length=512):
        tokenizer = Tokenizer.from_file(path)
        return cls(tokenizer, max_length=max_length)

    def __call__(self, data):
        if isinstance(data, str):
            out = self.tokenizer.encode(data)
            input_ids = out.ids
            tokens = out.tokens
            attention_mask = out.attention_mask
            segment_id = out.type_ids

        elif isinstance(data, tuple):
            out = self.tokenizer.encode(data[0], data[1])
            input_ids = out.ids
            tokens = out.tokens
            attention_mask = out.attention_mask
            segment_id = out.type_ids

        elif isinstance(data, list):
            out_list = self.tokenizer.encode_batch(data)
            input_ids, tokens, attention_mask, segment_id = [], [], [], []
            for out in out_list:
                input_ids.append(out.ids)
                tokens.append(out.tokens)
                attention_mask.append(out.attention_mask)
                segment_id.append(out.type_ids)
        return input_ids, segment_id, attention_mask, tokens


def grp_func(df):
    return " ".join(df.sort_values("SEQ_NUM")["ICD9_CODE"].tolist())


def prep_nsp_data(df, journey_col_name, journey_len_col):
    def str_split(a):
        a_tkn = a.split()
        half = int(len(a_tkn) / 2)
        return (" ".join(a_tkn[:half]), " ".join(a_tkn[half:]))

    df = df.sort_values(journey_len_col).reset_index(drop=True)
    delta = (df.shape[0] / 2) - int(df.shape[0] / 2)
    if delta == 0:
        i = int(df.shape[0] / 2)
    else:
        i = int(df.shape[0] / 2) + 1
    df1 = df[: i + 1]
    df1.loc[:, "nsp_label"] = 0
    df2 = df[i + 1 :]
    df2.loc[:, "nsp_label"] = 1

    df1_data = list(
        zip(df1[journey_col_name].tolist(), df1.sample(df1.shape[0])[journey_col_name])
    )
    df1["nsp_data"] = df1_data

    df2_data = df2[journey_col_name].apply(str_split).tolist()
    df2["nsp_data"] = df2_data
    return pd.concat([df1, df2], axis=0)


class BLBPEWholeWordMasker:
    def __init__(self, tokenizer, proba=0.15, ignore_id=-100):
        self.p = proba
        self.mask_id = tokenizer.tokenizer.token_to_id("[MASK]")
        self.ignore_id = ignore_id
        self.vocab_size = range(tokenizer.tokenizer.get_vocab_size())

    @staticmethod
    def _find_whole_words(tokens):
        whole_words = []
        idx_whole_words = []
        loc_idx_lst = []
        for i, t in enumerate(tokens):
            if t not in ["[CLS]", "[SEP]", "[PAD]"]:
                if i == 1:
                    pre_word = t
                    loc_idx_lst.append(i)
                else:
                    if t[0] == "Ġ":
                        idx_whole_words.append(loc_idx_lst)
                        whole_words.append(pre_word)
                        pre_word = t
                        loc_idx_lst = [i]
                    else:
                        pre_word = pre_word + t
                        loc_idx_lst.append(i)
        idx_whole_words.append(loc_idx_lst)
        whole_words.append(pre_word)
        return whole_words, idx_whole_words

    def add_masks(self, token_ids, token_lst):
        whole_words, whole_word_idx_mapping = BLBPEWholeWordMasker._find_whole_words(
            token_lst
        )
        n_masks = round(len(whole_words) * self.p)
        mask_word_indices = np.random.choice(
            range(len(whole_words)), size=n_masks, replace=False
        )

        masked_input = token_ids.copy()
        target = [self.ignore_id] * len(token_ids)
        for idx in sorted(mask_word_indices):
            word_indices = whole_word_idx_mapping[idx]

            action = np.random.choice(
                ["mask", "unchanged", "random"], size=1, p=[0.8, 0.1, 0.1]
            )[0]

            if action == "mask":
                for i in word_indices:
                    masked_input[i] = self.mask_id
                    target[i] = token_ids[i]
            elif action == "random":
                for i in word_indices:
                    masked_input[i] = np.random.choice(self.vocab_size, size=1)[0]
                    target[i] = token_ids[i]
            else:
                for i in word_indices:
                    target[i] = token_ids[i]
        return masked_input, target

    def add_batch_masks(self, token_ids, token_lst):
        target = []
        masked_input = []
        for x, y in zip(token_ids, token_lst):
            a, b = self.add_masks(x, y)
            masked_input.append(a)
            target.append(b)
        return masked_input, target


class BERTDataset(Dataset):
    def __init__(self, sentence_lst, label_lst, bert_tokenizer, mask_proba=0.15):
        self.sentence_lst = sentence_lst
        self.label_lst = label_lst
        self.tokenizer = bert_tokenizer
        self.masker = BLBPEWholeWordMasker(bert_tokenizer, proba=mask_proba)

    def __len__(self):
        return len(self.sentence_lst)

    def __getitem__(self, idx):
        return self.sentence_lst[idx], self.label_lst[idx]

    def dynamic_batching_for_pretraining(self, lst):
        tmp_arr = np.array(lst, dtype=object)
        batch_sen = list(tmp_arr[:, 0])
        ids, sent_ids, attn_mask, tokens = self.tokenizer(batch_sen)
        masked_input, lm_target = self.masker.add_batch_masks(ids, tokens)
        nsp_label = torch.tensor(list(tmp_arr[:, 1]))
        input_dict = {
            "masked_input": torch.tensor(masked_input),
            "sent_ids": torch.tensor(sent_ids),
            "attn_mask": torch.tensor(attn_mask),
            "unmasked_tokens": tokens,
        }
        target_dict = {"nsp_label": nsp_label, "lm_target": torch.tensor(lm_target)}
        return input_dict, target_dict


class Trainer:
    def __init__(self, model, train_data, val_data, optimizer_lr=2e-5):
        self.train_data = train_data
        self.val_data = val_data
        if torch.cuda.is_available():
            self.dev = torch.device("cuda")
        else:
            self.dev = torch.device("cpu")
        print("Using: ", self.dev)

        self.model = model.to(self.dev)
        self.optimizer = Trainer._get_optimizer_with_decay(self.model, lr=optimizer_lr)

    @staticmethod
    def _get_optimizer_with_decay(model, lr=2e-5):
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay": 0.01,
                "lr": lr
            },
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay": 0.0,
                "lr": lr
            }
        ]
        return AdamW(optimizer_grouped_parameters, lr=lr)

    def load_checkpoint(self, checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=self.dev)
        self.model.load_state_dict(checkpoint["model_state_dict"])
        self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        val_loss_init = checkpoint["val_loss"]
        print("Initialized from checkpoint")
        return val_loss_init

    def save_checkpoint(self, path_to_save_checkpoint, epoch, trn_loss, val_los):
        torch.save(
            {
                "epoch": epoch,
                "model_state_dict": self.model.state_dict(),
                "optimizer_state_dict": self.optimizer.state_dict(),
                "training_loss": trn_loss,
                "val_loss": val_los,
            },
            path_to_save_checkpoint
        )
        print(f"Checkpoint saved")

    def save_model(self, path_to_save_model):
        self.model.save_pretrained(path_to_save_model)
        print("Model and config saved")

    def train_model(
        self,
        n_epochs=10,
        logging_step=5,
        patience=3,
        use_scheduler=False,
        path_to_save_checkpoint="./checkpoint.tar",
        checkpoint_path=None,
        wandb_project=None
    ):

        if wandb_project is not None:
            run = wandb.init(project=wandb_project, reinit=True)

        if checkpoint_path is not None:
            val_loss_init = self.load_checkpoint(checkpoint_path)
            print("Last val loss: ", val_loss_init)
        else:
            val_loss_init = 1000

        if use_scheduler:
            warmup_steps = int(len(self.train_data) * n_epochs * 0.34)
            total_steps = len(self.train_data) * n_epochs
            print(
                f"Using linear lr scheduler with {warmup_steps} warmup steps and {total_steps} total steps"
            )
            linear_scheduler = get_linear_schedule_with_warmup(
                self.optimizer, warmup_steps, total_steps
            )

        count = 0
        min_val_loss = val_loss_init
        trn_loss, val_loss = [], []
        for epoch in tqdm(range(n_epochs)):
            print("Epoch: ", epoch)
            self.model.train()
            trn_loss_per_epoch = []
            for i, (x, y) in enumerate(tqdm(self.train_data)):
                x = {k: v.to(self.dev) for k, v in x.items() if k != "unmasked_tokens"}
                y = {k: v.to(self.dev) for k, v in y.items()}

                output = self.model(
                    input_ids=x["masked_input"],
                    token_type_ids=x["sent_ids"],
                    attention_mask=x["attn_mask"],
                    labels=y["lm_target"],
                    next_sentence_label=y["nsp_label"],
                )
                loss = output.loss
                loss.backward()
                self.optimizer.step()
                if use_scheduler:
                    linear_scheduler.step()
                    lr_0, _ = linear_scheduler.get_last_lr()
                    if wandb_project is not None:
                        wandb.log({"lr": lr_0})
                self.optimizer.zero_grad()

                del x, y, output
                torch.cuda.empty_cache()
                trn_loss_per_epoch.append(float(loss))
                if i % logging_step == 0:
                    if wandb_project is not None:
                        wandb.log({"train_loss": np.mean(trn_loss_per_epoch)})
                    else:
                        print(
                            f"Training loss for {len(trn_loss_per_epoch)} batches: ",
                            np.mean(trn_loss_per_epoch)
                        )
            trn_loss.append(np.mean(trn_loss_per_epoch))
            print(f"Training Loss for epoch {epoch}: ", trn_loss[-1])

            self.model.eval()
            valid_loss_per_epoch = []
            with torch.no_grad():
                # print("***************** Validation *****************")
                for x, y in tqdm(self.val_data):
                    x = {
                        k: v.to(self.dev)
                        for k, v in x.items()
                        if k != "unmasked_tokens"
                    }
                    y = {k: v.to(self.dev) for k, v in y.items()}
                    output = self.model(
                        input_ids=x["masked_input"],
                        token_type_ids=x["sent_ids"],
                        attention_mask=x["attn_mask"],
                        labels=y["lm_target"],
                        next_sentence_label=y["nsp_label"]
                        )
                    valid_loss_per_epoch.append(float(output.loss))
            val_loss.append(np.mean(valid_loss_per_epoch))
            print(f"Validation Loss for epoch {epoch}: ", val_loss[-1])

            if wandb_project is not None:
                wandb.log({"val_loss_epoch": val_loss[-1]})
                wandb.log({"train_loss_epoch": trn_loss[-1]})

            if val_loss[-1] < min_val_loss:
                count = 0
                min_val_loss = val_loss[-1]
                self.save_checkpoint(
                    path_to_save_checkpoint, epoch, trn_loss[-1], val_loss[-1]
                )
            else:
                count = count + 1

            if count > patience:
                tmp = self.load_checkpoint(path_to_save_checkpoint)
                print(
                    "******** Stopping early because of patience! Best checkpoint restored ********"
                )
                print("Validation loss at checkpoint: ", tmp)
                run.finish()
                break
            print("Min val loss till now: ", min_val_loss)

        print("Training completed!")
        run.finish()


# **Data Exploration**

In [8]:
dia = pd.read_csv(r'/content/drive/MyDrive/ColabData/datasets/MIMIC-3/PatientLevel/DIAGNOSES_ICD.csv').dropna()
# pres = pd.read_csv(r'../../../Side_Projects/2_readmission_prediction/data/PRESCRIPTIONS.csv').dropna()
pro = pd.read_csv(r'/content/drive/MyDrive/ColabData/datasets/MIMIC-3/PatientLevel/PROCEDURES_ICD.csv').dropna()
pro['ICD9_CODE'] = pro['ICD9_CODE'].astype(str)

In [None]:
dia.SUBJECT_ID.nunique(), pro.SUBJECT_ID.nunique() #, pres.SUBJECT_ID.nunique()

(46517, 42214)

In [10]:
dia_jour = dia.groupby('SUBJECT_ID').apply(grp_func).reset_index().rename(columns={0:'dia_journey'})
# pres_jour = pres.groupby('SUBJECT_ID').apply(grp_func).reset_index().rename(columns={0:'pres_journey'})
pro_jour = pro.groupby('SUBJECT_ID').apply(grp_func).reset_index().rename(columns={0:'pro_journey'})

In [11]:
dia_jour['dia_journey_len'] = dia_jour['dia_journey'].apply(lambda x: len(x.split()))
dia_jour = dia_jour.sort_values('dia_journey_len')
dia_jour = dia_jour[dia_jour.dia_journey_len>=3]

pro_jour['pro_journey_len'] = pro_jour['pro_journey'].apply(lambda x: len(x.split()))
pro_jour = pro_jour.sort_values('pro_journey_len')
pro_jour = pro_jour[pro_jour.pro_journey_len>=3]

In [12]:
dia_jour.shape, pro_jour.shape

((45501, 3), (28636, 3))

In [13]:
dia_jour.dia_journey_len.describe()

count    45501.000000
mean        14.266829
std         15.737225
min          3.000000
25%          7.000000
50%          9.000000
75%         17.000000
max        540.000000
Name: dia_journey_len, dtype: float64

In [14]:
pro_jour.pro_journey_len.describe()

count    28636.000000
mean         7.680682
std          5.841364
min          3.000000
25%          4.000000
50%          6.000000
75%          9.000000
max         98.000000
Name: pro_journey_len, dtype: float64

In [15]:
data = dia_jour.dia_journey.tolist() + pro_jour.pro_journey.tolist()
data = [x for x in data if pd.notnull(x)==True]

In [16]:
dia_icds = set(itertools.chain.from_iterable([x.split() for x in dia_jour.dia_journey.tolist()]))
pro_icds = set(itertools.chain.from_iterable([x.split() for x in pro_jour.pro_journey.tolist()]))

In [17]:
len(dia_icds.intersection(pro_icds)), len(dia_icds.union(pro_icds))

(517, 8434)

# **Training Tokenizer**

In [None]:
# bert_tokenizer = Tokenizer(M.BPE(unk_token="[UNK]"))
# bert_tokenizer.pre_tokenizer = PRE.ByteLevel()
# bert_tokenizer.normalizer = N.Sequence([N.Lowercase()])

# bert_tokenizer.post_processor = TemplateProcessing(
#     single="[CLS] $A [SEP]",
#     pair="[CLS] $A [SEP] $B:1 [SEP]:1",
#     special_tokens=[
#         ("[CLS]", 1),
#         ("[SEP]", 2),
#     ],
# )

# bert_tokenizer.decoder = D.ByteLevel()

# trainer = T.BpeTrainer(
#     vocab_size=2500, special_tokens=["[UNK]","[CLS]", "[SEP]", "[PAD]", "[MASK]"], show_progress=True
# )

# bert_tokenizer.train_from_iterator(data, trainer=trainer)

# # Saving Tokenizer
# bert_tokenizer.save("/content/drive/MyDrive/ColabData/saved_models/PatientBERT/bert-bpe-icd.json")

In [7]:
bpe_tokenizer = BERTBPETokenizer.from_pretrained(
    "/content/drive/MyDrive/ColabData/saved_models/PatientBERT/bert-bpe-icd.json",
    max_length=150
    )

# **Preparing Data for Pre-Training**

**NSP Data**

In [30]:
dia_jour = prep_nsp_data(dia_jour, 'dia_journey', 'dia_journey_len')

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  isetter(loc, value)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy


In [31]:
pro_jour = prep_nsp_data(pro_jour, 'pro_journey', 'pro_journey_len')

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  isetter(loc, value)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy


In [32]:
dia_jour.head()

Unnamed: 0,SUBJECT_ID,dia_journey,dia_journey_len,nsp_label,nsp_data
0,21831,V3000 V290 V053,3,0,"(V3000 V290 V053, V3001 V290 V053)"
1,6679,V3001 7708 V053,3,0,"(V3001 7708 V053, 80126 E8859 8738)"
2,16226,V3000 V053 V290,3,0,"(V3000 V053 V290, 2252 2449 2720 2441 78039)"
3,5356,V3000 V053 V290,3,0,"(V3000 V053 V290, V3101 7742 76518 76528 7706 ..."
4,27018,V3000 V053 V290,3,0,"(V3000 V053 V290, 1983 1987 431 41401 412 V1011)"


In [34]:
pro_jour.head()

Unnamed: 0,SUBJECT_ID,pro_journey,pro_journey_len,nsp_label,nsp_data
0,46473,9671 8659 881,3,0,"(9671 8659 881, 3734 3727 3726 3778 9961)"
1,46462,3322 3491 3893,3,0,"(3322 3491 3893, 8628 8601 8601 3893 3893)"
2,28797,9604 9671 4513,3,0,"(9604 9671 4513, 3505 3722 3778)"
3,92816,8594 9659 3893,3,0,"(8594 9659 3893, 22 415 4319 9671 966)"
4,92839,9604 9671 9960,3,0,"(9604 9671 9960, 9915 3893 3893 9904 9905)"


In [35]:
full_data = pd.concat([dia_jour.rename(columns={'dia_journey': 'journey', 'dia_journey_len': 'journey_len'})
                       , pro_jour.rename(columns={'pro_journey': 'journey', 'pro_journey_len': 'journey_len'})], axis=0)
full_data = full_data.sample(frac=1.0, random_state=42)

In [36]:
full_data.shape

(74137, 5)

**Creating Dataloader with Masks**

In [44]:
# dia_jour.nsp_data.iloc[0]

In [45]:
# ids, sent_ids, _, tokens = bpe_tokenizer(dia_jour.nsp_data.iloc[:64].tolist())

In [46]:
# masker = BLBPEWholeWordMasker(bpe_tokenizer)

In [47]:
# %%timeit
# _, _ = masker.add_batch_masks(ids, tokens) # Batch of 64

In [48]:
x_train, x_val, y_train, y_val = train_test_split(
    full_data['nsp_data'].tolist(),
    full_data['nsp_label'].tolist(),
    test_size=0.10, random_state=42,
    stratify=full_data['nsp_label'].tolist()
    )
len(x_train), len(x_val)

(66723, 7414)

In [49]:
trn_ds = BERTDataset(x_train, y_train, bpe_tokenizer)
trn_dl = torch.utils.data.DataLoader(trn_ds, batch_size=80, num_workers=0, collate_fn=trn_ds.dynamic_batching_for_pretraining, pin_memory=False)

val_ds = BERTDataset(x_val, y_val, bpe_tokenizer)
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=80, num_workers=0, collate_fn=trn_ds.dynamic_batching_for_pretraining, pin_memory=False)

# **Model**

In [41]:
config = BertConfig(vocab_size=2500, pad_token_id=3)
model = BertForPreTraining(config)

# **Training Model**

In [42]:
trainer = Trainer(model, trn_dl, val_dl)

Using:  cuda


In [43]:
trainer.train_model(wandb_project='PatientBERT')

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize


wandb: Paste an API key from your profile and hit enter: ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

Epoch:  0


HBox(children=(FloatProgress(value=0.0, max=835.0), HTML(value='')))


Training Loss for epoch 0:  6.275670718004603


HBox(children=(FloatProgress(value=0.0, max=93.0), HTML(value='')))


Validation Loss for epoch 0:  5.538558493378342
Model & optimizer state dictionaries saved
Checkpoint saved
Min val loss till now:  5.538558493378342
Epoch:  1


HBox(children=(FloatProgress(value=0.0, max=835.0), HTML(value='')))


Training Loss for epoch 1:  5.30154483475371


HBox(children=(FloatProgress(value=0.0, max=93.0), HTML(value='')))


Validation Loss for epoch 1:  5.017297949842227
Model & optimizer state dictionaries saved
Checkpoint saved
Min val loss till now:  5.017297949842227
Epoch:  2


HBox(children=(FloatProgress(value=0.0, max=835.0), HTML(value='')))


Training Loss for epoch 2:  4.881688987138029


HBox(children=(FloatProgress(value=0.0, max=93.0), HTML(value='')))


Validation Loss for epoch 2:  4.695858529818955
Model & optimizer state dictionaries saved
Checkpoint saved
Min val loss till now:  4.695858529818955
Epoch:  3


HBox(children=(FloatProgress(value=0.0, max=835.0), HTML(value='')))


Training Loss for epoch 3:  4.62248543493762


HBox(children=(FloatProgress(value=0.0, max=93.0), HTML(value='')))


Validation Loss for epoch 3:  4.42327780108298
Model & optimizer state dictionaries saved
Checkpoint saved
Min val loss till now:  4.42327780108298
Epoch:  4


HBox(children=(FloatProgress(value=0.0, max=835.0), HTML(value='')))


Training Loss for epoch 4:  4.43511563717962


HBox(children=(FloatProgress(value=0.0, max=93.0), HTML(value='')))


Validation Loss for epoch 4:  4.285528162474273
Model & optimizer state dictionaries saved
Checkpoint saved
Min val loss till now:  4.285528162474273
Epoch:  5


HBox(children=(FloatProgress(value=0.0, max=835.0), HTML(value='')))


Training Loss for epoch 5:  4.282015926109817


HBox(children=(FloatProgress(value=0.0, max=93.0), HTML(value='')))


Validation Loss for epoch 5:  4.220793031877087
Model & optimizer state dictionaries saved
Checkpoint saved
Min val loss till now:  4.220793031877087
Epoch:  6


HBox(children=(FloatProgress(value=0.0, max=835.0), HTML(value='')))


Training Loss for epoch 6:  4.175850341134443


HBox(children=(FloatProgress(value=0.0, max=93.0), HTML(value='')))


Validation Loss for epoch 6:  4.077636549549718
Model & optimizer state dictionaries saved
Checkpoint saved
Min val loss till now:  4.077636549549718
Epoch:  7


HBox(children=(FloatProgress(value=0.0, max=835.0), HTML(value='')))


Training Loss for epoch 7:  4.060327143012406


HBox(children=(FloatProgress(value=0.0, max=93.0), HTML(value='')))


Validation Loss for epoch 7:  4.055872307028822
Model & optimizer state dictionaries saved
Checkpoint saved
Min val loss till now:  4.055872307028822
Epoch:  8


HBox(children=(FloatProgress(value=0.0, max=835.0), HTML(value='')))


Training Loss for epoch 8:  3.986818653095268


HBox(children=(FloatProgress(value=0.0, max=93.0), HTML(value='')))


Validation Loss for epoch 8:  3.9459082952109714
Model & optimizer state dictionaries saved
Checkpoint saved
Min val loss till now:  3.9459082952109714
Epoch:  9


HBox(children=(FloatProgress(value=0.0, max=835.0), HTML(value='')))


Training Loss for epoch 9:  3.91213728510691


HBox(children=(FloatProgress(value=0.0, max=93.0), HTML(value='')))


Validation Loss for epoch 9:  3.8817550084924184
Model & optimizer state dictionaries saved
Checkpoint saved
Min val loss till now:  3.8817550084924184

Training completed!


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
train_loss,3.91281
_runtime,11804.0
_timestamp,1620901309.0
_step,1689.0
val_loss_epoch,3.88176
train_loss_epoch,3.91214


0,1
train_loss,█▇▆▆▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
_runtime,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val_loss_epoch,█▆▄▃▃▂▂▂▁▁
train_loss_epoch,█▅▄▃▃▂▂▁▁▁


**Saving in Hugginface compatible format** | Saves state dictionary with config file

In [50]:
trainer.save_model('/content/drive/MyDrive/ColabData/saved_models/PatientBERT/mimic-3-bert-base')

Model and config saved


In [58]:
(list(trainer.model.parameters())[0] == list(model.parameters())[0]).all()

tensor(True, device='cuda:0')

# **MASK Filling**

In [51]:
dia_jour.head(2)

Unnamed: 0,SUBJECT_ID,dia_journey,dia_journey_len,nsp_label,nsp_data
0,21831,V3000 V290 V053,3,0,"(V3000 V290 V053, V3001 V290 V053)"
1,6679,V3001 7708 V053,3,0,"(V3001 7708 V053, 80126 E8859 8738)"


In [71]:
k=9
token_ids, seg_ids, attn_mask, tokens = bpe_tokenizer(dia_jour.nsp_data.tolist()[k:k+1])
masker = BLBPEWholeWordMasker(bpe_tokenizer)
masked_input, targets = masker.add_batch_masks(token_ids, tokens)

In [72]:
out = model(
    input_ids=torch.tensor(masked_input).to(torch.device('cuda'))
    , token_type_ids=torch.tensor(seg_ids).to(torch.device('cuda'))
    , attention_mask=torch.tensor(attn_mask).to(torch.device('cuda'))
)

In [73]:
out.prediction_logits.shape

torch.Size([1, 17, 2500])

In [74]:
pred_ids = torch.argmax((out.prediction_logits[[np.array(targets) != -100]]), 1).detach().cpu().numpy()

In [75]:
[bpe_tokenizer.tokenizer.id_to_token(x) for x in pred_ids]

['Ġ41401', 'Ġ42731']

In [76]:
[bpe_tokenizer.tokenizer.id_to_token(x) for x in [t for t in targets[0] if t != -100]]

['Ġ68110', 'Ġ41401']