# Task 3: LOS - Multimodality

## Setting Environment

In [None]:
%load_ext autoreload
%autoreload 2

from IPython.core.magic import register_cell_magic

@register_cell_magic
def skip(line, cell):
    return

In [None]:
import sys
import os
import pickle

root_path = "..."

from config import device, data_folder, log_folder
task_dir = "LOS"
data_folder+=task_dir+"/"
log_folder+=task_dir+"/"

In [None]:
import src.utils.data_selection as stool
kfolds_fpath = root_path+"datasets/%s/kfolds.pkl"%(task_dir)
if os.path.exists(kfolds_fpath):
    with open(kfolds_fpath, "rb") as f:
        kfolds = pickle.load(f)

In [None]:
import json
from datasets.collate_fun import basic_collate_fn


_log_folder = log_folder
_data_folder = data_folder

def get_log_dataset(unimodal, unimodal_log):
    model_path = _log_folder+"%s_unimodal/%s/"%(unimodal, unimodal_log)
    with open(model_path+"log.json", "r") as f:
        log = json.load(f)
    dataset_parm = log["param"]["DATASET"]
    if "path" not in dataset_parm:
        dataloader_file = default_dataset[unimodal]
    else:
        dataloader_file = dataset_parm["path"].split("/")[-1]
    dataloader_path = _data_folder+dataloader_file

    if os.path.exists(dataloader_path):
        with open(dataloader_path, "rb") as f:
            dataloader = pickle.load(f)

    if "COLLATE_FN_PARAMS" in log["param"]:
        collate_fn_params = log["param"]["COLLATE_FN_PARAMS"]
        if len(collate_fn_params) == 1:
            collate_fn_params = collate_fn_params[0]
    else:
        collate_fn_params = {"name": basic_collate_fn.__name__}
    embed_dim = log["param"]["DECODER_PARAM"]["in_dim"]
    return dataloader, dataset_parm, collate_fn_params, embed_dim, model_path

default_dataset = {
    "static": "static.pkl",
    "labs": "labs.pkl",
    "vitals_num": "vitals_segs.pkl",
    "vitals_cat": "vitals_cat.pkl",
    "ecg": "ecg.pkl"
}

unimodal_list = {
    "static": "...",
    "labs": "...",
    "vitals_num": "...",
    "vitals_cat": "...",
    "ecg": "..."
}

In [None]:
from datasets.LOS.static import StaticLoader
from datasets.LOS.labs import LabsDataset
from datasets.LOS.vitalsigns import VitalsNumLoader, VitalsCatLoader, multiscale_vitalsigns
from datasets.LOS.ecg import ECGLoader

collate_fn_params = []
unimodal_model_params_list = []
# load static dataset
static, _dataset_parm, _collate_fn_params, _embed_dim, _model_path = get_log_dataset("static", unimodal_list["static"])
static_dataset = static.get_dataset()
collate_fn_params.append(_collate_fn_params)
unimodal_model_params_list.append({"model_path":_model_path,"out_dim":_embed_dim})

# load labs dataset
labs, _dataset_parm, _collate_fn_params, _embed_dim, _model_path = get_log_dataset("labs", unimodal_list["labs"])
labs_dataset = labs.get_dataset(only_valid=False)
collate_fn_params.append(_collate_fn_params)
unimodal_model_params_list.append({"model_path":_model_path,"out_dim":_embed_dim})

# load vitalsigns num dataset
vitals_segs, _dataset_parm, _collate_fn_params, _embed_dim, _model_path = get_log_dataset("vitals_num", unimodal_list["vitals_num"])
multiscale = []
for vitals_segs_win_size in _dataset_parm["win_size"]:
    vitals_segs.set_seg_data(vitals_segs_win_size, _dataset_parm["segs_num"])
    vitals_segs_dataset = vitals_segs.get_dataset(only_valid=False, flatten=False)
    multiscale.append(vitals_segs_dataset)
vitals_num_dataset = multiscale_vitalsigns(multiscale[0], multiscale[1])
collate_fn_params.append(_collate_fn_params)
unimodal_model_params_list.append({"model_path":_model_path,"out_dim":_embed_dim})

