In [None]:
#######    
### This function prints off the most likely predicted 
### channels for each of the cells in our dataset
#######

#######    
### Load the Model Parts
#######

import argparse
import SimpleLogger as SimpleLogger

import importlib
import numpy as np

import os
import pickle

import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torchvision.utils

#have to do this import to be able to use pyplot in the docker image
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
from IPython import display
import time
from model_utils import set_gpu_recursive, load_model, save_state, save_progress, get_latent_embeddings, maybe_save

import torch.backends.cudnn as cudnn
cudnn.benchmark = True

import pdb

parent_dir = './test_aaegan/aaegan3Dv4_32D-relu_v3'

model_dir = parent_dir + os.sep + 'struct_model' 

# logger_file = '{0}/logger_tmp.pkl'.format(model_dir)
opt = pickle.load( open( '{0}/opt.pkl'.format(model_dir), "rb" ) )

print(opt)

DP = importlib.import_module("data_providers." + opt.dataProvider)
model_provider = importlib.import_module("models." + opt.model_name)
train_module = importlib.import_module("train_modules." + opt.train_module)

torch.manual_seed(opt.myseed)
torch.cuda.manual_seed(opt.myseed)
np.random.seed(opt.myseed)

if not os.path.exists(opt.save_dir):
    os.makedirs(opt.save_dir)
    
if opt.nepochs_pt2 == -1:
    opt.nepochs_pt2 = opt.nepochs

opts = {}
opts['verbose'] = True
opts['pattern'] = '*.tif_flat.png'
opts['out_size'] = [opt.imsize, opt.imsize]

data_path = './data_{0}x{1}.pyt'.format(str(opts['out_size'][0]), str(opts['out_size'][1]))
if os.path.exists(data_path):
    dp = torch.load(data_path)
else:
    dp = DP.DataProvider(opt.imdir, opts)
    torch.save(dp, data_path)
    
dp.opts['dtype'] = 'float'
    
if opt.ndat == -1:
    opt.ndat = dp.get_n_dat('train')    

iters_per_epoch = np.ceil(opt.ndat/opt.batch_size)    
            
#######    
### Load REFERENCE MODEL
#######

embeddings_path = opt.save_parent + os.sep + 'ref_model' + os.sep + 'embeddings.pkl'
if os.path.exists(embeddings_path):
    embeddings = torch.load(embeddings_path)
else:
    embeddings = get_latent_embeddings(models['enc'], dp, opt)
    torch.save(embeddings, embeddings_path)

models = None
optimizers = None
    
def get_ref(self, inds, train_or_test='train'):
    inds = torch.LongTensor(inds)
    return self.embeddings[train_or_test][inds]

dp.embeddings = embeddings

# do this thing to bind the get_ref method to the dataprovider object
import types  
dp.get_ref = types.MethodType(get_ref, dp)
            

opt.channelInds = [0, 1, 2]
dp.opts['channelInds'] = opt.channelInds
opt.nch = len(opt.channelInds)
        
opt.nClasses = dp.get_n_classes()
opt.nRef = opt.nlatentdim

try:
    train_module = None
    train_module = importlib.import_module("train_modules." + opt.train_module)
    train_module = train_module.trainer(dp, opt)
except:
    pass    

if not hasattr(opt, 'critRecon'):
    opt.critRecon = 'BCELoss'
    
if not hasattr(opt, 'dtype'):
    opt.dtype = 'float'

# pdb.set_trace()
opt.gpu_ids = [1,2,3]
models, optimizers, criterions, logger, opt = load_model(model_provider, opt)

enc = models['enc']
dec = models['dec']
enc.train(False)
dec.train(False)

models = None
optimizers = None


print('Done loading model.')

# Get the embeddings for the structure localization

opt.batch_size = 200
# opt.gpu_ids = [0,1,3]
enc.gpu_ids = opt.gpu_ids
dec.gpu_ids = opt.gpu_ids

embeddings_path = opt.save_dir + os.sep + 'embeddings_struct.pkl'
if os.path.exists(embeddings_path):
    embeddings = torch.load(embeddings_path)
else:
    embeddings = get_latent_embeddings(enc, dp, opt, 1)
    torch.save(embeddings, embeddings_path)

