# Image Library Tools

For use with a Pyspark backend (or fakespark, not fully supported yet)

Use of notebooks in spark requires appropriate environment variables to be set

```bash
PYSPARK_DRIVER_PYTHON=jupyter PYSPARK_DRIVER_PYTHON_OPTS=notebook PYSPARK_PYTHON=/Users/mader/anaconda/bin/python /Volumes/ExDisk/spark-2.0.0-bin-hadoop2.7/bin/pyspark
```

In [1]:
import numpy as np
from skimage.io import imread
import pandas as pd
from io import BytesIO, StringIO
from pyspark import SQLContext

from dicom import read_file
class PyqaeContext(object):
    """
    The primary context for performing PYQAE functions
    """
    def __init__(self, cur_sc = None, faulty_io = 'FAIL', retry_att = 5, *args, **kwargs):
        """
        Create or initialize a new Pyqae Context
        
        Parameters
        ----------
        cur_sc : SparkContext
            An existing initialized SparkContext, if none a new one is initialized with the other parameters.
        faulty_io : String
            A string indicating what should happen if a file is missing (FAIL, RETRY, or return an EMPTY value)
        retry_att : Int
            The number of times a retry should be attempted (if faulty_io is in mode RETRY otherwise ignored)
        """
        assert faulty_io in ['FAIL', 'RETRY', 'EMPTY'], "Faulty IO must be in the list of FAIL, RETRY, or EMPTY"
        assert retry_att>0, "Retry attempt must be greater than 0"
        self.faulty_io = faulty_io
        self.retry_att = retry_att
        if cur_sc is None: 
            from pyspark import SparkContext
            self._cur_sc = SparkContext(*args, **kwargs)
        else:
            self._cur_sc = cur_sc
    
    @staticmethod
    def _wrapIOCalls(method, faulty_io, retry_att):
        """
        A general wrapper for IO calls which should be retried or returned empty 
        
        """
        assert faulty_io in ['FAIL', 'RETRY', 'EMPTY']
        assert retry_att > 0, "Retry attempts should be more than 0, {}".format(retry_att)
        if faulty_io == 'FAIL':
            return method
        else:
            def wrap_method(*args, **kwargs):
                if faulty_io == 'RETRY': max_iter = retry_att-1
                else: max_iter = 1
                
                for i in range(max_iter):
                    try:
                        return method(*args,**kwargs)
                    except:
                        if faulty_io == 'EMPTY': return None
                # if it still hasn't passed throw the error
                return method(*args,**kwargs)
            return wrap_method
    
    @staticmethod
    def readBinaryBlobAsImageArray(iblob):
        return imread(BytesIO(iblob))
    
    @staticmethod
    def readBinaryBlobAsDicomArray(iblob):
        sio_blob = BytesIO(iblob)
        return read_file(sio_blob)
    
    @staticmethod
    def imageTableToDataFrame(imt_rdd):
        return imt_rdd.map(lambda x: dict(list(x[0].iteritems())+[('image_data',x[1].tolist())])).toDF()
    
    
    def readImageDirectory(self, path, parts = 100):
        """
        Read a directory of images
        
        Parameters
        ----------
        path : String
            A path with wildcards for the images files can be prefixed with (s3, s3a, or a shared directory)
        """
        read_fun = PyqaeContext._wrapIOCalls(PyqaeContext.readBinaryBlobAsImageArray, self.faulty_io, self.retry_att)
        return self._cur_sc.binaryFiles(path, parts).mapValues(read_fun)
    
    def readDicomDirectory(self, path, parts = 100):
        """
        Read a directory of dicom files
        
        Parameters
        ----------
        path : String
            A path with wildcards for the images files can be prefixed with (s3, s3a, or a shared directory)
        """
        read_fun = PyqaeContext._wrapIOCalls(PyqaeContext.readBinaryBlobAsDicomArray, self.faulty_io, self.retry_att)
        return self._cur_sc.binaryFiles(path, parts).mapValues(read_fun)
    
    def readImageTable(self, path, col_name, im_path_prefix = '', n_partitions = 100, 
                       read_table_func = pd.read_csv, preproc_func = None):
        """
        Read a table from images from a csv file
        
        Parameters
        ----------
        path : String
            A path to the csv file
        n_partitions: Int
            The number of partitions to have
        col_name : String
            The name of the column containing the path to individual images
        im_path_prefix : String
            The prefix to append to the path in the text file so it is opened correctly (default empty)
        read_table_func: Function (String -> Pandas DataFrame)
            The function to read the table from a file-buffer object (default is the read_csv function)
        preproc_func: Function (ndarray -> ndarray)
            A function to preprocess the image (filtering, resizing, padding, etc)
        """
        c_file = self._cur_sc.wholeTextFiles(path,1)
        assert c_file.count()==1, "This function only support a single file at the moment"
        full_table_buffer = StringIO("\n".join(c_file.map(lambda x: x[1]).collect()))
        image_table = read_table_func(full_table_buffer)
        image_paths = [os.path.join(im_path_prefix,cpath) for cpath in image_table[col_name]]
        # read the binary files from a list
        rawimg_rdd = self._cur_sc.binaryFiles(",".join(image_paths),n_partitions)
        read_fun = PyqaeContext._wrapIOCalls(PyqaeContext.readBinaryBlobAsImageArray, self.faulty_io, self.retry_att)
        img_rdd = rawimg_rdd.mapValues(read_fun)
        pp_img_rdd = img_rdd if preproc_func is None else img_rdd.mapValues(preproc_func)
        # add the file prefix so the keys come up in the map operation
        image_paths = ['file:{}'.format(cpath) if cpath.find(':')<0 else cpath for cpath in image_paths]
        image_list = dict(zip(image_paths,image_table.T.to_dict().values()))
        
        return img_rdd.map(lambda x: (image_list[x[0]],x[1]))
    
    @staticmethod
    def _imageTableToDataFrame(imd_rdd, cur_sql, img_key_name = "image"):
        """
        Converts an image table to a DataFrame by converting the ndarray into a nested list (inefficient)
        but necessary for JVM compatibility. Written as a staticmethod to encapsulate the serialization.
        #TODO implement ndarray <-> JVM exchange
        
        Parameters
        ----------
        imd_rdd: RDD[(dict[String,_], ndarray)]
            The imageTable (created by readImageTable)
        
        cur_sql: SQLContext
            The SQLContext in which to make the DataFrame (important for making tables later)
        
        """
        first_row_dict, _ = imd_rdd.first()
        im_tbl_keys = list(first_row_dict.keys())
        #TODO handle key missing errors more gracefully
        iml_rdd = imd_rdd.map(lambda kv_pair: [kv_pair[0].get(ikey) for ikey in im_tbl_keys]+[kv_pair[1].tolist()])
        return cur_sql.createDataFrame(iml_rdd, im_tbl_keys+[img_key_name])
    
    def readImageDataFrame(self, path, col_name, im_path_prefix = '', n_partitions = 100, 
                           read_table_func = pd.read_csv, preproc_func = None,
                          sqlContext = None):
        """
        Read a table from images from a csv file and return as a dataframe
        See Help from [[readImageTable]]
        Parameters
        ----------
        
        sqlContext: SQLContext
            The SQL context to use (if one exists) otherwise make a new one
        """
        imd_rdd = self.readImageTable(path, col_name, im_path_prefix = im_path_prefix, 
                                   n_partitions = n_partitions, read_table_func = read_table_func)
        cur_sql = sqlContext if sqlContext is not None else SQLContext(self._cur_sc)
        return PyqaeContext._imageTableToDataFrame(imd_rdd, cur_sql)

