In [1]:
from kmembert.utils import Config
from kmembert.models import TransformerAggregator
from kmembert.utils import get_root, now

import os
import torch

resume = "kmembert-T2"
config = Config()
config.resume = resume

main_file = os.path.splitext(os.path.basename(sys.argv[0]))[0]
session_id = f"{main_file}_{now()}"
path_result = os.path.join(get_root(), "results", session_id)
config.path_result = path_result
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

nhead, num_layers, out_dim, time_dim = 8, 4, 2, 8

model = TransformerAggregator(device, config, nhead, num_layers, out_dim, time_dim)

In [2]:
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("-d", "--data_folder", type=str, default="ehr", 
    help="data folder name")
parser.add_argument("-a", "--aggregator", type=str, default="transformer", 
    help="aggregator name", choices=['conflation', 'sanity_check', 'sanity_check_transformer', 'transformer'])
parser.add_argument("-r", "--resume", type=str, default = "kmembert-base", 
    help="result folder in which the saved checkpoint will be reused")
parser.add_argument("-e", "--epochs", type=int, default=2, 
    help="number of epochs")
parser.add_argument("-nr", "--nrows", type=int, default=None, 
    help="maximum number of samples for training and validation")
parser.add_argument("-k", "--print_every_k_batch", type=int, default=1, 
    help="prints training loss every k batch")
parser.add_argument("-dt", "--days_threshold", type=int, default=365, 
    help="days threshold to convert into classification task")
parser.add_argument("-lr", "--learning_rate", type=float, default=1e-4, 
    help="model learning rate")
parser.add_argument("-wg", "--weight_decay", type=float, default=0, 
    help="the weight decay for L2 regularization")
parser.add_argument("-p", "--patience", type=int, default=4, 
    help="number of decreasing accuracy epochs to stop the training")
parser.add_argument("-me", "--max_ehrs", type=int, default=4, 
    help="maximum nusmber of ehrs to be used for multi ehrs prediction")
parser.add_argument("-nh", "--nhead", type=int, default=8, 
    help="number of transformer heads")
parser.add_argument("-nl", "--num_layers", type=int, default=4, 
    help="number of transformer layers")
parser.add_argument("-od", "--out_dim", type=int, default=2, 
    help="transformer out_dim (1 regression or 2 density)")
parser.add_argument("-td", "--time_dim", type=int, default=8, 
    help="transformer time_dim")

args = parser.parse_args("")

In [3]:
# Load dataset and dataloader
from kmembert.dataset import PredictionsDataset
from torch.utils.data import DataLoader
from kmembert.utils import create_session, get_label_threshold, collate_fn

path_dataset, _, device, config = create_session(args)

assert (768 + args.time_dim) % args.nhead == 0, f'd_model (i.e. 768 + time_dim) must be divisible by nhead. Found time_dim {args.time_dim} and nhead {args.nhead}'

config.label_threshold = get_label_threshold(config, path_dataset)

train_dataset, test_dataset = PredictionsDataset.get_train_validation(
    path_dataset, config, output_hidden_states=True, device=device)

if not args.aggregator in ['conflation', 'sanity_check']:
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)

[1m> DEVICE:  cpu[0m
[1m> ROOT:    c:\Users\DIPIAZZA\Documents\CLB Projet\Projet1\Test Load BERTS\KmemBERT[0m
[1m> SESSION: c:\Users\DIPIAZZA\Documents\CLB Projet\Projet1\Test Load BERTS\KmemBERT\results\ipykernel_launcher_22-04-28_09h57m46s[0m
[1m
Loading camembert and its tokenizer...[0m
[1mResuming with model at kmembert-base...[0m


  0%|          | 0/7 [00:00<?, ?it/s]

[92mSuccessfully loaded
[0m
[1m
Computing Health Bert predictions...[0m


100%|██████████| 7/7 [00:00<00:00,  7.01it/s]
  0%|          | 0/3 [00:00<?, ?it/s]

size: 376 bytes
[92mSuccessfully computed 12 Health Bert outputs
[0m
[1m
Computing Health Bert predictions...[0m


100%|██████████| 3/3 [00:00<00:00,  6.00it/s]

size: 248 bytes
[92mSuccessfully computed 6 Health Bert outputs
[0m





In [4]:
from kmembert.training import train_and_validate

model = TransformerAggregator(device, config, args.nhead, args.num_layers, args.out_dim, args.time_dim)
train_and_validate(model, train_loader, test_loader, device, config, config.path_result)  

[94m
----- STARTING TRAINING -----[0m
> EPOCH 0
tensor([[1176.],
        [   0.]])
tensor([[-2.2158e+02,  1.1607e+03,  8.3102e+02, -1.0273e+03, -1.1548e+03,
         -3.8240e+02,  1.0379e+02, -2.0070e+02],
        [ 3.1796e-01,  5.8347e-01, -7.2365e-01, -4.8283e-01,  3.3851e-01,
          6.5692e-01, -1.7412e-01, -7.6092e-01]], grad_fn=<CatBackward0>)
    [0-1]  -  Average loss: 0.043412  -  Time elapsed: 0m0s
tensor([[2765.],
        [   0.]])
tensor([[-5.2141e+02,  2.7282e+03,  1.9550e+03, -2.4148e+03, -2.7156e+03,
         -9.0005e+02,  2.4427e+02, -4.7071e+02],
        [ 3.1796e-01,  5.8347e-01, -7.2365e-01, -4.8283e-01,  3.3851e-01,
          6.5692e-01, -1.7412e-01, -7.6092e-01]], grad_fn=<CatBackward0>)
    [1-2]  -  Average loss: 0.233299  -  Time elapsed: 0m0s
tensor([[446.],
        [  0.]])
tensor([[-8.3832e+01,  4.4058e+02,  3.1466e+02, -3.8993e+02, -4.3774e+02,
         -1.4458e+02,  3.9254e+01, -7.6651e+01],
        [ 3.1796e-01,  5.8347e-01, -7.2365e-01, -4.8283e-01,  

-1.4376836220423381

In [5]:
import torch.nn as nn 
a = torch.tensor([[1176.], [   0.]])
lin = nn.Linear(1, 8)

lin(a)

tensor([[-7.5545e+02, -3.0557e+02, -5.5822e+01,  1.0581e+03,  4.1436e+02,
         -2.1913e+02,  7.6069e+02, -1.1082e+02],
        [ 5.5351e-01,  6.5893e-01,  4.8194e-01, -2.2832e-01, -4.5338e-01,
          2.5027e-01, -6.7313e-01,  7.1734e-02]], grad_fn=<AddmmBackward0>)