In [None]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
# default_exp apply
pass # xpython fix

# Apply preprocessing to db

> API

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
# export
from ifcimglib import imglmdb, utils, preprocessing, cif2lmdb
import numpy
import matplotlib.pyplot as plt
from tqdm import trange
import pickle
import logging
import lmdb
from pathlib import Path
from tqdm import tqdm
import os
import seaborn
from sklearn.model_selection import PredefinedSplit

In [None]:
# export

def preprocess_db_intra_image(db, preprocessed_output_path):
    logger = logging.getLogger(__name__)
    
    if Path(preprocessed_output_path).exists():
        Path(preprocessed_output_path).unlink()
    
    env = lmdb.open(preprocessed_output_path, lock=False, sync=False, map_size=cif2lmdb.map_size, subdir=False)
    logger.info("Opened lmdb database %s" % preprocessed_output_path)
    
    with env.begin(write=True) as txn:
        txn.put(b'__targets__', pickle.dumps(db.targets))
        txn.put(b'__len__', len(db).to_bytes(db.idx_byte_length, "big"))
        txn.put(b'__names__', " ".join(db.names).encode("utf-8"))
        
        for i in trange(len(db)):
            x, m, _ = db.get_image(i)
            x = x.astype(numpy.float32)
            x = preprocessing.log_transform(x, m, [1])
            x = preprocessing.min_max_normalize(x, m, "clip")
            x = preprocessing.crop_and_pad_to_square(x, 70)
            m = preprocessing.crop_and_pad_to_square(m.astype(numpy.uint8), 70).astype(bool)
            
            instance = cif2lmdb.get_instance(x.shape[1:], x.shape[0])
            instance = cif2lmdb.set_instance_data(instance, x.astype(numpy.float16), m)
            
            txn.put(i.to_bytes(db.idx_byte_length, byteorder='big'), pickle.dumps(instance))
    env.sync()
    env.close()

In [None]:
# export
def aggregate_fold_stats(db_paths, cv_pkl_file):
    preprocessed_db = imglmdb.multidbwrapper(sorted(db_paths))
    with open(cv_pkl_file, "rb") as pkl:
        test_fold, nested_test_folds = pickle.load(pkl)
        
    splitter = PredefinedSplit(test_fold)
    
    data = [{}]*splitter.get_n_splits()
    
    for i, (nested_test_fold, (_, test_idx)) in enumerate(zip(nested_test_folds, splitter.split())):        
        per_pixel_stats = preprocessing.compute_per_pixel_stats(preprocessed_db, None, idx=test_idx)
        std_per_pixel = numpy.where(per_pixel_stats[1] == 0.0, 1, per_pixel_stats[1])
        data[i]["outer"] = (per_pixel_stats[0], std_per_pixel)
        
        nested_splitter = PredefinedSplit(nested_test_fold)
        data[i]["nested"] = [{}]*nested_splitter.get_n_splits()
        
        for j, (train_idx, val_idx) in enumerate(nested_splitter.split()):
            per_pixel_stats = preprocessing.compute_per_pixel_stats(preprocessed_db, None, idx=train_idx)
            std_per_pixel = numpy.where(per_pixel_stats[1] == 0.0, 1, per_pixel_stats[1])
            data[i]["nested"][j]["train"] = (per_pixel_stats[0], std_per_pixel)
            
            per_pixel_stats = preprocessing.compute_per_pixel_stats(preprocessed_db, None, idx=val_idx)
            std_per_pixel = numpy.where(per_pixel_stats[1] == 0.0, 1, per_pixel_stats[1])
            data[i]["nested"][j]["val"] = (per_pixel_stats[0], std_per_pixel)
            
    with open(os.path.splitext(cv_pkl_file)[0] + "_stats.pkl", "wb") as pkl:
        pickle.dump(data, pkl)
        
    return data

In [None]:
for db_path in [f for f in Path("/home/maximl/scratch/data/wbc/").rglob("*.lmdb") if "preprocessed" not in str(f)]:
    db = imglmdb.imglmdb(str(db_path))
    preprocess_db_intra_image(db, os.path.join(*db_path.parts[:-1], "CD45_focused_singlets_preprocessed.lmdb"))

100%|██████████| 23030/23030 [00:20<00:00, 1117.98it/s]


In [None]:
preprocessed_db_paths = [str(f) for f in Path("/home/maximl/scratch/data/wbc/").rglob("*preprocessed.lmdb")]
cv_pkl_file = "/home/maximl/scratch/data/wbc/samplesplit_234_nested_3fold.pkl"

data = aggregate_fold_stats(preprocessed_db_paths, cv_pkl_file)

100%|██████████| 71599/71599 [00:14<00:00, 4942.38it/s]
100%|██████████| 71599/71599 [00:13<00:00, 5293.03it/s]
100%|██████████| 65336/65336 [00:11<00:00, 5471.76it/s]
100%|██████████| 65336/65336 [00:12<00:00, 5076.17it/s]
100%|██████████| 32668/32668 [00:06<00:00, 5023.30it/s]
100%|██████████| 32668/32668 [00:06<00:00, 5010.21it/s]
100%|██████████| 65336/65336 [00:11<00:00, 5582.70it/s]
100%|██████████| 65336/65336 [00:13<00:00, 4922.48it/s]
100%|██████████| 32668/32668 [00:06<00:00, 5297.60it/s]
100%|██████████| 32668/32668 [00:06<00:00, 4955.88it/s]
100%|██████████| 65336/65336 [00:11<00:00, 5486.94it/s]
100%|██████████| 65336/65336 [00:13<00:00, 4832.81it/s]
100%|██████████| 32668/32668 [00:05<00:00, 5715.46it/s]
100%|██████████| 32668/32668 [00:06<00:00, 4731.36it/s]
100%|██████████| 68010/68010 [00:13<00:00, 4949.34it/s]
100%|██████████| 68010/68010 [00:13<00:00, 5064.22it/s]
100%|██████████| 67728/67728 [00:12<00:00, 5444.02it/s]
100%|██████████| 67728/67728 [00:13<00:00, 5004.

In [None]:
with open(os.path.splitext(cv_pkl_file)[0] + "_stats.pkl", "rb") as pkl:
    data = pickle.load(pkl)