In [3]:
import os
import numpy as np
import tensorflow as tf

def load_names(namefile):
    """Loads DeepSEA label names from namefile.
    """
    return np.array(open(namefile).read().split(","))


def get_overlapping_indices_for_cell_type(cell_type_1, cell_type_2):
    """Gets target indices for transcription factors for the intersection 
    of cell_type_1 and cell_type_2.

    Args:
      cell_type_1: Name of cell type 1 as a string.
      cell_type_2: Name of cell type 2 as a string.

    Returns:
      List of indices ~ FOR CELL 1 ~ that overlap with cell 2.
      These indices are listed in alphabetical order for consistency.
    """
    
#     dir_path = os.path.dirname(os.path.realpath(__file__))
    namefile = "../tfti/deepsea_label_names.txt"
    names = load_names(namefile)
    
    valid_cell_types = list(map(lambda x: x.split("|")[1], names))
    
    # Make sure cell type parameters can be found in our data. 
    assert(cell_type_1 in valid_cell_types, "{} not in list of valid cell types".format(cell_type_1))
    assert(cell_type_2 in valid_cell_types, "{} not in list of valid cell types".format(cell_type_2))
        
    # Get positions for both cell lines.
    cell_type_1_pos = [(i, j) for i, j in enumerate(names) \
                        if cell_type_1 in j]
    cell_type_2_pos = [(i, j) for i, j in enumerate(names) \
                        if cell_type_2 in j]

    # Get marks for each cell type
    cell_type_1_marks =  [i[1].split("|")[1] for i in cell_type_1_pos]
    cell_type_2_marks =  [i[1].split("|")[1] for i in cell_type_2_pos]
    
    # Get overlapping marks between both cell types.
    overlapping_marks = list(set(cell_type_1_marks) & set(cell_type_2_marks))

    cell_type_1_final_pos = [(i,j) for i, j in cell_type_1_pos if \
                                    j.split("|")[1] in overlapping_marks]
    cell_type_2_final_pos = [(i,j) for i, j in  cell_type_2_pos if \
                                    j.split("|")[1] in overlapping_marks]

    # Filter out duplicates for both cell types.
    cell_type_1_items = []
    seen = set()
    for item in cell_type_1_final_pos:
      if item[1] not in seen:
        seen.add(item[1])
        cell_type_1_items.append(item)

    cell_type_1_items = sorted(cell_type_1_items, key=lambda i: i[1]) 
    
    # Print out sorted marks.
    tf.logging.info("Marks for CellType %s: %s" 
                    % (cell_type_1, 
                    cell_type_1_items))

    cell_type_2_items = []
    seen = set()
    for item in cell_type_2_final_pos:
      if not item[1] in seen:
        seen.add(item[1])
        cell_type_2_items.append(item)

    cell_type_2_items = sorted(cell_type_2_items, key=lambda i: i[1])

    # Verify that TFs match between cell types.
    for i, item in enumerate(cell_type_2_items):
      assert(cell_type_2_items[i][1].split("|")[1] ==
             cell_type_1_items[i][1].split("|")[1])

    # These are the indices we are using for the cell type 1 model.
    cell_type_1_indices = list(map(lambda x: x[0], cell_type_1_items))
    return cell_type_1_indices

  assert(cell_type_1 in valid_cell_types, "{} not in list of valid cell types".format(cell_type_1))
  assert(cell_type_2 in valid_cell_types, "{} not in list of valid cell types".format(cell_type_2))


In [5]:
namefile = "../tfti/deepsea_label_names.txt"
example = load_names(namefile)

gather_indices = get_overlapping_indices_for_cell_type("GM12878", "H1-hESC")
example[gather_indices]

array(['GM12878|ATF2|None', 'GM12878|ATF3|None', 'GM12878|BCL11A|None',
       'GM12878|BRCA1|None', 'GM12878|CEBPB|None', 'GM12878|CHD1|None',
       'GM12878|CHD2|None', 'GM12878|CTCF|None', 'GM12878|DNase|None',
       'GM12878|EZH2|None', 'GM12878|Egr-1|None', 'GM12878|GABP|None',
       'GM12878|JunD|None', 'GM12878|Max|None', 'GM12878|Mxi1|None',
       'GM12878|NRSF|None', 'GM12878|Nrf1|None', 'GM12878|Pol2-4H8|None',
       'GM12878|Pol2|None', 'GM12878|RFX5|None', 'GM12878|RXRA|None',
       'GM12878|Rad21|None', 'GM12878|SIN3A|None', 'GM12878|SIX5|None',
       'GM12878|SP1|None', 'GM12878|SRF|None', 'GM12878|TAF1|None',
       'GM12878|TBP|None', 'GM12878|TCF12|None', 'GM12878|USF-1|None',
       'GM12878|USF2|None', 'GM12878|YY1|None', 'GM12878|Znf143|None',
       'GM12878|c-Myc|None', 'GM12878|p300|None'], dtype='<U37')

