In [None]:
from PIL import Image
import nd2reader
import os
import cv2
import PIL
import numpy as np
from pims import ND2_Reader
import xml.etree.cElementTree as ET
import re
import pathos.multiprocessing
import multiprocessing
from datetime import datetime
from itertools import count
import dask
from dask.distributed import Client
from dask_jobqueue import SLURMCluster
from bokeh.models import HoverTool
import holoviews as hv
hv.notebook_extension()
hv.extension('bokeh')

In [None]:
class ND2_extractor():
    def __init__(self, nd2_file, file_directory, xml_file=None, xml_dir=None, output_path=None, frame_start=None, frame_end=None,
                 lanes_to_extract=None,channels_to_extract=None):
        self.input_path = file_directory
        self.nd2_file = nd2_file
        self.nd2_file_name = nd2_file[:-4]
        self.xml_file = xml_file
        self.xml_dir = xml_dir
        self.output_path = output_path
        self.main_dir = file_directory + "/" + self.nd2_file_name
        self.nd2_f = nd2_file
        self.file_dir = file_directory
        self.pos_dict = None
        self.pos_offset = None
        self.lane_dict = None
        self.single_pos = False
        self.channels = None
        self.frames = None
        self.lanes = None
        self.poses_to_extract = None


        self.frame_start = frame_start
        self.frame_end = frame_end
        self.lanes_to_extract = lanes_to_extract   # intermediate variables
        self.channels_to_extract = channels_to_extract
        self.nd2 = nd2reader.Nd2(self.nd2_f)    # for extraction iter
        self.nd2_new = ND2_Reader(self.nd2_file)  # for lane info iter
        
        self.pos_dict_xml=None

    def channel_info(self):
        self.channels = self.nd2.channels


    def lane_info(self):   # condition infos too
        lane_dict = {}
        lane_dict[0] = 1
        self
        pos_offset = {}
        cur_lane = 1
        pos_min = 0
        pos_offset[cur_lane] = pos_min - 1
        
        if 'm' in self.nd2_new.axes:
            self.nd2_new.iter_axes = 'm'
            y_prev = self.nd2_new[0].metadata['y_um']
            self.lanes = len(self.nd2_new)
            for i in range(1, self.lanes):
                f = self.nd2_new[i]
                y_now = f.metadata['y_um']
                if abs(y_now - y_prev) > 200:  # a new lane
                    cur_lane += 1
                    pos_min = i - 1
                    pos_offset[cur_lane] = pos_min
                lane_dict[i] = cur_lane
                y_prev = y_now
            self.lanes = cur_lane
            self.nd2_new.close()
        else:
            self.single_pos = True      # TODO: maybe unnecessary
        self.lane_dict = lane_dict
        self.pos_offset = pos_offset
        
        
        # get pos dict
        self.pos_dict = {v:[i for i in self.lane_dict.keys() 
                if self.lane_dict[i] == v ] 
             for k,v in self.lane_dict.items()}

    def select_cond(self):
        # channels
        self.channel_info()
        if self.channels_to_extract is None:
            self.channels_to_extract = [str(x) for x in self.channels]
        # lanes
        if self.lanes_to_extract is None:
            self.lanes_to_extract = list(range(1,self.lanes+1))
        
        self.poses_to_extract = []


        for lane in self.lanes_to_extract:
            self.poses_to_extract += list(self.pos_dict[lane])

        self.nd2_new.iter_axes = 't'
        self.frames = len(self.nd2_new)
        if self.frame_start is None:
            self.frame_start = 0
        if self.frame_end is None:
            self.frame_end = self.frames - 1
            
    

#     def tiff_extractor(self, pos):

#         if self.pos_dict_xml:
#             new_dir = self.main_dir + "/Lane_" + str(self.lane_dict[pos]).im(2) + "/" + self.pos_dict[pos] + "/"
#         else:
#             lane_ind = self.lane_dict[pos]
#             pos_off = self.pos_offset[lane_ind]
#             new_dir = self.main_dir + "/Lane_" + str(lane_ind).zfill(2) + "/pos_" + str(pos - pos_off).zfill(3) + "/"

#         # create a folder for each position
#         try:
#             os.makedirs(new_dir)
#         except OSError:
#             pass
#         os.chdir(new_dir)

#         if self.pos_dict_xml:
#             meta_name = self.nd2_file_name + "_" + self.pos_dict[pos] + "_Time_"
#         else:
#             meta_name = self.nd2_file_name + "_pos_" + str(pos - pos_off).zfill(3) + "_Time_"

#         for image in self.nd2.select(fields_of_view=pos, channels=self.channels_to_extract,start=self.frame_start,stop=self.frame_end):
#             channel = image._channel

#             channel = channel.encode('ascii', 'ignore')
#             channel = str(channel.decode("utf-8"))
#             # channel = str(channel.encode('ascii', 'ignore'))
#             # experimental, may not work

#             time_point = image.frame_number
#             tiff_name = meta_name + str(time_point).zfill(4) + "_c_" + channel + ".tiff"

#             # save file in 16-bit
#             # thanks to http://shortrecipes.blogspot.com/2009/01/python-python-imaging-library-16-bit.html
#             image = image.base.astype(np.uint16)
#             out = PIL.Image.frombytes("I;16", (image.shape[1], image.shape[0]), image.tobytes())
#             out.save(tiff_name)