print('Done loading embeddings.')

In [None]:
#######    
### Main Loop
#######

import pdb
from aicsimage.io import omeTifWriter
from imgToProjection import imgtoprojection
from IPython.core.display import display
import PIL.Image
import matplotlib.pyplot as plt
import scipy.misc

import pandas as pd

enc = None

opt.batch_size = 300
gpu_id = opt.gpu_ids[0]

loss = nn.MSELoss()
embeddings_all = torch.cat([embeddings['train'], embeddings['test']], 0);

dat_train_test = ['train'] * len(embeddings['train']) + ['test'] * len(embeddings['test'])
dat_dp_inds = np.concatenate([np.arange(0, len(embeddings['train'])), np.arange(0, len(embeddings['test']))], axis=0).astype('int')
dat_inds = np.concatenate([dp.data['train']['inds'], dp.data['test']['inds']])

err_cols = ['err_' + train_or_test + '_' + str(i) + '_' + str(img_index) for train_or_test, i, img_index in zip(dat_train_test, dat_dp_inds, dat_inds)]

colormap = 'hsv'
colors = plt.get_cmap(colormap)(np.linspace(0, 1, 4))

px_size = [0.3873, 0.3873, 0.3873]

train_or_test_split = ['test', 'train']

img_paths_all = list()

save_parent = opt.save_dir + os.sep + 'var_test' + os.sep
if not os.path.exists(save_parent):
    os.makedirs(save_parent)

def convert_image(img):
    img = img.data[0].cpu().numpy()
    img = np.transpose(img, (3, 0, 1, 2))
    
    return img

# For train or test
# pdb.set_trace()

for train_or_test, i, img_index in zip(dat_train_test, dat_dp_inds, dat_inds):
    print(str(i) + os.sep + str(len(dat_dp_inds)))
    
    img_class = dp.image_classes[img_index]    
    img_class_onehot = dp.get_classes([i], train_or_test, 'onehot')
    
    img_name = dp.get_image_paths([i], train_or_test)[0]    
    img_name = os.path.basename(img_name)
    img_name = img_name[0:img_name.rfind('.')]
    
    save_dir = save_parent + os.sep + train_or_test + os.sep + img_name
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    err_save_path = save_dir + os.sep + img_name + '.csv'
    if os.path.exists(err_save_path):
        continue

    #Load the image
    img_in = dp.get_images([i], train_or_test)
    img_in = Variable(img_in.cuda(gpu_id), volatile=True)

    img_in_struct = torch.index_select(img_in, 1, torch.LongTensor([1]).cuda(gpu_id))

    
    
    #pass forward through the model
#     enc.gpu_ids = [gpu_id]
#     z = enc(img_in)

    shape_embedding = embeddings[train_or_test][i]
    
    #set the class label so it is correct
    img_class_onehot_log = (img_class_onehot - 1) * 25

    #go through embeddings
    nembeddings = embeddings_all.size()[0]
    inds = list(range(0,nembeddings))
    data_iter = [inds[j:j+opt.batch_size] for j in range(0, len(inds), opt.batch_size)]        

    errors = list()
    z = [None] * 3

    for j in range(0, len(data_iter)):
        print('cell: ' + str(i) + ', ' + str(j) + '/' + str(len(data_iter)))

        batch_inds = data_iter[j]
        batch_size = len(data_iter[j])

        z[0] = Variable(img_class_onehot_log.repeat(batch_size, 1).float(), volatile=True).cuda(gpu_id)
        z[1] = Variable(shape_embedding.repeat(batch_size,1), volatile=True).cuda(gpu_id)

        struct_embeddings = embeddings_all.index(torch.Tensor(batch_inds).long())
        z[2] = Variable(struct_embeddings, volatile=True).cuda(gpu_id)

#             try:
        imgs_out = dec(z)
#             except:
#                 pdb.set_trace()

        imgs_out = torch.index_select(imgs_out, 1, torch.LongTensor([1]).cuda(gpu_id))

        for img in imgs_out:
            errors.append(loss(img, img_in_struct).data[0])

    tot_inten = torch.sum(img_in_struct).data[0]

    df = pd.DataFrame([[img_index, i, img_class, img_name, tot_inten, train_or_test] + errors], columns=['img_index', 'data_provider_index', 'label', 'path', 'tot_inten', 'train_or_test'] + err_cols)
    df.to_csv(err_save_path)
        


