<a href="https://colab.research.google.com/github/XiaoleiZ/Reproduce_HilbertCNN/blob/master/Reproduce_HibertCNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Paper: AN IMAGE REPRESENTATION BASED CONVOLUTIONAL NETWORK FOR DNA CLASSIFICATION

URL: image representation based convolutional network for dna classification

Summary: Transform 1D-DNA sequence into 2D image to improve the classification performance of using CNN

Plan: 

1.Transformation from 1D sequence to 2D image

2.CNN classification network



DNA sequence representation

1.Represent a sequence as a list of k-mers

2.Transform each k-mer into a one-hot vector
-> sequence is represented as a list of one-hot vector

3.Use Hilbert Curves to assign each element of the list of k-mers to a pixel in an image

In [0]:
import numpy as np

In [0]:
def kmerize(sequence, k=4):
    
    ###Description: represent a sequence as list of k-mers
    ###Input: sequence (string), 
    ###       parameter k (k>=1), use 4 as default as shown in the paper.
    ###Output: a list of k-mer strings (array)
    
    #include the list of k-mers
    if k <1 or not int(k):
      raise ValueError("k must be an integer larger than zero");
    
    kmer_list = [];
    
    l=len(sequence);
    for i in range(0,l-k+1): # 0...l-k, the last kmer is the segment between l-k,l
      kmer_list.append(sequence[i:i+k]);
    
    return kmer_list;

In [0]:
def create_four_mer_ohv(k=4):
  
  ###Description: create a dictionary to store the one-hot-vector encoding (value) for each 4-mer (key)
  
  four_mer_list=[];
  base = ['A','C','G','T'];
  for n1 in base:
    for n2 in base:
      for n3 in base:
        for n4 in base:
          four_mer_list.append(n1+n2+n3+n4);
  
  four_mer_dict ={};
  l=len(four_mer_list);
  for i,x in enumerate(four_mer_list):
    vector=np.zeros(l);
    vector[i]=1;
    four_mer_dict[x]=vector;
    
  return four_mer_dict;
  

In [0]:
def one_hot_vectorize(kmer_list,four_mer_dict):
  
  ###Description: transfrom a list of k-mers into a list of one-hot-vectors
  ###Input: a list of k-mers representing a sequence
  ###       a dictionary created beforehand to store the one-hot-vector encoding for each possible kmer
  ###       (4-mer is used here)
  ###Output: a list of one-hot-encoding vectors representing the sequence
  return list(map(lambda x: four_mer_dict[x], kmer_list))


In [0]:
def hilbert_curve (n):
    #Description: Implementation of Hilbert space filling curve of order log2(n) 
    #             (number of cells: n*n) to map a distance in 1-D to coordinate in 2D array
    #Input: the number of cells in each dimension
    #return: 2D array - for each coordinate x,y, the value is the position in 1D the hlbert map corresponds to.
    #
    if n==2:
        return np.array([[0,3],[1,2]],int)
    unit=hilbert_curve(int(n/2));
    #the number of cells in each quardrant
    step = int(n**2/4)
    #the directions of the four quardrant given in the paper
    ########################
    ### a # d #
    ###   #   # 
    ########################
    ### b # c #
    ###   #   #
    ########################
    
    #for the first quadrant: anti-closewise rotate 90 degree
    a = np.flipud(np.rot90(unit));
    b = unit +step
    c = unit +step*2
    d = np.fliplr(np.rot90(unit)) +step*3
    #stack the four quardrant together
    new = np.concatenate((np.concatenate((a,b),axis=0),np.concatenate((d,c),axis=0)),axis=1)
    return new  

In [0]:
def pixelize_seqs(seqs,k=4):
  #Description: give a batch of sequence, process these sequences into images
  #Input: a list of sequences, each element is a string
  #Ouput: a list of images, each element is an image corresponding to the sequence
  
  #crop the hilbert curve
  curve=hilbert_curve(2**5)[:,0:16]
  n_height,n_width,n_depth = curve.shape+(256,)
  
  img_list = np.zeros((len(seqs),n_height,n_width,n_depth))
  
  for i in range(len(seqs)):
    kmer_list=kmerize(seqs[i],k);
    four_mer_dict=create_four_mer_ohv(k)
    ohv_list=one_hot_vectorize(kmer_list,four_mer_dict)
    for j in range(len(ohv_list)):
    #find the coordinate in the img
      img_list[i,np.where(curve==j)]=ohv_list[j]
  
  return img_list

In [0]:
def read_data(file):
  with open(file) as f:
    content = f.readlines()
  line=1
  name=[]
  seq=[]
  label=[]
  for x in content:
    x=x.strip()
    if line % 3 == 1:
      name = name + [x[1:]] #remove ">"
    elif line % 3 == 2:
      seq = seq + [x]
    elif line % 3 == 0:
      label = label + [x]
    line += 1
  
  return name,seq,label

In [0]:
name,seq,label=read_data('H4.txt')
imgs=pixelize_seqs(seq,k=4)

In [1]:
imgs.shape


NameError: ignored