Skip to content

Commit

Permalink
convert to abs path + csv as input in test
Browse files Browse the repository at this point in the history
  • Loading branch information
bernardo-marques committed Jun 4, 2020
1 parent b1a9a8a commit 0f7bcb9
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 18 deletions.
2 changes: 2 additions & 0 deletions deepmedic/frontEnd/configParsing/testConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ class TestConfig(Config):
FOLDER_OUTP = "folderForOutput" #MUST BE GIVEN
SAVED_MODEL = "cnnModelFilePath" #MUST BE GIVEN
CHANNELS = "channels" #MUST BE GIVEN

CSV_TEST = "csvTest"

NAMES_FOR_PRED_PER_CASE = "namesForPredictionsPerCase"

Expand Down
48 changes: 39 additions & 9 deletions deepmedic/frontEnd/configParsing/testSessionParams.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@

from __future__ import absolute_import, print_function, division

from deepmedic.frontEnd.configParsing.utils import getAbsPathEvenIfRelativeIsGiven, parseAbsFileLinesInList, parseFileLinesInList, check_and_adjust_path_to_ckpt
import pandas as pd
import os

from deepmedic.frontEnd.configParsing.utils import getAbsPathEvenIfRelativeIsGiven, parseAbsFileLinesInList, parseFileLinesInList, check_and_adjust_path_to_ckpt, get_paths_from_csv


class TestSessionParameters(object) :
#To be called from outside too.
Expand All @@ -21,6 +25,14 @@ def errorIntNormZScoreTwoAppliesGiven():
"\n\tOtherwise, requires ['apply_to_all_channels': False] if ['apply_per_channel': [..list..] ]"
"\n\tExiting!")
exit(1)

@staticmethod
def errorRequireValidCsvTest():
print(
"ERROR: Test CSV file \"csvTest\" does not exist. Exiting.")
exit(1)

errReqCsvTest = errorRequireValidCsvTest

