In [72]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [73]:
from fastai.conv_learner import *
from fastai.dataset import *

from pathlib import Path
from glob import glob
import tables as tb
import tqdm

In [74]:
import sys
sys.path.insert(0, 'code')
from models import *
from v13_deeplab import *

In [75]:
MODEL_NAME = 'v13'
ORIGINAL_SIZE = 650
sz = 256
bs = 200
num_slice = 9
STRIDE_SZ = 197
PATH = 'data/'

BASE_DIR = "data/train"
BASE_TEST_DIR = "data/test"
WORKING_DIR = "data/working"

# Restore later
IMAGE_DIR = "data/working/images/{}".format('v12')
# IMAGE_DIR = "data/working/images/{}".format('v5')
V5_IMAGE_DIR = "data/working/images/{}".format('v5')

# ---------------------------------------------------------
# Parameters
MIN_POLYGON_AREA = 30  # 30

# ---------------------------------------------------------
# Input files
FMT_TRAIN_SUMMARY_PATH = str(
    Path(BASE_DIR) /
    Path("{prefix:s}_Train/") /
    Path("summaryData/{prefix:s}_Train_Building_Solutions.csv"))
FMT_TRAIN_RGB_IMAGE_PATH = str(
    Path(BASE_DIR) /
    Path("{prefix:s}_Train/") /
    Path("RGB-PanSharpen/RGB-PanSharpen_{image_id:s}.tif"))
FMT_TEST_RGB_IMAGE_PATH = str(
    Path(BASE_TEST_DIR) /
    Path("{prefix:s}_Test/") /
    Path("RGB-PanSharpen/RGB-PanSharpen_{image_id:s}.tif"))
FMT_TRAIN_MSPEC_IMAGE_PATH = str(
    Path(BASE_DIR) /
    Path("{prefix:s}_Train/") /
    Path("MUL-PanSharpen/MUL-PanSharpen_{image_id:s}.tif"))
FMT_TEST_MSPEC_IMAGE_PATH = str(
    Path(BASE_TEST_DIR) /
    Path("{prefix:s}_Test/") /
    Path("MUL-PanSharpen/MUL-PanSharpen_{image_id:s}.tif"))

# ---------------------------------------------------------
# Preprocessing result
FMT_RGB_BANDCUT_TH_PATH = IMAGE_DIR + "/rgb_bandcut.csv"
FMT_MUL_BANDCUT_TH_PATH = IMAGE_DIR + "/mul_bandcut.csv"

# ---------------------------------------------------------
# Image list, Image container and mask container
FMT_VALTRAIN_IM_FOLDER = V5_IMAGE_DIR + "/trn_full_rgb/"
FMT_VALTEST_IM_FOLDER = V5_IMAGE_DIR + "/trn_full_rgb/"

FMT_VALTRAIN_IMAGELIST_PATH = V5_IMAGE_DIR + "/{prefix:s}_valtrain_ImageId.csv"
FMT_VALTEST_IMAGELIST_PATH = V5_IMAGE_DIR + "/{prefix:s}_valtest_ImageId.csv"
FMT_VALTRAIN_IM_STORE = IMAGE_DIR + "/valtrain_{}_im.h5"
FMT_VALTEST_IM_STORE = IMAGE_DIR + "/valtest_{}_im.h5"
# FMT_VALTRAIN_MASK_STORE = IMAGE_DIR + "/valtrain_{}_mask.h5"
# FMT_VALTEST_MASK_STORE = IMAGE_DIR + "/valtest_{}_mask.h5"
FMT_VALTRAIN_MASK_STORE = V5_IMAGE_DIR + "/valtrain_{}_mask.h5"
FMT_VALTEST_MASK_STORE = V5_IMAGE_DIR + "/valtest_{}_mask.h5"
# FMT_VALTRAIN_MUL_STORE = IMAGE_DIR + "/valtrain_{}_mul.h5"
# FMT_VALTEST_MUL_STORE = IMAGE_DIR + "/valtest_{}_mul.h5"
FMT_VALTRAIN_MUL_STORE = V5_IMAGE_DIR + "/valtrain_{}_mul.h5"
FMT_VALTEST_MUL_STORE = V5_IMAGE_DIR + "/valtest_{}_mul.h5"