In [6]:
argsort_indices = np.argsort(gather_indices)
gather_indices_sorted = np.sort(gather_indices)

# Keep targets and latents corresponding to GM12878 (LCL cell line).
targets = example[gather_indices_sorted]
latents = example[gather_indices_sorted]
targets_copy = np.array(list(targets))

targets

array(['GM12878|DNase|None', 'GM12878|CTCF|None', 'GM12878|EZH2|None',
       'GM12878|ATF2|None', 'GM12878|ATF3|None', 'GM12878|BCL11A|None',
       'GM12878|CEBPB|None', 'GM12878|Egr-1|None', 'GM12878|GABP|None',
       'GM12878|NRSF|None', 'GM12878|p300|None', 'GM12878|Pol2-4H8|None',
       'GM12878|Pol2|None', 'GM12878|Rad21|None', 'GM12878|RXRA|None',
       'GM12878|SIX5|None', 'GM12878|SP1|None', 'GM12878|SRF|None',
       'GM12878|TAF1|None', 'GM12878|TCF12|None', 'GM12878|USF-1|None',
       'GM12878|YY1|None', 'GM12878|BRCA1|None', 'GM12878|CHD1|None',
       'GM12878|CHD2|None', 'GM12878|JunD|None', 'GM12878|Max|None',
       'GM12878|Mxi1|None', 'GM12878|Nrf1|None', 'GM12878|RFX5|None',
       'GM12878|SIN3A|None', 'GM12878|TBP|None', 'GM12878|USF2|None',
       'GM12878|Znf143|None', 'GM12878|c-Myc|None'], dtype='<U37')

In [7]:
# This is how we are currently realphabetizing. 

# Ensure sure tensors are sorted by alphabetical TFs.
targets[argsort_indices]


array(['GM12878|GABP|None', 'GM12878|Egr-1|None', 'GM12878|NRSF|None',
       'GM12878|DNase|None', 'GM12878|CTCF|None', 'GM12878|EZH2|None',
       'GM12878|ATF3|None', 'GM12878|p300|None', 'GM12878|Pol2-4H8|None',
       'GM12878|SIX5|None', 'GM12878|c-Myc|None', 'GM12878|SRF|None',
       'GM12878|TAF1|None', 'GM12878|YY1|None', 'GM12878|USF-1|None',
       'GM12878|CHD1|None', 'GM12878|CHD2|None', 'GM12878|JunD|None',
       'GM12878|Max|None', 'GM12878|Nrf1|None', 'GM12878|RFX5|None',
       'GM12878|TBP|None', 'GM12878|ATF2|None', 'GM12878|BCL11A|None',
       'GM12878|CEBPB|None', 'GM12878|Pol2|None', 'GM12878|Rad21|None',
       'GM12878|RXRA|None', 'GM12878|SP1|None', 'GM12878|TCF12|None',
       'GM12878|BRCA1|None', 'GM12878|Mxi1|None', 'GM12878|SIN3A|None',
       'GM12878|USF2|None', 'GM12878|Znf143|None'], dtype='<U37')

In [10]:
# This is how we should be doing it.

inds = list(range(len(targets)))
targets_copy[argsort_indices] = targets[inds]
targets_copy


array(['GM12878|ATF2|None', 'GM12878|ATF3|None', 'GM12878|BCL11A|None',
       'GM12878|BRCA1|None', 'GM12878|CEBPB|None', 'GM12878|CHD1|None',
       'GM12878|CHD2|None', 'GM12878|CTCF|None', 'GM12878|DNase|None',
       'GM12878|EZH2|None', 'GM12878|Egr-1|None', 'GM12878|GABP|None',
       'GM12878|JunD|None', 'GM12878|Max|None', 'GM12878|Mxi1|None',
       'GM12878|NRSF|None', 'GM12878|Nrf1|None', 'GM12878|Pol2-4H8|None',
       'GM12878|Pol2|None', 'GM12878|RFX5|None', 'GM12878|RXRA|None',
       'GM12878|Rad21|None', 'GM12878|SIN3A|None', 'GM12878|SIX5|None',
       'GM12878|SP1|None', 'GM12878|SRF|None', 'GM12878|TAF1|None',
       'GM12878|TBP|None', 'GM12878|TCF12|None', 'GM12878|USF-1|None',
       'GM12878|USF2|None', 'GM12878|YY1|None', 'GM12878|Znf143|None',
       'GM12878|c-Myc|None', 'GM12878|p300|None'], dtype='<U21')