#         os.chdir(self.file_dir)

    def run_extraction(self,):
        start_t = datetime.now()
        self.select_cond()

        os.chdir(self.input_path)
        # get position name if xml is available
        if self.xml_file:
            if not self.xml_dir:
                self.xml_dir = self.input_path
                self.pos_info()
        # otherwise get lane info from y_um
        else:
            self.lane_info()
        os.chdir(self.input_path)

        # switch to another ND2reader for faster iterations
        nd2 = nd2reader.Nd2(self.nd2_file)

        main_dir = self.input_path + "/" + self.nd2_file_name
        try:
            os.makedirs(main_dir)
        except OSError:
            pass

        # parallelize extraction
        # poses = nd2.fields_of_view
        poses = self.poses_to_extract
        cores = pathos.multiprocessing.cpu_count()
        print(poses, cores)
        pool = pathos.multiprocessing.Pool(cores)
        pool.map(self.tiff_extractor, poses)

        time_elapsed = datetime.now() - start_t
        print('Time elapsed for extraction (hh:mm:ss.ms) {}'.format(time_elapsed))
        
        
    
    def init_dask(self):
        self.lane_info()
        self.channel_info()
        self.select_cond()
        
        main_dir = self.input_path + "/" + self.nd2_file_name
        try:
            os.makedirs(main_dir)
        except OSError:
            pass
#         self.nd2_to_run = nd2reader.Nd2(self.nd2_file)
        return self.poses_to_extract,self.lane_dict,self.pos_offset,self.main_dir,self.channels_to_extract,self.frame_start, self.frame_end
    
    def run_dask(self):
        self.init_dask()
        n_workers_init = 6
        cluster = SLURMCluster(queue="short",walltime='0:10:00',job_cpu=1,job_mem='1G',cores=1,processes=1,memory='1GB')
        cluster.start_workers(n_workers_init)
        client = Client(cluster)
        client.map(self.tiff_extractor, self.poses_to_extract)
#         poses = self.poses_to_extract
#         cores = pathos.multiprocessing.cpu_count()
#         print(poses, cores)
#         pool = pathos.multiprocessing.Pool(cores)
#         pool.map(self.tiff_extractor, poses)

#         time_elapsed = datetime.now() - start_t
#         print('Time elapsed for extraction (hh:mm:ss.ms) {}'.format(time_elapsed))

In [None]:
file_directory = r"/n/scratch2/sw260/20181127"
nd2_file = "SB1_SJC110_EZRDM_L35W1.5.nd2"
os.chdir(file_directory)

In [None]:
nd2reader.Nd2(nd2_file)

In [None]:
lanes_to_extract=[9]
new_extractor = ND2_extractor(nd2_file, file_directory,lanes_to_extract=lanes_to_extract)



# new_extractor.run_dask()

In [None]:
poses_to_extract,lane_dict,pos_offset,main_dir,channels_to_extract,frame_start,frame_end= new_extractor.init_dask()


In [None]:
poses_to_extract,lane_dict,pos_offset,main_dir,channels_to_extract,frame_start,frame_end

In [None]:
n_workers_init = 10
cluster = SLURMCluster(queue="short",walltime='0:30:00',job_cpu=1,job_mem='500MB',cores=1,processes=1,memory='500MB')
cluster.start_workers(n_workers_init)
client = Client(cluster)

In [None]:
client

In [None]:

def extract_tiff(pos,lane_dict,pos_offset,main_dir,file_directory,nd2_file,channels_to_extract,frame_start,frame_end):
    nd2 = os.path.join(file_directory,nd2_file)
    nd2 = nd2reader.Nd2(nd2)
    nd2_file_name = nd2_file[:-4]
    lane_ind = lane_dict[pos]
    pos_off = pos_offset[lane_ind]
    new_dir = main_dir + "/Lane_" + str(lane_ind).zfill(2) + "/pos_" + str(pos - pos_off).zfill(3) + "/"
    try:
        os.makedirs(new_dir)
    except OSError:
        pass
    os.chdir(new_dir)

    meta_name = nd2_file_name + "_pos_" + str(pos - pos_off).zfill(3) + "_Time_"
    for channel_to_extract in channels_to_extract:
        for image in nd2.select(fields_of_view=pos, channels=channel_to_extract):
            time_point = image.frame_number
            if time_point>=frame_start and time_point<=frame_end:
                tiff_name = meta_name + str(time_point).zfill(4) + "_c_" + channel_to_extract + ".tiff"

                # save file in 16-bit
                # thanks to http://shortrecipes.blogspot.com/2009/01/python-python-imaging-library-16-bit.html
                image = image.base.astype(np.uint16)
                out = PIL.Image.frombytes("I;16", (image.shape[1], image.shape[0]), image.tobytes())
                out.save(tiff_name)
            
    


In [None]:
# pos = 1
# nd2 = os.path.join(file_directory,nd2_file)
# nd2 = nd2reader.Nd2(nd2)
# nd2_file_name = nd2_file[:-4]
# lane_ind = lane_dict[pos]
# pos_off = pos_offset[lane_ind]
# new_dir = main_dir + "/Lane_" + str(lane_ind).zfill(2) + "/pos_" + str(pos - pos_off).zfill(3) + "/"
# try:
#     os.makedirs(new_dir)
# except OSError:
#     pass
# os.chdir(new_dir)

In [None]:
futures = []

def run_extraction(pos):
    return extract_tiff(pos,lane_dict,pos_offset,main_dir,file_directory,nd2_file,channels_to_extract,frame_start,frame_end)


fut = client.map(run_extraction,poses_to_extract)
futures.append(fut)

all_futures = [f for sublist in futures for f in sublist]

dask.distributed.progress(all_futures)



In [None]:
# futures = []

    
# fut = client.map(run_extraction, poses_to_extract)
    

# futures.append(fut)


# dask.distributed.progress(futures)