In [1]:
#######    
### This function prints off the most likely predicted 
### channels for each of the cells in our dataset
#######

#######    
### Load the Model Parts
#######

import argparse

import SimpleLogger as SimpleLogger

import importlib
import numpy as np

import os
import pickle

import math

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torchvision.utils

#have to do this import to be able to use pyplot in the docker image
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
from IPython import display
import time
from model_utils import set_gpu_recursive, load_model, save_state, save_progress, get_latent_embeddings, maybe_save

import torch.backends.cudnn as cudnn
cudnn.benchmark = True

model_dirs = [
                './test_aaegan/aaegan3Dv4_32D/struct_model', 
                './test_aaegan/aaegan3Dv4_32D/ref_model' ]

model_dir = './test_aaegan/aaegan3Dv4_32D/struct_model' 

# logger_file = '{0}/logger_tmp.pkl'.format(model_dir)
opt = pickle.load( open( '{0}/opt.pkl'.format(model_dir), "rb" ) )

print(opt)

DP = importlib.import_module("data_providers." + opt.dataProvider)
model_provider = importlib.import_module("models." + opt.model_name)
train_module = importlib.import_module("train_modules." + opt.train_module)

torch.manual_seed(opt.myseed)
torch.cuda.manual_seed(opt.myseed)
np.random.seed(opt.myseed)

if not os.path.exists(opt.save_dir):
    os.makedirs(opt.save_dir)
    
if opt.nepochs_pt2 == -1:
    opt.nepochs_pt2 = opt.nepochs

opts = {}
opts['verbose'] = True
opts['pattern'] = '*.tif_flat.png'
opts['out_size'] = [opt.imsize, opt.imsize]

data_path = './data_{0}x{1}.pyt'.format(str(opts['out_size'][0]), str(opts['out_size'][1]))
if os.path.exists(data_path):
    dp = torch.load(data_path)
else:
    dp = DP.DataProvider(opt.imdir, opts)
    torch.save(dp, data_path)
    
if opt.ndat == -1:
    opt.ndat = dp.get_n_dat('train')    

iters_per_epoch = np.ceil(opt.ndat/opt.batch_size)    
            
#######    
### Load REFERENCE MODEL
#######

embeddings_path = opt.save_parent + os.sep + 'ref_model' + os.sep + 'embeddings.pkl'
if os.path.exists(embeddings_path):
    embeddings = torch.load(embeddings_path)
else:
    embeddings = get_latent_embeddings(models['enc'], dp, opt)
    torch.save(embeddings, embeddings_path)

models = None
optimizers = None
    
def get_ref(self, inds, train_or_test='train'):
    inds = torch.LongTensor(inds)
    return self.embeddings[train_or_test][inds]

dp.embeddings = embeddings

# do this thing to bind the get_ref method to the dataprovider object
import types  
dp.get_ref = types.MethodType(get_ref, dp)
            

opt.channelInds = [0, 1, 2]
dp.opts['channelInds'] = opt.channelInds
opt.nch = len(opt.channelInds)
        
opt.nClasses = dp.get_n_classes()
opt.nRef = opt.nlatentdim

try:
    train_module = None
    train_module = importlib.import_module("train_modules." + opt.train_module)
    train_module = train_module.trainer(dp, opt)
except:
    pass    

models, optimizers, criterions, logger, opt = load_model(model_provider, opt)


enc = models['enc']
dec = models['dec']
enc.train(False)
dec.train(False)



FileNotFoundError: [Errno 2] No such file or directory: './test_aaegan/aaegan3Dv4_32D/struct_model/opt.pkl'

In [42]:
#######    
### Main Loop
#######

import pdb
from aicsimage.io import omeTifWriter
from imgToProjection import imgtoprojection
from IPython.core.display import display
import PIL.Image
import matplotlib.pyplot as plt
import scipy.misc

import pandas as pd

gpu_id = 0