FMT_TRAIN_IMAGELIST_PATH = V5_IMAGE_DIR + "/{prefix:s}_train_ImageId.csv"
FMT_TEST_IMAGELIST_PATH = V5_IMAGE_DIR + "/{prefix:s}_test_ImageId.csv"
FMT_TRAIN_IM_STORE = IMAGE_DIR + "/train_{}_im.h5"
FMT_TEST_IM_STORE = IMAGE_DIR + "/test_{}_im.h5"
FMT_TRAIN_MASK_STORE = IMAGE_DIR + "/train_{}_mask.h5"
FMT_TRAIN_MUL_STORE = IMAGE_DIR + "/train_{}_mul.h5"
FMT_TEST_MUL_STORE = IMAGE_DIR + "/test_{}_mul.h5"
FMT_MULMEAN = IMAGE_DIR + "/{}_mulmean.h5"

# ---------------------------------------------------------
# Model files
MODEL_DIR = "data/working/models/{}".format(MODEL_NAME)
FMT_VALMODEL_PATH = MODEL_DIR + "/{}_val_weights.h5"
FMT_FULLMODEL_PATH = MODEL_DIR + "/{}_full_weights.h5"
FMT_VALMODEL_HIST = MODEL_DIR + "/{}_val_hist.csv"
FMT_VALMODEL_EVALHIST = MODEL_DIR + "/{}_val_evalhist.csv"
FMT_VALMODEL_EVALTHHIST = MODEL_DIR + "/{}_val_evalhist_th.csv"

# ---------------------------------------------------------
# Prediction & polygon result
FMT_TESTPRED_PATH = MODEL_DIR + "/{}_pred.h5"
FMT_VALTESTPRED_PATH = MODEL_DIR + "/{}_eval_pred.h5"
FMT_VALTESTPOLY_PATH = MODEL_DIR + "/{}_eval_poly.csv"
FMT_VALTESTTRUTH_PATH = MODEL_DIR + "/{}_eval_poly_truth.csv"
FMT_VALTESTPOLY_OVALL_PATH = MODEL_DIR + "/eval_poly.csv"
FMT_VALTESTTRUTH_OVALL_PATH = MODEL_DIR + "/eval_poly_truth.csv"
FMT_TESTPOLY_PATH = MODEL_DIR + "/{}_poly.csv"
FN_SOLUTION_CSV = "data/output/{}.csv".format(MODEL_NAME)

# ---------------------------------------------------------
# Model related files (others)
FMT_VALMODEL_LAST_PATH = MODEL_DIR + "/{}_val_weights_last.h5"
FMT_FULLMODEL_LAST_PATH = MODEL_DIR + "/{}_full_weights_last.h5"

## Preprocessing

In [102]:
# datapaths = ['data/train/AOI_2_Vegas_Train', 'data/train/AOI_3_Paris_Train', 'data/train/AOI_4_Shanghai_Train', 'data/train/AOI_5_Khartoum_Train']
# !python code/v5_im-full_rgb.py preproc_train {datapaths[1]}

2018-04-12 00:02:02,364 INFO Preproc for training on AOI_3_Paris
2018-04-12 00:02:02,364 INFO Generate IMAGELIST csv ... skip
2018-04-12 00:02:02,364 INFO Generate IMAGELIST csv ... skip
2018-04-12 00:02:02,364 INFO Generate band stats csv (RGB) ... skip
2018-04-12 00:02:02,364 INFO Generate MASK (valtrain) ... skip
2018-04-12 00:02:02,365 INFO Generate MASK (valtest) ... skip
2018-04-12 00:02:02,365 INFO Generate RGB_STORE (valtrain)
2018-04-12 00:02:02,370 INFO prep_rgb_image_store_train for AOI_3_Paris
2018-04-12 00:02:02,372 INFO Image store file: data/working/images/v5/trn_full_rgb/
2018-04-12 00:06:03,924 INFO Generate RGB_STORE (valtest)
2018-04-12 00:06:03,937 INFO prep_rgb_image_store_train for AOI_3_Paris
2018-04-12 00:06:03,939 INFO Image store file: data/working/images/v5/test_full_rgb/
2018-04-12 00:07:45,233 INFO Generate RGBMEAN
2018-04-12 00:08:13,445 INFO Prepare mean image: data/working/images/v5/AOI_3_Paris_immean.h5
2018-04-12 00:08:13,538 INFO Preproc for training 

