### This Notebook describe how to load DeepECG-SSL and use it on your personal training pipeline

- Make sure you deployed [`Fairseq-signals`](https://github.com/HeartWise-AI/fairseq-signals) locally

### Inference


Inference can be done either from command line interface or from this notebook

#### Inference in the notebook

In [None]:
# import useful old code

import sys
import importlib

import torch.nn as nn

from collections import OrderedDict
from typing import Any, Dict, Optional, Union

#TODO: configure it
project_dir = ##INIT_TO_FAIRSEQ_SIGNAL_PATH
root_dir = project_dir
if not root_dir in sys.path:
    sys.path.append(root_dir)

spec = importlib.util.spec_from_file_location("checkpoint_utils", f"{project_dir}/fairseq_signals/utils/checkpoint_utils.py")
checkpoint_utils = importlib.util.module_from_spec(spec)
spec.loader.exec_module(checkpoint_utils)


class WCREcgTransformer(nn.Module):
    def __init__(
        self, 
        model_path: str,
        pretrained_path: str = None,
        overrides: Optional[Dict[str, Any]] = None,
        task=None,
        strict=True,
        suffix="",
        num_shards=1,
        state=None,
    ):
        super().__init__()
        overrides = {} if overrides is None else vars(overrides)
        if pretrained_path is not None:
            overrides.update({"model_path": pretrained_path})
        model, saved_cfg, task = checkpoint_utils.load_model_and_task(
            model_path,
            arg_overrides=overrides,
            suffix=suffix
        )

        self.model = model
        
    def forward(self, x, padding_mask=None):
        net_input = { "source": x, "padding_mask": padding_mask}
        net_output = self.model(**net_input)
        return self.model.get_logits(net_output)


In [None]:
import sys
import os

m_root = '/media/data1/achilsowa/results/fairseq/outputs/'

m_paths = {
    "SSL": '', #SSL_PATH
    "FT_AFIB-5": '', #SSL_PATH
    "FT_LABELS-77": os.path.join(m_root, "2024-10-08/04-39-01/checkpoint_last-ft-labels-77-bce/checkpoint_best.pt")
}

# Get the path to the root directory of your project
root_path = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # Adjust according to your folder depth

# Add the root directory to sys.path
if root_path not in sys.path:
    sys.path.append(root_path)

from models.modules.wcr import WCREcgTransformer


model_ssl = WCREcgTransformer(m_paths['FT_AFIB'], m_paths['SSL'])
model_fevg = WCREcgTransformer(m_paths['FT_FEVG-REG'], m_paths['SSL'])

In [None]:
1. download ptb
2. 

### Finetuning

In [None]:
import os 

def generate_train_cli(
    devices=1,
    encoder='_last',
    task = 'labels-77',
    num_labels = 77,
    mode='ft', # possible values are 'ft', 'le', 'e2e'
    is_df=True,
    cls='',
    wd=0,
    criterion='binary_cross_entropy_with_logits'
):    
    def loss_str():
        if criterion == 'asymmetric':
            return 'as'
        if criterion == 'binary_focal':
            return 'bf'
        if criterion == 'mse':
            return 'mse'
        if criterion == 'binary_cross_entropy_with_logits':
            return 'bce'
        if criterion == 'mlsml':
            return 'mlsml'
        assert False, 'Invalid error'
    
    cli = f'CUDA_VISIBLE_DEVICES={devices} fairseq-hydra-train '
    if criterion == 'mse':
        cli += f'common.fp16=false '
    else:
        cli +=f'common.fp16=true '
    cli += f'task.data=/media/data1/achilsowa/datasets/fairseq/mhi-mimic-code15/manifest/finetune/{task} '
    if mode == 'e2e':
        cli += f'model.no_pretrained_weights=true '
        encoder = '_e2e'
    else:
        cli += f'model.model_path=/media/data1/achilsowa/results/fairseq/outputs/2024-09-22/03-16-32/checkpoints-all/checkpoint{encoder}.pt '
    if cls == 'attn':
        cli += f'model._name=ecg_transformer_attn_classifier '
    if wd:
        cli += f'optimizer.weight_decay={wd} '
    if mode == 'le':
        cli += f'model.linear_evaluation=true '
    
    if is_df:
        cli += f'+task.df_dataset=true '
    else:
        cli += f'+task.npy_dataset=true '

    cli += f'model.num_labels={num_labels} '
    cli += f'criterion._name={criterion} '
    if cls == 'attn':
        cls = '-attn'
    cli += f'checkpoint.save_dir=checkpoint{encoder}-{mode}-{task}-{loss_str()}{cls} '
    cli += '--config-dir examples/w2v_cmsc/config/finetuning/ecg_transformer --config-name diagnosis'

    return cli