enc = models['enc']
dec = models['dec']
enc.train(False)
dec.train(False)

colormap = 'hsv'
colors = plt.get_cmap(colormap)(np.linspace(0, 1, 4))

px_size = [0.3873, 0.3873, 0.3873]

train_or_test_split = ['train', 'test']

img_paths_all = list()

save_parent = opt.save_dir + os.sep + 'images_out'
save_out_table = save_parent + os.sep + 'list_of_images.csv'

column_names = ['orig', 'recon'] + ['pred_' + name for name in dp.label_names] + ['train_or_test', 'orig_struct', 'img_index']

if not os.path.exists(save_parent):
    os.makedirs(save_parent)

def convert_image(img):
    img = img.data[0].cpu().numpy()
    img = np.transpose(img, (3, 0, 1, 2))
    
    return img

# For train or test
for train_or_test in train_or_test_split:
    ndat = dp.get_n_dat(train_or_test)
    # For each cell in the data split
    for i in range(0, ndat):
        print(str(i) + os.sep + str(ndat))
        
        
        img_index = dp.data[train_or_test]['inds'][i]
        img_class = dp.image_classes[img_index]
        img_name = os.path.basename(dp.image_paths[img_index])[0:-3]
        
        save_dir = save_parent + os.sep + train_or_test + os.sep + img_name
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        
        #Load the image
        img_in = dp.get_images([i], train_or_test)
        img_in = Variable(img_in.cuda(gpu_id))
        
        #pass forward through the model
        z = enc(img_in)
        img_recon = dec(z)
        
        pred_imgs = list()
        img_paths = list()
        
        #print original images
        img_orig = convert_image(img_in)
        channel_names = ['memb', img_class, 'dna']
        img_name = save_dir + os.sep + 'img' + str(img_index) + '.ome.tif'
        with omeTifWriter.OmeTifWriter(img_name, overwrite_file=True) as w:
            w.save(img_orig, channel_names=channel_names, pixels_physical_size=px_size)
        
        pred_imgs.append(img_orig)
        img_paths.append(img_name)
        
        #print reconstructed images
        img_recon = convert_image(img_recon)
        channel_names_recon = ['memb_recon', img_class + '_recon', 'dna_recon']
        img_name = save_dir + os.sep + 'img' + str(img_index) + '_' + img_class + '-recon.ome.tif'
        with omeTifWriter.OmeTifWriter(img_name, overwrite_file=True) as w:
            w.save(img_recon, channel_names=channel_names_recon, pixels_physical_size=px_size)

        pred_imgs.append(img_recon)
        img_paths.append(img_name)
        channel_names += channel_names_recon
        
        #for each structure type
        for j in range(0, dp.get_n_classes()):
            pred_class_name = dp.label_names[j]
            
            img_name = save_dir + os.sep + 'img' + str(img_index) + '_' + img_class + '-pred_' + pred_class_name + '.ome.tif'
            
            #Set the class label in log(one-hot) form
            z[0].data[0] = torch.zeros(z[0].size()).cuda(gpu_id)
            z[0].data[0][j] = 1
            z[0].data[0] = (z[0].data[0] - 1) * 25
            
            #Reference variable is set as z[1]
            
            #Set the structure variation variable to most probable
            z[-1] = torch.zeros(z[-1].size()).cuda(gpu_id)
            
            #generate image with these settings
            img_recon = dec(z)
            
            #convert the image and get only the GFP channel
            img_recon = convert_image(img_recon)
            img_recon = np.expand_dims(img_recon[:,1,:,:],1)
            
            #save the gfp channel
            with omeTifWriter.OmeTifWriter(img_name, overwrite_file=True) as w:
                w.save(img_recon, channel_names=[pred_class_name + '_pred'], pixels_physical_size=px_size)
            
            channel_names.append(pred_class_name + ' pred')
            
            pred_imgs.append(img_recon)
            img_paths.append(img_name)
            
        
        img_paths += [train_or_test, img_class, img_index]
        img_paths_all.append(img_paths)
        
        pred_imgs_all = np.concatenate(pred_imgs,1)
        
        # save the all-channels image (orig, recon, and predicted structures)
        img_name = save_dir + os.sep + 'img' + str(img_index) + '_' + img_class + '-pred_all.ome.tif'
        with omeTifWriter.OmeTifWriter(img_name, overwrite_file=True) as w:
                w.save(pred_imgs_all, channel_names=channel_names, pixels_physical_size=px_size)
                
        images_proj = list()
        
        # save flat images
        img_in = convert_image(img_in)
        
        img = np.transpose(img_in[0], (1,0,2,3))
        img = imgtoprojection(img, proj_all=True, colors = colors, global_adjust=True)
        img = np.transpose(img, (1,2,0))
        
        images_proj.append(img)
        
        img = np.transpose(pred_imgs[1], (1,0,2,3))
        img = imgtoprojection(img, proj_all=True, colors = colors, global_adjust=True)
        img = np.transpose(img, (1,2,0))

        images_proj.append(img)
        
        for j in range(2, len(pred_imgs)):
            img = np.transpose(pred_imgs[j], (1,0,2,3))
            img = imgtoprojection(img, proj_all=True, global_adjust=True)
            img = np.transpose(img, (1,2,0))
            
            images_proj.append(img)
        
        images_proj = np.concatenate(images_proj,1)
        
        scipy.misc.imsave(save_dir + os.sep + 'img' + str(img_index) + '_' + img_class + '-pred_all.png', images_proj)

