## Imports

In [3]:
import os
import sys
sys.path.append('../')
from sklearn.model_selection import train_test_split

from TumorDetection.data.loader import DataPathLoader
from TumorDetection.data.dataset import TorchDataset
from TumorDetection.utils.dict_classes import DataPathDir, ReportingPathDir, Verbosity, Device
from TumorDetection.models.efsnet import EFSNet
from TumorDetection.models.utils.lightning_model import LightningModel
from TumorDetection.models.utils.trainer import Trainer

## Configuration params

In [7]:
MODEL_NAME = 'EFSNet_clf_seg'
DESCRIPTION = 'EFSNet with classification and binary segmentation.'
CLASS_WEIGHTS = [1., 3., 3.]
POS_WEIGHT = 5
FROM_CHECKPOINT = False
VALIDATE = True
TEST = True
VERBOSE = Verbosity.get('verbose')
DEVICE = Device.get('device')

## Path Finder

If not using DataPathLoader (*for BUSI Dataset*) consider passing paths as tuple of:  
- Image path
- List images masks paths associated
- List of associated label mask.

In [5]:
dp = DataPathLoader(DataPathDir.get('dir_path'))
paths = dp()

## Train Test Datasets

In [None]:
tr_paths, val_paths = train_test_split(paths, test_size=100, random_state=0, shuffle=True)
tr_td = TorchDataset(tr_paths)
val_td = TorchDataset(val_paths)

## Model definition.

In [None]:
lightningmodel = LightningModel(model=EFSNet(device=DEVICE,
                                            verbose=VERBOSE),
                                model_name=MODEL_NAME,
                                description=DESCRIPTION,
                                class_weights=CLASS_WEIGHTS,
                                pos_weight=POS_WEIGHT,
                                device=DEVICE)

## Training

In [None]:
trainer = Trainer(model_name=MODEL_NAME,
                  ckpt_dir=os.path.join(ReportingPathDir.get('dir_path'), 'ckpt'),
                  verbose=VERBOSE)
trainer(model=lighningmodel,
        train_data=tr_td,
        test_data=val_td,
        from_checkpoint=FROM_CHECKPOINT,
        validate_model=VALIDATE,
        test_model=TEST)