0/15827
1/15827
2/15827
3/15827
4/15827
5/15827
6/15827
7/15827
8/15827
9/15827
10/15827
11/15827
12/15827
13/15827
14/15827
15/15827
16/15827
17/15827
18/15827
19/15827
20/15827
21/15827
22/15827
23/15827
24/15827
25/15827
26/15827
27/15827
28/15827
29/15827
30/15827
31/15827
32/15827
33/15827
34/15827
35/15827
36/15827
37/15827
38/15827
39/15827
40/15827
41/15827
42/15827
43/15827
44/15827
45/15827
46/15827
47/15827
48/15827
49/15827
50/15827
51/15827
52/15827
53/15827
54/15827
55/15827
56/15827
57/15827
58/15827
59/15827
60/15827
61/15827
62/15827
63/15827
64/15827
65/15827
66/15827
67/15827
68/15827
69/15827
70/15827
71/15827
72/15827
73/15827
74/15827
75/15827
76/15827
77/15827
78/15827
79/15827
80/15827
81/15827
82/15827
83/15827
84/15827
85/15827
86/15827
87/15827
88/15827
89/15827
90/15827
91/15827
92/15827
93/15827
94/15827
95/15827
96/15827
97/15827
98/15827
99/15827
100/15827
101/15827
102/15827
103/15827
104/15827
105/15827
106/15827
107/15827
108/15827
109/15827
110/15827


831/15827
832/15827
833/15827
834/15827
835/15827
836/15827
837/15827
838/15827
839/15827
840/15827
841/15827
842/15827
843/15827
844/15827
845/15827
846/15827
847/15827
848/15827
849/15827
850/15827
851/15827
852/15827
853/15827
854/15827
855/15827
856/15827
857/15827
858/15827
859/15827
860/15827
861/15827
862/15827
863/15827
864/15827
865/15827
866/15827
867/15827
868/15827
869/15827
870/15827
871/15827
872/15827
873/15827
874/15827
875/15827
876/15827
877/15827
878/15827
879/15827
880/15827
881/15827
882/15827
883/15827
884/15827
885/15827
886/15827
887/15827
888/15827
889/15827
890/15827
891/15827
892/15827
893/15827
894/15827
895/15827
896/15827
897/15827
898/15827
899/15827
900/15827
901/15827
902/15827
903/15827
904/15827
905/15827
906/15827
907/15827
908/15827
909/15827
910/15827
911/15827
912/15827
913/15827
914/15827
915/15827
916/15827
917/15827
918/15827
919/15827
920/15827
921/15827
922/15827
923/15827
924/15827
925/15827
926/15827
927/15827
928/15827
929/15827
930/15827


1619/15827
1620/15827
1621/15827
1622/15827
1623/15827
1624/15827
1625/15827
1626/15827
1627/15827
1628/15827
1629/15827
1630/15827
1631/15827
1632/15827
1633/15827
1634/15827
1635/15827
1636/15827
1637/15827
1638/15827
1639/15827
1640/15827
1641/15827
1642/15827
1643/15827
1644/15827
1645/15827
1646/15827
1647/15827
1648/15827
1649/15827
1650/15827
1651/15827
1652/15827
1653/15827
1654/15827
1655/15827
1656/15827
1657/15827
1658/15827
1659/15827
1660/15827
1661/15827
1662/15827
1663/15827
1664/15827
1665/15827
1666/15827
1667/15827
1668/15827
1669/15827
1670/15827
1671/15827
1672/15827
1673/15827
1674/15827
1675/15827
1676/15827
1677/15827
1678/15827
1679/15827
1680/15827
1681/15827
1682/15827
1683/15827
1684/15827
1685/15827
1686/15827
1687/15827
1688/15827
1689/15827
1690/15827
1691/15827
1692/15827
1693/15827
1694/15827
1695/15827
1696/15827
1697/15827
1698/15827
1699/15827
1700/15827
1701/15827
1702/15827
1703/15827
1704/15827
1705/15827
1706/15827
1707/15827
1708/15827
1709/15827