#save the list of all images
img_paths_all_df = pd.DataFrame(img_paths_all, columns=column_names);
img_paths_all_df.to_csv(save_out_table)


0/5770
1/5770
2/5770
3/5770
4/5770
5/5770
6/5770
7/5770
8/5770
9/5770
10/5770
11/5770
12/5770
13/5770
14/5770
15/5770
16/5770
17/5770
18/5770
19/5770
20/5770
21/5770
22/5770
23/5770
24/5770
25/5770
26/5770
27/5770
28/5770
29/5770
30/5770
31/5770
32/5770
33/5770
34/5770
35/5770
36/5770
37/5770
38/5770
39/5770
40/5770
41/5770
42/5770
43/5770
44/5770
45/5770
46/5770
47/5770
48/5770
49/5770
50/5770
51/5770
52/5770
53/5770
54/5770
55/5770
56/5770
57/5770
58/5770
59/5770
60/5770
61/5770
62/5770
63/5770
64/5770
65/5770
66/5770
67/5770
68/5770
69/5770
70/5770
71/5770
72/5770
73/5770
74/5770
75/5770
76/5770
77/5770
78/5770
79/5770
80/5770
81/5770
82/5770
83/5770
84/5770
85/5770
86/5770
87/5770
88/5770
89/5770
90/5770
91/5770
92/5770
93/5770
94/5770
95/5770
96/5770
97/5770
98/5770
99/5770
100/5770
101/5770
102/5770
103/5770
104/5770
105/5770
106/5770
107/5770
108/5770
109/5770
110/5770
111/5770
112/5770
113/5770
114/5770
115/5770
116/5770
117/5770
118/5770
119/5770
120/5770
121/5770
122/5770
123