def __init__(self,
log,
Expand All @@ -42,14 +54,32 @@ def __init__(self,
self.savedModelFilepath = check_and_adjust_path_to_ckpt( self.log, abs_path_to_saved) if abs_path_to_saved is not None else None

#Input:
#[[case1-ch1, ..., caseN-ch1], [case1-ch2,...,caseN-ch2]]
listOfAListPerChannelWithFilepathsOfAllCases = [parseAbsFileLinesInList(getAbsPathEvenIfRelativeIsGiven(channelConfPath, abs_path_to_cfg)) for channelConfPath in cfg[cfg.CHANNELS]]
self.channelsFilepaths = [ list(item) for item in zip(*tuple(listOfAListPerChannelWithFilepathsOfAllCases)) ] # [[case1-ch1, case1-ch2], ..., [caseN-ch1, caseN-ch2]]
self.gtLabelsFilepaths = parseAbsFileLinesInList( getAbsPathEvenIfRelativeIsGiven(cfg[cfg.GT_LABELS], abs_path_to_cfg) ) if cfg[cfg.GT_LABELS] is not None else None
self.roiMasksFilepaths = parseAbsFileLinesInList( getAbsPathEvenIfRelativeIsGiven(cfg[cfg.ROI_MASKS], abs_path_to_cfg) ) if cfg[cfg.ROI_MASKS] is not None else None

#Output:
self.namesToSavePredictionsAndFeatures = parseFileLinesInList( getAbsPathEvenIfRelativeIsGiven(cfg[cfg.NAMES_FOR_PRED_PER_CASE], abs_path_to_cfg) ) if cfg[cfg.NAMES_FOR_PRED_PER_CASE] is not None else None #CAREFUL: different parser! #Optional. Not required if not saving results.
self.csv_test_fname = getAbsPathEvenIfRelativeIsGiven(cfg[cfg.CSV_TEST], abs_path_to_cfg) \
if cfg[cfg.CSV_TEST] is not None else None
if self.csv_test_fname is not None:
try:
self.csv_test = pd.read_csv(self.csv_test_fname)
except FileNotFoundError:
self.errReqCsvTest()
else:
self.csv_test = None

if self.csv_test is None:
#[[case1-ch1, ..., caseN-ch1], [case1-ch2,...,caseN-ch2]]
listOfAListPerChannelWithFilepathsOfAllCases = [parseAbsFileLinesInList(getAbsPathEvenIfRelativeIsGiven(channelConfPath, abs_path_to_cfg)) for channelConfPath in cfg[cfg.CHANNELS]]
self.channelsFilepaths = [ list(item) for item in zip(*tuple(listOfAListPerChannelWithFilepathsOfAllCases)) ] # [[case1-ch1, case1-ch2], ..., [caseN-ch1, caseN-ch2]]
self.gtLabelsFilepaths = parseAbsFileLinesInList( getAbsPathEvenIfRelativeIsGiven(cfg[cfg.GT_LABELS], abs_path_to_cfg) ) if cfg[cfg.GT_LABELS] is not None else None
self.roiMasksFilepaths = parseAbsFileLinesInList( getAbsPathEvenIfRelativeIsGiven(cfg[cfg.ROI_MASKS], abs_path_to_cfg) ) if cfg[cfg.ROI_MASKS] is not None else None

#Output:
self.namesToSavePredictionsAndFeatures = parseFileLinesInList( getAbsPathEvenIfRelativeIsGiven(cfg[cfg.NAMES_FOR_PRED_PER_CASE], abs_path_to_cfg) ) if cfg[cfg.NAMES_FOR_PRED_PER_CASE] is not None else None #CAREFUL: different parser! #Optional. Not required if not saving results.
else:
(self.channelsFilepaths,
self.gtLabelsFilepaths,
self.roiMasksFilepaths,
self.namesToSavePredictionsAndFeatures) = get_paths_from_csv(self.csv_test,
os.path.dirname(self.csv_test_fname))

#predictions
self.saveSegmentation = cfg[cfg.SAVE_SEGM] if cfg[cfg.SAVE_SEGM] is not None else True
self.saveProbMapsBoolPerClass = cfg[cfg.SAVE_PROBMAPS_PER_CLASS] if (cfg[cfg.SAVE_PROBMAPS_PER_CLASS] is not None and cfg[cfg.SAVE_PROBMAPS_PER_CLASS] != []) else [True]*num_classes
Expand Down
9 changes: 5 additions & 4 deletions deepmedic/frontEnd/configParsing/trainSessionParams.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from deepmedic.dataManagement.augmentImage import AugmenterAffineParams

import pandas as pd
import os


def get_default(value, default, required=False):
Expand All @@ -36,7 +37,6 @@ class TrainSessionParameters(object):
def getSessionName(sessionName):
return sessionName if sessionName is not None else "trainSession"

# REQUIRED:
@staticmethod
def errorRequireValidCsvTraining():
print(
Expand All @@ -45,7 +45,6 @@ def errorRequireValidCsvTraining():

errReqCsvTrain = errorRequireValidCsvTraining

# REQUIRED:
@staticmethod
def errorRequireValidCsvValidation():
print(
Expand Down Expand Up @@ -249,10 +248,11 @@ def __init__(self,
self.gtLabelsFilepathsTrain = \
parseAbsFileLinesInList(getAbsPathEvenIfRelativeIsGiven(cfg[cfg.GT_LABELS_TR], abs_path_to_cfg))
else:
print(os.path.dirname(self.csv_train_fname))
(self.channelsFilepathsTrain,
self.gtLabelsFilepathsTrain,
self.roiMasksFilepathsTrain,
_) = get_paths_from_csv(self.csv_train)
_) = get_paths_from_csv(self.csv_train, os.path.dirname(self.csv_train_fname))

# [Optionals]
# ~~~~~~~~~Sampling~~~~~~~
Expand Down Expand Up @@ -365,7 +365,8 @@ def __init__(self,
(self.channelsFilepathsVal,
self.gtLabelsFilepathsVal,
self.roiMasksFilepathsVal,
self.namesToSavePredictionsAndFeaturesVal) = get_paths_from_csv(self.csv_val)
self.namesToSavePredictionsAndFeaturesVal) = get_paths_from_csv(self.csv_val,
os.path.dirname(self.csv_val_fname))

else:
self.channelsFilepathsVal = []
Expand Down
18 changes: 13 additions & 5 deletions deepmedic/frontEnd/configParsing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,19 +94,27 @@ def check_and_adjust_path_to_ckpt( log, filepath_to_ckpt ):
return filepath_to_ckpt


def get_paths_from_csv(csv, no_target_okay=False):
def normfullpath(abspath, relpath):
if os.path.isabs(relpath):
return relpath
else:
return os.path.normpath(os.path.join(abspath, relpath))


def get_paths_from_csv(csv, abs_path, no_target_okay=False):
# channels are sorted alphabetically to ensure consistency
c_names = sorted([c for c in list(csv.columns) if c.startswith('channel_')])

if not c_names:
# no channels error raise - move to function later
print('No channel columns on csv. Columns should be named "channel_[channel_name]". Exiting')
exit(1)

# [[case1-ch1, case1-ch2], ..., [caseN-ch1, caseN-ch2]]
channels = [list(item[c_names]) for _, item in csv.iterrows()]
channels = [[normfullpath(abs_path, c) for c in list(item[c_names])] for _, item in csv.iterrows()]

try:
target = [list(csv['gt'])]
target = [normfullpath(abs_path, g) for g in list(csv['gt'])]
except KeyError:
target = None
if not no_target_okay:
Expand All @@ -115,13 +123,13 @@ def get_paths_from_csv(csv, no_target_okay=False):
exit(1)

try:
roi = [list(csv['roi'])]
roi = [normfullpath(abs_path, r) for r in list(csv['roi'])]
except KeyError:
print('No "roi" column in input csv, not using roi masks.')
roi = None

try:
pred = [list(csv['pred'])]
pred = [normfullpath(abs_path, p) for p in list(csv['pred'])]
except KeyError:
pred = None

Expand Down

0 comments on commit 0f7bcb9

Please sign in to comment.