2381/15827
2382/15827
2383/15827
2384/15827
2385/15827
2386/15827
2387/15827
2388/15827
2389/15827
2390/15827
2391/15827
2392/15827
2393/15827
2394/15827
2395/15827
2396/15827
2397/15827
2398/15827
2399/15827
2400/15827
2401/15827
2402/15827
2403/15827
2404/15827
2405/15827
2406/15827
2407/15827
2408/15827
2409/15827
2410/15827
2411/15827
2412/15827
2413/15827
2414/15827
2415/15827
2416/15827
2417/15827
2418/15827
2419/15827
2420/15827
2421/15827
2422/15827
2423/15827
2424/15827
2425/15827
2426/15827
2427/15827
2428/15827
2429/15827
2430/15827
2431/15827
2432/15827
2433/15827
2434/15827
2435/15827
2436/15827
2437/15827
2438/15827
2439/15827
2440/15827
2441/15827
2442/15827
2443/15827
2444/15827
2445/15827
2446/15827
2447/15827
2448/15827
2449/15827
2450/15827
2451/15827
2452/15827
2453/15827
2454/15827
2455/15827
2456/15827
2457/15827
2458/15827
2459/15827
2460/15827
2461/15827
2462/15827
2463/15827
2464/15827
2465/15827
2466/15827
2467/15827
2468/15827
2469/15827
2470/15827
2471/15827

3147/15827
3148/15827
3149/15827
3150/15827
3151/15827
3152/15827
3153/15827
3154/15827
3155/15827
3156/15827
3157/15827
3158/15827
3159/15827
3160/15827
3161/15827
3162/15827
3163/15827
3164/15827
3165/15827
3166/15827
3167/15827
3168/15827
3169/15827
3170/15827
3171/15827
3172/15827
3173/15827
3174/15827
3175/15827
3176/15827
3177/15827
3178/15827
3179/15827
3180/15827
3181/15827
3182/15827
3183/15827
3184/15827
3185/15827
3186/15827
3187/15827
3188/15827
3189/15827
3190/15827
3191/15827
3192/15827
3193/15827
3194/15827
3195/15827
3196/15827
3197/15827
3198/15827
3199/15827
3200/15827
3201/15827
3202/15827
3203/15827
3204/15827
3205/15827
3206/15827
3207/15827
3208/15827
3209/15827
3210/15827
3211/15827
3212/15827
3213/15827
3214/15827
3215/15827
3216/15827
3217/15827
3218/15827
3219/15827
3220/15827
3221/15827
3222/15827
3223/15827
3224/15827
3225/15827
3226/15827
3227/15827
3228/15827
3229/15827
3230/15827
3231/15827
3232/15827
3233/15827
3234/15827
3235/15827
3236/15827
3237/15827

cell: 3745, 38/53
cell: 3745, 39/53
cell: 3745, 40/53
cell: 3745, 41/53
cell: 3745, 42/53
cell: 3745, 43/53
cell: 3745, 44/53
cell: 3745, 45/53
cell: 3745, 46/53
cell: 3745, 47/53
cell: 3745, 48/53
cell: 3745, 49/53
cell: 3745, 50/53
cell: 3745, 51/53
cell: 3745, 52/53
3746/15827
cell: 3746, 0/53
cell: 3746, 1/53
cell: 3746, 2/53
cell: 3746, 3/53
cell: 3746, 4/53
cell: 3746, 5/53
cell: 3746, 6/53
cell: 3746, 7/53
cell: 3746, 8/53
cell: 3746, 9/53
cell: 3746, 10/53
cell: 3746, 11/53
cell: 3746, 12/53
cell: 3746, 13/53
cell: 3746, 14/53
cell: 3746, 15/53
cell: 3746, 16/53
cell: 3746, 17/53
cell: 3746, 18/53
cell: 3746, 19/53
cell: 3746, 20/53
cell: 3746, 21/53
cell: 3746, 22/53
cell: 3746, 23/53
cell: 3746, 24/53
cell: 3746, 25/53
cell: 3746, 26/53
cell: 3746, 27/53
cell: 3746, 28/53
cell: 3746, 29/53
cell: 3746, 30/53
cell: 3746, 31/53
cell: 3746, 32/53
cell: 3746, 33/53
cell: 3746, 34/53
cell: 3746, 35/53
cell: 3746, 36/53
cell: 3746, 37/53
cell: 3746, 38/53
cell: 3746, 39/53
cell: 374

