In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import torch
import kornia
import numpy as np
import gc
import os
import sys
import re
%matplotlib notebook
import matplotlib.pyplot as plt

module_path = os.path.abspath(os.path.join('.'))
if module_path not in sys.path:
    sys.path.append(module_path)

from helpers import Trainer
from helpers import Logger
from helpers import Trial
from mibi_dataloader import MIBIData
from modules import SelfSupervisedEstimator
from criteria import SelfSupervisedEstimatorLoss
import utils

In [2]:
def get_idx(string):
    parts = re.split('-', string)
    return int(parts[1])

def gdt_estimator(samples):
    bs = [10];
    while bs[-1]>.7:
        bs.append(bs[-1]*0.8)
    ts = torch.tensor(bs).float()**-2
    bs = np.array(bs)
    ts = ts.numpy()
    l_ests = torch.zeros(samples.shape)
    for i in range(samples.size(0)):
        l_ests[i,:,:,:], cimg = utils.estimate_lambda(samples[i,0,:,:], bs, ts)
        print('\rProcessing.......' + str(100 * i / samples.size(0)) + '%', end='')
    return l_ests

def net_estimator(samples, network):
    network.cuda()
    l_ests = torch.zeros(samples.shape)
    for i in range(samples.size(0)):
        with torch.no_grad():
            limg = network.process(samples[i,:,:,:].unsqueeze(0).cuda())
            limg[limg<0] = 0
            l_ests[i,:,:,:] = limg
            print('\rProcessing.......' + str(100 * i / samples.size(0)) + '%', end='')
    return l_ests

def gdt_synthesize_dataset(source_data_dir, main_dir, number):
    print('    Loading empirical MIBI data...')
    empirical_ds = MIBIData(folder=source_data_dir, crop=31, scale=1, stride=8, number=number)
    print('    Creating synthetic lambda data using gdt-estimator...')
    lambda_1 = gdt_estimator(empirical_ds.images)
    lambda_1_ds = MIBIData(images=lambda_1, labels='None', source=empirical_ds.source, crop=31, scale=1, stride=8)
    print('    Saving synthetic lambda data to disk...')
    lambda_1_ds.pickle(main_dir + 'synthetic_datasets/synthetic_lambda-1')
    print('    Initial synthetic lambda data saved.')
    return lambda_1_ds

def net_synthesize_dataset(source_data_dir, main_dir, network, number, index):
    print('    Loading empirical MIBI data...')
    empirical_ds = MIBIData(folder=source_data_dir, crop=31, scale=1, stride=8, number=number)
    print('    Creating synthetic lambda data using net-estimator...')
    lambda_i = net_estimator(empirical_ds.images, network)
    lambda_i_ds = MIBIData(images=lambda_i, labels='None', source=empirical_ds.source, crop=31, scale=1, stride=8)
    print('    Saving synthetic lambda data to disk...')
    lambda_i_ds.pickle(main_dir + 'synthetic_datasets/synthetic_lambda-' + str(index))
    return lambda_i_ds
    
def find_latest(thing, directory):
    if thing is 'model':
        elements = os.listdir(directory + 'models')
    elif thing is 'dataset':
        elements = os.listdir(directory + 'synthetic_datasets')
    else:
        print('Invalid argument in "find_latest", must search for a "model" or a "dataset".')
    if len(elements)==0:
        return 0, ''
    else:
        element_idxs = [get_idx(i) for i in elements]
        last_element_index = max(element_idxs)
        element_index = element_idxs.index(last_element_index)
        element_name = elements[element_index]
        return last_element_index, element_name

# train the network on the dataset using the training parameters
def train_estimator(estimator, dataset, estimator_train_args):
    torch.cuda.empty_cache()
    estimator.cuda()
    estimator_logger = Logger(['loss'])
    estimator_trainer = Trainer()
    estimator_criterion = SelfSupervisedEstimatorLoss()
    estimator_train_args['continue'] = False
    dataset.set_crop(estimator_train_args['crop'])
    
    print('Training lambda-estimator...')
    estimator_trainer.train(estimator, dataset, estimator_criterion, estimator_logger, main_dir + 'estimator/models/', **estimator_train_args)
    return estimator
    

# returns a model and a dataset, does some housekeeping
def load_model_and_dataset(main_dir, source_data_dir, number):
    torch.cuda.empty_cache()
    
    model_idx, model_name = find_latest('model', main_dir)
    if model_idx==0:
        print('    No pretrained model found, initializing random network.')
        estimator = SelfSupervisedEstimator()
    else:
        print('    Pretrained model found, loading network.')
        estimator = SelfSupervisedEstimator.load_model(main_dir + 'models/', model_name)
        
    ds_idx, ds_name = find_latest('dataset', main_dir)
    
    if ds_idx==0:
        print('    No synthetic data, generating synthetic data...')
        if model_idx==0:
            print('    No pretrained model found, using GDT estimator...')
            dataset = gdt_synthesize_dataset(source_data_dir, main_dir, number)
        else:
            print('    Using pretrained model to synthesize data...')
            dataset = net_synthesize_dataset(source_data_dir, main_dir, estimator, number, model_idx+1)
    else:
        if model_idx<ds_idx:
            dataset = MIBIData.depickle(main_dir + '/synthetic_datasets/' + ds_name)
        else: # model_idx > ds_idx
            dataset = net_synthesize_dataset(source_data_dir, main_dir, estimator, number, model_idx)
            
    return estimator, model_idx, dataset

    
def bootstrap_estimator(main_dir, source_data_dir, number, estimator_train_args, N):
    print('Bootstrapping lambda-estimator function...')
    for i in range(N):
        estimator, model_idx, dataset = load_model_and_dataset(main_dir, source_data_dir, number)
        estimator = train_estimator(estimator, dataset, estimator_train_args)
        estimator.save_model(main_dir + 'models/', 'estimator-' + str(model_idx+1))
        print('Finished iteration.')
        print()
    print('Finished with bootstrapping protocol.')

In [3]:
main_dir = '/home/hazmat/GitHub/Denoisotron/estimator/'
source_data_dir = '/home/hazmat/GitHub/Denoisotron/data/traindat/'
number = 10000

N = 10

estimator_train_args = dict()
estimator_train_args['lr'] = 0.0001
estimator_train_args['batch_size'] = 100
estimator_train_args['epochs'] = 1
estimator_train_args['report'] = 5
estimator_train_args['crop'] = 121
estimator_train_args['clip'] = 1
estimator_train_args['decay'] = 0
estimator_train_args['restart'] = False
estimator_train_args['epoch_frac'] = 0.1

In [4]:
bootstrap_estimator(main_dir, source_data_dir, number, estimator_train_args, N)

Bootstrapping lambda-estimator function...
    Pretrained model found, loading network.
    Loading empirical MIBI data...
Loading.......99.9899989999%998%%%%900
There are  35996400 samples
    Creating synthetic lambda data using net-estimator...
Processing.......99.9899989999%900%%%%
There are  35996400 samples
    Saving synthetic lambda data to disk...
2304
Training lambda-estimator...
2304
Epoch:0 > < 0.014462148648064546                                                                                                                  
trained in 2790.41073012352 seconds
Finished iteration.

    Pretrained model found, loading network.
    Loading empirical MIBI data...
Loading.......99.9899989999%998%%%%900
There are  35996400 samples
    Creating synthetic lambda data using net-estimator...
Processing.......99.9899989999%900%%%%
There are  35996400 samples
    Saving synthetic lambda data to disk...
2304
Training lambda-estimator...
2304
Epoch:0 > < 0.013910954567653932           