### Classification-T2 - Phase I

In [1]:
# Load T2 model
from kmembert.utils import Config
from kmembert.models import TransformerAggregator
from kmembert.utils import get_root, now

import os
import torch

resume = "kmembert-T2"
config = Config()
config.resume = resume

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

nhead, num_layers, out_dim, time_dim = 8, 4, 2, 8

# Init model
model = TransformerAggregator(device, config, nhead, num_layers, out_dim, time_dim)

# Load the model
model.resume(config)

[1mResuming with model at kmembert-T2...[0m
[92mSuccessfully loaded
[0m


In [2]:
# ArgParse
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("-d", "--data_folder", type=str, default="ehr", 
    help="data folder name")
parser.add_argument("-a", "--aggregator", type=str, default="transformer", 
    help="aggregator name", choices=['conflation', 'sanity_check', 'sanity_check_transformer', 'transformer'])
parser.add_argument("-r", "--resume", type=str, default = "kmembert-base", 
    help="result folder in which the saved checkpoint will be reused")
parser.add_argument("-e", "--epochs", type=int, default=2, 
    help="number of epochs")
parser.add_argument("-nr", "--nrows", type=int, default=None, 
    help="maximum number of samples for training and validation")
parser.add_argument("-k", "--print_every_k_batch", type=int, default=1, 
    help="prints training loss every k batch")
parser.add_argument("-dt", "--days_threshold", type=int, default=365, 
    help="days threshold to convert into classification task")
parser.add_argument("-lr", "--learning_rate", type=float, default=1e-4, 
    help="model learning rate")
parser.add_argument("-wg", "--weight_decay", type=float, default=0, 
    help="the weight decay for L2 regularization")
parser.add_argument("-p", "--patience", type=int, default=4, 
    help="number of decreasing accuracy epochs to stop the training")
parser.add_argument("-me", "--max_ehrs", type=int, default=4, 
    help="maximum nusmber of ehrs to be used for multi ehrs prediction")
parser.add_argument("-nh", "--nhead", type=int, default=8, 
    help="number of transformer heads")
parser.add_argument("-nl", "--num_layers", type=int, default=4, 
    help="number of transformer layers")
parser.add_argument("-od", "--out_dim", type=int, default=2, 
    help="transformer out_dim (1 regression or 2 density)")
parser.add_argument("-td", "--time_dim", type=int, default=8, 
    help="transformer time_dim")

args = parser.parse_args("")

In [3]:
# Load dataset and dataloader
from kmembert.dataset import PredictionsDataset
from torch.utils.data import DataLoader
from kmembert.utils import create_session, get_label_threshold, collate_fn, collate_fn_with_id

path_dataset, _, device, config = create_session(args)

assert (768 + args.time_dim) % args.nhead == 0, f'd_model (i.e. 768 + time_dim) must be divisible by nhead. Found time_dim {args.time_dim} and nhead {args.nhead}'

config.label_threshold = get_label_threshold(config, path_dataset)

train_dataset, test_dataset = PredictionsDataset.get_train_validation(
    path_dataset, config, output_hidden_states=True, device=device)

if not args.aggregator in ['conflation', 'sanity_check']:
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)

[1m> DEVICE:  cpu[0m
[1m> ROOT:    c:\Users\DIPIAZZA\Documents\CLBProjet\VirtualMachine_T2_Classification_Phase_I\KmemBERT[0m
[1m> SESSION: c:\Users\DIPIAZZA\Documents\CLBProjet\VirtualMachine_T2_Classification_Phase_I\KmemBERT\results\ipykernel_launcher_22-08-02_15h15m13s[0m
str_date:  19991203
str_date:  19990913
str_date:  19991203
str_date:  19991007
str_date:  19991203
str_date:  19991028
str_date:  19991203
str_date:  19991116
str_date:  20010316
str_date:  19991129
str_date:  20030602
str_date:  20000307
str_date:  20010506
str_date:  20000410
str_date:  20010506
str_date:  20000413
str_date:  20010506
str_date:  20000425
str_date:  20010506
str_date:  20000427
str_date:  20001116
str_date:  20000607
str_date:  20010316
str_date:  20000619
str_date:  20001116
str_date:  20000719
str_date:  20001116
str_date:  20000719
str_date:  20001116
str_date:  20000802
str_date:  20030602
str_date:  20000912
str_date:  20030602
str_date:  20000927
str_date:  20030602
str_date:  200010

  0%|          | 0/21 [00:00<?, ?it/s]

[92mSuccessfully loaded
[0m
[1m
Computing Health Bert predictions...[0m


100%|██████████| 21/21 [00:48<00:00,  2.29s/it]
  0%|          | 0/3 [00:00<?, ?it/s]

size: 656 bytes
[92mSuccessfully computed 80 Health Bert outputs
[0m
str_date:  20180110
str_date:  20170614
str_date:  20180110
str_date:  20170619
str_date:  20180110
str_date:  20170620
str_date:  20180110
str_date:  20170626
str_date:  20181120
str_date:  20180129
str_date:  20181120
str_date:  20180302
str_date:  20181120
str_date:  20180306
str_date:  20181120
str_date:  20180313
str_date:  20190128
str_date:  20180831
str_date:  20190128
str_date:  20180904
str_date:  20190128
str_date:  20180907
str_date:  20190128
str_date:  20180917
[1m
Computing Health Bert predictions...[0m


100%|██████████| 3/3 [00:08<00:00,  2.94s/it]

size: 248 bytes
[92mSuccessfully computed 12 Health Bert outputs
[0m





In [4]:
from kmembert.training import train_and_validate

train_and_validate(model, train_loader, test_loader, device, config, config.path_result)  

[94m
----- STARTING TRAINING -----[0m
ON EST DANS LE SCHEDULER DE INTERFACE.PY
[104m[97m> EPOCH 0/1[0m
[95m    Training | Epoch: 0 - Mean Loss: 0.843885 - Time elapsed: 0m7s[0m
[95m    Testing | Epoch: 0 - Mean Loss: 1.519836 - Time elapsed: 0m0s
[0m
[92m    Best loss so far[0m
    Saving model state...
    Saving predictions...   
[92m   Predictions and Labels saved.   [0m
    (Ended validation)

[104m[97m> EPOCH 1/1[0m
[95m    Training | Epoch: 1 - Mean Loss: 0.826050 - Time elapsed: 0m7s[0m
[95m    Testing | Epoch: 1 - Mean Loss: 1.498594 - Time elapsed: 0m0s
[0m
[92m    Best loss so far[0m
    Saving model state...
    Saving predictions...   
[92m   Predictions and Labels saved.   [0m
    (Ended validation)

[94m-----  Ended Training  -----
[0m
   Saving losses...   
[92m   Losses saved...[0m


1.4985936085383098