## 0. Imports

In [None]:
import os
import sys
sys.path.append('../')
from sklearn.model_selection import train_test_split

from TumorDetection.data.loader import DataPathLoader
from TumorDetection.data.dataset import TorchDatasetClfSeg, TorchDatasetSeg
from TumorDetection.utils.dict_classes import DataPathDir, ReportingPathDir, Verbosity, Device
from TumorDetection.models.efsnet import EFSNetClfSeg, EFSNetSeg
from TumorDetection.models.utils.lightning_model import LightningModelClfSeg, LightningModelSeg
from TumorDetection.models.utils.trainer import Trainer

## 1. Classification - Segmentation

### Configuration params

In [None]:
MODEL_NAME = 'EFSNet_clf_seg'
DESCRIPTION = 'EFSNet with classification and binary segmentation.'
CLASS_WEIGHTS = [1., 3., 3.]
POS_WEIGHT = 5
BATCH_SIZE = 64
TEST_SIZE = 100
FROM_CHECKPOINT = False
VALIDATE = True
TEST = True
VERBOSE = Verbosity.get('verbose')
DEVICE = Device.get('device')

### Path Finder

If not using DataPathLoader (*for BUSI Dataset*) consider passing paths as tuple of:  
- Image path
- List images masks paths associated
- List of associated label mask.

In [None]:
dp = DataPathLoader(DataPathDir.get('dir_path'))
paths = dp()

### Train Test Datasets

In [None]:
tr_paths, val_paths = train_test_split(paths, test_size=TEST_SIZE, random_state=0, shuffle=True)
tr_td = TorchDatasetClfSeg(tr_paths,
                           crop_prob=0.5,
                           rotation_degrees=180,
                           range_contrast=(0.75, 1.25),
                           range_brightness=(0.75, 1.25),
                           vertical_flip_prob=0.25,
                           horizontal_flip_prob=0.25)
val_td = TorchDatasetClfSeg(val_paths,
                            crop_prob=None,
                            rotation_degrees=None,
                            range_contrast=None,
                            range_brightness=None,
                            vertical_flip_prob=None,
                            horizontal_flip_prob=None)

### Model definition.

In [None]:
lighningmodel = LightningModelClfSeg(model=EFSNetClfSeg(device=DEVICE,
                                                        verbose=VERBOSE),
                                     model_name=MODEL_NAME,
                                     description=DESCRIPTION,
                                     class_weights=CLASS_WEIGHTS,
                                     pos_weight=POS_WEIGHT,
                                     device=DEVICE)

### Training
*Better performance using the associated script train_efsnet_clfseg.py*

In [None]:
trainer = Trainer(model_name=MODEL_NAME,
                  max_epochs=EPOCHS,
                  ckpt_dir=os.path.join(ReportingPathDir.get('dir_path'), 'ckpt'),
                  verbose=VERBOSE)
trainer(model=lighningmodel,
        train_batch_size=BATCH_SIZE,
        val_batch_size=TEST_SIZE,
        train_data=tr_td,
        test_data=val_td,
        from_checkpoint=False,
        validate_model=VALIDATE,
        test_model=TEST)

## 2. Segementation

### Configuration Params

In [None]:
MODEL_NAME = 'EFSNet_seg'
DESCRIPTION = 'EFSNet with classification and multiclass segmentation'
VERBOSE = Verbosity.get('verbose')
DEVICE = Device.get('device')
EPOCHS = 2500
CLASS_WEIGHT = [1., 5., 5.]
BATCH_SIZE = 64
TEST_SIZE = 100

### Data

In [None]:
dp = DataPathLoader(DataPathDir.get('dir_path'))
paths = dp()
tr_paths, val_paths = train_test_split(paths, test_size=TEST_SIZE, random_state=0, shuffle=True)
tr_td = TorchDatasetSeg(tr_paths,
                        crop_prob=0.5,
                        rotation_degrees=180,
                        range_contrast=(0.75, 1.25),
                        range_brightness=(0.75, 1.25),
                        vertical_flip_prob=0.25,
                        horizontal_flip_prob=0.25)
val_td = TorchDatasetSeg(val_paths,
                         crop_prob=None,
                         rotation_degrees=None,
                         range_contrast=None,
                         range_brightness=None,
                         vertical_flip_prob=None,
                         horizontal_flip_prob=None)

### Model definition

In [None]:
lighningmodel = LightningModelSeg(model=EFSNetSeg(device=DEVICE,
                                                  verbose=VERBOSE),
                                  model_name=MODEL_NAME,
                                  description=DESCRIPTION,
                                  class_weights=CLASS_WEIGHT,
                                  device=DEVICE)

### Train
*Better performance using the associated script train_efsnet_seg.py*

In [None]:
trainer = Trainer(model_name=MODEL_NAME,
                  max_epochs=EPOCHS,
                  ckpt_dir=os.path.join(ReportingPathDir.get('dir_path'), 'ckpt'),
                  verbose=VERBOSE)
trainer(model=lighningmodel,
        train_batch_size=BATCH_SIZE,
        val_batch_size=TEST_SIZE,
        train_data=tr_td,
        test_data=val_td,
        from_checkpoint=False,
        validate_model=True,
        test_model=True)