In [None]:
from kmembert.utils import Config
from kmembert.models import HealthBERT

import torch

config = Config()
config.mode = "density"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = HealthBERT(device, config)

In [None]:
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("-d", "--data_folder", type=str, default="ehr", 
    help="data folder name")
parser.add_argument("-m", "--mode", type=str, default="regression", choices=['regression', 'density'],
    help="name of the task")
parser.add_argument("-b", "--batch_size", type=int, default=8, 
    help="dataset batch size")
parser.add_argument("-e", "--epochs", type=int, default=2, 
    help="number of epochs")
parser.add_argument("-drop", "--drop_rate", type=float, default=None, 
    help="dropout ratio. By default, None uses p=0.1")
parser.add_argument("-nr", "--nrows", type=int, default=None, 
    help="maximum number of samples for training and validation")
parser.add_argument("-k", "--print_every_k_batch", type=int, default=1, 
    help="prints training loss every k batch")
parser.add_argument("-f", "--freeze", type=bool, default=False, const=True, nargs="?",
    help="whether or not to freeze the Bert part")
parser.add_argument("-dt", "--days_threshold", type=int, default=365, 
    help="days threshold to convert into classification task")
parser.add_argument("-lr", "--learning_rate", type=float, default=1e-4, 
    help="model learning rate")
parser.add_argument("-r_lr", "--ratio_lr_embeddings", type=float, default=1, 
    help="the ratio applied to lr for embeddings layer")
parser.add_argument("-wg", "--weight_decay", type=float, default=0, 
    help="the weight decay for L2 regularization")
parser.add_argument("-v", "--voc_file", type=str, default=None, 
    help="voc file containing camembert added vocabulary")
parser.add_argument("-r", "--resume", type=str, default=None, 
    help="result folder in which the saved checkpoint will be reused")
parser.add_argument("-p", "--patience", type=int, default=4, 
    help="number of decreasing accuracy epochs to stop the training")

args = parser.parse_args("")

In [None]:
from kmembert.training import train_and_validate 
from kmembert.utils import get_label_threshold, create_session
from kmembert.dataset import EHRDataset
import torch
from torch.utils.data import DataLoader

path_dataset, _, device, config = create_session(args)

assert not (args.freeze and args.voc_file), "Don't use freeze argument while adding vocabulary. It would not be learned"

config.label_threshold = get_label_threshold(config, path_dataset)

train_dataset, validation_dataset = EHRDataset.get_train_validation(path_dataset, config)

train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=1, shuffle=True)

train_and_validate(model, train_loader, validation_loader, device, config, config.path_result)