In [2]:
pq_context = PyqaeContext(sc)
im_files = pq_context.readImageDirectory('/Volumes/WinDisk/openi_images/*.png')
im_files.mapValues(lambda x: x.shape).first()

('file:/Volumes/WinDisk/openi_images/00000-CXR1005.png', (420, 512, 3))

In [3]:
pq_context2 = PyqaeContext(sc)
dim_files = pq_context2.readImageTable('/Volumes/WinDisk/openi_db_path.csv',
                         'local_path', im_path_prefix = '/Volumes/WinDisk/')
dim_files.mapValues(lambda x: x.shape).first()

({'Unnamed: 0': 0,
  'abstract': '<p><b>Comparison: </b>None.</p><p><b>Indication: </b>Pruritic.</p><p><b>Findings: </b>Cardiac and mediastinal contours are within normal limits. The lungs are clear. Bony structures are intact.</p><p><b>Impression: </b>No acute findings.</p>',
  'caption': 'Chest, 2 views, frontal and lateral',
  'image_id': 'F1',
  'local_path': 'openi_images/00000-CXR1005.png',
  'major': 'normal',
  'minor': nan,
  'problem': 'normal',
  'row': 0,
  'uid': 'CXR1005',
  'url': '/imgs/512/203/1005/CXR1005_IM-0006-1001.png'},
 (420, 512, 3))