cell: 3755, 16/53
cell: 3755, 17/53
cell: 3755, 18/53
cell: 3755, 19/53
cell: 3755, 20/53
cell: 3755, 21/53
cell: 3755, 22/53
cell: 3755, 23/53
cell: 3755, 24/53
cell: 3755, 25/53
cell: 3755, 26/53
cell: 3755, 27/53
cell: 3755, 28/53
cell: 3755, 29/53
cell: 3755, 30/53
cell: 3755, 31/53
cell: 3755, 32/53
cell: 3755, 33/53
cell: 3755, 34/53
cell: 3755, 35/53
cell: 3755, 36/53
cell: 3755, 37/53
cell: 3755, 38/53
cell: 3755, 39/53
cell: 3755, 40/53
cell: 3755, 41/53
cell: 3755, 42/53
cell: 3755, 43/53
cell: 3755, 44/53
cell: 3755, 45/53
cell: 3755, 46/53
cell: 3755, 47/53
cell: 3755, 48/53
cell: 3755, 49/53
cell: 3755, 50/53
cell: 3755, 51/53
cell: 3755, 52/53
3756/15827
cell: 3756, 0/53
cell: 3756, 1/53
cell: 3756, 2/53
cell: 3756, 3/53
cell: 3756, 4/53
cell: 3756, 5/53
cell: 3756, 6/53
cell: 3756, 7/53
cell: 3756, 8/53
cell: 3756, 9/53
cell: 3756, 10/53
cell: 3756, 11/53
cell: 3756, 12/53
cell: 3756, 13/53
cell: 3756, 14/53
cell: 3756, 15/53
cell: 3756, 16/53
cell: 3756, 17/53
cell: 375

cell: 3765, 46/53
cell: 3765, 47/53
cell: 3765, 48/53
cell: 3765, 49/53
cell: 3765, 50/53
cell: 3765, 51/53
cell: 3765, 52/53
3766/15827
3767/15827
cell: 3767, 0/53
cell: 3767, 1/53
cell: 3767, 2/53
cell: 3767, 3/53
cell: 3767, 4/53
cell: 3767, 5/53
cell: 3767, 6/53
cell: 3767, 7/53
cell: 3767, 8/53
cell: 3767, 9/53
cell: 3767, 10/53
cell: 3767, 11/53
cell: 3767, 12/53
cell: 3767, 13/53
cell: 3767, 14/53
cell: 3767, 15/53
cell: 3767, 16/53
cell: 3767, 17/53
cell: 3767, 18/53
cell: 3767, 19/53
cell: 3767, 20/53
cell: 3767, 21/53
cell: 3767, 22/53
cell: 3767, 23/53
cell: 3767, 24/53
cell: 3767, 25/53
cell: 3767, 26/53
cell: 3767, 27/53
cell: 3767, 28/53
cell: 3767, 29/53
cell: 3767, 30/53
cell: 3767, 31/53
cell: 3767, 32/53
cell: 3767, 33/53
cell: 3767, 34/53
cell: 3767, 35/53
cell: 3767, 36/53
cell: 3767, 37/53
cell: 3767, 38/53
cell: 3767, 39/53
cell: 3767, 40/53
cell: 3767, 41/53
cell: 3767, 42/53
cell: 3767, 43/53
cell: 3767, 44/53
cell: 3767, 45/53
cell: 3767, 46/53
cell: 3767, 47/5

In [None]:
type(dat_dp_inds[0])


In [None]:
save_all_path = save_parent + os.sep + 'all_dat.csv'

