In [2]:
from perceiver.dna_tokenizer import *

KMER = 6

In [3]:
from textwrap import wrap

def read_genome(filename, label):
    rawfile = open(filename, mode="r")
    instances = []
    new_instance = [ "", label ]

    for line in rawfile:
        if line[0] == '>':
            if len( new_instance[0] ) > 0:
                if len( new_instance[0] ) % KMER != 0:
                    new_instance[0] += "A" * ( KMER - len( new_instance[0] ) % KMER )
                new_instance[0] = wrap(new_instance[0], KMER)
                new_instance[0] = ' '.join( new_instance[0] )
                instances.append( new_instance )
            new_instance = [ "", label ]
        else:
            new_instance[0] += line.strip()

    return instances


In [4]:
alpha_samples = read_genome("data/coronavirus/alpha.fna", label=0)
mers_samples = read_genome("data/coronavirus/mers.fna", label=1)
covid_samples = read_genome("data/coronavirus/SARS-Cov-2.fasta", label=2)

In [5]:
len( covid_samples )

299

In [6]:
len( mers_samples )

259

In [7]:
len( alpha_samples )

112

In [5]:
len( covid_samples ) + len( mers_samples ) + len( alpha_samples )

670

In [6]:
import numpy as np
np.array( covid_samples )[0][1]

'2'

In [7]:
import numpy as np
import tensorflow as tf

viral_squences = np.array( alpha_samples + mers_samples + covid_samples )
concat = tf.data.Dataset.from_tensor_slices((viral_squences[:, 0],viral_squences[:, 1]))
#concat = concat.map(one_hot)
concat = concat.shuffle( len(concat) )

In [11]:
viral_squences.ndim

2

In [12]:
import tensorflow as tf

tfdt = tf.data.Dataset.from_tensor_slices((viral_squences[:, 0],viral_squences[:, 1]))

In [13]:
stfdt = tfdt.shuffle(670)

In [8]:
counts = [0, 0, 0]
labels = [b"covid", b"mers", b"alpha"]
max_length = 0

for gene, label in concat.take(670):  # only take first element of dataset
    numpy_gene = gene.numpy()
    #counts[ labels.index(numpy_label) ] += 1
    if max_length < len(numpy_gene):
        max_length = len(numpy_gene)

(max_length - KMER)/7 #counts

5024.0

In [13]:
for gene, label in concat.take(1):  # only take first element of dataset
    print(label.numpy().decode("utf-8"))

1


In [14]:
dna_tokenizer.to_int(numpy_gene.decode("utf-8").split(' '))

<tf.Tensor: shape=(4948,), dtype=int32, numpy=array([ 799, 3436, 1041, ...,  890, 2274,    6], dtype=int32)>

In [38]:
#reload(perceiver.dna_tokenizer)

#from perceiver.dna_tokenizer import *

import functools
MAX_SEQ_LEN = 5024


# use decorator to input default max_len=MAX_SEQ_LEN
def kmerlist_padding(max_len):
    def wrapper_converter(func):
        @functools.wraps(func)
        def wrapper(gene_str):
            seq = func(gene_str)
            padded_seq = np.concatenate( [seq, np.repeat( dna_tokenizer.pad_token, (max_len - seq.shape[0]) )] )
            return padded_seq
        return wrapper

    return wrapper_converter

@kmerlist_padding(max_len=MAX_SEQ_LEN)
def string_to_kmerlist(gene_str):
    gene_seq = gene_str.decode("utf-8").split(sep=' ')
    return dna_tokenizer.to_int(gene_seq)

def tokenizing_input(gene_str, label_str):
    gene = tf.numpy_function(func=string_to_kmerlist, inp=[gene_str], Tout=tf.int64)
    label = tf.numpy_function(func=lambda x: tf.cast(int(x.decode("utf-8")), tf.int32), inp=[label_str], Tout=tf.int32)

    return ( gene, label )    # convert label to int

transformed = concat.take(10).map(tokenizing_input)
transformed
#npf_conversion(concat.take(2))

<MapDataset shapes: (<unknown>, <unknown>), types: (tf.int64, tf.int32)>

In [40]:
for input, label in transformed.take(1):
    print(input.numpy())
    print(label.numpy())


[1104 1454 3912 ...    0    0    0]
2


In [43]:
input.numpy()[-100:]

array([2253, 3991,  264, 3420, 3675, 2779,  714, 1107, 2348, 1508,  781,
         95, 2888, 1162, 3622, 3383, 3387,  155,  108,  537, 2079, 1292,
        891, 1063, 1282, 3883, 3090, 3728,  539, 1585, 1006, 3010, 3378,
       1747, 3369, 1806,  490, 4050, 3712, 1101, 3129, 2710, 1914,   26,
        346, 3389, 2336, 2595, 3163,   84, 1434, 3894, 1830,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
          0])

In [35]:
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenizer implementation mapping strings to their UTF-8 bytes."""

from typing import Union
import numpy as np
import tensorflow as tf


class DNATokenizer:
  """Tokenizes string to utf-8 bytes."""

  def __init__(self, vocab_file):
    self._num_reserved_tokens = 6  # PAD, BOS, EOS, MASK, CLS, SEP
    self._vocabs = np.array( [ line.strip() for line in open(vocab_file) ] )

  def to_string(self, inputs: np.ndarray) -> str:
    return self._vocabs[ inputs.argmax(axis=-1) ]

  def to_int(self, inputs: Union[list, np.ndarray]) -> np.ndarray:
    if isinstance(inputs, list):
      inputs = np.array(inputs)
    encoded = np.where( inputs[:, None] == dna_tokenizer._vocabs[None, :] )[1]

    return encoded #.astype(np.int32)

  @property
  def vocab_size(self) -> int:
    return 4102

  @property
  def pad_token(self) -> int:
    return 0

  @property
  def bos_token(self) -> int:
    return 1

  @property
  def eos_token(self) -> int:
    return 2

  @property
  def mask_token(self) -> int:
    return 3

  @property
  def cls_token(self) -> int:
    return 4

  @property
  def sep_token(self) -> int:
    return 5

dna_tokenizer = DNATokenizer(vocab_file="tokenization/vocab_6mer.txt")

In [72]:
import jax
jax.device_count()



1

In [71]:
dataset = tf.data.Dataset.from_tensor_slices( [[1,2,3,4], [5,6,7,8]] )

dataset_filter = dataset.map(lambda x: tf.gather(x, [0, 2], axis=0))
result = list(dataset_filter.as_numpy_iterator())
print(result)

TypeError: in user code:


    TypeError: <lambda>() takes 1 positional argument but 2 were given