In [103]:
# !python code/v5_im-full_rgb.py preproc_train {datapaths[3]}

2018-04-12 00:08:16,929 INFO Preproc for training on AOI_5_Khartoum
2018-04-12 00:08:16,929 INFO Generate IMAGELIST csv ... skip
2018-04-12 00:08:16,929 INFO Generate IMAGELIST csv ... skip
2018-04-12 00:08:16,929 INFO Generate band stats csv (RGB) ... skip
2018-04-12 00:08:16,929 INFO Generate MASK (valtrain) ... skip
2018-04-12 00:08:16,929 INFO Generate MASK (valtest) ... skip
2018-04-12 00:08:16,929 INFO Generate RGB_STORE (valtrain)
2018-04-12 00:08:16,933 INFO prep_rgb_image_store_train for AOI_5_Khartoum
2018-04-12 00:08:16,934 INFO Image store file: data/working/images/v5/trn_full_rgb/
2018-04-12 00:11:36,933 INFO Generate RGB_STORE (valtest)
2018-04-12 00:11:36,947 INFO prep_rgb_image_store_train for AOI_5_Khartoum
2018-04-12 00:11:36,949 INFO Image store file: data/working/images/v5/test_full_rgb/
2018-04-12 00:12:59,733 INFO Generate RGBMEAN
2018-04-12 00:13:25,436 INFO Prepare mean image: data/working/images/v5/AOI_5_Khartoum_immean.h5
2018-04-12 00:13:25,475 INFO Preproc f

In [104]:
# !python code/v5_im-full_rgb.py preproc_train {datapaths[0]}
# !python code/v5_im-full_rgb.py preproc_train {datapaths[2]}

2018-04-12 00:13:28,875 INFO Preproc for training on AOI_2_Vegas
2018-04-12 00:13:28,875 INFO Generate IMAGELIST csv ... skip
2018-04-12 00:13:28,875 INFO Generate IMAGELIST csv ... skip
2018-04-12 00:13:28,875 INFO Generate band stats csv (RGB) ... skip
2018-04-12 00:13:28,875 INFO Generate MASK (valtrain) ... skip
2018-04-12 00:13:28,875 INFO Generate MASK (valtest) ... skip
2018-04-12 00:13:28,875 INFO Generate RGB_STORE (valtrain)
2018-04-12 00:13:28,879 INFO prep_rgb_image_store_train for AOI_2_Vegas
2018-04-12 00:13:28,882 INFO Image store file: data/working/images/v5/trn_full_rgb/
2018-04-12 00:27:09,650 INFO Generate RGB_STORE (valtest)
2018-04-12 00:27:09,655 INFO prep_rgb_image_store_train for AOI_2_Vegas
2018-04-12 00:27:09,657 INFO Image store file: data/working/images/v5/test_full_rgb/
2018-04-12 00:32:47,330 INFO Generate RGBMEAN
2018-04-12 00:34:23,289 INFO Prepare mean image: data/working/images/v5/AOI_2_Vegas_immean.h5
2018-04-12 00:34:23,969 INFO Preproc for training 

In [79]:
# for d in datapaths: print(d, end=' ')

In [80]:
# !parallel python code/v5_im-full_rgb.py preproc_train {} ::: data/train/AOI_2_Vegas_Train data/train/AOI_3_Paris_Train data/train/AOI_4_Shanghai_Train data/train/AOI_5_Khartoum_Train

