In [None]:
import os
import sys
import re
import importlib
import pandas as pd
import yaml
import matplotlib.pyplot as plt
import torch.nn as nn

from pathlib import Path
from torch.utils.data import DataLoader

from IPython.core.debugger import set_trace

%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
module_path = os.path.abspath(os.path.join('../../deeplearner'))
sys.path.insert(0, module_path)

In [None]:
from deeplearner.datatorch import *
from deeplearner.utils import *
from deeplearner.compiler import *
from deeplearner.losses import *
from deeplearner.models.unet_att_d import unet_att_d

In [None]:
yaml_config_path = "/home/airg/skhallaghi/deeplearner_normalization_test/deeplearner/config/default_config.yaml"

with open(yaml_config_path, "r") as cfg:
    config = yaml.load(cfg, Loader=yaml.SafeLoader)

In [None]:
model = unet_att_d(n_classes=config["n_classes"],
                   in_channels=config["channels"],
                   filter_config=config["stage_width"],
                   block_num=config["block_num"],
                   dropout_rate=config["train_dropout_rate"],
                   dropout_type=config["dropout_type"],
                  use_skipAtt=False,)

In [None]:
compiled_model = ModelCompiler(model,
                               working_dir=config["working_dir"],
                               out_dir=config["out_dir"],
                               buffer = config["one_side_buffer"],
                               class_mapping=config["class_mapping"],
                               gpuDevices = config["gpu_devices"], 
                               params_init = config["params_init_path"],
                               freeze_params = config["freeze_layer_ls"])

In [None]:
criterion_name = config['criterion']['name']
weight = config['criterion']['weight']
ignore_index = config['criterion']['ignore_index']
gamma = config['criterion']['gamma']
alpha = config['criterion']['alpha']

if criterion_name == 'TverskyFocalLoss':
    criterion = TverskyFocalLoss(weight=weight, ignore_index=ignore_index, alpha=alpha, gamma=gamma)
elif criterion_name == "LocallyWeightedTverskyFocalLoss":
    criterion = LocallyWeightedTverskyFocalLoss(ignore_index=ignore_index, alpha=alpha, gamma=gamma)
elif criterion_name == "LocallyWeightedTverskyFocalCELoss":
    criterion = LocallyWeightedTverskyFocalCELoss(ignore_index=ignore_index, tversky_alpha=alpha, 
                                                  tversky_gamma=gamma)
else:
    raise ValueError("Invalid 'criterion_name'.")