923/5770
924/5770
925/5770
926/5770
927/5770
928/5770
929/5770
930/5770
931/5770
932/5770
933/5770
934/5770
935/5770
936/5770
937/5770
938/5770
939/5770
940/5770
941/5770
942/5770
943/5770
944/5770
945/5770
946/5770
947/5770
948/5770
949/5770
950/5770
951/5770
952/5770
953/5770
954/5770
955/5770
956/5770
957/5770
958/5770
959/5770
960/5770
961/5770
962/5770
963/5770
964/5770
965/5770
966/5770
967/5770
968/5770
969/5770
970/5770
971/5770
972/5770
973/5770
974/5770
975/5770
976/5770
977/5770
978/5770
979/5770
980/5770
981/5770
982/5770
983/5770
984/5770
985/5770
986/5770
987/5770
988/5770
989/5770
990/5770
991/5770
992/5770
993/5770
994/5770
995/5770
996/5770
997/5770
998/5770
999/5770
1000/5770
1001/5770
1002/5770
1003/5770
1004/5770
1005/5770
1006/5770
1007/5770
1008/5770
1009/5770
1010/5770
1011/5770
1012/5770
1013/5770
1014/5770
1015/5770
1016/5770
1017/5770
1018/5770
1019/5770
1020/5770
1021/5770
1022/5770
1023/5770
1024/5770
1025/5770
1026/5770
1027/5770
1028/5770
1029/5770
1030/57

2784/5770
2785/5770
2786/5770
2787/5770
2788/5770
2789/5770
2790/5770
2791/5770
2792/5770
2793/5770
2794/5770
2795/5770
2796/5770
2797/5770
2798/5770
2799/5770
2800/5770
2801/5770
2802/5770
2803/5770
2804/5770
2805/5770
2806/5770
2807/5770
2808/5770
2809/5770
2810/5770
2811/5770
2812/5770
2813/5770
2814/5770
2815/5770
2816/5770
2817/5770
2818/5770
2819/5770
2820/5770
2821/5770
2822/5770
2823/5770
2824/5770
2825/5770
2826/5770
2827/5770
2828/5770
2829/5770
2830/5770
2831/5770
2832/5770
2833/5770
2834/5770
2835/5770
2836/5770
2837/5770
2838/5770
2839/5770
2840/5770
2841/5770
2842/5770
2843/5770
2844/5770
2845/5770
2846/5770
2847/5770
2848/5770
2849/5770
2850/5770
2851/5770
2852/5770
2853/5770
2854/5770
2855/5770
2856/5770
2857/5770
2858/5770
2859/5770
2860/5770
2861/5770
2862/5770
2863/5770
2864/5770
2865/5770
2866/5770
2867/5770
2868/5770
2869/5770
2870/5770
2871/5770
2872/5770
2873/5770
2874/5770
2875/5770
2876/5770
2877/5770
2878/5770
2879/5770
2880/5770
2881/5770
2882/5770
2883/5770


4770/5770
4771/5770
4772/5770
4773/5770
4774/5770
4775/5770
4776/5770
4777/5770
4778/5770
4779/5770
4780/5770
4781/5770
4782/5770
4783/5770
4784/5770
4785/5770
4786/5770
4787/5770
4788/5770
4789/5770
4790/5770
4791/5770
4792/5770
4793/5770
4794/5770
4795/5770
4796/5770
4797/5770
4798/5770
4799/5770
4800/5770
4801/5770
4802/5770
4803/5770
4804/5770
4805/5770
4806/5770
4807/5770
4808/5770
4809/5770
4810/5770
4811/5770
4812/5770
4813/5770
4814/5770
4815/5770
4816/5770
4817/5770
4818/5770
4819/5770
4820/5770
4821/5770
4822/5770
4823/5770
4824/5770
4825/5770
4826/5770
4827/5770
4828/5770
4829/5770
4830/5770
4831/5770
4832/5770
4833/5770
4834/5770
4835/5770
4836/5770
4837/5770
4838/5770
4839/5770
4840/5770
4841/5770
4842/5770
4843/5770
4844/5770
4845/5770
4846/5770
4847/5770
4848/5770
4849/5770
4850/5770
4851/5770
4852/5770
4853/5770
4854/5770
4855/5770
4856/5770
4857/5770
4858/5770
4859/5770
4860/5770
4861/5770
4862/5770
4863/5770
4864/5770
4865/5770
4866/5770
4867/5770
4868/5770
4869/5770


297/305
298/305
299/305
300/305
301/305
302/305
303/305
304/305