if not os.path.exists(save_all_path):
    csv_list = list()

    # For train or test
    for train_or_test in train_or_test_split:
        ndat = dp.get_n_dat(train_or_test)
        # For each cell in the data split
        for i in range(0, ndat):
            print(str(i) + os.sep + str(ndat))

            img_index = dp.data[train_or_test]['inds'][i]
            img_class = dp.image_classes[img_index]

            img_class_onehot = dp.get_classes([i], train_or_test, 'onehot')
            img_name = os.path.basename(dp.image_paths[img_index])[0:-3]

            save_dir = save_parent + os.sep + train_or_test + os.sep + img_name
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)

            err_save_path = save_dir + os.sep + img_name + '.csv'
            if os.path.exists(err_save_path):
                csv_errors = pd.read_csv(err_save_path)
#                 csv_errors['train_or_test'] = train_or_test
                csv_list.append(csv_errors)
            else:
                print('Missing ' +  err_save_path)

    errors_all = pd.DataFrame(csv_list)

    errors_all.to_csv(save_all_path)
else:
    errors_all = pd.read_csv(save_all_path)

    
ulabels = np.unique(errors_all['label'])

In [None]:
from matplotlib import pyplot as plt
%matplotlib inline

plt.figure(num=None, figsize=(10, 5), dpi=80, facecolor='w', edgecolor='k')

errors = errors_all.filter(regex='recon_err')
errors_mean = errors.median(axis=1)

errors_mean = np.divide(errors_mean, errors_all['tot_inten'])

min_bin = np.min(errors_mean)
max_bin = np.max(errors_mean)

bins = np.linspace(min_bin, max_bin, 250)

c = 0

for train_or_test in train_or_test_split:
    c+=1
    plt.subplot(len(train_or_test_split), 1, c)
    
    train_inds = errors_all['train_or_test'] == train_or_test
    
    for label in ulabels:
        label_inds = errors_all['label'] == label
        
        inds = np.logical_and(train_inds, label_inds)
        
        legend_key = label
        plt.hist(errors_mean[inds], bins, alpha=0.5, label=legend_key, normed=True)
        
    
plt.legend(loc='upper right')
plt.show()

In [None]:
from data_providers.DataProvider3D import load_h5 
from model_utils import tensor2img
from IPython.core.display import display
import PIL.Image

def get_images(dp, paths):
    dims = list(dp.imsize)
    dims[0] = len(dp.opts['channelInds'])

    dims.insert(0, len(paths))

    images = torch.zeros(tuple(dims))

    if dp.opts['dtype'] == 'half':
        images = images.type(torch.HalfTensor)

    c = 0
    for h5_path in paths:
        image = load_h5(h5_path)
        image = torch.from_numpy(image)
        images[c] = image.index_select(0, torch.LongTensor(dp.opts['channelInds'])).clone()
        c += 1

    # images *= 2
    # images -= 1
    return images



for label in ulabels:
    print(label)
    label_inds = errors_all['label'] == label

    imgs_flat = list()
#         label_inds = errors_all['label'] == 'Alpha tubulin'
    for train_or_test in train_or_test_split:
#         print(train_or_test)
        train_inds = errors_all['train_or_test'] == train_or_test
        inds = np.where(np.logical_and(train_inds, label_inds))

        inds_sorted = np.argsort(errors_mean[inds[0]])

        errors_sub = errors_all.loc[inds[0][inds_sorted]]

        im_paths = [dp.image_paths[i] for i in errors_sub.iloc[0:10]['img_index']]
        img_out = get_images(dp, im_paths)
        img_flat_low_err = tensor2img(img_out)
        
        im_paths = [dp.image_paths[i] for i in errors_sub.iloc[-10:]['img_index']]
        img_out = get_images(dp, im_paths)
        img_flat_hi_err = tensor2img(img_out)
    
        imsize = img_flat_low_err.shape
        border = np.ones([imsize[0], 10, 3])
    
        img_flat = np.concatenate([img_flat_low_err, border, img_flat_hi_err], axis=1)
        imgs_flat.append(img_flat)
    
    display(PIL.Image.fromarray(np.uint8(np.concatenate(imgs_flat)*255)))


In [None]:
len(imgs)flat)

In [None]:
train_or_test

In [None]:
errors_mean[inds[0]].iloc[inds_sorted]

In [None]:
tensor2img(img_out)[:,:,1]