# load vitalsigns cat dataset
vitals_cat, _dataset_parm, _collate_fn_params, _embed_dim, _model_path = get_log_dataset("vitals_cat", unimodal_list["vitals_cat"])
vitals_cat_dataset = vitals_cat.get_dataset(only_valid=False)
collate_fn_params.append(_collate_fn_params)
unimodal_model_params_list.append({"model_path":_model_path,"out_dim":_embed_dim})

# load ecg dataset
ecg, _dataset_parm, _collate_fn_params, _embed_dim, _model_path = get_log_dataset("ecg", unimodal_list["ecg"])
ecg_dataset = ecg.get_dataset(only_valid=False)
collate_fn_params.append(_collate_fn_params)
unimodal_model_params_list.append({"model_path":_model_path,"out_dim":_embed_dim})

In [None]:
from datasets.LOS.multimodal import MultiModalDataset
from datasets.LOS.static import StaticLoader

ids = static.get_ids()
targets_df = static.get_targets()
targets = targets_df.targets.values
targets_num = 1
_static = [x[0] if len(x) == 3 else x[:-2] for x in static_dataset]
_labs = [x[0] if len(x) == 3 else x[:-2] for x in labs_dataset]
_vitals_num = [x[0] if len(x) == 3 else x[:-2] for x in vitals_num_dataset]
_vitals_cat = [x[0] if len(x) == 3 else x[:-2] for x in vitals_cat_dataset]
_ecg = [x[0] if len(x) == 3 else x[:-2] for x in ecg_dataset]
dataset = MultiModalDataset(ids, targets, _static, _labs, _vitals_num, _vitals_cat, _ecg)

In [None]:
from itertools import product
from models.multimodal import BiModalAttn
from blocks.mlp import MLPDecoderReg, MLP
from models.multimodal import create_multimodal_model
from datasets.collate_fun import CreateCustomDataset
from training_evaluation import run_kfolds
import math

k_models = True
EMBED_DIM = 512
BATCH_SIZE = 64
KFOLDS = "kfolds"
BI_MODAL = True
SHARED_LAYER = True
LR = 0.0001

encoders_i = [0,1,2,3,4]

BiModalAttn_param = {
    "embed_size": EMBED_DIM,
    "num_blocks": 1,
    "num_heads": 64,
    "drop_prob": 0.1,
    "fusion_type": "add"
}

shared_layer_param = {
    "in_dim": EMBED_DIM,
    "hidden_dim": [EMBED_DIM],
    "drop_prob": 0.05,
    "BatchNorm": False
}

if BI_MODAL:
    DECODER_IN_DIM = int(EMBED_DIM*len(encoders_i) + EMBED_DIM*math.comb(len(encoders_i), 2))
else:
    DECODER_IN_DIM = int(EMBED_DIM*len(encoders_i))

MLP_decoder_param = {
    "in_dim": DECODER_IN_DIM,
    "out_dim": targets_num,
    "hidden_dim": [DECODER_IN_DIM//2],
    "drop_prob": 0.1
}

train_param = {
    "MODEL_NAME": "multimodal",
    "KFOLDS": KFOLDS,
    "ENCODERS_I": encoders_i,
    "ENCODERS_PARAM": unimodal_model_params_list,
    "INTER_MODEL": BiModalAttn.__name__,
    "INTER_MODEL_PARAM": BiModalAttn_param,
    "SHARED_LAYER_PARAM": shared_layer_param,
    "DECODER_MODEL": MLPDecoderReg.__name__,
    "DECODER_PARAM": MLP_decoder_param,
    "EMBED_DIM": EMBED_DIM,
    "BATCH_SIZE": BATCH_SIZE,
    "LR": LR,
    "MAX_EPOCHS": 20,
    "OPTIMIZER": "Adam",
    "COLLATE_FN_PARAMS": collate_fn_params,
    "RESET_MODEL_PARAMS": False
}


model = create_multimodal_model(train_param, device, k_models=k_models)


collate_batch = CreateCustomDataset(len(collate_fn_params), train_param["COLLATE_FN_PARAMS"], classfication=False)
log = run_kfolds(train_param, model, dataset, kfolds, log_folder=log_folder, collate_fun = collate_batch, classification=False)