In [81]:
# for train_path in ['data/train/AOI_2_Vegas_Train', 'data/train/AOI_3_Paris_Train', 'data/train/AOI_4_Shanghai_Train', 'data/train/AOI_5_Khartoum_Train']:
#     !python code/v12_im_deeplab.py preproc_train {train_path}

### Overload

In [82]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

In [92]:
def get_data(area_id, is_test):
    prefix = area_id_to_prefix(area_id)
    fn_train = FMT_VALTEST_IMAGELIST_PATH.format(prefix=prefix) if is_test else FMT_VALTRAIN_IMAGELIST_PATH.format(prefix=prefix)
    df_train = pd.read_csv(fn_train)
    
    fn_im = FMT_VALTEST_MASK_STORE.format(prefix) if is_test else FMT_VALTRAIN_MASK_STORE.format(prefix)
    y_val = []
    with tb.open_file(fn_im, 'r') as f:
        for image_id in tqdm.tqdm(df_train.ImageId.tolist(), total=df_train.shape[0]):
            fn = '/' + image_id
            im = np.array(f.get_node(fn))[None]
            y_val.append(im)
            
    fn_im = FMT_VALTEST_IM_FOLDER if is_test else FMT_VALTRAIN_IM_FOLDER
    X_val = []
    for image_id in tqdm.tqdm(df_train.ImageId.tolist(), total=df_train.shape[0]):
        im = plt.imread(fn_im + image_id + '.png')
#             print(im.shape)
        X_val.append(im)

    X_val, y_val = np.array(X_val), np.array(y_val)
    return X_val, y_val

In [99]:
# memory dataset
def get_dataset(datapath):
    area_id = directory_name_to_area_id(datapath)
    prefix = area_id_to_prefix(area_id)
    trn_x, trn_y = get_data(area_id, False)
#     print(trn_x.shape, trn_y.shape)
#     trn_x = np.moveaxis(trn_x, 1, -1).astype('float') # --> [bs, h, w, ch]
#     trn_y = np.moveaxis(trn_y, 1, -1).astype('float')

    print(trn_x.shape, trn_y.shape)
    trn_y = np.broadcast_to(trn_y, [trn_y.shape[0], 3, ORIGINAL_SIZE, ORIGINAL_SIZE])

    val_x, val_y = get_data(area_id, True)
#     val_x = val_x[:,:3]
#     val_x = np.moveaxis(val_x, 1, -1).astype('float')
#     val_y = val_y[:,:3]
#     val_y = np.moveaxis(val_y, 1, -1).astype('float')
    val_y = np.broadcast_to(val_y, [val_y.shape[0], 3, ORIGINAL_SIZE, ORIGINAL_SIZE])
               
    return (trn_x,trn_y), (val_x,val_y)


In [100]:
datapaths = ['data/train/AOI_3_Paris_Train', 'data/train/AOI_2_Vegas_Train', 'data/train/AOI_4_Shanghai_Train', 'data/train/AOI_5_Khartoum_Train']

In [101]:
(trn_x,trn_y), (val_x,val_y) = get_dataset(datapaths[0])

100%|██████████| 803/803 [00:04<00:00, 187.01it/s]
100%|██████████| 803/803 [00:02<00:00, 320.84it/s]
  3%|▎         | 11/345 [00:00<00:04, 82.76it/s]

(803, 256, 256, 4) (803, 1, 650, 650)


100%|██████████| 345/345 [00:03<00:00, 96.92it/s]
  0%|          | 0/345 [00:00<?, ?it/s]


FileNotFoundError: [Errno 2] No such file or directory: 'data/working/images/v5/trn_full_rgb/AOI_3_Paris_img1669.png'