In [4]:
pq_context2 = PyqaeContext(sc)
df_files = pq_context2.readImageDataFrame('/Volumes/WinDisk/openi_db_path.csv',
                         'local_path', im_path_prefix = '/Volumes/WinDisk/', 
                                          n_partitions = 2000,
                                          sqlContext = sqlContext)
f_row = df_files.first()

df_files

DataFrame[minor: double, Unnamed: 0: bigint, url: string, uid: string, major: string, image_id: string, caption: string, problem: string, abstract: string, local_path: string, row: bigint, image: array<array<array<bigint>>>]

In [5]:
# save the table as parquet
df_files.withColumnRenamed('Unnamed: 0','id').sample(False, 0.1).write.parquet("/Volumes/WinDisk/full_open_db.pqt")

Py4JJavaError: An error occurred while calling o121.parquet.
: org.apache.spark.SparkException: Job aborted.
	at org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand$$anonfun$run$1.apply$mcV$sp(InsertIntoHadoopFsRelationCommand.scala:149)
	at org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand$$anonfun$run$1.apply(InsertIntoHadoopFsRelationCommand.scala:115)
	at org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand$$anonfun$run$1.apply(InsertIntoHadoopFsRelationCommand.scala:115)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:57)
	at org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand.run(InsertIntoHadoopFsRelationCommand.scala:115)
	at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult$lzycompute(commands.scala:60)
	at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult(commands.scala:58)
	at org.apache.spark.sql.execution.command.ExecutedCommandExec.doExecute(commands.scala:74)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:115)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:115)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:136)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:133)
	at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:114)
	at org.apache.spark.sql.execution.QueryExecution.toRdd$lzycompute(QueryExecution.scala:86)
	at org.apache.spark.sql.execution.QueryExecution.toRdd(QueryExecution.scala:86)
	at org.apache.spark.sql.execution.datasources.DataSource.write(DataSource.scala:487)
	at org.apache.spark.sql.DataFrameWriter.save(DataFrameWriter.scala:211)
	at org.apache.spark.sql.DataFrameWriter.save(DataFrameWriter.scala:194)
	at org.apache.spark.sql.DataFrameWriter.parquet(DataFrameWriter.scala:478)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:497)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:237)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:280)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:128)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:211)
	at java.lang.Thread.run(Thread.java:745)
Caused by: org.apache.spark.SparkException: Job 9 cancelled because Stage 9 was cancelled
	at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1450)
	at org.apache.spark.scheduler.DAGScheduler.handleJobCancellation(DAGScheduler.scala:1389)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleStageCancellation$1.apply$mcVI$sp(DAGScheduler.scala:1377)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleStageCancellation$1.apply(DAGScheduler.scala:1376)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleStageCancellation$1.apply(DAGScheduler.scala:1376)
	at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
	at scala.collection.mutable.ArrayOps$ofInt.foreach(ArrayOps.scala:234)
	at org.apache.spark.scheduler.DAGScheduler.handleStageCancellation(DAGScheduler.scala:1376)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:1632)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1618)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1607)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:632)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:1871)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:1884)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:1904)
	at org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand$$anonfun$run$1.apply$mcV$sp(InsertIntoHadoopFsRelationCommand.scala:143)
	... 30 more


In [6]:
# save the table as parquet
df_files.withColumnRenamed('Unnamed: 0','id').write.parquet("/Volumes/WinDisk/full_open_db.pqt")
# save the table as parquet
df_files.withColumnRenamed('Unnamed: 0','id').sample(False, 0.1).write.parquet("/Volumes/WinDisk/small_open_db.pqt")

In [None]:
from glob import glob
cf_file = glob('/Users/mader/Dropbox/4Quant/Projects/TumorSegmentation/10092825/0013_t1_tse_tra_+c/*.dcm')[0]
with open(cf_file,'r') as ifile:
    sdata = StringIO(ifile.read())

read_file(BytesIO(sdata.readlines()[0]))

In [None]:
from io import BytesIO
sc.binaryFiles('/Users/mader/Dropbox/4Quant/Projects/TumorSegmentation/10092825/0002_t2_blade_tra/*.dcm').mapValues(lambda x: read_file(BytesIO(x))).first()

In [None]:
pq_context3 = PyqaeContext(sc)
dcm_files = pq_context3.readDicomDirectory('/Users/mader/Dropbox/4Quant/Projects/TumorSegmentation/10092825/0002_t2_blade_tra/*.dcm')
dcm_files.take(2)

In [None]:
dim_files.count()

In [None]:
d_table = pq_context2.imageTableToDataFrame(dim_files)
d_table

In [None]:
for cpath, c_record in zip(test_table.head()['local_path'],test_table.head().to_records()):
    print(cpath,c_record)

In [None]:
?sc.binaryFiles

In [None]:
?sc.binaryRecords

In [None]:
import requests
?requests.get

In [None]:
a=np.zeros((3,3,3))
import urllib
urllib.parse

In [None]:
import requests
import json
import numpy as np
try:
    import urlparse
except: # for python 3
    from urllib import parse as urlparse
class DRESTAccess(object):
    """
    A distributed access to a REST interface
    """
    
    def __init__(self, base_url, fetch_path, def_args, verbose = False):
        self.base_url = base_url
        self.fetch_path = fetch_path
        self.def_args = def_args
        self.verbose = verbose
        
    @staticmethod
    def jsonRequest(req_url, args):
        response = requests.get(req_url, args)
        if response.ok:
            return json.loads(response.content)
        raise ValueError("{} could not be processed correctly".format(req_url),args)
    
    @staticmethod
    def bufferRequest(req_url, args):
        response = requests.get(req_url, args)
        if response.ok:
            return StringIO(response.content)
        raise ValueError("{} could not be processed correctly".format(req_url),args)
    
    def pull_results(self, **args):
        full_url = urlparse.urljoin(self.base_url,self.fetch_path)
        new_param = dict(self.def_args + list(args.iteritems()))
        print(full_url, new_param)
        return DRESTAccess.jsonRequest(full_url, new_param)
    
    def parallel_pull(self, sc, arg_list, parts = 10):
        return sc.parallelize(arg_list).map(lambda x: self.pull_results(**x))
    

class OpenIDB(DRESTAccess):
    def __init__(self, step_count = 50):
        self.step_count = step_count
        DRESTAccess.__init__(self,
                         base_url = "https://openi.nlm.nih.gov", 
                        fetch_path = "retrieve.php",
                        def_args = [])
    
    def db_query(self, sc, **args):
        base_args = list(args.iteritems())
        test_query = self.pull_results(**dict(base_args + [('m',1), ('n',1)]))
        m_range = np.arange(1,test_query['total'],self.step_count)
        n_range = np.append(m_range[1:],test_query['total'])
        qry_rdd = self.parallel_pull(sc, [dict(base_args + [('m',m), ('n',n)]) for m,n in zip(m_range,n_range)])
        return qry_rdd.flatMap(lambda x: x['list'])
    
    @staticmethod
    def format_entry(ie):
        return {
            'uid': ie['uid'],
            'major': ";".join(ie['MeSH']['major']), 
               'minor': ";".join(ie['MeSH']['minor']), 
               'problem': ie['Problems'],
              'abstract':ie['abstract'],
               'caption':ie['image']['caption'],
               'image_id':ie['image']['id'],
            'url': ie['imgLarge']
              }
    
    def get_collection(self,sc, coll='cxr', it='xg', lic='byncnd', **args):
        """
        Fetch an entire collection of images as a dataframe
        """
        study_results = self.db_query(sc, **dict(list(args.iteritems())+
                                                 [('coll',coll), ('it',it), ('lic', lic)]))
        return study_results.map(OpenIDB.format_entry).toDF()

In [None]:
odb = OpenIDB()
#odb.pull_results(m=1, n=1, coll='cxr', it='xg', lic='byncnd')

In [None]:
all_results = odb.db_query(sc, coll='cxr', it='xg', lic='byncnd')

In [None]:
all_results.first()

In [None]:
nw_results = odb.get_collection(sc)

In [None]:
nw_results.first()

In [None]:
nw_results.head()

In [None]:
nw_results.registerTempTable("LungStudy")

In [None]:
sqlContext.sql("Hey")