In [None]:
class ArraysSingleDataset(BaseDataset):
    def __init__(self, x, y, transform):
        self.x = x; self.y = y
        self.num_groups = len(x)
        self.sz = x[0].shape[1]
        self.ns = np.array([o.shape[0] for o in x])
        self.cum_ns = np.cumsum(self.ns * num_slice)
        super().__init__(transform)

        
    def get_im(self, i, is_y):
        idx_file, idx_im = self.get_file_idx(i)
        if is_y:
            im = self.y[idx_file][idx_im//num_slice]
        else:
            im = self.x[idx_file][idx_im//num_slice]
        slice_pos = idx_im % num_slice
        a = np.sqrt(num_slice)
        cut_i = slice_pos // a
        cut_j = slice_pos % a
        stride = (self.sz - sz) // a
        cut_x = int(cut_j * stride)
        cut_y = int(cut_i * stride)
        return im[cut_x:cut_x + sz, cut_y:cut_y + sz]
        
            
    def get_x(self, i): return self.get_im(i, False)
    def get_y(self, i): return self.get_im(i, True)
        
    def get_file_idx(self, i):
        idx_file = np.argmax(i + 1 <= self.cum_ns)
        if idx_file == 0:
            idx_im = i
        else:
            idx_im = i - self.cum_ns[idx_file - 1]
        return idx_file, idx_im
    
    def get_n(self): return self.cum_ns[-1]
    
    def get_sz(self): return self.sz
        
    def get_c(self): return 1
    def denorm(self, arr):
        """Reverse the normalization done to a batch of images.

        Arguments:
            arr: of shape/size (N,3,sz,sz)
        """
        if type(arr) is not np.ndarray: arr = to_np(arr)
        if len(arr.shape)==3: arr = arr[None]
#         return np.clip(self.transform.denorm(np.rollaxis(arr,1,4)), 0, 1)
        return self.transform.denorm(np.rollaxis(arr,1,4))

In [None]:
cut_base = 8
class UpsampleModel():
    def __init__(self,model,name='upsample'):
        self.model,self.name = model,name

    def get_layer_groups(self, precompute):
        c = list(children(self.model.module))
        return [c[:cut_base],
               c[cut_base:]]

In [None]:
def jaccard_coef(y_true, y_pred, thresh=0.5):
    smooth = 1e-12
    ma = torch.max(y_pred)
    mi = torch.min(y_pred)
    y_pred = to_np((y_pred - mi) / (ma - mi) > thresh)
    y_true = np.round(to_np(y_true))
    intersection = np.sum(y_true * y_pred)
    sum_ = np.sum(y_true + y_pred)
    jac = (intersection + smooth) / (sum_ - intersection + smooth)
    return jac.mean()


def jaccard_coef_int(y_true, y_pred):
    smooth = 1e-12
    y_true = torch.round(y_true)
    y_pred_pos = torch.round(torch.clamp(y_pred, 0, 1))
    intersection = torch.sum(y_true * y_pred_pos)
    sum_ = torch.sum(y_true + y_pred_pos)
    jac = (intersection + smooth) / (sum_ - intersection + smooth)
    return jac.mean()

In [None]:
def get_mul_mean_stat(area_id):
    prefix = area_id_to_prefix(area_id)

    with tb.open_file(FMT_MULMEAN.format(prefix), 'r') as f:
        im_mean = np.array(f.get_node('/mulmean'))[:3]
    
    mean = [np.mean(im_mean[i]) for i in range(3)]
    std = [np.std(im_mean[i]) for i in range(3)]
    return np.stack([np.array(mean), np.array(std)])

def get_md_model(datapaths, device_ids=range(7)):
    aug_tfms = transforms_top_down
    for o in aug_tfms: o.tfm_y = TfmType.CLASS
        
    area_ids = [directory_name_to_area_id(datapath) for datapath in datapaths]
    stats = np.mean([get_mul_mean_stat(area_id) for area_id in area_ids], axis=0)
    tfms = tfms_from_stats(stats, sz, crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)
    
    datasets = ImageData.get_ds(ArraysSingleDataset, (trn_x, trn_y), (val_x, val_y), tfms)
    md = ImageData('data', datasets, bs, num_workers=int(np.ceil(bs / 3)), classes=None)
    denorm = md.trn_ds.denorm

    if not Path(MODEL_DIR).exists():
        Path(MODEL_DIR).mkdir(parents=True)

    net = to_gpu(UNet16(pretrained='vgg'))
    net = nn.DataParallel(net, device_ids)
    models = UpsampleModel(net)
    return md, models, denorm

def expanded_loss(pred, target):
#     pred = torch.clamp(pred, 0, 1)
    return F.binary_cross_entropy_with_logits(pred[:,0], target)

In [None]:
md, model, denorm = get_md_model([datapaths[0]])

In [None]:
learn=ConvLearner(md, model)
learn.opt_fn=optim.Adam
learn.crit=expanded_loss
learn.metrics=[jaccard_coef]

In [None]:
learn.model.module.load_state_dict(torch.load('data/models/unfreezed_1.h5'))

In [None]:
learn.lr_find()
learn.sched.plot()

In [None]:
learn.sched

In [None]:
learn.sched.plot_loss()

In [None]:
lr = 1e-5
learn.freeze_to(1)
learn.fit(lr,1,cycle_len=8,use_clr=(20,8))

In [None]:
learn.save('freezed_1')

In [None]:
x, y = md.trn_dl.get_batch(range(9))

In [None]:
for i in range(1, 10):
    plt.subplot(3, 6, i*2)
    plt.imshow(denorm(x[i-1])[0])
    plt.subplot(3, 6, i*2-1)
    plt.imshow(y[i-1])
#     plt.imshow(to_np(learn.model(V(x[i-1][None]))).squeeze())

In [None]:
for i in range(1, 10):
    plt.subplot(3, 6, i*2)
    plt.imshow(denorm(x[i-1])[0])
    plt.subplot(3, 6, i*2-1)
#     plt.imshow(y[i-1])
    pred = to_np(learn.model(V(x[i-1][None]))).squeeze()
#     pred = np.clip(pred, 0, 1)
    plt.imshow(pred)

In [None]:
learn.unfreeze()
learn.bn_freeze(True)
lrs = np.array([lr/3,lr]) / 5

In [None]:
learn.fit(lrs,5,cycle_len=40,use_clr=(20,8))

In [None]:
md.bs

In [None]:
learn.sched.plot_lr()

In [None]:
learn.sched.plot_loss()

In [None]:
learn.save('unfreezed_2')

In [None]:
learn.load('unfreezed_2')

In [None]:
x, y = md.trn_dl.get_batch(np.arange(1, 10) * 4)
preds = learn.model(V(x))

In [None]:
for i in range(1, 10):
    plt.subplot(3, 6, i*2)
    plt.imshow(denorm(x[i-1])[0])
    plt.subplot(3, 6, i*2-1)
    plt.imshow(y[i-1])
#     plt.imshow(to_np(learn.model(V(x[i-1][None]))).squeeze())

In [None]:
learn.predict()

In [None]:
for i in range(1, 10):
    plt.subplot(3, 6, i*2)
    plt.imshow(denorm(x[i-1])[0])
    plt.subplot(3, 6, i*2-1)
#     plt.imshow(y[i-1])
    pred = to_np(learn.model(V(x[i-1][None]))).squeeze()
#     pred = np.clip(pred, 0, 1)
    plt.imshow(pred)

In [None]:
plt.imshow(to_np(y[0]))

In [None]:
plt.imshow(to_np(pred))

In [None]:
t = to_np(learn.model(V(x[1-1][None]))).squeeze()
ma = np.max(t)
mi = np.min(t)
ta = (t - mi) / (ma - mi)
print(ta)

In [None]:
plt.imshow(denorm(x[1-1])[0])

In [None]:
plt.imshow(t>0.5)

In [None]:
smooth = 1e-12
intersection = torch.sum(y_true * y_pred)
sum_ = torch.sum(y_true + y_pred)
jac = (intersection + smooth) / (sum_ - intersection + smooth)
print(jac)

## Ditched experiments

In [None]:
class H5Dataset(BaseDataset):
    def __init__(self, idxs, y, transform, datapaths=datapaths, is_rgb=True):
        area_ids = [directory_name_to_area_id(datapath) for datapath in datapaths]
        self.prefixes = [area_id_to_prefix(area_id) for area_id in area_ids]
        self.is_rgb = is_rgb
        self.file_lists = [FMT_VALTRAIN_IMAGELIST_PATH.format(prefix=prefix) for prefix in self.prefixes] +\
            [FMT_VALTEST_IMAGELIST_PATH.format(prefix=prefix) for prefix in self.prefixes]
        self.x_h5_lists = [FMT_VALTRAIN_MUL_STORE.format(prefix) for prefix in self.prefixes] +\
            [FMT_VALTEST_MUL_STORE.format(prefix) for prefix in self.prefixes]
        self.y_h5_lists = [FMT_VALTRAIN_MASK_STORE.format(prefix) for prefix in self.prefixes] +\
            [FMT_VALTEST_MASK_STORE.format(prefix) for prefix in self.prefixes]
        self.idxs = idxs # idx of trn or val. 0 ... len-1. Generate by permutation
        if transform is not None:
            super().__init__(transform)
        self.ys = y
        
        # open all files
        self.x_h5_lists_open = [tb.open_file(o) for o in self.x_h5_lists]
        self.df_lists = [pd.read_csv(o) for o in self.file_lists]

        # choose next h5 after one is exhausted
        self.ns = []
        for file_list in self.file_lists:
            df = pd.read_csv(file_list)
            self.ns.append(df.shape[0])
        self.ns = np.array(self.ns) # number of pre-crop images
        self.cum_ns = np.cumsum(self.ns * num_slice)
        
    @staticmethod
    def load_y(datapaths=datapaths):
        dummy_dataset = H5Dataset(None, None, None, datapaths=datapaths)
            
        y = []
        print('Loading masks...')
        for idx_file, df_list in enumerate(dummy_dataset.df_lists):
            with tb.open_file(dummy_dataset.y_h5_lists[idx_file]) as f:
                ys = []
                for idx_im in tqdm.tqdm(range(dummy_dataset.ns[idx_file] * num_slice), total=dummy_dataset.ns[idx_file] * num_slice):
                    slice_pos = idx_im % num_slice
                    im = np.array(f.get_node('/' + df_list.iloc[idx_im // num_slice][0] + '_' + str(slice_pos)))
                    im = np.broadcast_to(im[...,None], (256, 256, 3))
                    ys.append(im.astype('float'))
            y += ys
        return np.array(y)

        
    def get_sz(self): return self.transform.sz
    
    def get_file_idx(self,i):
        idx_file = np.argmax(i + 1 <= self.cum_ns)
        if idx_file == 0:
            idx_im = i
        else:
            idx_im = i - self.cum_ns[idx_file - 1]
        return idx_file, idx_im
    
    def get_x(self, i):
        idx_file, idx_im = self.get_file_idx(i)
        h5_list_open = self.x_h5_lists_open
        f = h5_list_open[idx_file]
        df_list = self.df_lists[idx_file]
        slice_pos = idx_im % num_slice
        
        im = np.array(f.get_node('/' + df_list.iloc[idx_im // num_slice][0] + '_' + str(slice_pos)))
        if self.is_rgb:
            # Or other bands
            im = im[...,:3]
        return im.astype('float')
    
    def get_y(self, i):
        return self.ys[i]
        
    def get_c(self): return 1
        
    def get_n(self): return self.idxs.shape[0]

#     def resize_imgs(self, targ, new_path):
#         dest = resize_imgs(self.fnames, targ, self.path, new_path)
#         return self.__class__(self.fnames, self.y, self.transform, dest)

    def denorm(self,arr):
        """Reverse the normalization done to a batch of images.

        Arguments:
            arr: of shape/size (N,3,sz,sz)
        """
        if type(arr) is not np.ndarray: arr = to_np(arr)
        if len(arr.shape)==3: arr = arr[None]
        return self.transform.denorm(np.rollaxis(arr,1,4))

    @staticmethod
    def get_ns(datapaths=datapaths):
        return H5Dataset(None, None, None, datapaths=datapaths).cum_ns[-1]

In [None]:
# # rgb scaled
# def get_rgb_scaled(datapath):
#     area_id = directory_name_to_area_id(datapath)
#     prefix = area_id_to_prefix(area_id)
    
#     X_val = []
#     fn_im = FMT_VALTEST_IM_STORE.format(prefix)
#     with tb.open_file(fn_im, 'r') as f:
#         for idx, image_id in enumerate(df_test.ImageId.tolist()):
#             im = np.array(f.get_node('/' + image_id))
#             im = np.swapaxes(im, 0, 2)
#             im = np.swapaxes(im, 1, 2)
#             X_val.append(im)
#     X_val = np.array(X_val)

#     y_val = []
#     fn_mask = FMT_VALTEST_MASK_STORE.format(prefix)
#     with tb.open_file(fn_mask, 'r') as f:
#         for idx, image_id in enumerate(df_test.ImageId.tolist()):
#             mask = np.array(f.get_node('/' + image_id))
#             mask = (mask > 0.5).astype(np.uint8)
#             y_val.append(mask)
#     y_val = np.array(y_val)
#     y_val = y_val.reshape((-1, 1, INPUT_SIZE, INPUT_SIZE))
#     return X_val, y_val

In [None]:
# (trn_x,trn_y), (val_x,val_y) = get_dataset(datapath)

In [None]:
def merge_file_list():
    df = pd.DataFrame()
    # clear file
    df.to_csv(FMT_VALTRAIN_IMAGELIST_PATH_ALL)
    for datapath in datapaths:
        area_id = directory_name_to_area_id(datapath)        
        prefix = area_id_to_prefix(area_id)
        fn_train = FMT_VALTRAIN_IMAGELIST_PATH.format(prefix=prefix)
        df_train = pd.read_csv(fn_train)
#         fn_a = FMT_VALTRAIN_IMAGELIST_PATH_ALL
#         fn_im = FMT_VALTRAIN_MUL_STORE.format(prefix)
        fn_test = FMT_VALTEST_IMAGELIST_PATH.format(prefix=prefix)
        df_test = pd.read_csv(fn_test)
        df = df.append(df_train).append(df_test)
        
    df.to_csv(FMT_VALTRAIN_IMAGELIST_PATH_ALL)
        

In [None]:
import traceback
def merge_im(rgb=False):
    if rgb:
        fn_store_w = FMT_VALTRAIN_IM_STORE_ALL
    else:
        fn_store_w = FMT_VALTRAIN_MUL_STORE_ALL
        
    with tb.open_file(fn_store_w, 'w') as fw:
        for datapath in datapaths:
            try:
                area_id = directory_name_to_area_id(datapath)        
                prefix = area_id_to_prefix(area_id)

                # valtrain + valtest
                for (fn_store, fn_list) in [(FMT_VALTRAIN_MUL_STORE.format(prefix), FMT_VALTRAIN_IMAGELIST_PATH.format(prefix=prefix)),
                                         (FMT_VALTEST_MUL_STORE.format(prefix), FMT_VALTEST_IMAGELIST_PATH.format(prefix=prefix))]:
                    df_list = pd.read_csv(fn_list, index_col='ImageId')
                    with tb.open_file(fn_store, 'r') as fr:
                        for idx, image_id in tqdm.tqdm(enumerate(df_list.index), total=df_list.shape[0]):
                            for slice_pos in range(9):
                                slice_id = image_id + '_' + str(slice_pos)
                                im = np.array(fr.get_node('/' + slice_id))
                                im = np.swapaxes(im, 0, 2)
                                im = np.swapaxes(im, 1, 2)
                                if rgb:
                                    im = im[:3,...]
                                atom = tb.Atom.from_dtype(im.dtype)
                                filters = tb.Filters(complib='blosc', complevel=9)
                                ds = fw.create_carray(fw.root, slice_id, atom, im.shape,
                                                     filters=filters)
                                ds[:] = im
            except Exception as e:
                traceback.print_exc()
                print(datapath, e)

In [None]:
# merge_file_list()
# df = pd.read_csv(FMT_VALTRAIN_IMAGELIST_PATH_ALL)
# df.head()

In [None]:
# merge_im(True)
# merge_im(False) # Too big